From c4cd9c80889d2f92f32e730bbf0d17237c792192 Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 11 Jan 2026 03:52:29 +0000 Subject: [PATCH 01/14] [chore] Update dependencies and improve testing documentation - Bumped versions of several dependencies including `caniuse-lite`, `@typescript-eslint/eslint-plugin`, `@typescript-eslint/parser`, `globals`, and `turbo` for enhanced functionality and compatibility. - Added `@sroussey/json-schema-library` as a new dependency in the project. - Updated testing documentation to clarify the command for running specific tests. --- .cursor/rules/testing.mdc | 2 +- bun.lock | 62 +++++++++++++++++++++------------------ package.json | 10 +++---- 3 files changed, 39 insertions(+), 35 deletions(-) diff --git a/.cursor/rules/testing.mdc b/.cursor/rules/testing.mdc index 0809cd15..2f373b87 100644 --- a/.cursor/rules/testing.mdc +++ b/.cursor/rules/testing.mdc @@ -4,4 +4,4 @@ alwaysApply: true Use bun, not jest, not npm, not node -To run tests: `bun test`. To run specific ones: `bun run `. +To run tests: `bun test`. To run specific ones: `bun test `. diff --git a/bun.lock b/bun.lock index 491d0497..8db3757a 100644 --- a/bun.lock +++ b/bun.lock @@ -5,22 +5,22 @@ "": { "name": "workglow", "dependencies": { - "caniuse-lite": "^1.0.30001761", + "caniuse-lite": "^1.0.30001763", }, "devDependencies": { "@sroussey/changesets-cli": "^2.29.7", "@types/bun": "^1.3.5", - "@typescript-eslint/eslint-plugin": "^8.50.1", - "@typescript-eslint/parser": "^8.50.1", + "@typescript-eslint/eslint-plugin": "^8.52.0", + "@typescript-eslint/parser": "^8.52.0", "concurrently": "^9.2.1", "eslint": "^9.39.2", "eslint-plugin-jsx-a11y": "^6.10.2", "eslint-plugin-react": "^7.37.5", "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-regexp": "^2.10.0", - "globals": "^16.5.0", + "globals": "^17.0.0", "prettier": "^3.7.4", - "turbo": "^2.7.2", + "turbo": "^2.7.3", "typescript": "5.9.3", "vitest": "^4.0.16", }, @@ -257,8 +257,8 @@ "name": "@workglow/util", "version": "0.0.85", "dependencies": { + "@sroussey/json-schema-library": "^10.5.3", "@sroussey/json-schema-to-ts": "3.1.3", - "json-schema-library": "^10.5.1", }, }, }, @@ -613,7 +613,7 @@ "@rollup/rollup-win32-x64-msvc": ["@rollup/rollup-win32-x64-msvc@4.53.3", "", { "os": "win32", "cpu": "x64" }, "sha512-UhTd8u31dXadv0MopwGgNOBpUVROFKWVQgAg5N1ESyCz8AuBcMqm4AuTjrwgQKGDfoFuz02EuMRHQIw/frmYKQ=="], - "@sagold/json-pointer": ["@sagold/json-pointer@7.2.0", "", {}, "sha512-RZpwGl1yhNuzQVKOADJx65TrWL7T6HTGs2Rpv7KlbFY0CfbFWNAKsisvC/uGfchknCGJEnoxz9uPAdmgoAE3IA=="], + "@sagold/json-pointer": ["@sagold/json-pointer@7.2.1", "", {}, "sha512-8EX4r5Royl5M3qNPTh5W5njdOtRqbWgQfVv26DbzjGj2/55b60EqvbiqUIglzw7fADfY/Io6jDJxGNJHmC+g8g=="], "@sagold/json-query": ["@sagold/json-query@6.2.0", "", { "dependencies": { "@sagold/json-pointer": "^5.1.2", "ebnf": "^1.9.1" } }, "sha512-7bOIdUE6eHeoWtFm8TvHQHfTVSZuCs+3RpOKmZCDBIOrxpvF/rNFTeuvIyjHva/RR0yVS3kQtr+9TW72LQEZjA=="], @@ -621,6 +621,8 @@ "@sroussey/changesets-cli": ["@sroussey/changesets-cli@2.29.7", "", { "dependencies": { "@changesets/apply-release-plan": "^7.0.13", "@changesets/assemble-release-plan": "^6.0.9", "@changesets/changelog-git": "^0.2.1", "@changesets/config": "^3.1.1", "@changesets/errors": "^0.2.0", "@changesets/get-dependents-graph": "^2.1.3", "@changesets/get-release-plan": "^4.0.13", "@changesets/git": "^3.0.4", "@changesets/logger": "^0.1.1", "@changesets/pre": "^2.0.2", "@changesets/read": "^0.6.5", "@changesets/should-skip-package": "^0.1.2", "@changesets/types": "^6.1.0", "@changesets/write": "^0.4.0", "@inquirer/external-editor": "^1.0.0", "@manypkg/get-packages": "^1.1.3", "ansi-colors": "^4.1.3", "ci-info": "^3.7.0", "enquirer": "^2.4.1", "fs-extra": "^7.0.1", "mri": "^1.2.0", "p-limit": "^2.2.0", "package-manager-detector": "^0.2.0", "picocolors": "^1.1.0", "resolve-from": "^5.0.0", "semver": "^7.7.3", "spawndamnit": "^3.0.1", "term-size": "^2.1.0" }, "bin": { "changeset": "bin.js" } }, "sha512-y47qkQTbei/sDIk0//S1bBXnOxPkrmeXjHxljubE9xqS+k8h+RAt+J0RyzJxk3SUYgY7D2vYJMQw4s9865F34w=="], + "@sroussey/json-schema-library": ["@sroussey/json-schema-library@10.5.3", "", { "dependencies": { "@sagold/json-pointer": "^7.2.1", "@sagold/json-query": "^6.2.0", "deepmerge": "^4.3.1", "fast-copy": "^3.0.2", "fast-deep-equal": "^3.1.3", "smtp-address-parser": "1.0.10", "uri-js": "^4.4.1", "valid-url": "^1.0.9" } }, "sha512-B+4Q84gJk56qAuM/4UAOm/pmmy2p+YtLChl0JBhgz8Qk4J0lDxgJh0JtT/CX5mew7Eq5nv6t9zuqdFLM6yQoDQ=="], + "@sroussey/json-schema-to-ts": ["@sroussey/json-schema-to-ts@3.1.3", "", { "dependencies": { "ts-algebra": "^2.0.0" } }, "sha512-N4j/Mz1YkZHvQfStIvtS4DiQLltzzU84jFt6qoo0DsUHV+n3UDfduWlYQSwov8gS9iJliIJ4L4Vb15k5HVdLwg=="], "@sroussey/transformers": ["@sroussey/transformers@3.8.2", "", { "dependencies": { "@huggingface/jinja": "^0.5.3", "onnxruntime-node": "1.23.2", "onnxruntime-web": "1.23.2", "sharp": "^0.34.5" } }, "sha512-K9g7aGnUZ8xdBBhrt6rZIB1rnFY1H4VggGCsrJoEL6tvhm5/Z+VpAHGduvdNS/s2a3lTsJt+scxfTU4DQ0T5JA=="], @@ -685,25 +687,25 @@ "@types/ws": ["@types/ws@8.18.1", "", { "dependencies": { "@types/node": "*" } }, "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg=="], - "@typescript-eslint/eslint-plugin": ["@typescript-eslint/eslint-plugin@8.50.1", "", { "dependencies": { "@eslint-community/regexpp": "^4.10.0", "@typescript-eslint/scope-manager": "8.50.1", "@typescript-eslint/type-utils": "8.50.1", "@typescript-eslint/utils": "8.50.1", "@typescript-eslint/visitor-keys": "8.50.1", "ignore": "^7.0.0", "natural-compare": "^1.4.0", "ts-api-utils": "^2.1.0" }, "peerDependencies": { "@typescript-eslint/parser": "^8.50.1", "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-PKhLGDq3JAg0Jk/aK890knnqduuI/Qj+udH7wCf0217IGi4gt+acgCyPVe79qoT+qKUvHMDQkwJeKW9fwl8Cyw=="], + "@typescript-eslint/eslint-plugin": ["@typescript-eslint/eslint-plugin@8.52.0", "", { "dependencies": { "@eslint-community/regexpp": "^4.12.2", "@typescript-eslint/scope-manager": "8.52.0", "@typescript-eslint/type-utils": "8.52.0", "@typescript-eslint/utils": "8.52.0", "@typescript-eslint/visitor-keys": "8.52.0", "ignore": "^7.0.5", "natural-compare": "^1.4.0", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "@typescript-eslint/parser": "^8.52.0", "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-okqtOgqu2qmZJ5iN4TWlgfF171dZmx2FzdOv2K/ixL2LZWDStL8+JgQerI2sa8eAEfoydG9+0V96m7V+P8yE1Q=="], - "@typescript-eslint/parser": ["@typescript-eslint/parser@8.50.1", "", { "dependencies": { "@typescript-eslint/scope-manager": "8.50.1", "@typescript-eslint/types": "8.50.1", "@typescript-eslint/typescript-estree": "8.50.1", "@typescript-eslint/visitor-keys": "8.50.1", "debug": "^4.3.4" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-hM5faZwg7aVNa819m/5r7D0h0c9yC4DUlWAOvHAtISdFTc8xB86VmX5Xqabrama3wIPJ/q9RbGS1worb6JfnMg=="], + "@typescript-eslint/parser": ["@typescript-eslint/parser@8.52.0", "", { "dependencies": { "@typescript-eslint/scope-manager": "8.52.0", "@typescript-eslint/types": "8.52.0", "@typescript-eslint/typescript-estree": "8.52.0", "@typescript-eslint/visitor-keys": "8.52.0", "debug": "^4.4.3" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-iIACsx8pxRnguSYhHiMn2PvhvfpopO9FXHyn1mG5txZIsAaB6F0KwbFnUQN3KCiG3Jcuad/Cao2FAs1Wp7vAyg=="], - "@typescript-eslint/project-service": ["@typescript-eslint/project-service@8.50.1", "", { "dependencies": { "@typescript-eslint/tsconfig-utils": "^8.50.1", "@typescript-eslint/types": "^8.50.1", "debug": "^4.3.4" }, "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-E1ur1MCVf+YiP89+o4Les/oBAVzmSbeRB0MQLfSlYtbWU17HPxZ6Bhs5iYmKZRALvEuBoXIZMOIRRc/P++Ortg=="], + "@typescript-eslint/project-service": ["@typescript-eslint/project-service@8.52.0", "", { "dependencies": { "@typescript-eslint/tsconfig-utils": "^8.52.0", "@typescript-eslint/types": "^8.52.0", "debug": "^4.4.3" }, "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-xD0MfdSdEmeFa3OmVqonHi+Cciab96ls1UhIF/qX/O/gPu5KXD0bY9lu33jj04fjzrXHcuvjBcBC+D3SNSadaw=="], - "@typescript-eslint/scope-manager": ["@typescript-eslint/scope-manager@8.50.1", "", { "dependencies": { "@typescript-eslint/types": "8.50.1", "@typescript-eslint/visitor-keys": "8.50.1" } }, "sha512-mfRx06Myt3T4vuoHaKi8ZWNTPdzKPNBhiblze5N50//TSHOAQQevl/aolqA/BcqqbJ88GUnLqjjcBc8EWdBcVw=="], + "@typescript-eslint/scope-manager": ["@typescript-eslint/scope-manager@8.52.0", "", { "dependencies": { "@typescript-eslint/types": "8.52.0", "@typescript-eslint/visitor-keys": "8.52.0" } }, "sha512-ixxqmmCcc1Nf8S0mS0TkJ/3LKcC8mruYJPOU6Ia2F/zUUR4pApW7LzrpU3JmtePbRUTes9bEqRc1Gg4iyRnDzA=="], - "@typescript-eslint/tsconfig-utils": ["@typescript-eslint/tsconfig-utils@8.50.1", "", { "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-ooHmotT/lCWLXi55G4mvaUF60aJa012QzvLK0Y+Mp4WdSt17QhMhWOaBWeGTFVkb2gDgBe19Cxy1elPXylslDw=="], + "@typescript-eslint/tsconfig-utils": ["@typescript-eslint/tsconfig-utils@8.52.0", "", { "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-jl+8fzr/SdzdxWJznq5nvoI7qn2tNYV/ZBAEcaFMVXf+K6jmXvAFrgo/+5rxgnL152f//pDEAYAhhBAZGrVfwg=="], - "@typescript-eslint/type-utils": ["@typescript-eslint/type-utils@8.50.1", "", { "dependencies": { "@typescript-eslint/types": "8.50.1", "@typescript-eslint/typescript-estree": "8.50.1", "@typescript-eslint/utils": "8.50.1", "debug": "^4.3.4", "ts-api-utils": "^2.1.0" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-7J3bf022QZE42tYMO6SL+6lTPKFk/WphhRPe9Tw/el+cEwzLz1Jjz2PX3GtGQVxooLDKeMVmMt7fWpYRdG5Etg=="], + "@typescript-eslint/type-utils": ["@typescript-eslint/type-utils@8.52.0", "", { "dependencies": { "@typescript-eslint/types": "8.52.0", "@typescript-eslint/typescript-estree": "8.52.0", "@typescript-eslint/utils": "8.52.0", "debug": "^4.4.3", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-JD3wKBRWglYRQkAtsyGz1AewDu3mTc7NtRjR/ceTyGoPqmdS5oCdx/oZMWD5Zuqmo6/MpsYs0wp6axNt88/2EQ=="], - "@typescript-eslint/types": ["@typescript-eslint/types@8.50.1", "", {}, "sha512-v5lFIS2feTkNyMhd7AucE/9j/4V9v5iIbpVRncjk/K0sQ6Sb+Np9fgYS/63n6nwqahHQvbmujeBL7mp07Q9mlA=="], + "@typescript-eslint/types": ["@typescript-eslint/types@8.52.0", "", {}, "sha512-LWQV1V4q9V4cT4H5JCIx3481iIFxH1UkVk+ZkGGAV1ZGcjGI9IoFOfg3O6ywz8QqCDEp7Inlg6kovMofsNRaGg=="], - "@typescript-eslint/typescript-estree": ["@typescript-eslint/typescript-estree@8.50.1", "", { "dependencies": { "@typescript-eslint/project-service": "8.50.1", "@typescript-eslint/tsconfig-utils": "8.50.1", "@typescript-eslint/types": "8.50.1", "@typescript-eslint/visitor-keys": "8.50.1", "debug": "^4.3.4", "minimatch": "^9.0.4", "semver": "^7.6.0", "tinyglobby": "^0.2.15", "ts-api-utils": "^2.1.0" }, "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-woHPdW+0gj53aM+cxchymJCrh0cyS7BTIdcDxWUNsclr9VDkOSbqC13juHzxOmQ22dDkMZEpZB+3X1WpUvzgVQ=="], + "@typescript-eslint/typescript-estree": ["@typescript-eslint/typescript-estree@8.52.0", "", { "dependencies": { "@typescript-eslint/project-service": "8.52.0", "@typescript-eslint/tsconfig-utils": "8.52.0", "@typescript-eslint/types": "8.52.0", "@typescript-eslint/visitor-keys": "8.52.0", "debug": "^4.4.3", "minimatch": "^9.0.5", "semver": "^7.7.3", "tinyglobby": "^0.2.15", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-XP3LClsCc0FsTK5/frGjolyADTh3QmsLp6nKd476xNI9CsSsLnmn4f0jrzNoAulmxlmNIpeXuHYeEQv61Q6qeQ=="], - "@typescript-eslint/utils": ["@typescript-eslint/utils@8.50.1", "", { "dependencies": { "@eslint-community/eslint-utils": "^4.7.0", "@typescript-eslint/scope-manager": "8.50.1", "@typescript-eslint/types": "8.50.1", "@typescript-eslint/typescript-estree": "8.50.1" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-lCLp8H1T9T7gPbEuJSnHwnSuO9mDf8mfK/Nion5mZmiEaQD9sWf9W4dfeFqRyqRjF06/kBuTmAqcs9sewM2NbQ=="], + "@typescript-eslint/utils": ["@typescript-eslint/utils@8.52.0", "", { "dependencies": { "@eslint-community/eslint-utils": "^4.9.1", "@typescript-eslint/scope-manager": "8.52.0", "@typescript-eslint/types": "8.52.0", "@typescript-eslint/typescript-estree": "8.52.0" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-wYndVMWkweqHpEpwPhwqE2lnD2DxC6WVLupU/DOt/0/v+/+iQbbzO3jOHjmBMnhu0DgLULvOaU4h4pwHYi2oRQ=="], - "@typescript-eslint/visitor-keys": ["@typescript-eslint/visitor-keys@8.50.1", "", { "dependencies": { "@typescript-eslint/types": "8.50.1", "eslint-visitor-keys": "^4.2.1" } }, "sha512-IrDKrw7pCRUR94zeuCSUWQ+w8JEf5ZX5jl/e6AHGSLi1/zIr0lgutfn/7JpfCey+urpgQEdrZVYzCaVVKiTwhQ=="], + "@typescript-eslint/visitor-keys": ["@typescript-eslint/visitor-keys@8.52.0", "", { "dependencies": { "@typescript-eslint/types": "8.52.0", "eslint-visitor-keys": "^4.2.1" } }, "sha512-ink3/Zofus34nmBsPjow63FP5M7IGff0RKAgqR6+CFpdk22M7aLwC9gOcLGYqr7MczLPzZVERW9hRog3O4n1sQ=="], "@uiw/codemirror-extensions-basic-setup": ["@uiw/codemirror-extensions-basic-setup@4.25.3", "", { "dependencies": { "@codemirror/autocomplete": "^6.0.0", "@codemirror/commands": "^6.0.0", "@codemirror/language": "^6.0.0", "@codemirror/lint": "^6.0.0", "@codemirror/search": "^6.0.0", "@codemirror/state": "^6.0.0", "@codemirror/view": "^6.0.0" } }, "sha512-F1doRyD50CWScwGHG2bBUtUpwnOv/zqSnzkZqJcX5YAHQx6Z1CuX8jdnFMH6qktRrPU1tfpNYftTWu3QIoHiMA=="], @@ -853,7 +855,7 @@ "camelcase-css": ["camelcase-css@2.0.1", "", {}, "sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA=="], - "caniuse-lite": ["caniuse-lite@1.0.30001761", "", {}, "sha512-JF9ptu1vP2coz98+5051jZ4PwQgd2ni8A+gYSN7EA7dPKIMf0pDlSUxhdmVOaV3/fYK5uWBkgSXJaRLr4+3A6g=="], + "caniuse-lite": ["caniuse-lite@1.0.30001763", "", {}, "sha512-mh/dGtq56uN98LlNX9qdbKnzINhX0QzhiWBFEkFfsFO4QyCvL8YegrJAazCwXIeqkIob8BlZPGM3xdnY+sgmvQ=="], "chai": ["chai@6.2.1", "", {}, "sha512-p4Z49OGG5W/WBCPSS/dH3jQ73kD6tiMmUM+bckNK6Jr5JHMG3k9bg/BvKR8lKmtVBKmOiuVaV2ws8s9oSbwysg=="], @@ -1107,7 +1109,7 @@ "global-agent": ["global-agent@3.0.0", "", { "dependencies": { "boolean": "^3.0.1", "es6-error": "^4.1.1", "matcher": "^3.0.0", "roarr": "^2.15.3", "semver": "^7.3.2", "serialize-error": "^7.0.1" } }, "sha512-PT6XReJ+D07JvGoxQMkT6qji/jVNfX/h364XHZOWeRzy64sSFr+xJ5OX7LI3b4MPQzdL4H8Y8M0xzPpsVMwA8Q=="], - "globals": ["globals@16.5.0", "", {}, "sha512-c/c15i26VrJ4IRt5Z89DnIzCGDn9EcebibhAOjw5ibqEHsE1wLUgkPn9RDmNcUKyU87GeaL633nyJ+pplFR2ZQ=="], + "globals": ["globals@17.0.0", "", {}, "sha512-gv5BeD2EssA793rlFWVPMMCqefTlpusw6/2TbAVMy0FzcG8wKJn4O+NqJ4+XWmmwrayJgw5TzrmWjFgmz1XPqw=="], "globalthis": ["globalthis@1.0.4", "", { "dependencies": { "define-properties": "^1.2.1", "gopd": "^1.0.1" } }, "sha512-DpLKbNU4WylpxJykQujfCcwYWiV/Jhm50Goo0wrVILAv5jOr9d+H+UR3PhSCD2rCCEIg0uc+G+muBTwD54JhDQ=="], @@ -1241,8 +1243,6 @@ "json-buffer": ["json-buffer@3.0.1", "", {}, "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ=="], - "json-schema-library": ["json-schema-library@10.5.1", "", { "dependencies": { "@sagold/json-pointer": "^7.2.0", "@sagold/json-query": "^6.2.0", "deepmerge": "^4.3.1", "fast-copy": "^3.0.2", "fast-deep-equal": "^3.1.3", "smtp-address-parser": "1.0.10", "uri-js": "^4.4.1", "valid-url": "^1.0.9" } }, "sha512-QDKmtWbgHoxzZEBZ3XESZBQprpgfSlOezQC+wKukZJzNOlBc8nomWZxYBY4qFGKawmtWkLRmZUDW34WlKVhAug=="], - "json-schema-traverse": ["json-schema-traverse@0.4.1", "", {}, "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg=="], "json-stable-stringify-without-jsonify": ["json-stable-stringify-without-jsonify@1.0.1", "", {}, "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw=="], @@ -1643,7 +1643,7 @@ "ts-algebra": ["ts-algebra@2.0.0", "", {}, "sha512-FPAhNPFMrkwz76P7cdjdmiShwMynZYN6SgOujD1urY4oNm80Ou9oMdmbR45LotcKOXoy7wSmHkRFE6Mxbrhefw=="], - "ts-api-utils": ["ts-api-utils@2.1.0", "", { "peerDependencies": { "typescript": ">=4.8.4" } }, "sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ=="], + "ts-api-utils": ["ts-api-utils@2.4.0", "", { "peerDependencies": { "typescript": ">=4.8.4" } }, "sha512-3TaVTaAv2gTiMB35i3FiGJaRfwb3Pyn/j3m/bfAvGe8FB7CF6u+LMYqYlDh7reQf7UNvoTvdfAqHGmPGOSsPmA=="], "ts-interface-checker": ["ts-interface-checker@0.1.13", "", {}, "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA=="], @@ -1651,19 +1651,19 @@ "tunnel-agent": ["tunnel-agent@0.6.0", "", { "dependencies": { "safe-buffer": "^5.0.1" } }, "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w=="], - "turbo": ["turbo@2.7.2", "", { "optionalDependencies": { "turbo-darwin-64": "2.7.2", "turbo-darwin-arm64": "2.7.2", "turbo-linux-64": "2.7.2", "turbo-linux-arm64": "2.7.2", "turbo-windows-64": "2.7.2", "turbo-windows-arm64": "2.7.2" }, "bin": { "turbo": "bin/turbo" } }, "sha512-5JIA5aYBAJSAhrhbyag1ZuMSgUZnHtI+Sq3H8D3an4fL8PeF+L1yYvbEJg47akP1PFfATMf5ehkqFnxfkmuwZQ=="], + "turbo": ["turbo@2.7.3", "", { "optionalDependencies": { "turbo-darwin-64": "2.7.3", "turbo-darwin-arm64": "2.7.3", "turbo-linux-64": "2.7.3", "turbo-linux-arm64": "2.7.3", "turbo-windows-64": "2.7.3", "turbo-windows-arm64": "2.7.3" }, "bin": { "turbo": "bin/turbo" } }, "sha512-+HjKlP4OfYk+qzvWNETA3cUO5UuK6b5MSc2UJOKyvBceKucQoQGb2g7HlC2H1GHdkfKrk4YF1VPvROkhVZDDLQ=="], - "turbo-darwin-64": ["turbo-darwin-64@2.7.2", "", { "os": "darwin", "cpu": "x64" }, "sha512-dxY3X6ezcT5vm3coK6VGixbrhplbQMwgNsCsvZamS/+/6JiebqW9DKt4NwpgYXhDY2HdH00I7FWs3wkVuan4rA=="], + "turbo-darwin-64": ["turbo-darwin-64@2.7.3", "", { "os": "darwin", "cpu": "x64" }, "sha512-aZHhvRiRHXbJw1EcEAq4aws1hsVVUZ9DPuSFaq9VVFAKCup7niIEwc22glxb7240yYEr1vLafdQ2U294Vcwz+w=="], - "turbo-darwin-arm64": ["turbo-darwin-arm64@2.7.2", "", { "os": "darwin", "cpu": "arm64" }, "sha512-1bXmuwPLqNFt3mzrtYcVx1sdJ8UYb124Bf48nIgcpMCGZy3kDhgxNv1503kmuK/37OGOZbsWSQFU4I08feIuSg=="], + "turbo-darwin-arm64": ["turbo-darwin-arm64@2.7.3", "", { "os": "darwin", "cpu": "arm64" }, "sha512-CkVrHSq+Bnhl9sX2LQgqQYVfLTWC2gvI74C4758OmU0djfrssDKU9d4YQF0AYXXhIIRZipSXfxClQziIMD+EAg=="], - "turbo-linux-64": ["turbo-linux-64@2.7.2", "", { "os": "linux", "cpu": "x64" }, "sha512-kP+TiiMaiPugbRlv57VGLfcjFNsFbo8H64wMBCPV2270Or2TpDCBULMzZrvEsvWFjT3pBFvToYbdp8/Kw0jAQg=="], + "turbo-linux-64": ["turbo-linux-64@2.7.3", "", { "os": "linux", "cpu": "x64" }, "sha512-GqDsCNnzzr89kMaLGpRALyigUklzgxIrSy2pHZVXyifgczvYPnLglex78Aj3T2gu+T3trPPH2iJ+pWucVOCC2Q=="], - "turbo-linux-arm64": ["turbo-linux-arm64@2.7.2", "", { "os": "linux", "cpu": "arm64" }, "sha512-VDJwQ0+8zjAfbyY6boNaWfP6RIez4ypKHxwkuB6SrWbOSk+vxTyW5/hEjytTwK8w/TsbKVcMDyvpora8tEsRFw=="], + "turbo-linux-arm64": ["turbo-linux-arm64@2.7.3", "", { "os": "linux", "cpu": "arm64" }, "sha512-NdCDTfIcIo3dWjsiaAHlxu5gW61Ed/8maah1IAF/9E3EtX0aAHNiBMbuYLZaR4vRJ7BeVkYB6xKWRtdFLZ0y3g=="], - "turbo-windows-64": ["turbo-windows-64@2.7.2", "", { "os": "win32", "cpu": "x64" }, "sha512-rPjqQXVnI6A6oxgzNEE8DNb6Vdj2Wwyhfv3oDc+YM3U9P7CAcBIlKv/868mKl4vsBtz4ouWpTQNXG8vljgJO+w=="], + "turbo-windows-64": ["turbo-windows-64@2.7.3", "", { "os": "win32", "cpu": "x64" }, "sha512-7bVvO987daXGSJVYBoG8R4Q+csT1pKIgLJYZevXRQ0Hqw0Vv4mKme/TOjYXs9Qb1xMKh51Tb3bXKDbd8/4G08g=="], - "turbo-windows-arm64": ["turbo-windows-arm64@2.7.2", "", { "os": "win32", "cpu": "arm64" }, "sha512-tcnHvBhO515OheIFWdxA+qUvZzNqqcHbLVFc1+n+TJ1rrp8prYicQtbtmsiKgMvr/54jb9jOabU62URAobnB7g=="], + "turbo-windows-arm64": ["turbo-windows-arm64@2.7.3", "", { "os": "win32", "cpu": "arm64" }, "sha512-nTodweTbPmkvwMu/a55XvjMsPtuyUSC+sV7f/SR57K36rB2I0YG21qNETN+00LOTUW9B3omd8XkiXJkt4kx/cw=="], "type-check": ["type-check@0.4.0", "", { "dependencies": { "prelude-ls": "^1.2.1" } }, "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew=="], @@ -1777,6 +1777,8 @@ "@typescript-eslint/typescript-estree/minimatch": ["minimatch@9.0.5", "", { "dependencies": { "brace-expansion": "^2.0.1" } }, "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow=="], + "@typescript-eslint/utils/@eslint-community/eslint-utils": ["@eslint-community/eslint-utils@4.9.1", "", { "dependencies": { "eslint-visitor-keys": "^3.4.3" }, "peerDependencies": { "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" } }, "sha512-phrYmNiYppR7znFEdqgfWHXR6NCkZEK7hwWDHZUjit/2/U0r6XvkDl0SYnoM51Hq7FhCGdLDT6zxCCOY1hexsQ=="], + "@workglow/web/@types/react": ["@types/react@19.2.7", "", { "dependencies": { "csstype": "^3.2.2" } }, "sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg=="], "@workglow/web/react": ["react@19.2.1", "", {}, "sha512-DGrYcCWK7tvYMnWh79yrPHt+vdx9tY+1gPZa7nJQtO/p8bLTDaHp4dzwEhQB7pZ4Xe3ok4XKuEPrVuc+wlpkmw=="], @@ -1859,6 +1861,8 @@ "@typescript-eslint/typescript-estree/minimatch/brace-expansion": ["brace-expansion@2.0.2", "", { "dependencies": { "balanced-match": "^1.0.0" } }, "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ=="], + "@typescript-eslint/utils/@eslint-community/eslint-utils/eslint-visitor-keys": ["eslint-visitor-keys@3.4.3", "", {}, "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag=="], + "cliui/string-width/emoji-regex": ["emoji-regex@8.0.0", "", {}, "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A=="], "cliui/string-width/is-fullwidth-code-point": ["is-fullwidth-code-point@3.0.0", "", {}, "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg=="], diff --git a/package.json b/package.json index 225eed16..e3b41125 100644 --- a/package.json +++ b/package.json @@ -32,7 +32,7 @@ "publish": "bun ./scripts/publish-workspaces.ts" }, "dependencies": { - "caniuse-lite": "^1.0.30001761" + "caniuse-lite": "^1.0.30001763" }, "catalog": { "@sroussey/transformers": "3.8.2", @@ -44,17 +44,17 @@ "devDependencies": { "@sroussey/changesets-cli": "^2.29.7", "@types/bun": "^1.3.5", - "@typescript-eslint/eslint-plugin": "^8.50.1", - "@typescript-eslint/parser": "^8.50.1", + "@typescript-eslint/eslint-plugin": "^8.52.0", + "@typescript-eslint/parser": "^8.52.0", "concurrently": "^9.2.1", "eslint": "^9.39.2", "eslint-plugin-jsx-a11y": "^6.10.2", "eslint-plugin-react": "^7.37.5", "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-regexp": "^2.10.0", - "globals": "^16.5.0", + "globals": "^17.0.0", "prettier": "^3.7.4", - "turbo": "^2.7.2", + "turbo": "^2.7.3", "typescript": "5.9.3", "vitest": "^4.0.16" }, From 74d0847c041eb264650dd7e943821f5f5c522b7f Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 11 Jan 2026 04:32:14 +0000 Subject: [PATCH 02/14] [refactor] Remove provenance tracking from tasks and dataflows, simplify to remove ArrayTask from JobQueueTask parentage - Eliminated provenance tracking from the TaskGraphRunner, Task, and Dataflow classes to simplify the architecture. - Updated related documentation to reflect the removal of provenance references. - Adjusted task execution methods to no longer require provenance input, enhancing clarity and reducing complexity in task management. - Refactored tests to align with the updated task structure and removed any assertions related to provenance. --- docs/background/06_run_graph_orchestration.md | 2 +- docs/background/07_intrumentation.md | 2 +- docs/developers/02_architecture.md | 4 +- packages/ai-provider/README.md | 2 +- .../hf-transformers/common/HFT_JobRunFns.ts | 198 +++++---- .../src/tf-mediapipe/common/TFMP_JobRunFns.ts | 133 +++--- packages/ai/src/task/BackgroundRemovalTask.ts | 25 +- packages/ai/src/task/DownloadModelTask.ts | 19 +- packages/ai/src/task/FaceDetectorTask.ts | 15 +- packages/ai/src/task/FaceLandmarkerTask.ts | 19 +- packages/ai/src/task/GestureRecognizerTask.ts | 19 +- packages/ai/src/task/HandLandmarkerTask.ts | 19 +- .../ai/src/task/ImageClassificationTask.ts | 20 +- packages/ai/src/task/ImageEmbeddingTask.ts | 35 +- packages/ai/src/task/ImageSegmentationTask.ts | 19 +- packages/ai/src/task/ImageToTextTask.ts | 15 +- packages/ai/src/task/ObjectDetectionTask.ts | 20 +- packages/ai/src/task/PoseLandmarkerTask.ts | 19 +- .../ai/src/task/TextClassificationTask.ts | 16 +- packages/ai/src/task/TextEmbeddingTask.ts | 23 +- packages/ai/src/task/TextFillMaskTask.ts | 12 +- packages/ai/src/task/TextGenerationTask.ts | 22 +- .../ai/src/task/TextLanguageDetectionTask.ts | 16 +- .../task/TextNamedEntityRecognitionTask.ts | 18 +- .../ai/src/task/TextQuestionAnswerTask.ts | 22 +- packages/ai/src/task/TextRewriterTask.ts | 25 +- packages/ai/src/task/TextSummaryTask.ts | 21 +- packages/ai/src/task/TextTranslationTask.ts | 60 ++- packages/ai/src/task/UnloadModelTask.ts | 17 +- packages/task-graph/README.md | 2 +- .../task-graph/src/task-graph/Dataflow.ts | 7 +- .../task-graph/src/task-graph/ITaskGraph.ts | 2 +- packages/task-graph/src/task-graph/README.md | 2 - .../task-graph/src/task-graph/TaskGraph.ts | 5 +- .../src/task-graph/TaskGraphRunner.ts | 61 +-- .../task-graph/src/task-graph/Workflow.ts | 322 ++++++++++++-- packages/task-graph/src/task/ArrayTask.ts | 50 ++- .../task-graph/src/task/GraphAsTaskRunner.ts | 34 +- packages/task-graph/src/task/ITask.ts | 8 +- packages/task-graph/src/task/JobQueueTask.ts | 6 +- packages/task-graph/src/task/README.md | 27 +- packages/task-graph/src/task/Task.ts | 108 ++++- packages/task-graph/src/task/TaskEvents.ts | 2 +- packages/task-graph/src/task/TaskJSON.test.ts | 20 +- packages/task-graph/src/task/TaskJSON.ts | 11 +- packages/task-graph/src/task/TaskRunner.ts | 14 +- packages/task-graph/src/task/TaskTypes.ts | 4 - packages/tasks/src/task/DebugLogTask.ts | 4 +- packages/tasks/src/task/DelayTask.ts | 4 +- packages/tasks/src/task/FetchUrlTask.ts | 4 +- .../tasks/src/task/FileLoaderTask.server.ts | 2 +- packages/tasks/src/task/FileLoaderTask.ts | 2 +- packages/tasks/src/task/JavaScriptTask.ts | 2 +- packages/tasks/src/task/JsonTask.ts | 2 +- packages/tasks/src/task/MergeTask.ts | 4 +- packages/tasks/src/task/SplitTask.ts | 4 +- packages/test/src/samples/ONNXModelSamples.ts | 12 - .../test/src/test/task-graph/Workflow.test.ts | 317 ++++++++++++++ .../src/test/task/Task.smartClone.test.ts | 203 +++++++++ packages/test/src/test/task/TestTasks.ts | 393 +++++++++++++++++- packages/util/package.json | 2 +- .../util/src/json-schema/SchemaValidation.ts | 4 +- 62 files changed, 1761 insertions(+), 720 deletions(-) create mode 100644 packages/test/src/test/task/Task.smartClone.test.ts diff --git a/docs/background/06_run_graph_orchestration.md b/docs/background/06_run_graph_orchestration.md index 200f9da6..ac809fb6 100644 --- a/docs/background/06_run_graph_orchestration.md +++ b/docs/background/06_run_graph_orchestration.md @@ -10,7 +10,7 @@ The editor DAG is defined by the end user and saved in the database (tasks and d ## Graph -The graph is a DAG. It is a list of tasks and a list of dataflows. The tasks are the nodes and the dataflows are the connections between task outputs and inputs, plus status and provenance. +The graph is a DAG. It is a list of tasks and a list of dataflows. The tasks are the nodes and the dataflows are the connections between task outputs and inputs, plus status. We expose events for graphs, tasks, and dataflows. A suspend/resume could be added for bulk creation. This helps keep a UI in sync as the graph runs. diff --git a/docs/background/07_intrumentation.md b/docs/background/07_intrumentation.md index 6579d775..4b3ace64 100644 --- a/docs/background/07_intrumentation.md +++ b/docs/background/07_intrumentation.md @@ -7,5 +7,5 @@ Instrumentation is the process of adding code to a program to collect data about Some of these tools cost money, so we need to track and estimate costs. - Tasks emit status/progress events (`TaskStatus`, progress percent) -- Dataflows emit start/complete/error events and carry provenance +- Dataflows emit start/complete/error events - Task graphs emit start/progress/complete/error events diff --git a/docs/developers/02_architecture.md b/docs/developers/02_architecture.md index 680605c4..49fb3bd0 100644 --- a/docs/developers/02_architecture.md +++ b/docs/developers/02_architecture.md @@ -242,11 +242,10 @@ classDiagram class TaskGraphRunner{ Map layers - Map provenanceInput TaskGraph dag TaskOutputRepository repository assignLayers(Task[] sortedNodes) - runGraph(TaskInput parentProvenance) TaskOutput + runGraph(TaskInput input) TaskOutput runGraphReactive() TaskOutput } @@ -255,7 +254,6 @@ classDiagram The TaskGraphRunner is responsible for executing tasks in a task graph. Key features include: - **Layer-based Execution**: Tasks are organized into layers based on dependencies, allowing parallel execution of independent tasks -- **Provenance Tracking**: Tracks the lineage and input data that led to each task's output - **Caching Support**: Can use a TaskOutputRepository to cache task outputs and avoid re-running tasks - **Reactive Mode**: Supports reactive execution where tasks can respond to input changes without full re-execution - **Smart Task Scheduling**: Automatically determines task execution order based on dependencies diff --git a/packages/ai-provider/README.md b/packages/ai-provider/README.md index b652c88b..87d559a6 100644 --- a/packages/ai-provider/README.md +++ b/packages/ai-provider/README.md @@ -138,7 +138,7 @@ const task = new TextEmbeddingTask({ }); const result = await task.execute(); -// result.vector: TypedArray - Vector embedding +// result.vector: Vector - Vector embedding ``` **Text Translation:** diff --git a/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts b/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts index 3f48d98c..a1327d6d 100644 --- a/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts +++ b/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts @@ -37,44 +37,45 @@ import { } from "@sroussey/transformers"; import type { AiProviderRunFn, - BackgroundRemovalTaskExecuteInput, - BackgroundRemovalTaskExecuteOutput, - DownloadModelTaskExecuteInput, - DownloadModelTaskExecuteOutput, - ImageClassificationTaskExecuteInput, - ImageClassificationTaskExecuteOutput, - ImageEmbeddingTaskExecuteInput, - ImageEmbeddingTaskExecuteOutput, - ImageSegmentationTaskExecuteInput, - ImageSegmentationTaskExecuteOutput, - ImageToTextTaskExecuteInput, - ImageToTextTaskExecuteOutput, - ObjectDetectionTaskExecuteInput, - ObjectDetectionTaskExecuteOutput, - TextClassificationTaskExecuteInput, - TextClassificationTaskExecuteOutput, - TextEmbeddingTaskExecuteInput, - TextEmbeddingTaskExecuteOutput, - TextFillMaskTaskExecuteInput, - TextFillMaskTaskExecuteOutput, - TextGenerationTaskExecuteInput, - TextGenerationTaskExecuteOutput, - TextLanguageDetectionTaskExecuteInput, - TextLanguageDetectionTaskExecuteOutput, - TextNamedEntityRecognitionTaskExecuteInput, - TextNamedEntityRecognitionTaskExecuteOutput, - TextQuestionAnswerTaskExecuteInput, - TextQuestionAnswerTaskExecuteOutput, - TextRewriterTaskExecuteInput, - TextRewriterTaskExecuteOutput, - TextSummaryTaskExecuteInput, - TextSummaryTaskExecuteOutput, - TextTranslationTaskExecuteInput, - TextTranslationTaskExecuteOutput, - TypedArray, - UnloadModelTaskExecuteInput, - UnloadModelTaskExecuteOutput, + BackgroundRemovalTaskInput, + BackgroundRemovalTaskOutput, + DownloadModelTaskRunInput, + DownloadModelTaskRunOutput, + ImageClassificationTaskInput, + ImageClassificationTaskOutput, + ImageEmbeddingTaskInput, + ImageEmbeddingTaskOutput, + ImageSegmentationTaskInput, + ImageSegmentationTaskOutput, + ImageToTextTaskInput, + ImageToTextTaskOutput, + ObjectDetectionTaskInput, + ObjectDetectionTaskOutput, + TextClassificationTaskInput, + TextClassificationTaskOutput, + TextEmbeddingTaskInput, + TextEmbeddingTaskOutput, + TextFillMaskTaskInput, + TextFillMaskTaskOutput, + TextGenerationTaskInput, + TextGenerationTaskOutput, + TextLanguageDetectionTaskInput, + TextLanguageDetectionTaskOutput, + TextNamedEntityRecognitionTaskInput, + TextNamedEntityRecognitionTaskOutput, + TextQuestionAnswerTaskInput, + TextQuestionAnswerTaskOutput, + TextRewriterTaskInput, + TextRewriterTaskOutput, + TextSummaryTaskInput, + TextSummaryTaskOutput, + TextTranslationTaskInput, + TextTranslationTaskOutput, + UnloadModelTaskRunInput, + UnloadModelTaskRunOutput, } from "@workglow/ai"; + +import { TypedArray } from "@workglow/util"; import { CallbackStatus } from "./HFT_CallbackStatus"; import { HTF_CACHE_NAME } from "./HFT_Constants"; import { HfTransformersOnnxModelConfig } from "./HFT_ModelSchema"; @@ -441,8 +442,8 @@ const getPipeline = async ( * This is shared between inline and worker implementations. */ export const HFT_Download: AiProviderRunFn< - DownloadModelTaskExecuteInput, - DownloadModelTaskExecuteOutput, + DownloadModelTaskRunInput, + DownloadModelTaskRunOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Download the model by creating a pipeline @@ -459,8 +460,8 @@ export const HFT_Download: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_Unload: AiProviderRunFn< - UnloadModelTaskExecuteInput, - UnloadModelTaskExecuteOutput, + UnloadModelTaskRunInput, + UnloadModelTaskRunOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Delete the pipeline from the in-memory map @@ -524,8 +525,8 @@ const deleteModelCache = async (model_path: string): Promise => { */ export const HFT_TextEmbedding: AiProviderRunFn< - TextEmbeddingTaskExecuteInput, - TextEmbeddingTaskExecuteOutput, + TextEmbeddingTaskInput, + TextEmbeddingTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateEmbedding: FeatureExtractionPipeline = await getPipeline(model!, onProgress, { @@ -539,15 +540,47 @@ export const HFT_TextEmbedding: AiProviderRunFn< ...(signal ? { abort_signal: signal } : {}), }); - // Validate the embedding dimensions - if (hfVector.size !== model?.provider_config.native_dimensions) { + const isArrayInput = Array.isArray(input.text); + const embeddingDim = model?.provider_config.native_dimensions; + + // If the input is an array, the tensor will have multiple dimensions (e.g., [10, 384]) + // We need to split it into separate vectors for each input text + if (isArrayInput && hfVector.dims.length > 1) { + const [numTexts, vectorDim] = hfVector.dims; + + // Validate that the number of texts matches + if (numTexts !== input.text.length) { + throw new Error( + `HuggingFace Embedding tensor batch size does not match input array length: ${numTexts} != ${input.text.length}` + ); + } + + // Validate dimensions + if (vectorDim !== embeddingDim) { + throw new Error( + `HuggingFace Embedding vector dimension does not match model dimensions: ${vectorDim} != ${embeddingDim}` + ); + } + + // Extract each embedding vector using tensor indexing + // hfVector[i] returns a sub-tensor for the i-th text + const vectors: TypedArray[] = Array.from( + { length: numTexts }, + (_, i) => (hfVector as any)[i].data as TypedArray + ); + + return { vector: vectors }; + } + + // Single text input - validate dimensions + if (hfVector.size !== embeddingDim) { console.warn( - `HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${model?.provider_config.native_dimensions}`, + `HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${embeddingDim}`, input, hfVector ); throw new Error( - `HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${model?.provider_config.native_dimensions}` + `HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${embeddingDim}` ); } @@ -555,8 +588,8 @@ export const HFT_TextEmbedding: AiProviderRunFn< }; export const HFT_TextClassification: AiProviderRunFn< - TextClassificationTaskExecuteInput, - TextClassificationTaskExecuteOutput, + TextClassificationTaskInput, + TextClassificationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.provider_config?.pipeline === "zero-shot-classification") { @@ -611,8 +644,8 @@ export const HFT_TextClassification: AiProviderRunFn< }; export const HFT_TextLanguageDetection: AiProviderRunFn< - TextLanguageDetectionTaskExecuteInput, - TextLanguageDetectionTaskExecuteOutput, + TextLanguageDetectionTaskInput, + TextLanguageDetectionTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const TextClassification: TextClassificationPipeline = await getPipeline(model!, onProgress, { @@ -641,8 +674,8 @@ export const HFT_TextLanguageDetection: AiProviderRunFn< }; export const HFT_TextNamedEntityRecognition: AiProviderRunFn< - TextNamedEntityRecognitionTaskExecuteInput, - TextNamedEntityRecognitionTaskExecuteOutput, + TextNamedEntityRecognitionTaskInput, + TextNamedEntityRecognitionTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const textNamedEntityRecognition: TokenClassificationPipeline = await getPipeline( @@ -672,8 +705,8 @@ export const HFT_TextNamedEntityRecognition: AiProviderRunFn< }; export const HFT_TextFillMask: AiProviderRunFn< - TextFillMaskTaskExecuteInput, - TextFillMaskTaskExecuteOutput, + TextFillMaskTaskInput, + TextFillMaskTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const unmasker: FillMaskPipeline = await getPipeline(model!, onProgress, { @@ -700,8 +733,8 @@ export const HFT_TextFillMask: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextGeneration: AiProviderRunFn< - TextGenerationTaskExecuteInput, - TextGenerationTaskExecuteOutput, + TextGenerationTaskInput, + TextGenerationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateText: TextGenerationPipeline = await getPipeline(model!, onProgress, { @@ -733,8 +766,8 @@ export const HFT_TextGeneration: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextTranslation: AiProviderRunFn< - TextTranslationTaskExecuteInput, - TextTranslationTaskExecuteOutput, + TextTranslationTaskInput, + TextTranslationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const translate: TranslationPipeline = await getPipeline(model!, onProgress, { @@ -749,12 +782,9 @@ export const HFT_TextTranslation: AiProviderRunFn< ...(signal ? { abort_signal: signal } : {}), } as any); - let translatedText: string | string[] = ""; - if (Array.isArray(result)) { - translatedText = result.map((r) => (r as TranslationSingle)?.translation_text || ""); - } else { - translatedText = (result as TranslationSingle)?.translation_text || ""; - } + const translatedText = Array.isArray(result) + ? (result[0] as TranslationSingle)?.translation_text || "" + : (result as TranslationSingle)?.translation_text || ""; return { text: translatedText, @@ -767,8 +797,8 @@ export const HFT_TextTranslation: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextRewriter: AiProviderRunFn< - TextRewriterTaskExecuteInput, - TextRewriterTaskExecuteOutput, + TextRewriterTaskInput, + TextRewriterTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateText: TextGenerationPipeline = await getPipeline(model!, onProgress, { @@ -807,8 +837,8 @@ export const HFT_TextRewriter: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextSummary: AiProviderRunFn< - TextSummaryTaskExecuteInput, - TextSummaryTaskExecuteOutput, + TextSummaryTaskInput, + TextSummaryTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateSummary: SummarizationPipeline = await getPipeline(model!, onProgress, { @@ -838,8 +868,8 @@ export const HFT_TextSummary: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextQuestionAnswer: AiProviderRunFn< - TextQuestionAnswerTaskExecuteInput, - TextQuestionAnswerTaskExecuteOutput, + TextQuestionAnswerTaskInput, + TextQuestionAnswerTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Get the question answering pipeline @@ -869,8 +899,8 @@ export const HFT_TextQuestionAnswer: AiProviderRunFn< * Core implementation for image segmentation using Hugging Face Transformers. */ export const HFT_ImageSegmentation: AiProviderRunFn< - ImageSegmentationTaskExecuteInput, - ImageSegmentationTaskExecuteOutput, + ImageSegmentationTaskInput, + ImageSegmentationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const segmenter: ImageSegmentationPipeline = await getPipeline(model!, onProgress, { @@ -902,8 +932,8 @@ export const HFT_ImageSegmentation: AiProviderRunFn< * Core implementation for image to text using Hugging Face Transformers. */ export const HFT_ImageToText: AiProviderRunFn< - ImageToTextTaskExecuteInput, - ImageToTextTaskExecuteOutput, + ImageToTextTaskInput, + ImageToTextTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const captioner: ImageToTextPipeline = await getPipeline(model!, onProgress, { @@ -926,8 +956,8 @@ export const HFT_ImageToText: AiProviderRunFn< * Core implementation for background removal using Hugging Face Transformers. */ export const HFT_BackgroundRemoval: AiProviderRunFn< - BackgroundRemovalTaskExecuteInput, - BackgroundRemovalTaskExecuteOutput, + BackgroundRemovalTaskInput, + BackgroundRemovalTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const remover: BackgroundRemovalPipeline = await getPipeline(model!, onProgress, { @@ -949,8 +979,8 @@ export const HFT_BackgroundRemoval: AiProviderRunFn< * Core implementation for image embedding using Hugging Face Transformers. */ export const HFT_ImageEmbedding: AiProviderRunFn< - ImageEmbeddingTaskExecuteInput, - ImageEmbeddingTaskExecuteOutput, + ImageEmbeddingTaskInput, + ImageEmbeddingTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const embedder: ImageFeatureExtractionPipeline = await getPipeline(model!, onProgress, { @@ -961,7 +991,7 @@ export const HFT_ImageEmbedding: AiProviderRunFn< return { vector: result.data as TypedArray, - }; + } as ImageEmbeddingTaskOutput; }; /** @@ -969,8 +999,8 @@ export const HFT_ImageEmbedding: AiProviderRunFn< * Auto-selects between regular and zero-shot classification. */ export const HFT_ImageClassification: AiProviderRunFn< - ImageClassificationTaskExecuteInput, - ImageClassificationTaskExecuteOutput, + ImageClassificationTaskInput, + ImageClassificationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.provider_config?.pipeline === "zero-shot-image-classification") { @@ -1024,8 +1054,8 @@ export const HFT_ImageClassification: AiProviderRunFn< * Auto-selects between regular and zero-shot detection. */ export const HFT_ObjectDetection: AiProviderRunFn< - ObjectDetectionTaskExecuteInput, - ObjectDetectionTaskExecuteOutput, + ObjectDetectionTaskInput, + ObjectDetectionTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.provider_config?.pipeline === "zero-shot-object-detection") { diff --git a/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts b/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts index d01ff205..0f82e636 100644 --- a/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts +++ b/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts @@ -23,34 +23,34 @@ import { } from "@mediapipe/tasks-vision"; import type { AiProviderRunFn, - DownloadModelTaskExecuteInput, - DownloadModelTaskExecuteOutput, - FaceDetectorTaskExecuteInput, - FaceDetectorTaskExecuteOutput, - FaceLandmarkerTaskExecuteInput, - FaceLandmarkerTaskExecuteOutput, - GestureRecognizerTaskExecuteInput, - GestureRecognizerTaskExecuteOutput, - HandLandmarkerTaskExecuteInput, - HandLandmarkerTaskExecuteOutput, - ImageClassificationTaskExecuteInput, - ImageClassificationTaskExecuteOutput, - ImageEmbeddingTaskExecuteInput, - ImageEmbeddingTaskExecuteOutput, - ImageSegmentationTaskExecuteInput, - ImageSegmentationTaskExecuteOutput, - ObjectDetectionTaskExecuteInput, - ObjectDetectionTaskExecuteOutput, - PoseLandmarkerTaskExecuteInput, - PoseLandmarkerTaskExecuteOutput, - TextClassificationTaskExecuteInput, - TextClassificationTaskExecuteOutput, - TextEmbeddingTaskExecuteInput, - TextEmbeddingTaskExecuteOutput, - TextLanguageDetectionTaskExecuteInput, - TextLanguageDetectionTaskExecuteOutput, - UnloadModelTaskExecuteInput, - UnloadModelTaskExecuteOutput, + DownloadModelTaskRunInput, + DownloadModelTaskRunOutput, + FaceDetectorTaskInput, + FaceDetectorTaskOutput, + FaceLandmarkerTaskInput, + FaceLandmarkerTaskOutput, + GestureRecognizerTaskInput, + GestureRecognizerTaskOutput, + HandLandmarkerTaskInput, + HandLandmarkerTaskOutput, + ImageClassificationTaskInput, + ImageClassificationTaskOutput, + ImageEmbeddingTaskInput, + ImageEmbeddingTaskOutput, + ImageSegmentationTaskInput, + ImageSegmentationTaskOutput, + ObjectDetectionTaskInput, + ObjectDetectionTaskOutput, + PoseLandmarkerTaskInput, + PoseLandmarkerTaskOutput, + TextClassificationTaskInput, + TextClassificationTaskOutput, + TextEmbeddingTaskInput, + TextEmbeddingTaskOutput, + TextLanguageDetectionTaskInput, + TextLanguageDetectionTaskOutput, + UnloadModelTaskRunInput, + UnloadModelTaskRunOutput, } from "@workglow/ai"; import { PermanentJobError } from "@workglow/job-queue"; import { TFMPModelConfig } from "./TFMP_ModelSchema"; @@ -262,8 +262,8 @@ const getModelTask = async ( * This is shared between inline and worker implementations. */ export const TFMP_Download: AiProviderRunFn< - DownloadModelTaskExecuteInput, - DownloadModelTaskExecuteOutput, + DownloadModelTaskRunInput, + DownloadModelTaskRunOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { let task: TaskInstance; @@ -327,11 +327,30 @@ export const TFMP_Download: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const TFMP_TextEmbedding: AiProviderRunFn< - TextEmbeddingTaskExecuteInput, - TextEmbeddingTaskExecuteOutput, + TextEmbeddingTaskInput, + TextEmbeddingTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const textEmbedder = await getModelTask(model!, {}, onProgress, signal, TextEmbedder); + + // Handle array of texts + if (Array.isArray(input.text)) { + const embeddings = input.text.map((text) => { + const result = textEmbedder.embed(text); + + if (!result.embeddings?.[0]?.floatEmbedding) { + throw new PermanentJobError("Failed to generate embedding: Empty result"); + } + + return Float32Array.from(result.embeddings[0].floatEmbedding); + }); + + return { + vector: embeddings, + }; + } + + // Handle single text const result = textEmbedder.embed(input.text); if (!result.embeddings?.[0]?.floatEmbedding) { @@ -350,8 +369,8 @@ export const TFMP_TextEmbedding: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const TFMP_TextClassification: AiProviderRunFn< - TextClassificationTaskExecuteInput, - TextClassificationTaskExecuteOutput, + TextClassificationTaskInput, + TextClassificationTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const TextClassification = await getModelTask( @@ -387,8 +406,8 @@ export const TFMP_TextClassification: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const TFMP_TextLanguageDetection: AiProviderRunFn< - TextLanguageDetectionTaskExecuteInput, - TextLanguageDetectionTaskExecuteOutput, + TextLanguageDetectionTaskInput, + TextLanguageDetectionTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const maxLanguages = input.maxLanguages === 0 ? -1 : input.maxLanguages; @@ -431,8 +450,8 @@ export const TFMP_TextLanguageDetection: AiProviderRunFn< * 3. If no other models are using the WASM fileset (count reaches 0), unloads the WASM */ export const TFMP_Unload: AiProviderRunFn< - UnloadModelTaskExecuteInput, - UnloadModelTaskExecuteOutput, + UnloadModelTaskRunInput, + UnloadModelTaskRunOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const model_path = model!.provider_config.model_path; @@ -471,8 +490,8 @@ export const TFMP_Unload: AiProviderRunFn< * Core implementation for image segmentation using MediaPipe. */ export const TFMP_ImageSegmentation: AiProviderRunFn< - ImageSegmentationTaskExecuteInput, - ImageSegmentationTaskExecuteOutput, + ImageSegmentationTaskInput, + ImageSegmentationTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageSegmenter = await getModelTask(model!, {}, onProgress, signal, ImageSegmenter); @@ -504,8 +523,8 @@ export const TFMP_ImageSegmentation: AiProviderRunFn< * Core implementation for image embedding using MediaPipe. */ export const TFMP_ImageEmbedding: AiProviderRunFn< - ImageEmbeddingTaskExecuteInput, - ImageEmbeddingTaskExecuteOutput, + ImageEmbeddingTaskInput, + ImageEmbeddingTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageEmbedder = await getModelTask(model!, {}, onProgress, signal, ImageEmbedder); @@ -519,15 +538,15 @@ export const TFMP_ImageEmbedding: AiProviderRunFn< return { vector: embedding, - }; + } as ImageEmbeddingTaskOutput; }; /** * Core implementation for image classification using MediaPipe. */ export const TFMP_ImageClassification: AiProviderRunFn< - ImageClassificationTaskExecuteInput, - ImageClassificationTaskExecuteOutput, + ImageClassificationTaskInput, + ImageClassificationTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageClassifier = await getModelTask( @@ -559,8 +578,8 @@ export const TFMP_ImageClassification: AiProviderRunFn< * Core implementation for object detection using MediaPipe. */ export const TFMP_ObjectDetection: AiProviderRunFn< - ObjectDetectionTaskExecuteInput, - ObjectDetectionTaskExecuteOutput, + ObjectDetectionTaskInput, + ObjectDetectionTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const objectDetector = await getModelTask( @@ -598,8 +617,8 @@ export const TFMP_ObjectDetection: AiProviderRunFn< * Core implementation for gesture recognition using MediaPipe. */ export const TFMP_GestureRecognizer: AiProviderRunFn< - GestureRecognizerTaskExecuteInput, - GestureRecognizerTaskExecuteOutput, + GestureRecognizerTaskInput, + GestureRecognizerTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const gestureRecognizer = await getModelTask( @@ -650,8 +669,8 @@ export const TFMP_GestureRecognizer: AiProviderRunFn< * Core implementation for hand landmark detection using MediaPipe. */ export const TFMP_HandLandmarker: AiProviderRunFn< - HandLandmarkerTaskExecuteInput, - HandLandmarkerTaskExecuteOutput, + HandLandmarkerTaskInput, + HandLandmarkerTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const handLandmarker = await getModelTask( @@ -698,8 +717,8 @@ export const TFMP_HandLandmarker: AiProviderRunFn< * Core implementation for face detection using MediaPipe. */ export const TFMP_FaceDetector: AiProviderRunFn< - FaceDetectorTaskExecuteInput, - FaceDetectorTaskExecuteOutput, + FaceDetectorTaskInput, + FaceDetectorTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const faceDetector = await getModelTask( @@ -743,8 +762,8 @@ export const TFMP_FaceDetector: AiProviderRunFn< * Core implementation for face landmark detection using MediaPipe. */ export const TFMP_FaceLandmarker: AiProviderRunFn< - FaceLandmarkerTaskExecuteInput, - FaceLandmarkerTaskExecuteOutput, + FaceLandmarkerTaskInput, + FaceLandmarkerTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const faceLandmarker = await getModelTask( @@ -799,8 +818,8 @@ export const TFMP_FaceLandmarker: AiProviderRunFn< * Core implementation for pose landmark detection using MediaPipe. */ export const TFMP_PoseLandmarker: AiProviderRunFn< - PoseLandmarkerTaskExecuteInput, - PoseLandmarkerTaskExecuteOutput, + PoseLandmarkerTaskInput, + PoseLandmarkerTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const poseLandmarker = await getModelTask( diff --git a/packages/ai/src/task/BackgroundRemovalTask.ts b/packages/ai/src/task/BackgroundRemovalTask.ts index 5be33802..b3e1a81f 100644 --- a/packages/ai/src/task/BackgroundRemovalTask.ts +++ b/packages/ai/src/task/BackgroundRemovalTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:BackgroundRemovalTask")); +const modelSchema = TypeModel("model:BackgroundRemovalTask"); const processedImageSchema = { type: "string", @@ -27,7 +22,7 @@ const processedImageSchema = { export const BackgroundRemovalInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, }, required: ["image", "model"], @@ -37,11 +32,7 @@ export const BackgroundRemovalInputSchema = { export const BackgroundRemovalOutputSchema = { type: "object", properties: { - image: { - oneOf: [processedImageSchema, { type: "array", items: processedImageSchema }], - title: processedImageSchema.title, - description: processedImageSchema.description, - }, + image: processedImageSchema, }, required: ["image"], additionalProperties: false, @@ -49,12 +40,6 @@ export const BackgroundRemovalOutputSchema = { export type BackgroundRemovalTaskInput = FromSchema; export type BackgroundRemovalTaskOutput = FromSchema; -export type BackgroundRemovalTaskExecuteInput = DeReplicateFromSchema< - typeof BackgroundRemovalInputSchema ->; -export type BackgroundRemovalTaskExecuteOutput = DeReplicateFromSchema< - typeof BackgroundRemovalOutputSchema ->; /** * Removes backgrounds from images using computer vision models @@ -88,7 +73,7 @@ export const backgroundRemoval = ( input: BackgroundRemovalTaskInput, config?: JobQueueTaskConfig ) => { - return new BackgroundRemovalTask(input, config).run(); + return new BackgroundRemovalTask({} as BackgroundRemovalTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/DownloadModelTask.ts b/packages/ai/src/task/DownloadModelTask.ts index 6586f9b9..1fff2a26 100644 --- a/packages/ai/src/task/DownloadModelTask.ts +++ b/packages/ai/src/task/DownloadModelTask.ts @@ -4,12 +4,19 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { + CreateWorkflow, + DeReplicateFromSchema, + JobQueueTaskConfig, + TaskRegistry, + TypeReplicateArray, + Workflow, +} from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model")); +const modelSchema = TypeModel("model"); const DownloadModelInputSchema = { type: "object", @@ -31,10 +38,6 @@ const DownloadModelOutputSchema = { export type DownloadModelTaskRunInput = FromSchema; export type DownloadModelTaskRunOutput = FromSchema; -export type DownloadModelTaskExecuteInput = DeReplicateFromSchema; -export type DownloadModelTaskExecuteOutput = DeReplicateFromSchema< - typeof DownloadModelOutputSchema ->; /** * Download a model from a remote source and cache it locally. @@ -103,7 +106,7 @@ TaskRegistry.registerTask(DownloadModelTask); * @returns Promise resolving to the downloaded model(s) */ export const downloadModel = (input: DownloadModelTaskRunInput, config?: JobQueueTaskConfig) => { - return new DownloadModelTask(input, config).run(); + return new DownloadModelTask({} as DownloadModelTaskRunInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/FaceDetectorTask.ts b/packages/ai/src/task/FaceDetectorTask.ts index 7989a46f..465d8717 100644 --- a/packages/ai/src/task/FaceDetectorTask.ts +++ b/packages/ai/src/task/FaceDetectorTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:FaceDetectorTask")); +const modelSchema = TypeModel("model:FaceDetectorTask"); /** * A bounding box for face detection. @@ -99,7 +94,7 @@ const TypeFaceDetection = { export const FaceDetectorInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, minDetectionConfidence: { type: "number", @@ -142,8 +137,6 @@ export const FaceDetectorOutputSchema = { export type FaceDetectorTaskInput = FromSchema; export type FaceDetectorTaskOutput = FromSchema; -export type FaceDetectorTaskExecuteInput = DeReplicateFromSchema; -export type FaceDetectorTaskExecuteOutput = DeReplicateFromSchema; /** * Detects faces in images using MediaPipe Face Detector. @@ -176,7 +169,7 @@ TaskRegistry.registerTask(FaceDetectorTask); * @returns Promise resolving to the detected faces with bounding boxes and keypoints */ export const faceDetector = (input: FaceDetectorTaskInput, config?: JobQueueTaskConfig) => { - return new FaceDetectorTask(input, config).run(); + return new FaceDetectorTask({} as FaceDetectorTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/FaceLandmarkerTask.ts b/packages/ai/src/task/FaceLandmarkerTask.ts index 2dc151bd..961bc436 100644 --- a/packages/ai/src/task/FaceLandmarkerTask.ts +++ b/packages/ai/src/task/FaceLandmarkerTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:FaceLandmarkerTask")); +const modelSchema = TypeModel("model:FaceLandmarkerTask"); /** * A landmark point with x, y, z coordinates. @@ -102,7 +97,7 @@ const TypeFaceLandmarkerDetection = { export const FaceLandmarkerInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, numFaces: { type: "number", @@ -177,12 +172,6 @@ export const FaceLandmarkerOutputSchema = { export type FaceLandmarkerTaskInput = FromSchema; export type FaceLandmarkerTaskOutput = FromSchema; -export type FaceLandmarkerTaskExecuteInput = DeReplicateFromSchema< - typeof FaceLandmarkerInputSchema ->; -export type FaceLandmarkerTaskExecuteOutput = DeReplicateFromSchema< - typeof FaceLandmarkerOutputSchema ->; /** * Detects facial landmarks and expressions in images using MediaPipe Face Landmarker. @@ -216,7 +205,7 @@ TaskRegistry.registerTask(FaceLandmarkerTask); * @returns Promise resolving to the detected facial landmarks, blendshapes, and transformation matrices */ export const faceLandmarker = (input: FaceLandmarkerTaskInput, config?: JobQueueTaskConfig) => { - return new FaceLandmarkerTask(input, config).run(); + return new FaceLandmarkerTask({} as FaceLandmarkerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/GestureRecognizerTask.ts b/packages/ai/src/task/GestureRecognizerTask.ts index 64868453..706b6c44 100644 --- a/packages/ai/src/task/GestureRecognizerTask.ts +++ b/packages/ai/src/task/GestureRecognizerTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:GestureRecognizerTask")); +const modelSchema = TypeModel("model:GestureRecognizerTask"); /** * A landmark point with x, y, z coordinates. @@ -122,7 +117,7 @@ const TypeHandGestureDetection = { export const GestureRecognizerInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, numHands: { type: "number", @@ -183,12 +178,6 @@ export const GestureRecognizerOutputSchema = { export type GestureRecognizerTaskInput = FromSchema; export type GestureRecognizerTaskOutput = FromSchema; -export type GestureRecognizerTaskExecuteInput = DeReplicateFromSchema< - typeof GestureRecognizerInputSchema ->; -export type GestureRecognizerTaskExecuteOutput = DeReplicateFromSchema< - typeof GestureRecognizerOutputSchema ->; /** * Recognizes hand gestures in images using MediaPipe Gesture Recognizer. @@ -225,7 +214,7 @@ export const gestureRecognizer = ( input: GestureRecognizerTaskInput, config?: JobQueueTaskConfig ) => { - return new GestureRecognizerTask(input, config).run(); + return new GestureRecognizerTask({} as GestureRecognizerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/HandLandmarkerTask.ts b/packages/ai/src/task/HandLandmarkerTask.ts index 1d0beec8..739e92a1 100644 --- a/packages/ai/src/task/HandLandmarkerTask.ts +++ b/packages/ai/src/task/HandLandmarkerTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:HandLandmarkerTask")); +const modelSchema = TypeModel("model:HandLandmarkerTask"); /** * A landmark point with x, y, z coordinates. @@ -95,7 +90,7 @@ const TypeHandDetection = { export const HandLandmarkerInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, numHands: { type: "number", @@ -156,12 +151,6 @@ export const HandLandmarkerOutputSchema = { export type HandLandmarkerTaskInput = FromSchema; export type HandLandmarkerTaskOutput = FromSchema; -export type HandLandmarkerTaskExecuteInput = DeReplicateFromSchema< - typeof HandLandmarkerInputSchema ->; -export type HandLandmarkerTaskExecuteOutput = DeReplicateFromSchema< - typeof HandLandmarkerOutputSchema ->; /** * Detects hand landmarks in images using MediaPipe Hand Landmarker. @@ -194,7 +183,7 @@ TaskRegistry.registerTask(HandLandmarkerTask); * @returns Promise resolving to the detected hand landmarks and handedness */ export const handLandmarker = (input: HandLandmarkerTaskInput, config?: JobQueueTaskConfig) => { - return new HandLandmarkerTask(input, config).run(); + return new HandLandmarkerTask({} as HandLandmarkerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ImageClassificationTask.ts b/packages/ai/src/task/ImageClassificationTask.ts index 9dd75954..858c40f3 100644 --- a/packages/ai/src/task/ImageClassificationTask.ts +++ b/packages/ai/src/task/ImageClassificationTask.ts @@ -6,21 +6,15 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeCategory, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeCategory, TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ImageClassificationTask")); +const modelSchema = TypeModel("model:ImageClassificationTask"); export const ImageClassificationInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, categories: { type: "array", @@ -64,12 +58,6 @@ export const ImageClassificationOutputSchema = { export type ImageClassificationTaskInput = FromSchema; export type ImageClassificationTaskOutput = FromSchema; -export type ImageClassificationTaskExecuteInput = DeReplicateFromSchema< - typeof ImageClassificationInputSchema ->; -export type ImageClassificationTaskExecuteOutput = DeReplicateFromSchema< - typeof ImageClassificationOutputSchema ->; /** * Classifies images into categories using vision models. @@ -105,7 +93,7 @@ export const imageClassification = ( input: ImageClassificationTaskInput, config?: JobQueueTaskConfig ) => { - return new ImageClassificationTask(input, config).run(); + return new ImageClassificationTask({} as ImageClassificationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ImageEmbeddingTask.ts b/packages/ai/src/task/ImageEmbeddingTask.ts index 80194f52..94e0219c 100644 --- a/packages/ai/src/task/ImageEmbeddingTask.ts +++ b/packages/ai/src/task/ImageEmbeddingTask.ts @@ -5,17 +5,16 @@ */ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; import { - DeReplicateFromSchema, + DataPortSchema, + FromSchema, TypedArraySchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; + TypedArraySchemaOptions, +} from "@workglow/util"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ImageEmbeddingTask")); +const modelSchema = TypeModel("model:ImageEmbeddingTask"); const embeddingSchema = TypedArraySchema({ title: "Embedding", @@ -25,7 +24,7 @@ const embeddingSchema = TypedArraySchema({ export const ImageEmbeddingInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, }, required: ["image", "model"], @@ -35,23 +34,19 @@ export const ImageEmbeddingInputSchema = { export const ImageEmbeddingOutputSchema = { type: "object", properties: { - vector: { - oneOf: [embeddingSchema, { type: "array", items: embeddingSchema }], - title: "Embedding", - description: "The image embedding vector", - }, + vector: embeddingSchema, }, required: ["vector"], additionalProperties: false, } as const satisfies DataPortSchema; -export type ImageEmbeddingTaskInput = FromSchema; -export type ImageEmbeddingTaskOutput = FromSchema; -export type ImageEmbeddingTaskExecuteInput = DeReplicateFromSchema< - typeof ImageEmbeddingInputSchema +export type ImageEmbeddingTaskInput = FromSchema< + typeof ImageEmbeddingInputSchema, + TypedArraySchemaOptions >; -export type ImageEmbeddingTaskExecuteOutput = DeReplicateFromSchema< - typeof ImageEmbeddingOutputSchema +export type ImageEmbeddingTaskOutput = FromSchema< + typeof ImageEmbeddingOutputSchema, + TypedArraySchemaOptions >; /** @@ -83,7 +78,7 @@ TaskRegistry.registerTask(ImageEmbeddingTask); * @returns Promise resolving to the image embedding vector */ export const imageEmbedding = (input: ImageEmbeddingTaskInput, config?: JobQueueTaskConfig) => { - return new ImageEmbeddingTask(input, config).run(); + return new ImageEmbeddingTask({} as ImageEmbeddingTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ImageSegmentationTask.ts b/packages/ai/src/task/ImageSegmentationTask.ts index 17204dbd..4c4e96ed 100644 --- a/packages/ai/src/task/ImageSegmentationTask.ts +++ b/packages/ai/src/task/ImageSegmentationTask.ts @@ -6,20 +6,15 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ImageSegmentationTask")); +const modelSchema = TypeModel("model:ImageSegmentationTask"); export const ImageSegmentationInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, threshold: { type: "number", @@ -88,12 +83,6 @@ export const ImageSegmentationOutputSchema = { export type ImageSegmentationTaskInput = FromSchema; export type ImageSegmentationTaskOutput = FromSchema; -export type ImageSegmentationTaskExecuteInput = DeReplicateFromSchema< - typeof ImageSegmentationInputSchema ->; -export type ImageSegmentationTaskExecuteOutput = DeReplicateFromSchema< - typeof ImageSegmentationOutputSchema ->; /** * Segments images into regions using computer vision models @@ -128,7 +117,7 @@ export const imageSegmentation = ( input: ImageSegmentationTaskInput, config?: JobQueueTaskConfig ) => { - return new ImageSegmentationTask(input, config).run(); + return new ImageSegmentationTask({} as ImageSegmentationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ImageToTextTask.ts b/packages/ai/src/task/ImageToTextTask.ts index c7fff8bd..2cf6b919 100644 --- a/packages/ai/src/task/ImageToTextTask.ts +++ b/packages/ai/src/task/ImageToTextTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ImageToTextTask")); +const modelSchema = TypeModel("model:ImageToTextTask"); const generatedTextSchema = { type: "string", @@ -25,7 +20,7 @@ const generatedTextSchema = { export const ImageToTextInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, maxTokens: { type: "number", @@ -55,8 +50,6 @@ export const ImageToTextOutputSchema = { export type ImageToTextTaskInput = FromSchema; export type ImageToTextTaskOutput = FromSchema; -export type ImageToTextTaskExecuteInput = DeReplicateFromSchema; -export type ImageToTextTaskExecuteOutput = DeReplicateFromSchema; /** * Generates text descriptions from images using vision-language models @@ -88,7 +81,7 @@ TaskRegistry.registerTask(ImageToTextTask); * @returns Promise resolving to the generated text description */ export const imageToText = (input: ImageToTextTaskInput, config?: JobQueueTaskConfig) => { - return new ImageToTextTask(input, config).run(); + return new ImageToTextTask({} as ImageToTextTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ObjectDetectionTask.ts b/packages/ai/src/task/ObjectDetectionTask.ts index 78244e31..6132d1cd 100644 --- a/packages/ai/src/task/ObjectDetectionTask.ts +++ b/packages/ai/src/task/ObjectDetectionTask.ts @@ -6,16 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeBoundingBox, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeBoundingBox, TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ObjectDetectionTask")); +const modelSchema = TypeModel("model:ObjectDetectionTask"); const detectionSchema = { type: "object", @@ -41,7 +35,7 @@ const detectionSchema = { export const ObjectDetectionInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, labels: { type: "array", @@ -85,12 +79,6 @@ export const ObjectDetectionOutputSchema = { export type ObjectDetectionTaskInput = FromSchema; export type ObjectDetectionTaskOutput = FromSchema; -export type ObjectDetectionTaskExecuteInput = DeReplicateFromSchema< - typeof ObjectDetectionInputSchema ->; -export type ObjectDetectionTaskExecuteOutput = DeReplicateFromSchema< - typeof ObjectDetectionOutputSchema ->; /** * Detects objects in images using vision models. @@ -123,7 +111,7 @@ TaskRegistry.registerTask(ObjectDetectionTask); * @returns Promise resolving to the detected objects with labels, scores, and bounding boxes */ export const objectDetection = (input: ObjectDetectionTaskInput, config?: JobQueueTaskConfig) => { - return new ObjectDetectionTask(input, config).run(); + return new ObjectDetectionTask({} as ObjectDetectionTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/PoseLandmarkerTask.ts b/packages/ai/src/task/PoseLandmarkerTask.ts index 1c596e47..8f0a45f3 100644 --- a/packages/ai/src/task/PoseLandmarkerTask.ts +++ b/packages/ai/src/task/PoseLandmarkerTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:PoseLandmarkerTask")); +const modelSchema = TypeModel("model:PoseLandmarkerTask"); /** * A landmark point with x, y, z coordinates and visibility/presence scores. @@ -105,7 +100,7 @@ const TypePoseDetection = { export const PoseLandmarkerInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, numPoses: { type: "number", @@ -173,12 +168,6 @@ export const PoseLandmarkerOutputSchema = { export type PoseLandmarkerTaskInput = FromSchema; export type PoseLandmarkerTaskOutput = FromSchema; -export type PoseLandmarkerTaskExecuteInput = DeReplicateFromSchema< - typeof PoseLandmarkerInputSchema ->; -export type PoseLandmarkerTaskExecuteOutput = DeReplicateFromSchema< - typeof PoseLandmarkerOutputSchema ->; /** * Detects pose landmarks in images using MediaPipe Pose Landmarker. @@ -211,7 +200,7 @@ TaskRegistry.registerTask(PoseLandmarkerTask); * @returns Promise resolving to the detected pose landmarks and optional segmentation masks */ export const poseLandmarker = (input: PoseLandmarkerTaskInput, config?: JobQueueTaskConfig) => { - return new PoseLandmarkerTask(input, config).run(); + return new PoseLandmarkerTask({} as PoseLandmarkerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextClassificationTask.ts b/packages/ai/src/task/TextClassificationTask.ts index 170d8686..0b7c6215 100644 --- a/packages/ai/src/task/TextClassificationTask.ts +++ b/packages/ai/src/task/TextClassificationTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextClassificationTask")); +const modelSchema = TypeModel("model:TextClassificationTask"); export const TextClassificationInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to classify", - }), + }, candidateLabels: { type: "array", items: { @@ -75,12 +75,6 @@ export const TextClassificationOutputSchema = { export type TextClassificationTaskInput = FromSchema; export type TextClassificationTaskOutput = FromSchema; -export type TextClassificationTaskExecuteInput = DeReplicateFromSchema< - typeof TextClassificationInputSchema ->; -export type TextClassificationTaskExecuteOutput = DeReplicateFromSchema< - typeof TextClassificationOutputSchema ->; /** * Classifies text into categories using language models. @@ -115,7 +109,7 @@ export const textClassification = ( input: TextClassificationTaskInput, config?: JobQueueTaskConfig ) => { - return new TextClassificationTask(input, config).run(); + return new TextClassificationTask({} as TextClassificationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextEmbeddingTask.ts b/packages/ai/src/task/TextEmbeddingTask.ts index 9f50e9a7..e85f3ded 100644 --- a/packages/ai/src/task/TextEmbeddingTask.ts +++ b/packages/ai/src/task/TextEmbeddingTask.ts @@ -5,22 +5,21 @@ */ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; -import { AiTask } from "./base/AiTask"; import { - DeReplicateFromSchema, + DataPortSchema, + FromSchema, TypedArraySchema, TypedArraySchemaOptions, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +} from "@workglow/util"; +import { AiTask } from "./base/AiTask"; +import { TypeModel, TypeSingleOrArray } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextEmbeddingTask")); +const modelSchema = TypeModel("model:TextEmbeddingTask"); export const TextEmbeddingInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: TypeSingleOrArray({ type: "string", title: "Text", description: "The text to embed", @@ -34,7 +33,7 @@ export const TextEmbeddingInputSchema = { export const TextEmbeddingOutputSchema = { type: "object", properties: { - vector: TypeReplicateArray( + vector: TypeSingleOrArray( TypedArraySchema({ title: "Vector", description: "The vector embedding of the text", @@ -53,10 +52,6 @@ export type TextEmbeddingTaskOutput = FromSchema< typeof TextEmbeddingOutputSchema, TypedArraySchemaOptions >; -export type TextEmbeddingTaskExecuteInput = DeReplicateFromSchema; -export type TextEmbeddingTaskExecuteOutput = DeReplicateFromSchema< - typeof TextEmbeddingOutputSchema ->; /** * A task that generates vector embeddings for text using a specified embedding model. @@ -86,7 +81,7 @@ TaskRegistry.registerTask(TextEmbeddingTask); * @returns Promise resolving to the generated embeddings */ export const textEmbedding = async (input: TextEmbeddingTaskInput, config?: JobQueueTaskConfig) => { - return new TextEmbeddingTask(input, config).run(); + return new TextEmbeddingTask({} as TextEmbeddingTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextFillMaskTask.ts b/packages/ai/src/task/TextFillMaskTask.ts index dbec9052..a308c0c9 100644 --- a/packages/ai/src/task/TextFillMaskTask.ts +++ b/packages/ai/src/task/TextFillMaskTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextFillMaskTask")); +const modelSchema = TypeModel("model:TextFillMaskTask"); export const TextFillMaskInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text with a mask token to fill", - }), + }, model: modelSchema, }, required: ["text", "model"], @@ -62,8 +62,6 @@ export const TextFillMaskOutputSchema = { export type TextFillMaskTaskInput = FromSchema; export type TextFillMaskTaskOutput = FromSchema; -export type TextFillMaskTaskExecuteInput = DeReplicateFromSchema; -export type TextFillMaskTaskExecuteOutput = DeReplicateFromSchema; /** * Fills masked tokens in text using language models @@ -90,7 +88,7 @@ TaskRegistry.registerTask(TextFillMaskTask); * @returns Promise resolving to the predicted tokens with scores and complete sequences */ export const textFillMask = (input: TextFillMaskTaskInput, config?: JobQueueTaskConfig) => { - return new TextFillMaskTask(input, config).run(); + return new TextFillMaskTask({} as TextFillMaskTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextGenerationTask.ts b/packages/ai/src/task/TextGenerationTask.ts index 03f1eb00..0bea2a0c 100644 --- a/packages/ai/src/task/TextGenerationTask.ts +++ b/packages/ai/src/task/TextGenerationTask.ts @@ -7,7 +7,7 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; const generatedTextSchema = { type: "string", @@ -15,17 +15,17 @@ const generatedTextSchema = { description: "The generated text", } as const; -const modelSchema = TypeReplicateArray(TypeModel("model:TextGenerationTask")); +const modelSchema = TypeModel("model:TextGenerationTask"); export const TextGenerationInputSchema = { type: "object", properties: { model: modelSchema, - prompt: TypeReplicateArray({ + prompt: { type: "string", title: "Prompt", description: "The prompt to generate text from", - }), + }, maxTokens: { type: "number", title: "Max Tokens", @@ -74,11 +74,7 @@ export const TextGenerationInputSchema = { export const TextGenerationOutputSchema = { type: "object", properties: { - text: { - oneOf: [generatedTextSchema, { type: "array", items: generatedTextSchema }], - title: generatedTextSchema.title, - description: generatedTextSchema.description, - }, + text: generatedTextSchema, }, required: ["text"], additionalProperties: false, @@ -86,12 +82,6 @@ export const TextGenerationOutputSchema = { export type TextGenerationTaskInput = FromSchema; export type TextGenerationTaskOutput = FromSchema; -export type TextGenerationTaskExecuteInput = DeReplicateFromSchema< - typeof TextGenerationInputSchema ->; -export type TextGenerationTaskExecuteOutput = DeReplicateFromSchema< - typeof TextGenerationOutputSchema ->; export class TextGenerationTask extends AiTask< TextGenerationTaskInput, @@ -116,7 +106,7 @@ TaskRegistry.registerTask(TextGenerationTask); * Task for generating text using a language model */ export const textGeneration = (input: TextGenerationTaskInput, config?: JobQueueTaskConfig) => { - return new TextGenerationTask(input, config).run(); + return new TextGenerationTask({} as TextGenerationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextLanguageDetectionTask.ts b/packages/ai/src/task/TextLanguageDetectionTask.ts index 59cdb8b9..c12c6c34 100644 --- a/packages/ai/src/task/TextLanguageDetectionTask.ts +++ b/packages/ai/src/task/TextLanguageDetectionTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextLanguageDetectionTask")); +const modelSchema = TypeModel("model:TextLanguageDetectionTask"); export const TextLanguageDetectionInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to detect the language of", - }), + }, maxLanguages: { type: "number", minimum: 0, @@ -100,12 +100,6 @@ export const TextLanguageDetectionOutputSchema = { export type TextLanguageDetectionTaskInput = FromSchema; export type TextLanguageDetectionTaskOutput = FromSchema; -export type TextLanguageDetectionTaskExecuteInput = DeReplicateFromSchema< - typeof TextLanguageDetectionInputSchema ->; -export type TextLanguageDetectionTaskExecuteOutput = DeReplicateFromSchema< - typeof TextLanguageDetectionOutputSchema ->; /** * Detects the language of text using language models @@ -138,7 +132,7 @@ export const textLanguageDetection = ( input: TextLanguageDetectionTaskInput, config?: JobQueueTaskConfig ) => { - return new TextLanguageDetectionTask(input, config).run(); + return new TextLanguageDetectionTask({} as TextLanguageDetectionTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextNamedEntityRecognitionTask.ts b/packages/ai/src/task/TextNamedEntityRecognitionTask.ts index 1a91ba95..4550712c 100644 --- a/packages/ai/src/task/TextNamedEntityRecognitionTask.ts +++ b/packages/ai/src/task/TextNamedEntityRecognitionTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextNamedEntityRecognitionTask")); +const modelSchema = TypeModel("model:TextNamedEntityRecognitionTask"); export const TextNamedEntityRecognitionInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to extract named entities from", - }), + }, blockList: { type: "array", items: { @@ -76,12 +76,6 @@ export type TextNamedEntityRecognitionTaskInput = FromSchema< export type TextNamedEntityRecognitionTaskOutput = FromSchema< typeof TextNamedEntityRecognitionOutputSchema >; -export type TextNamedEntityRecognitionTaskExecuteInput = DeReplicateFromSchema< - typeof TextNamedEntityRecognitionInputSchema ->; -export type TextNamedEntityRecognitionTaskExecuteOutput = DeReplicateFromSchema< - typeof TextNamedEntityRecognitionOutputSchema ->; /** * Extracts named entities from text using language models @@ -114,7 +108,9 @@ export const textNamedEntityRecognition = ( input: TextNamedEntityRecognitionTaskInput, config?: JobQueueTaskConfig ) => { - return new TextNamedEntityRecognitionTask(input, config).run(); + return new TextNamedEntityRecognitionTask({} as TextNamedEntityRecognitionTaskInput, config).run( + input + ); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextQuestionAnswerTask.ts b/packages/ai/src/task/TextQuestionAnswerTask.ts index c2100ee2..f928eb98 100644 --- a/packages/ai/src/task/TextQuestionAnswerTask.ts +++ b/packages/ai/src/task/TextQuestionAnswerTask.ts @@ -7,7 +7,7 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; const contextSchema = { type: "string", @@ -27,13 +27,13 @@ const textSchema = { description: "The generated text", } as const; -const modelSchema = TypeReplicateArray(TypeModel("model:TextQuestionAnswerTask")); +const modelSchema = TypeModel("model:TextQuestionAnswerTask"); export const TextQuestionAnswerInputSchema = { type: "object", properties: { - context: TypeReplicateArray(contextSchema), - question: TypeReplicateArray(questionSchema), + context: contextSchema, + question: questionSchema, model: modelSchema, }, required: ["context", "question", "model"], @@ -43,11 +43,7 @@ export const TextQuestionAnswerInputSchema = { export const TextQuestionAnswerOutputSchema = { type: "object", properties: { - text: { - oneOf: [textSchema, { type: "array", items: textSchema }], - title: textSchema.title, - description: textSchema.description, - }, + text: textSchema, }, required: ["text"], additionalProperties: false, @@ -55,12 +51,6 @@ export const TextQuestionAnswerOutputSchema = { export type TextQuestionAnswerTaskInput = FromSchema; export type TextQuestionAnswerTaskOutput = FromSchema; -export type TextQuestionAnswerTaskExecuteInput = DeReplicateFromSchema< - typeof TextQuestionAnswerInputSchema ->; -export type TextQuestionAnswerTaskExecuteOutput = DeReplicateFromSchema< - typeof TextQuestionAnswerOutputSchema ->; /** * This is a special case of text generation that takes a context and a question @@ -94,7 +84,7 @@ export const textQuestionAnswer = ( input: TextQuestionAnswerTaskInput, config?: JobQueueTaskConfig ) => { - return new TextQuestionAnswerTask(input, config).run(); + return new TextQuestionAnswerTask({} as TextQuestionAnswerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextRewriterTask.ts b/packages/ai/src/task/TextRewriterTask.ts index 7d031d91..60d92bf4 100644 --- a/packages/ai/src/task/TextRewriterTask.ts +++ b/packages/ai/src/task/TextRewriterTask.ts @@ -4,26 +4,33 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { + CreateWorkflow, + DeReplicateFromSchema, + JobQueueTaskConfig, + TaskRegistry, + TypeReplicateArray, + Workflow, +} from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextRewriterTask")); +const modelSchema = TypeModel("model:TextRewriterTask"); export const TextRewriterInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to rewrite", - }), - prompt: TypeReplicateArray({ + }, + prompt: { type: "string", title: "Prompt", description: "The prompt to direct the rewriting", - }), + }, model: modelSchema, }, required: ["text", "prompt", "model"], @@ -45,8 +52,6 @@ export const TextRewriterOutputSchema = { export type TextRewriterTaskInput = FromSchema; export type TextRewriterTaskOutput = FromSchema; -export type TextRewriterTaskExecuteInput = DeReplicateFromSchema; -export type TextRewriterTaskExecuteOutput = DeReplicateFromSchema; /** * This is a special case of text generation that takes a prompt and text to rewrite @@ -73,7 +78,7 @@ TaskRegistry.registerTask(TextRewriterTask); * @returns Promise resolving to the rewritten text output(s) */ export const textRewriter = (input: TextRewriterTaskInput, config?: JobQueueTaskConfig) => { - return new TextRewriterTask(input, config).run(); + return new TextRewriterTask({} as TextRewriterTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextSummaryTask.ts b/packages/ai/src/task/TextSummaryTask.ts index 675643a1..471a925b 100644 --- a/packages/ai/src/task/TextSummaryTask.ts +++ b/packages/ai/src/task/TextSummaryTask.ts @@ -4,21 +4,28 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { + CreateWorkflow, + DeReplicateFromSchema, + JobQueueTaskConfig, + TaskRegistry, + TypeReplicateArray, + Workflow, +} from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextSummaryTask")); +const modelSchema = TypeModel("model:TextSummaryTask"); export const TextSummaryInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to summarize", - }), + }, model: modelSchema, }, required: ["text", "model"], @@ -40,8 +47,6 @@ export const TextSummaryOutputSchema = { export type TextSummaryTaskInput = FromSchema; export type TextSummaryTaskOutput = FromSchema; -export type TextSummaryTaskExecuteInput = DeReplicateFromSchema; -export type TextSummaryTaskExecuteOutput = DeReplicateFromSchema; /** * This summarizes a piece of text @@ -70,7 +75,7 @@ TaskRegistry.registerTask(TextSummaryTask); * @returns Promise resolving to the summarized text output(s) */ export const textSummary = async (input: TextSummaryTaskInput, config?: JobQueueTaskConfig) => { - return new TextSummaryTask(input, config).run(); + return new TextSummaryTask({} as TextSummaryTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextTranslationTask.ts b/packages/ai/src/task/TextTranslationTask.ts index 988c682e..208396dc 100644 --- a/packages/ai/src/task/TextTranslationTask.ts +++ b/packages/ai/src/task/TextTranslationTask.ts @@ -4,17 +4,19 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; -import { AiTask } from "./base/AiTask"; import { + CreateWorkflow, DeReplicateFromSchema, - TypeLanguage, - TypeModel, + JobQueueTaskConfig, + TaskRegistry, TypeReplicateArray, -} from "./base/AiTaskSchemas"; + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; +import { AiTask } from "./base/AiTask"; +import { TypeLanguage, TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextTranslationTask")); +const modelSchema = TypeModel("model:TextTranslationTask"); const translationTextSchema = { type: "string", @@ -25,27 +27,23 @@ const translationTextSchema = { export const TextTranslationInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to translate", + }, + source_lang: TypeLanguage({ + title: "Source Language", + description: "The source language", + minLength: 2, + maxLength: 2, + }), + target_lang: TypeLanguage({ + title: "Target Language", + description: "The target language", + minLength: 2, + maxLength: 2, }), - source_lang: TypeReplicateArray( - TypeLanguage({ - title: "Source Language", - description: "The source language", - minLength: 2, - maxLength: 2, - }) - ), - target_lang: TypeReplicateArray( - TypeLanguage({ - title: "Target Language", - description: "The target language", - minLength: 2, - maxLength: 2, - }) - ), model: modelSchema, }, required: ["text", "source_lang", "target_lang", "model"], @@ -55,11 +53,7 @@ export const TextTranslationInputSchema = { export const TextTranslationOutputSchema = { type: "object", properties: { - text: { - oneOf: [translationTextSchema, { type: "array", items: translationTextSchema }], - title: translationTextSchema.title, - description: translationTextSchema.description, - }, + text: translationTextSchema, target_lang: TypeLanguage({ title: "Output Language", description: "The output language", @@ -73,12 +67,6 @@ export const TextTranslationOutputSchema = { export type TextTranslationTaskInput = FromSchema; export type TextTranslationTaskOutput = FromSchema; -export type TextTranslationTaskExecuteInput = DeReplicateFromSchema< - typeof TextTranslationInputSchema ->; -export type TextTranslationTaskExecuteOutput = DeReplicateFromSchema< - typeof TextTranslationOutputSchema ->; /** * This translates text from one language to another @@ -108,7 +96,7 @@ TaskRegistry.registerTask(TextTranslationTask); * @returns Promise resolving to the translated text output(s) */ export const textTranslation = (input: TextTranslationTaskInput, config?: JobQueueTaskConfig) => { - return new TextTranslationTask(input, config).run(); + return new TextTranslationTask({} as TextTranslationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/UnloadModelTask.ts b/packages/ai/src/task/UnloadModelTask.ts index 8a027d7b..21b8dbc6 100644 --- a/packages/ai/src/task/UnloadModelTask.ts +++ b/packages/ai/src/task/UnloadModelTask.ts @@ -4,12 +4,19 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { + CreateWorkflow, + DeReplicateFromSchema, + JobQueueTaskConfig, + TaskRegistry, + TypeReplicateArray, + Workflow, +} from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model")); +const modelSchema = TypeModel("model"); const UnloadModelInputSchema = { type: "object", @@ -31,8 +38,6 @@ const UnloadModelOutputSchema = { export type UnloadModelTaskRunInput = FromSchema; export type UnloadModelTaskRunOutput = FromSchema; -export type UnloadModelTaskExecuteInput = DeReplicateFromSchema; -export type UnloadModelTaskExecuteOutput = DeReplicateFromSchema; /** * Unload a model from memory and clear its cache. @@ -67,7 +72,7 @@ TaskRegistry.registerTask(UnloadModelTask); * @returns Promise resolving to the unloaded model(s) */ export const unloadModel = (input: UnloadModelTaskRunInput, config?: JobQueueTaskConfig) => { - return new UnloadModelTask(input, config).run(); + return new UnloadModelTask({} as UnloadModelTaskRunInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/task-graph/README.md b/packages/task-graph/README.md index 203c9ce5..73824243 100644 --- a/packages/task-graph/README.md +++ b/packages/task-graph/README.md @@ -120,7 +120,7 @@ console.log(result); // { result: 60 } // 2.3 Create a helper function export const MultiplyBy2 = (input: { value: number }) => { - return new MultiplyBy2Task(input).run(); + return new MultiplyBy2Task().run(input); }; const first = await MultiplyBy2({ value: 15 }); const second = await MultiplyBy2({ value: first.result }); diff --git a/packages/task-graph/src/task-graph/Dataflow.ts b/packages/task-graph/src/task-graph/Dataflow.ts index bac428a6..c8e58258 100644 --- a/packages/task-graph/src/task-graph/Dataflow.ts +++ b/packages/task-graph/src/task-graph/Dataflow.ts @@ -7,7 +7,7 @@ import { areSemanticallyCompatible, EventEmitter } from "@workglow/util"; import { TaskError } from "../task/TaskError"; import { DataflowJson } from "../task/TaskJSON"; -import { Provenance, TaskIdType, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import { TaskIdType, TaskOutput, TaskStatus } from "../task/TaskTypes"; import { DataflowEventListener, DataflowEventListeners, @@ -48,7 +48,6 @@ export class Dataflow { ); } public value: any = undefined; - public provenance: Provenance = {}; public status: TaskStatus = TaskStatus.PENDING; public error: TaskError | undefined; @@ -56,7 +55,6 @@ export class Dataflow { this.status = TaskStatus.PENDING; this.error = undefined; this.value = undefined; - this.provenance = {}; this.emit("reset"); this.emit("status", this.status); } @@ -87,7 +85,7 @@ export class Dataflow { this.emit("status", this.status); } - setPortData(entireDataBlock: any, nodeProvenance: any) { + setPortData(entireDataBlock: any) { if (this.sourceTaskPortId === DATAFLOW_ALL_PORTS) { this.value = entireDataBlock; } else if (this.sourceTaskPortId === DATAFLOW_ERROR_PORT) { @@ -95,7 +93,6 @@ export class Dataflow { } else { this.value = entireDataBlock[this.sourceTaskPortId]; } - if (nodeProvenance) this.provenance = nodeProvenance; } getPortData(): TaskOutput { diff --git a/packages/task-graph/src/task-graph/ITaskGraph.ts b/packages/task-graph/src/task-graph/ITaskGraph.ts index e2f6dc84..ba115d2a 100644 --- a/packages/task-graph/src/task-graph/ITaskGraph.ts +++ b/packages/task-graph/src/task-graph/ITaskGraph.ts @@ -6,7 +6,7 @@ import { ITask } from "../task/ITask"; import { JsonTaskItem, TaskGraphJson } from "../task/TaskJSON"; -import { TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import type { TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; import { Dataflow, DataflowIdType } from "./Dataflow"; import type { TaskGraphRunConfig } from "./TaskGraph"; import type { TaskGraphEventListener, TaskGraphEvents } from "./TaskGraphEvents"; diff --git a/packages/task-graph/src/task-graph/README.md b/packages/task-graph/src/task-graph/README.md index e9673f4e..abdf6629 100644 --- a/packages/task-graph/src/task-graph/README.md +++ b/packages/task-graph/src/task-graph/README.md @@ -25,7 +25,6 @@ A robust TypeScript library for creating and managing task graphs with dependenc - Directed Acyclic Graph (DAG) structure for task dependencies - Data flow management between task inputs/outputs - Workflow builder API with fluent interface -- Provenance tracking - Caching of task results (same run on same input returns cached result) - Error handling and abortion support - Serial and parallel execution patterns @@ -87,7 +86,6 @@ const output = await workflow.run(); - Connects task outputs to inputs - Value propagation -- Provenance tracking ### TaskGraphRunner diff --git a/packages/task-graph/src/task-graph/TaskGraph.ts b/packages/task-graph/src/task-graph/TaskGraph.ts index 0f7e5e83..88e71ab6 100644 --- a/packages/task-graph/src/task-graph/TaskGraph.ts +++ b/packages/task-graph/src/task-graph/TaskGraph.ts @@ -8,7 +8,7 @@ import { DirectedAcyclicGraph, EventEmitter, uuid4 } from "@workglow/util"; import { TaskOutputRepository } from "../storage/TaskOutputRepository"; import type { ITask } from "../task/ITask"; import { JsonTaskItem, TaskGraphJson } from "../task/TaskJSON"; -import type { Provenance, TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import type { TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; import { ensureTask, type PipeFunction } from "./Conversions"; import { Dataflow, type DataflowIdType } from "./Dataflow"; import type { ITaskGraph } from "./ITaskGraph"; @@ -37,8 +37,6 @@ export interface TaskGraphRunConfig { outputCache?: TaskOutputRepository | boolean; /** Optional signal to abort the task graph */ parentSignal?: AbortSignal; - /** Optional provenance to use for this task graph */ - parentProvenance?: Provenance; } class TaskGraphDAG extends DirectedAcyclicGraph< @@ -102,7 +100,6 @@ export class TaskGraph implements ITaskGraph { ): Promise> { return this.runner.runGraph(input, { outputCache: config?.outputCache || this.outputCache, - parentProvenance: config?.parentProvenance || {}, parentSignal: config?.parentSignal || undefined, }); } diff --git a/packages/task-graph/src/task-graph/TaskGraphRunner.ts b/packages/task-graph/src/task-graph/TaskGraphRunner.ts index fc2af972..94bd630d 100644 --- a/packages/task-graph/src/task-graph/TaskGraphRunner.ts +++ b/packages/task-graph/src/task-graph/TaskGraphRunner.ts @@ -50,7 +50,7 @@ export type GraphResult< /** * Class for running a task graph - * Manages the execution of tasks in a task graph, including provenance tracking and caching + * Manages the execution of tasks in a task graph, including caching */ export class TaskGraphRunner { /** @@ -59,11 +59,6 @@ export class TaskGraphRunner { protected running = false; protected reactiveRunning = false; - /** - * Map of provenance input for each task - */ - protected provenanceInput: Map; - /** * The task graph to run */ @@ -99,7 +94,6 @@ export class TaskGraphRunner { protected reactiveScheduler = new TopologicalScheduler(graph) ) { this.graph = graph; - this.provenanceInput = new Map(); graph.outputCache = outputCache; this.handleProgress = this.handleProgress.bind(this); } @@ -136,10 +130,9 @@ export class TaskGraphRunner { // Only filter input for non-root tasks; root tasks get the full input const taskInput = isRootTask ? input : this.filterInputForTask(task, input); - const taskPromise = this.runTaskWithProvenance( + const taskPromise = this.runTask( task, - taskInput, - config?.parentProvenance || {} + taskInput ); this.inProgressTasks!.set(task.config.id, taskPromise); const taskResult = await taskPromise; @@ -332,40 +325,22 @@ export class TaskGraphRunner { } } - /** - * Retrieves the provenance input for a task - * @param node The task to retrieve provenance input for - * @returns The provenance input for the task - */ - protected getInputProvenance(node: ITask): TaskInput { - const nodeProvenance: Provenance = {}; - this.graph.getSourceDataflows(node.config.id).forEach((dataflow) => { - Object.assign(nodeProvenance, dataflow.provenance); - }); - return nodeProvenance; - } - /** * Pushes the output of a task to its target tasks * @param node The task that produced the output * @param results The output of the task - * @param nodeProvenance The provenance input for the task */ - protected async pushOutputFromNodeToEdges( - node: ITask, - results: TaskOutput, - nodeProvenance?: Provenance - ) { + protected async pushOutputFromNodeToEdges(node: ITask, results: TaskOutput) { const dataflows = this.graph.getTargetDataflows(node.config.id); for (const dataflow of dataflows) { const compatibility = dataflow.semanticallyCompatible(this.graph, dataflow); // console.log("pushOutputFromNodeToEdges", dataflow.id, compatibility, Object.keys(results)); if (compatibility === "static") { - dataflow.setPortData(results, nodeProvenance); + dataflow.setPortData(results); } else if (compatibility === "runtime") { const task = this.graph.getTask(dataflow.targetTaskId)!; - const narrowed = await task.narrowInput({ ...results }); - dataflow.setPortData(narrowed, nodeProvenance); + const narrowed = await task.narrowInput({ ...results }, this.registry); + dataflow.setPortData(narrowed); } else { // don't push incompatible data } @@ -494,33 +469,21 @@ export class TaskGraphRunner { } /** - * Runs a task with provenance input + * Runs a task * @param task The task to run - * @param parentProvenance The provenance input for the task + * @param input The input for the task * @returns The output of the task */ - protected async runTaskWithProvenance( - task: ITask, - input: TaskInput, - parentProvenance: Provenance - ): Promise> { - // Update provenance for the current task - const nodeProvenance = { - ...parentProvenance, - ...this.getInputProvenance(task), - ...task.getProvenance(), - }; - this.provenanceInput.set(task.config.id, nodeProvenance); + protected async runTask(task: ITask, input: TaskInput): Promise> { this.copyInputFromEdgesToNode(task); const results = await task.runner.run(input, { - nodeProvenance, outputCache: this.outputCache, updateProgress: async (task: ITask, progress: number, message?: string, ...args: any[]) => await this.handleProgress(task, progress, message, ...args), }); - await this.pushOutputFromNodeToEdges(task, results, nodeProvenance); + await this.pushOutputFromNodeToEdges(task, results); return { id: task.config.id, @@ -706,7 +669,7 @@ export class TaskGraphRunner { progress = Math.round(completed / total); } this.pushStatusFromNodeToEdges(this.graph, task); - await this.pushOutputFromNodeToEdges(task, task.runOutputData, task.getProvenance()); + await this.pushOutputFromNodeToEdges(task, task.runOutputData); this.graph.emit("graph_progress", progress, message, args); } } diff --git a/packages/task-graph/src/task-graph/Workflow.ts b/packages/task-graph/src/task-graph/Workflow.ts index 440e6d52..dd216884 100644 --- a/packages/task-graph/src/task-graph/Workflow.ts +++ b/packages/task-graph/src/task-graph/Workflow.ts @@ -23,10 +23,10 @@ import { } from "./TaskGraphRunner"; // Type definitions for the workflow -export type CreateWorkflow = ( +export type CreateWorkflow = ( input?: Partial, config?: Partial -) => Workflow; +) => Workflow; // Event types export type WorkflowEventListeners = { @@ -57,9 +57,10 @@ let taskIdCounter = 0; * Class for building and managing a task graph * Provides methods for adding tasks, connecting outputs to inputs, and running the task graph */ -export class Workflow - implements IWorkflow -{ +export class Workflow< + Input extends DataPorts = DataPorts, + Output extends DataPorts = DataPorts, +> implements IWorkflow { /** * Creates a new Workflow * @@ -99,10 +100,10 @@ export class Workflow(taskClass: ITaskConstructor): CreateWorkflow { const helper = function ( - this: Workflow, + this: Workflow, input: Partial = {}, config: Partial = {} - ): Workflow { + ) { this._error = ""; const parent = getLastTask(this); @@ -150,7 +151,19 @@ export class Workflow boolean ): Map => { - // If either schema is true (accepts everything), skip auto-matching + if (typeof sourceSchema === "object") { + if ( + targetSchema === true || + (typeof targetSchema === "object" && targetSchema.additionalProperties === true) + ) { + for (const fromOutputPortId of Object.keys(sourceSchema.properties || {})) { + matches.set(fromOutputPortId, fromOutputPortId); + this.connect(parent.config.id, fromOutputPortId, task.config.id, fromOutputPortId); + } + return matches; + } + } + // If either schema is true or false, skip auto-matching // as we cannot determine the appropriate connections if (typeof sourceSchema === "boolean" || typeof targetSchema === "boolean") { return matches; @@ -177,51 +190,277 @@ export class Workflow { - // Skip if either schema is boolean - if ( - typeof fromPortOutputSchema === "boolean" || - typeof toPortInputSchema === "boolean" - ) { - if (fromPortOutputSchema === true && toPortInputSchema === true) { - return true; + /** + * Extracts specific type identifiers (format, $id) from a schema, + * looking inside oneOf/anyOf wrappers if needed. + */ + const getSpecificTypeIdentifiers = ( + schema: JsonSchema + ): { formats: Set; ids: Set } => { + const formats = new Set(); + const ids = new Set(); + + if (typeof schema === "boolean") { + return { formats, ids }; + } + + // Helper to extract from a single schema object + const extractFromSchema = (s: any): void => { + if (!s || typeof s !== "object" || Array.isArray(s)) return; + if (s.format) formats.add(s.format); + if (s.$id) ids.add(s.$id); + }; + + // Check top-level format/$id + extractFromSchema(schema); + + // Check inside oneOf/anyOf + const checkUnion = (schemas: JsonSchema[] | undefined): void => { + if (!schemas) return; + for (const s of schemas) { + if (typeof s === "boolean") continue; + extractFromSchema(s); + // Also check nested items for array types + if (s.items && typeof s.items === "object" && !Array.isArray(s.items)) { + extractFromSchema(s.items); } - return false; } - // $id matches - const idTypeMatch = - fromPortOutputSchema.$id !== undefined && - fromPortOutputSchema.$id === toPortInputSchema.$id; - // $id both blank - const idTypeBlank = - fromPortOutputSchema.$id === undefined && undefined === toPortInputSchema.$id; - const typeMatch = - idTypeBlank && - (fromPortOutputSchema.type === toPortInputSchema.type || - (toPortInputSchema.oneOf?.some((i: any) => i.type == fromPortOutputSchema.type) ?? - false)); + }; + + checkUnion(schema.oneOf as JsonSchema[] | undefined); + checkUnion(schema.anyOf as JsonSchema[] | undefined); + + // Check items for array types (single schema, not tuple) + if (schema.items && typeof schema.items === "object" && !Array.isArray(schema.items)) { + extractFromSchema(schema.items); + } + + return { formats, ids }; + }; + + /** + * Checks if output schema type is compatible with input schema type. + * Handles $id matching, format matching, and oneOf/anyOf unions. + */ + const isTypeCompatible = ( + fromPortOutputSchema: JsonSchema, + toPortInputSchema: JsonSchema, + requireSpecificType: boolean = false + ): boolean => { + if (typeof fromPortOutputSchema === "boolean" || typeof toPortInputSchema === "boolean") { + return fromPortOutputSchema === true && toPortInputSchema === true; + } + + // Extract specific type identifiers from both schemas + const outputIds = getSpecificTypeIdentifiers(fromPortOutputSchema); + const inputIds = getSpecificTypeIdentifiers(toPortInputSchema); + + // Check if any format matches + for (const format of outputIds.formats) { + if (inputIds.formats.has(format)) { + return true; + } + } + + // Check if any $id matches + for (const id of outputIds.ids) { + if (inputIds.ids.has(id)) { + return true; + } + } + + // For type-only fallback, we require specific types (not primitives) + // to avoid over-matching strings, numbers, etc. + if (requireSpecificType) { + return false; + } + + // $id both blank at top level - check type directly (only for name-matched ports) + const idTypeBlank = + fromPortOutputSchema.$id === undefined && toPortInputSchema.$id === undefined; + if (!idTypeBlank) return false; + + // Direct type match (for primitives, only when names also match) + if (fromPortOutputSchema.type === toPortInputSchema.type) return true; + + // Check if output type matches any option in oneOf/anyOf + const matchesOneOf = + toPortInputSchema.oneOf?.some((schema: any) => { + if (typeof schema === "boolean") return schema; + return schema.type === fromPortOutputSchema.type; + }) ?? false; + + const matchesAnyOf = + toPortInputSchema.anyOf?.some((schema: any) => { + if (typeof schema === "boolean") return schema; + return schema.type === fromPortOutputSchema.type; + }) ?? false; + + return matchesOneOf || matchesAnyOf; + }; + + // Strategy 1: Match by type AND port name (highest priority) + makeMatch( + ([fromOutputPortId, fromPortOutputSchema], [toInputPortId, toPortInputSchema]) => { const outputPortIdMatch = fromOutputPortId === toInputPortId; const outputPortIdOutputInput = fromOutputPortId === "output" && toInputPortId === "input"; const portIdsCompatible = outputPortIdMatch || outputPortIdOutputInput; - return (idTypeMatch || typeMatch) && portIdsCompatible; + + return ( + portIdsCompatible && isTypeCompatible(fromPortOutputSchema, toPortInputSchema, false) + ); } ); - // If no matches were found, remove the task and report an error - if (matches.size === 0) { + // Strategy 2: Match by specific type only (fallback for unmatched ports) + // Only matches specific types like TypedArray (with format), not primitives + // This allows connecting ports with different names but compatible specific types + makeMatch( + ([_fromOutputPortId, fromPortOutputSchema], [_toInputPortId, toPortInputSchema]) => { + return isTypeCompatible(fromPortOutputSchema, toPortInputSchema, true); + } + ); + + // Strategy 3: Look back through earlier tasks for unmatched required inputs + // Extract required inputs from target schema + const requiredInputs = new Set( + typeof targetSchema === "object" ? (targetSchema.required as string[]) || [] : [] + ); + + // Filter out required inputs that are already provided in the input parameter + // These don't need to be connected from previous tasks + const providedInputKeys = new Set(Object.keys(input || {})); + const requiredInputsNeedingConnection = [...requiredInputs].filter( + (r) => !providedInputKeys.has(r) + ); + + // Compute unmatched required inputs (that aren't already provided) + let unmatchedRequired = requiredInputsNeedingConnection.filter((r) => !matches.has(r)); + + // If there are unmatched required inputs, iterate backwards through earlier tasks + if (unmatchedRequired.length > 0) { + const nodes = this._graph.getTasks(); + const parentIndex = nodes.findIndex((n) => n.config.id === parent.config.id); + + // Iterate backwards from task before parent + for (let i = parentIndex - 1; i >= 0 && unmatchedRequired.length > 0; i--) { + const earlierTask = nodes[i]; + const earlierOutputSchema = earlierTask.outputSchema(); + + // Helper function to match from an earlier task (only for unmatched required inputs) + const makeMatchFromEarlier = ( + comparator: ( + [fromOutputPortId, fromPortOutputSchema]: [string, JsonSchema], + [toInputPortId, toPortInputSchema]: [string, JsonSchema] + ) => boolean + ): void => { + if (typeof earlierOutputSchema === "boolean" || typeof targetSchema === "boolean") { + return; + } + + for (const [fromOutputPortId, fromPortOutputSchema] of Object.entries( + earlierOutputSchema.properties || {} + )) { + for (const requiredInputId of unmatchedRequired) { + const toPortInputSchema = (targetSchema.properties as any)?.[requiredInputId]; + if ( + !matches.has(requiredInputId) && + toPortInputSchema && + comparator( + [fromOutputPortId, fromPortOutputSchema], + [requiredInputId, toPortInputSchema] + ) + ) { + matches.set(requiredInputId, fromOutputPortId); + this.connect( + earlierTask.config.id, + fromOutputPortId, + task.config.id, + requiredInputId + ); + } + } + } + }; + + // Try both matching strategies for earlier tasks + // Strategy 1: Match by type AND port name + makeMatchFromEarlier( + ([fromOutputPortId, fromPortOutputSchema], [toInputPortId, toPortInputSchema]) => { + const outputPortIdMatch = fromOutputPortId === toInputPortId; + const outputPortIdOutputInput = + fromOutputPortId === "output" && toInputPortId === "input"; + const portIdsCompatible = outputPortIdMatch || outputPortIdOutputInput; + + return ( + portIdsCompatible && + isTypeCompatible(fromPortOutputSchema, toPortInputSchema, false) + ); + } + ); + + // Strategy 2: Match by specific type only + makeMatchFromEarlier( + ([_fromOutputPortId, fromPortOutputSchema], [_toInputPortId, toPortInputSchema]) => { + return isTypeCompatible(fromPortOutputSchema, toPortInputSchema, true); + } + ); + + // Update unmatched required inputs + unmatchedRequired = unmatchedRequired.filter((r) => !matches.has(r)); + } + } + + // Updated failure condition: only fail when required inputs (that need connection) remain unmatched + const stillUnmatchedRequired = requiredInputsNeedingConnection.filter( + (r) => !matches.has(r) + ); + if (stillUnmatchedRequired.length > 0) { this._error = - `Could not find a match between the outputs of ${parent.type} and the inputs of ${task.type}. ` + - `You now need to connect the outputs to the inputs via connect() manually before adding this task. Task not added.`; + `Could not find matches for required inputs [${stillUnmatchedRequired.join(", ")}] of ${task.type}. ` + + `Attempted to match from ${parent.type} and earlier tasks. Task not added.`; console.error(this._error); this.graph.removeTask(task.config.id); + } else if (matches.size === 0 && requiredInputsNeedingConnection.length === 0) { + // No matches were made AND no required inputs need connection + // This happens in two cases: + // 1. Task has required inputs, but they were all provided as parameters + // 2. Task has no required inputs (all optional) + + // If task has required inputs that were all provided as parameters, allow the task + const hasRequiredInputs = requiredInputs.size > 0; + const allRequiredInputsProvided = + hasRequiredInputs && [...requiredInputs].every((r) => providedInputKeys.has(r)); + + // If no required inputs (all optional), check if there are defaults + const hasInputsWithDefaults = + typeof targetSchema === "object" && + targetSchema.properties && + Object.values(targetSchema.properties).some( + (prop: any) => prop && typeof prop === "object" && "default" in prop + ); + + // Allow if: + // - All required inputs were provided as parameters, OR + // - No required inputs and task has defaults + // Otherwise fail (no required inputs, no defaults, no matches) + if (!allRequiredInputsProvided && !hasInputsWithDefaults) { + this._error = + `Could not find a match between the outputs of ${parent.type} and the inputs of ${task.type}. ` + + `You now need to connect the outputs to the inputs via connect() manually before adding this task. Task not added.`; + + console.error(this._error); + this.graph.removeTask(task.config.id); + } } } - return this; + // Preserve input type from the start of the chain + // If this is the first task, set both input and output types + // Otherwise, only update the output type (input type is preserved from 'this') + return this as any; }; // Copy metadata from the task class @@ -233,7 +472,7 @@ export class Workflow; } /** @@ -296,7 +535,6 @@ export class Workflow(input, { parentSignal: this._abortController.signal, - parentProvenance: {}, outputCache: this._repository, }); const results = this.graph.mergeExecuteOutputsToRunOutput( @@ -597,7 +835,11 @@ export class Workflow( + type: T, + annotations: Record = {} +) => + ({ + oneOf: [type, { type: "array", items: type }], + title: type.title, + description: type.description, + ...(type.format ? { format: type.format } : {}), + ...annotations, + "x-replicate": true, + }) as const; + +/** + * Removes array types from a union, leaving only non-array types. + * For example, `string | string[]` becomes `string`. + * Used to extract the single-value type from schemas with x-replicate annotation. + * Uses distributive conditional types to filter out arrays from unions. + * Checks for both array types and types with numeric index signatures (FromSchema array output). + * Preserves Vector types like Float64Array which also have numeric indices. + */ +type UnwrapArrayUnion = T extends readonly any[] + ? T extends TypedArray + ? T + : never + : number extends keyof T + ? "push" extends keyof T + ? never + : T + : T; + +/** + * Transforms a schema by removing array variants from properties marked with x-replicate. + * Properties with x-replicate use {@link TypeReplicateArray} which creates a union of + * `T | T[]`, and this type extracts just `T`. + */ +export type DeReplicateFromSchema }> = { + [K in keyof S["properties"]]: S["properties"][K] extends { "x-replicate": true } + ? UnwrapArrayUnion> + : VectorFromSchema; +}; + /** * ArrayTask is a compound task that either: * 1. Executes directly if all inputs are non-arrays diff --git a/packages/task-graph/src/task/GraphAsTaskRunner.ts b/packages/task-graph/src/task/GraphAsTaskRunner.ts index d097392a..31b88005 100644 --- a/packages/task-graph/src/task/GraphAsTaskRunner.ts +++ b/packages/task-graph/src/task/GraphAsTaskRunner.ts @@ -27,7 +27,6 @@ export class GraphAsTaskRunner< } ); const results = await this.task.subGraph!.run(input, { - parentProvenance: this.nodeProvenance || {}, parentSignal: this.abortController?.signal, outputCache: this.outputCache, }); @@ -53,35 +52,6 @@ export class GraphAsTaskRunner< super.handleDisable(); } - // ======================================================================== - // Utility methods - // ======================================================================== - - private fixInput(input: Input): Input { - // inputs has turned each property into an array, so we need to flatten the input - // but only for properties marked with x-replicate in the schema - const inputSchema = this.task.inputSchema(); - if (typeof inputSchema === "boolean") { - return input; - } - - const flattenedInput = Object.entries(input).reduce((acc, [key, value]) => { - const inputDef = inputSchema.properties?.[key]; - const shouldFlatten = - Array.isArray(value) && - typeof inputDef === "object" && - inputDef !== null && - "x-replicate" in inputDef && - (inputDef as any)["x-replicate"] === true; - - if (shouldFlatten) { - return { ...acc, [key]: value[0] }; - } - return { ...acc, [key]: value }; - }, {}); - return flattenedInput as Input; - } - // ======================================================================== // TaskRunner method overrides and helpers // ======================================================================== @@ -97,7 +67,7 @@ export class GraphAsTaskRunner< this.task.compoundMerge ); } else { - const result = await super.executeTask(this.fixInput(input)); + const result = await super.executeTask(input); this.task.runOutputData = result ?? ({} as Output); } return this.task.runOutputData as Output; @@ -114,7 +84,7 @@ export class GraphAsTaskRunner< this.task.compoundMerge ); } else { - const reactiveResults = await super.executeTaskReactive(this.fixInput(input), output); + const reactiveResults = await super.executeTaskReactive(input, output); this.task.runOutputData = Object.assign({}, output, reactiveResults ?? {}) as Output; } return this.task.runOutputData as Output; diff --git a/packages/task-graph/src/task/ITask.ts b/packages/task-graph/src/task/ITask.ts index 2a528fd6..10056be6 100644 --- a/packages/task-graph/src/task/ITask.ts +++ b/packages/task-graph/src/task/ITask.ts @@ -19,14 +19,13 @@ import type { } from "./TaskEvents"; import type { JsonTaskItem, TaskGraphItemJson } from "./TaskJSON"; import { TaskRunner } from "./TaskRunner"; -import type { Provenance, TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; +import type { TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; /** * Context for task execution */ export interface IExecuteContext { signal: AbortSignal; - nodeProvenance: Provenance; updateProgress: (progress: number, message?: string, ...args: any[]) => Promise; own: (i: T) => T; } @@ -142,7 +141,6 @@ export interface ITaskEvents { * Interface for task serialization */ export interface ITaskSerialization { - getProvenance(): Provenance; toJSON(): JsonTaskItem | TaskGraphItemJson; toDependencyJSON(): JsonTaskItem; id(): unknown; @@ -168,7 +166,9 @@ export interface ITask< Input extends TaskInput = TaskInput, Output extends TaskOutput = TaskOutput, Config extends TaskConfig = TaskConfig, -> extends ITaskState, +> + extends + ITaskState, ITaskIO, ITaskEvents, ITaskLifecycle, diff --git a/packages/task-graph/src/task/JobQueueTask.ts b/packages/task-graph/src/task/JobQueueTask.ts index 65c3b337..fb809b56 100644 --- a/packages/task-graph/src/task/JobQueueTask.ts +++ b/packages/task-graph/src/task/JobQueueTask.ts @@ -5,7 +5,7 @@ */ import { Job, JobConstructorParam } from "@workglow/job-queue"; -import { ArrayTask } from "./ArrayTask"; +import { GraphAsTask } from "./GraphAsTask"; import { IExecuteContext } from "./ITask"; import { getJobQueueFactory } from "./JobQueueFactory"; import { JobTaskFailedError, TaskConfigurationError } from "./TaskError"; @@ -47,7 +47,7 @@ export abstract class JobQueueTask< Input extends TaskInput = TaskInput, Output extends TaskOutput = TaskOutput, Config extends JobQueueTaskConfig = JobQueueTaskConfig, -> extends ArrayTask { +> extends GraphAsTask { static readonly type: string = "JobQueueTask"; static canRunDirectly = true; @@ -60,7 +60,7 @@ export abstract class JobQueueTask< public jobClass: new (config: JobConstructorParam) => Job; - constructor(input: Input = {} as Input, config: Config = {} as Config) { + constructor(input: Partial = {} as Input, config: Config = {} as Config) { config.queue ??= true; super(input, config); this.jobClass = Job as unknown as new ( diff --git a/packages/task-graph/src/task/README.md b/packages/task-graph/src/task/README.md index d04b4267..18b17812 100644 --- a/packages/task-graph/src/task/README.md +++ b/packages/task-graph/src/task/README.md @@ -30,6 +30,9 @@ This module provides a flexible task processing system with support for various ### A Simple Task ```typescript +import { Task, type DataPortSchema } from "@workglow/task-graph"; +import { Type } from "@sinclair/typebox"; + interface MyTaskInput { input: number; } @@ -178,6 +181,15 @@ static outputSchema = () => { }), }) satisfies DataPortSchema; }; + +type MyInput = FromSchema; +type MyOutput = FromSchema; + +class MyTask extends Task { + static readonly type = "MyTask"; + static inputSchema = () => MyInputSchema; + static outputSchema = () => MyOutputSchema; +} ``` ### Using Zod @@ -201,13 +213,16 @@ const outputSchemaZod = z.object({ type MyInput = z.infer; type MyOutput = z.infer; -static inputSchema = () => { - return inputSchemaZod.toJSONSchema() as DataPortSchema; -}; +class MyTask extends Task { + static readonly type = "MyTask"; + static inputSchema = () => { + return inputSchemaZod.toJSONSchema() as DataPortSchema; + }; -static outputSchema = () => { - return outputSchemaZod.toJSONSchema() as DataPortSchema; -}; + static outputSchema = () => { + return outputSchemaZod.toJSONSchema() as DataPortSchema; + }; +} ``` ## Registry & Queues diff --git a/packages/task-graph/src/task/Task.ts b/packages/task-graph/src/task/Task.ts index 2bc7dd3d..bed88b1a 100644 --- a/packages/task-graph/src/task/Task.ts +++ b/packages/task-graph/src/task/Task.ts @@ -26,7 +26,6 @@ import type { JsonTaskItem, TaskGraphItemJson } from "./TaskJSON"; import { TaskRunner } from "./TaskRunner"; import { TaskStatus, - type Provenance, type TaskConfig, type TaskIdType, type TaskInput, @@ -307,11 +306,6 @@ export class Task< } protected _events: EventEmitter | undefined; - /** - * Provenance information for the task - */ - protected nodeProvenance: Provenance = {}; - /** * Creates a new task instance * @@ -380,11 +374,86 @@ export class Task< * Resets input data to defaults */ public resetInputData(): void { - // Use deep clone to avoid state leakage + this.runInputData = this.smartClone(this.defaults) as Record; + } + + /** + * Smart clone that deep-clones plain objects and arrays while preserving + * class instances (objects with non-Object prototype) by reference. + * Detects and throws an error on circular references. + * + * This is necessary because: + * - structuredClone cannot clone class instances (methods are lost) + * - JSON.parse/stringify loses methods and fails on circular references + * - Class instances like repositories should be passed by reference + * + * This breaks the idea of everything being json serializable, but it allows + * more efficient use cases. Do be careful with this though! Use sparingly. + * + * @param obj The object to clone + * @param visited Set of objects in the current cloning path (for circular reference detection) + * @returns A cloned object with class instances preserved by reference + */ + private smartClone(obj: any, visited: WeakSet = new WeakSet()): any { + if (obj === null || obj === undefined) { + return obj; + } + + // Primitives (string, number, boolean, symbol, bigint) are returned as-is + if (typeof obj !== "object") { + return obj; + } + + // Check for circular references + if (visited.has(obj)) { + throw new Error( + "Circular reference detected in input data. " + + "Cannot clone objects with circular references." + ); + } + + // Clone TypedArrays (Float32Array, Int8Array, etc.) to avoid shared-mutation + // between defaults and runInputData, while preserving DataView by reference. + if (ArrayBuffer.isView(obj)) { + // Preserve DataView instances by reference (constructor signature differs) + if (typeof DataView !== "undefined" && obj instanceof DataView) { + return obj; + } + // For TypedArrays, create a new instance with the same data + const typedArray = obj as any; + return new (typedArray.constructor as any)(typedArray); + } + + // Preserve class instances (objects with non-Object/non-Array prototype) + // This includes repository instances, custom classes, etc. + if (!Array.isArray(obj)) { + const proto = Object.getPrototypeOf(obj); + if (proto !== Object.prototype && proto !== null) { + return obj; // Pass by reference + } + } + + // Add object to visited set before recursing + visited.add(obj); + try { - this.runInputData = structuredClone(this.defaults) as Record; - } catch (err) { - this.runInputData = JSON.parse(JSON.stringify(this.defaults)) as Record; + // Deep clone arrays, preserving class instances within + if (Array.isArray(obj)) { + return obj.map((item) => this.smartClone(item, visited)); + } + + // Deep clone plain objects + const result: Record = {}; + for (const key in obj) { + if (Object.prototype.hasOwnProperty.call(obj, key)) { + result[key] = this.smartClone(obj[key], visited); + } + } + return result; + } finally { + // Remove from visited set after processing to allow the same object + // in different branches (non-circular references) + visited.delete(obj); } } @@ -429,7 +498,7 @@ export class Task< // If additionalProperties is true, also copy any additional input properties if (schema.additionalProperties === true) { for (const [inputId, value] of Object.entries(input)) { - if (value !== undefined && !(inputId in properties)) { + if (!(inputId in properties)) { this.runInputData[inputId] = value; } } @@ -506,7 +575,7 @@ export class Task< // If additionalProperties is true, also accept any additional input properties if (inputSchema.additionalProperties === true) { for (const [inputId, value] of Object.entries(overrides)) { - if (value !== undefined && !(inputId in properties)) { + if (!(inputId in properties)) { if (!deepEqual(this.runInputData[inputId], value)) { this.runInputData[inputId] = value; changed = true; @@ -650,7 +719,7 @@ export class Task< return `${e.message}${path ? ` (${path})` : ""}`; }); throw new TaskInvalidInputError( - `Input ${JSON.stringify(input)} does not match schema: ${errorMessages.join(", ")}` + `Input ${JSON.stringify(Object.keys(input))} does not match schema: ${errorMessages.join(", ")}` ); } @@ -664,13 +733,6 @@ export class Task< return this.config.id; } - /** - * Gets provenance information for the task - */ - public getProvenance(): Provenance { - return this.config.provenance ?? {}; - } - // ======================================================================== // Serialization methods // ======================================================================== @@ -684,6 +746,10 @@ export class Task< if (obj === null || obj === undefined) { return obj; } + // Preserve TypedArrays (Float32Array, Int8Array, etc.) + if (ArrayBuffer.isView(obj)) { + return obj; + } if (Array.isArray(obj)) { return obj.map((item) => this.stripSymbols(item)); } @@ -704,14 +770,12 @@ export class Task< * @returns The serialized task and subtasks */ public toJSON(): TaskGraphItemJson { - const provenance = this.getProvenance(); const extras = this.config.extras; let json: TaskGraphItemJson = this.stripSymbols({ id: this.config.id, type: this.type, ...(this.config.name ? { name: this.config.name } : {}), defaults: this.defaults, - ...(Object.keys(provenance).length ? { provenance } : {}), ...(extras && Object.keys(extras).length ? { extras } : {}), }); return json as TaskGraphItemJson; diff --git a/packages/task-graph/src/task/TaskEvents.ts b/packages/task-graph/src/task/TaskEvents.ts index 6eee2280..383389ec 100644 --- a/packages/task-graph/src/task/TaskEvents.ts +++ b/packages/task-graph/src/task/TaskEvents.ts @@ -5,8 +5,8 @@ */ import { EventParameters, type DataPortSchema } from "@workglow/util"; -import { TaskStatus } from "../common"; import { TaskAbortedError, TaskError } from "./TaskError"; +import { TaskStatus } from "./TaskTypes"; // ======================================================================== // Event Handling Types diff --git a/packages/task-graph/src/task/TaskJSON.test.ts b/packages/task-graph/src/task/TaskJSON.test.ts index a7e06874..4e5be6fe 100644 --- a/packages/task-graph/src/task/TaskJSON.test.ts +++ b/packages/task-graph/src/task/TaskJSON.test.ts @@ -130,7 +130,6 @@ describe("TaskJSON", () => { expect(json.type).toBe("TestTask"); expect(json.name).toBe("My Task"); expect(json.defaults).toEqual({ value: 42 }); - expect(json.provenance).toBeUndefined(); expect(json.extras).toBeUndefined(); }); @@ -141,27 +140,18 @@ describe("TaskJSON", () => { expect(json.defaults).toEqual({ value: 10, multiplier: 5 }); }); - test("should serialize task with provenance and extras", () => { + test("should serialize task with extras", () => { const task = new TestTask( { value: 100 }, { id: "task3", - provenance: { source: "test", version: "1.0" }, extras: { metadata: { key: "value" } }, } ); const json = task.toJSON(); - expect(json.provenance).toEqual({ source: "test", version: "1.0" }); expect(json.extras).toEqual({ metadata: { key: "value" } }); }); - - test("should not include empty provenance in JSON", () => { - const task = new TestTask({ value: 50 }, { id: "task4", provenance: {} }); - const json = task.toJSON(); - - expect(json.provenance).toBeUndefined(); - }); }); describe("TaskGraph.toJSON()", () => { @@ -234,18 +224,16 @@ describe("TaskJSON", () => { expect(task.defaults).toEqual({ value: 10, multiplier: 5 }); }); - test("should create a task with provenance and extras", () => { + test("should create a task with extras", () => { const json: TaskGraphItemJson = { id: "task3", type: "TestTask", defaults: { value: 100 }, - provenance: { source: "test", version: "1.0" }, extras: { metadata: { key: "value" } }, }; const task = createTaskFromGraphJSON(json); - expect(task.config.provenance).toEqual({ source: "test", version: "1.0" }); expect(task.config.extras).toEqual({ metadata: { key: "value" } }); }); @@ -386,14 +374,13 @@ describe("TaskJSON", () => { expect(restoredDataflows[0].targetTaskId).toBe(originalDataflows[0].targetTaskId); }); - test("should round-trip a task graph with defaults, provenance, and extras", () => { + test("should round-trip a task graph with defaults and extras", () => { const originalGraph = new TaskGraph(); const task1 = new TestTaskWithDefaults( { value: 10, multiplier: 3 }, { id: "task1", name: "Task with Defaults", - provenance: { source: "test", version: "1.0" }, extras: { metadata: { key: "value" } }, } ); @@ -404,7 +391,6 @@ describe("TaskJSON", () => { const restoredTask = restoredGraph.getTasks()[0]; expect(restoredTask.defaults).toEqual({ value: 10, multiplier: 3 }); - expect(restoredTask.config.provenance).toEqual({ source: "test", version: "1.0" }); expect(restoredTask.config.extras).toEqual({ metadata: { key: "value" } }); }); diff --git a/packages/task-graph/src/task/TaskJSON.ts b/packages/task-graph/src/task/TaskJSON.ts index 529e8b14..abd1ecb9 100644 --- a/packages/task-graph/src/task/TaskJSON.ts +++ b/packages/task-graph/src/task/TaskJSON.ts @@ -10,7 +10,7 @@ import { CompoundMergeStrategy } from "../task-graph/TaskGraphRunner"; import { TaskConfigurationError, TaskJSONError } from "../task/TaskError"; import { TaskRegistry } from "../task/TaskRegistry"; import { GraphAsTask } from "./GraphAsTask"; -import { DataPorts, Provenance, TaskConfig, TaskInput } from "./TaskTypes"; +import { DataPorts, TaskConfig, TaskInput } from "./TaskTypes"; // ======================================================================== // JSON Serialization Types @@ -53,9 +53,6 @@ export type JsonTaskItem = { /** Optional user data to use for this task, not used by the task framework except it will be exported as part of the task JSON*/ extras?: DataPorts; - /** Optional metadata about task origin */ - provenance?: Provenance; - /** Nested tasks for compound operations */ subtasks?: JsonTaskItem[]; }; /** @@ -67,7 +64,6 @@ export type TaskGraphItemJson = { type: string; name?: string; defaults?: TaskInput; - provenance?: Provenance; extras?: DataPorts; subgraph?: TaskGraphJson; merge?: CompoundMergeStrategy; @@ -88,10 +84,8 @@ export type DataflowJson = { const createSingleTaskFromJSON = (item: JsonTaskItem | TaskGraphItemJson) => { if (!item.id) throw new TaskJSONError("Task id required"); if (!item.type) throw new TaskJSONError("Task type required"); - if (item.defaults && (Array.isArray(item.defaults) || Array.isArray(item.provenance))) + if (item.defaults && Array.isArray(item.defaults)) throw new TaskJSONError("Task defaults must be an object"); - if (item.provenance && (Array.isArray(item.provenance) || typeof item.provenance !== "object")) - throw new TaskJSONError("Task provenance must be an object"); const taskClass = TaskRegistry.all.get(item.type); if (!taskClass) @@ -100,7 +94,6 @@ const createSingleTaskFromJSON = (item: JsonTaskItem | TaskGraphItemJson) => { const taskConfig: TaskConfig = { id: item.id, name: item.name, - provenance: item.provenance ?? {}, extras: item.extras, }; const task = new taskClass(item.defaults ?? {}, taskConfig); diff --git a/packages/task-graph/src/task/TaskRunner.ts b/packages/task-graph/src/task/TaskRunner.ts index b5a8bad8..a5621904 100644 --- a/packages/task-graph/src/task/TaskRunner.ts +++ b/packages/task-graph/src/task/TaskRunner.ts @@ -10,7 +10,7 @@ import { ensureTask, type Taskish } from "../task-graph/Conversions"; import { IRunConfig, ITask } from "./ITask"; import { ITaskRunner } from "./ITaskRunner"; import { TaskAbortedError, TaskError, TaskFailedError, TaskInvalidInputError } from "./TaskError"; -import { Provenance, TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; +import { TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; /** * Responsible for running tasks @@ -27,11 +27,6 @@ export class TaskRunner< protected running = false; protected reactiveRunning = false; - /** - * Provenance information for the task - */ - protected nodeProvenance: Provenance = {}; - /** * The task to run */ @@ -168,7 +163,6 @@ export class TaskRunner< const result = await this.task.execute(input, { signal: this.abortController!.signal, updateProgress: this.handleProgress.bind(this), - nodeProvenance: this.nodeProvenance, own: this.own, }); return await this.executeTaskReactive(input, result || ({} as Output)); @@ -192,7 +186,6 @@ export class TaskRunner< protected async handleStart(config: IRunConfig = {}): Promise { if (this.task.status === TaskStatus.PROCESSING) return; - this.nodeProvenance = {}; this.running = true; this.task.startedAt = new Date(); @@ -204,8 +197,6 @@ export class TaskRunner< this.handleAbort(); }); - this.nodeProvenance = config.nodeProvenance ?? {}; - const cache = this.task.config.outputCache ?? config.outputCache; if (cache === true) { let instance = globalServiceRegistry.get(TASK_OUTPUT_REPOSITORY); @@ -260,7 +251,6 @@ export class TaskRunner< this.task.progress = 100; this.task.status = TaskStatus.COMPLETED; this.abortController = undefined; - this.nodeProvenance = {}; this.task.emit("complete"); this.task.emit("status", this.task.status); @@ -276,7 +266,6 @@ export class TaskRunner< this.task.progress = 100; this.task.completedAt = new Date(); this.abortController = undefined; - this.nodeProvenance = {}; this.task.emit("disabled"); this.task.emit("status", this.task.status); } @@ -303,7 +292,6 @@ export class TaskRunner< this.task.error = err instanceof TaskError ? err : new TaskFailedError(err?.message || "Task failed"); this.abortController = undefined; - this.nodeProvenance = {}; this.task.emit("error", this.task.error); this.task.emit("status", this.task.status); } diff --git a/packages/task-graph/src/task/TaskTypes.ts b/packages/task-graph/src/task/TaskTypes.ts index 3d844459..6ac1a678 100644 --- a/packages/task-graph/src/task/TaskTypes.ts +++ b/packages/task-graph/src/task/TaskTypes.ts @@ -61,8 +61,6 @@ export type CompoundTaskOutput = [key: string]: unknown | unknown[] | undefined; }; -/** Type for task provenance metadata */ -export type Provenance = DataPorts; /** Type for task type names */ export type TaskTypeName = string; @@ -81,8 +79,6 @@ export interface IConfig { /** Optional display name for the task */ name?: string; - /** Optional metadata about task origin */ - provenance?: Provenance; /** Optional ID of the runner to use for this task */ runnerId?: string; diff --git a/packages/tasks/src/task/DebugLogTask.ts b/packages/tasks/src/task/DebugLogTask.ts index 803c7e9e..e2785ccf 100644 --- a/packages/tasks/src/task/DebugLogTask.ts +++ b/packages/tasks/src/task/DebugLogTask.ts @@ -89,8 +89,8 @@ export class DebugLogTask< TaskRegistry.registerTask(DebugLogTask); export const debugLog = (input: DebugLogTaskInput, config: TaskConfig = {}) => { - const task = new DebugLogTask(input, config); - return task.run(); + const task = new DebugLogTask({}, config); + return task.run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/DelayTask.ts b/packages/tasks/src/task/DelayTask.ts index b1ffedb6..6e75b0e0 100644 --- a/packages/tasks/src/task/DelayTask.ts +++ b/packages/tasks/src/task/DelayTask.ts @@ -88,8 +88,8 @@ TaskRegistry.registerTask(DelayTask); * @param {delay} - The delay in milliseconds */ export const delay = (input: DelayTaskInput, config: TaskConfig = {}) => { - const task = new DelayTask(input, config); - return task.run(); + const task = new DelayTask({}, config); + return task.run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/FetchUrlTask.ts b/packages/tasks/src/task/FetchUrlTask.ts index 64208b13..c2879110 100644 --- a/packages/tasks/src/task/FetchUrlTask.ts +++ b/packages/tasks/src/task/FetchUrlTask.ts @@ -356,7 +356,7 @@ export class FetchUrlTask< } as const satisfies DataPortSchema; } - constructor(input: Input = {} as Input, config: Config = {} as Config) { + constructor(input: Partial = {} as Input, config: Config = {} as Config) { config.queue = input?.queue ?? config.queue; if (config.queue === undefined) { config.queue = false; // change default to false to run directly @@ -421,7 +421,7 @@ export const fetchUrl = async ( input: FetchUrlTaskInput, config: FetchUrlTaskConfig = {} ): Promise => { - const result = await new FetchUrlTask(input, config).run(); + const result = await new FetchUrlTask({}, config).run(input); return result as FetchUrlTaskOutput; }; diff --git a/packages/tasks/src/task/FileLoaderTask.server.ts b/packages/tasks/src/task/FileLoaderTask.server.ts index 761b813e..0cad8e5f 100644 --- a/packages/tasks/src/task/FileLoaderTask.server.ts +++ b/packages/tasks/src/task/FileLoaderTask.server.ts @@ -216,7 +216,7 @@ export class FileLoaderTask extends BaseFileLoaderTask { TaskRegistry.registerTask(FileLoaderTask); export const fileLoader = (input: FileLoaderTaskInput, config?: JobQueueTaskConfig) => { - return new FileLoaderTask(input, config).run(); + return new FileLoaderTask({}, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/FileLoaderTask.ts b/packages/tasks/src/task/FileLoaderTask.ts index 9ecf0030..92559a4a 100644 --- a/packages/tasks/src/task/FileLoaderTask.ts +++ b/packages/tasks/src/task/FileLoaderTask.ts @@ -408,7 +408,7 @@ export class FileLoaderTask extends Task< TaskRegistry.registerTask(FileLoaderTask); export const fileLoader = (input: FileLoaderTaskInput, config?: JobQueueTaskConfig) => { - return new FileLoaderTask(input, config).run(); + return new FileLoaderTask({}, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/JavaScriptTask.ts b/packages/tasks/src/task/JavaScriptTask.ts index 12320213..fdf278f4 100644 --- a/packages/tasks/src/task/JavaScriptTask.ts +++ b/packages/tasks/src/task/JavaScriptTask.ts @@ -74,7 +74,7 @@ export class JavaScriptTask extends Task { - return new JavaScriptTask(input, config).run(); + return new JavaScriptTask({}, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/JsonTask.ts b/packages/tasks/src/task/JsonTask.ts index b62c7cf2..8994326b 100644 --- a/packages/tasks/src/task/JsonTask.ts +++ b/packages/tasks/src/task/JsonTask.ts @@ -103,7 +103,7 @@ TaskRegistry.registerTask(JsonTask); * Convenience function to create and run a JsonTask */ export const json = (input: JsonTaskInput, config: TaskConfig = {}) => { - return new JsonTask(input, config).run(); + return new JsonTask({}, config).run(input); }; // Add Json task workflow to Workflow interface diff --git a/packages/tasks/src/task/MergeTask.ts b/packages/tasks/src/task/MergeTask.ts index 60d72a1b..c6029459 100644 --- a/packages/tasks/src/task/MergeTask.ts +++ b/packages/tasks/src/task/MergeTask.ts @@ -85,8 +85,8 @@ export class MergeTask< TaskRegistry.registerTask(MergeTask); export const merge = (input: MergeTaskInput, config: TaskConfig = {}) => { - const task = new MergeTask(input, config); - return task.run(); + const task = new MergeTask({}, config); + return task.run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/SplitTask.ts b/packages/tasks/src/task/SplitTask.ts index ada6290b..4e7e0c06 100644 --- a/packages/tasks/src/task/SplitTask.ts +++ b/packages/tasks/src/task/SplitTask.ts @@ -88,8 +88,8 @@ export class SplitTask< TaskRegistry.registerTask(SplitTask); export const split = (input: SplitTaskInput, config: TaskConfig = {}) => { - const task = new SplitTask(input, config); - return task.run(); + const task = new SplitTask({}, config); + return task.run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/test/src/samples/ONNXModelSamples.ts b/packages/test/src/samples/ONNXModelSamples.ts index 1d98e048..aae837f7 100644 --- a/packages/test/src/samples/ONNXModelSamples.ts +++ b/packages/test/src/samples/ONNXModelSamples.ts @@ -120,18 +120,6 @@ export async function registerHuggingfaceLocalModels(): Promise { }, metadata: {}, }, - { - model_id: "onnx:Xenova/LaMini-Flan-T5-783M:q8", - title: "LaMini-Flan-T5-783M", - description: "Xenova/LaMini-Flan-T5-783M quantized to 8bit", - tasks: ["TextGenerationTask", "TextRewriterTask"], - provider: HF_TRANSFORMERS_ONNX, - provider_config: { - pipeline: "text2text-generation", - model_path: "Xenova/LaMini-Flan-T5-783M", - }, - metadata: {}, - }, { model_id: "onnx:Xenova/LaMini-Flan-T5-783M:q8", title: "LaMini-Flan-T5-783M", diff --git a/packages/test/src/test/task-graph/Workflow.test.ts b/packages/test/src/test/task-graph/Workflow.test.ts index b9e48c2e..3bcc825d 100644 --- a/packages/test/src/test/task-graph/Workflow.test.ts +++ b/packages/test/src/test/task-graph/Workflow.test.ts @@ -23,6 +23,8 @@ import { TestOutputTask, TestSimpleTask, } from "../task/TestTasks"; +// Import to register vector test tasks with the workflow system +import "../task/TestTasks"; const spyOn = vi.spyOn; @@ -369,6 +371,321 @@ describe("Workflow", () => { expect(workflow.error).toContain("Could not find a match"); expect(workflow.graph.getTasks()).toHaveLength(1); // Second task not added }); + + it("should auto-connect TypedArray ports with different names by format", () => { + // VectorOutputTask outputs 'vector', VectorsInputTask expects 'vectors' + // They should match because both have format: "TypedArray" + workflow = workflow.vectorOutput({ text: "test" }).vectorsInput(); + + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(2); + + const edges = workflow.graph.getDataflows(); + expect(edges).toHaveLength(1); + expect(edges[0].sourceTaskPortId).toBe("vector"); + expect(edges[0].targetTaskPortId).toBe("vectors"); + }); + + it("should auto-connect TypedArray with oneOf wrapper to anyOf input", () => { + // VectorOneOfOutputTask outputs 'embedding' (oneOf wrapped TypedArray) + // VectorAnyOfInputTask expects 'data' (anyOf wrapped TypedArray) + // They should match because both contain format: "TypedArray" inside the wrappers + workflow = workflow.vectorOneOfOutput({ text: "test" }).vectorAnyOfInput(); + + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(2); + + const edges = workflow.graph.getDataflows(); + expect(edges).toHaveLength(1); + expect(edges[0].sourceTaskPortId).toBe("embedding"); + expect(edges[0].targetTaskPortId).toBe("data"); + }); + + it("should not match primitive types (string) with different port names", () => { + // StringTask outputs 'output', TestInputTask expects 'customInput' + // These should NOT match because strings are primitive types + // and we only do type-only matching for specific types (like TypedArray) + workflow = workflow.string({ input: "test" }); + workflow = workflow.testInput(); + + expect(workflow.error).toContain("Could not find a match"); + expect(workflow.graph.getTasks()).toHaveLength(1); + }); + }); + + describe("multi-source input matching", () => { + it("should match required inputs from multiple earlier tasks (grandparent + parent)", () => { + // TextOutputTask outputs { text } + // VectorOutputOnlyTask outputs { vector } + // TextVectorInputTask requires both { text, vector } + // Should successfully connect text from grandparent and vector from parent + workflow = workflow + .textOutput({ input: "hello" }) + .vectorOutputOnly({ size: 5 }) + .textVectorInput(); + + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(3); + + const dataflows = workflow.graph.getDataflows(); + expect(dataflows).toHaveLength(2); + + const nodes = workflow.graph.getTasks(); + + // Check connections - text should come from first task (TextOutputTask) + const textConnection = dataflows.find( + (df) => df.targetTaskId === nodes[2].config.id && df.targetTaskPortId === "text" + ); + expect(textConnection).toBeDefined(); + expect(textConnection?.sourceTaskId).toBe(nodes[0].config.id); + expect(textConnection?.sourceTaskPortId).toBe("text"); + + // Check connections - vector should come from second task (VectorOutputOnlyTask) + const vectorConnection = dataflows.find( + (df) => df.targetTaskId === nodes[2].config.id && df.targetTaskPortId === "vector" + ); + expect(vectorConnection).toBeDefined(); + expect(vectorConnection?.sourceTaskId).toBe(nodes[1].config.id); + expect(vectorConnection?.sourceTaskPortId).toBe("vector"); + }); + + it("should fail when required inputs cannot be satisfied by any previous task", () => { + // VectorOutputOnlyTask only outputs { vector } + // TextVectorInputTask requires both { text, vector } + // Should fail because no previous task provides text + workflow = workflow.vectorOutputOnly({ size: 3 }).textVectorInput(); + + expect(workflow.error).toContain("Could not find matches for required inputs"); + expect(workflow.error).toContain("text"); + expect(workflow.graph.getTasks()).toHaveLength(1); // Second task not added + }); + + it("should match required inputs looking back multiple tasks (2+ hops)", () => { + // TextOutputTask outputs { text } + // VectorOutputOnlyTask outputs { vector } + // PassthroughVectorTask outputs { vector } (passes through) + // TextVectorInputTask requires both { text, vector } + // Should connect text from 2 tasks back and vector from parent + workflow = workflow + .textOutput({ input: "test" }) + .vectorOutputOnly({ size: 4 }) + .passthroughVector() + .textVectorInput(); + + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(4); + + const dataflows = workflow.graph.getDataflows(); + const nodes = workflow.graph.getTasks(); + + // Should have connections: + // 1. vector: VectorOutputOnlyTask -> PassthroughVectorTask + // 2. vector: PassthroughVectorTask -> TextVectorInputTask + // 3. text: TextOutputTask -> TextVectorInputTask + expect(dataflows.length).toBeGreaterThanOrEqual(3); + + // Verify the text connection comes from the first task (looking back 2 tasks) + const textConnection = dataflows.find( + (df) => df.targetTaskId === nodes[3].config.id && df.targetTaskPortId === "text" + ); + expect(textConnection).toBeDefined(); + expect(textConnection?.sourceTaskId).toBe(nodes[0].config.id); + expect(textConnection?.sourceTaskPortId).toBe("text"); + + // Verify the vector connection comes from the passthrough task (parent) + const vectorConnection = dataflows.find( + (df) => df.targetTaskId === nodes[3].config.id && df.targetTaskPortId === "vector" + ); + expect(vectorConnection).toBeDefined(); + expect(vectorConnection?.sourceTaskId).toBe(nodes[2].config.id); + }); + + it("should handle partial match where parent provides some required inputs", () => { + // Test that we successfully find text from earlier when parent only provides vector + + // TextOutputTask outputs { text } + // PassthroughVectorTask just to add another task in between + // VectorOutputOnlyTask outputs { vector } + // TextVectorInputTask requires both { text, vector } + workflow = workflow + .textOutput({ input: "partial" }) + .vectorOutputOnly({ size: 3 }) // First vectorOutputOnly + .passthroughVector() // Passes vector through (doesn't provide text) + .vectorOutputOnly({ size: 2 }) // Second vectorOutputOnly (overwrites parent's vector) + .textVectorInput(); + + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(5); + + const dataflows = workflow.graph.getDataflows(); + const nodes = workflow.graph.getTasks(); + + // Verify text comes from the first task (4 tasks back) + const textConnection = dataflows.find( + (df) => df.targetTaskId === nodes[4].config.id && df.targetTaskPortId === "text" + ); + expect(textConnection).toBeDefined(); + expect(textConnection?.sourceTaskId).toBe(nodes[0].config.id); + + // Verify vector comes from the parent (last vectorOutputOnly) + const vectorConnection = dataflows.find( + (df) => df.targetTaskId === nodes[4].config.id && df.targetTaskPortId === "vector" + ); + expect(vectorConnection).toBeDefined(); + expect(vectorConnection?.sourceTaskId).toBe(nodes[3].config.id); + }); + + it("should successfully match when all required inputs come from parent", () => { + // Special case: if parent already provides all required inputs, + // we shouldn't need to look back (standard auto-connection) + + // TestSimpleTask outputs { output: string } + // Another TestSimpleTask requires { input: string } + // Should auto-match because "output" -> "input" is a special case + workflow = workflow.testSimple({ input: "test" }).testSimple(); + + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(2); + + const dataflows = workflow.graph.getDataflows(); + const nodes = workflow.graph.getTasks(); + + // Verify connection from parent only (not looking back) + expect(dataflows).toHaveLength(1); + const connection = dataflows[0]; + expect(connection.sourceTaskId).toBe(nodes[0].config.id); + expect(connection.targetTaskId).toBe(nodes[1].config.id); + expect(connection.sourceTaskPortId).toBe("output"); + expect(connection.targetTaskPortId).toBe("input"); + }); + + it("should NOT match when types are incompatible even when looking back", () => { + // TestSimpleTask outputs { output: string } + // TestSimpleTask outputs { output: string } + // TestSimpleTask outputs { output: string } + // TextVectorInputTask requires { text: string, vector: Float32Array } + // Should fail because: + // - All tasks provide string, but port name is "output" not "text" or "vector" + // - None provide Float32Array + // - Primitive type string with name "output" won't match different names "text" or "vector" + workflow = workflow + .testSimple({ input: "first" }) + .testSimple({ input: "second" }) + .testSimple({ input: "third" }) + .textVectorInput(); + + expect(workflow.error).toContain("Could not find matches for required inputs"); + // Should fail because no task provides the required ports + expect(workflow.graph.getTasks()).toHaveLength(3); // Fourth task not added + }); + + it("should NOT connect optional (non-required) inputs from earlier tasks", () => { + // TestInputTask has { customInput: string } but it's NOT in the required array + // TestSimpleTask provides { output: string } + // TestOutputTask provides { customOutput: string } + // Should fail because backward matching only considers required inputs, + // and customInput is optional (not required) + workflow = workflow.testSimple({ input: "test" }).testSimple({ input: "middle" }).testInput(); + + // Should fail because customInput doesn't match "output" (different names, primitive type) + expect(workflow.error).toContain("Could not find a match"); + expect(workflow.graph.getTasks()).toHaveLength(2); + }); + + it("should NOT match primitive string ports with different names", () => { + // TestSimpleTask outputs { output: string } + // TestSimpleTask outputs { output: string } + // TestInputTask requires { customInput: string } + // Primitive types only match if names are the same or output->input special case + // "output" vs "customInput" are different names, so won't match + workflow = workflow + .testSimple({ input: "first" }) + .testSimple({ input: "second" }) + .testInput(); + + expect(workflow.error).toContain("Could not find a match"); + expect(workflow.graph.getTasks()).toHaveLength(2); + }); + + it("should allow first task with required inputs if no parent exists", () => { + // When adding a task as the first task in workflow, it's allowed + // even if it has required inputs (they can be provided at runtime) + // TextVectorInputTask has required inputs [text, vector] + workflow = workflow.textVectorInput(); + + // Should succeed - no parent means no auto-connection check + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(1); + + // But dataflows should be empty (no connections) + expect(workflow.graph.getDataflows()).toHaveLength(0); + }); + + it("should NOT require connections for required inputs that are provided as parameters", () => { + // TextVectorInputTask has required inputs [text, vector] + // But if we provide them as parameters, they don't need connections + workflow = workflow.testSimple({ input: "test" }).textVectorInput({ + text: "provided text", + vector: new Float32Array([1, 2, 3]), + }); + + // Should succeed - required inputs are provided as parameters + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(2); + + // No connections should be made since inputs are provided directly + expect(workflow.graph.getDataflows()).toHaveLength(0); + }); + + it("should only look for connections for required inputs NOT provided as parameters", () => { + // TextVectorInputTask has required inputs [text, vector] + // Provide text as parameter, but not vector + // Should look back only for vector connection + workflow = workflow + .testSimple({ input: "test" }) + .vectorOutputOnly({ size: 3 }) + .textVectorInput({ + text: "provided text", + // vector is NOT provided, should be connected from parent + }); + + // Should succeed - text is provided, vector is connected from parent + expect(workflow.error).toBe(""); + expect(workflow.graph.getTasks()).toHaveLength(3); + + const dataflows = workflow.graph.getDataflows(); + const nodes = workflow.graph.getTasks(); + + // Should have 1 connection: vector from VectorOutputOnlyTask + expect(dataflows.length).toBeGreaterThanOrEqual(1); + + const vectorConnection = dataflows.find( + (df) => df.targetTaskId === nodes[2].config.id && df.targetTaskPortId === "vector" + ); + expect(vectorConnection).toBeDefined(); + expect(vectorConnection?.sourceTaskId).toBe(nodes[1].config.id); + + // No connection for text (it was provided as parameter) + const textConnection = dataflows.find( + (df) => df.targetTaskId === nodes[2].config.id && df.targetTaskPortId === "text" + ); + expect(textConnection).toBeUndefined(); + }); + + it("should NOT match when no earlier task provides the required type", () => { + // TestSimpleTask outputs { output: string } + // TestSimpleTask outputs { output: string } + // TextVectorInputTask requires { text: string, vector: Float32Array } + // Should fail because no task provides Float32Array + workflow = workflow + .testSimple({ input: "test" }) + .testSimple({ input: "hello" }) + .textVectorInput(); + + expect(workflow.error).toContain("Could not find matches for required inputs"); + expect(workflow.error).toContain("vector"); // Missing vector type + expect(workflow.graph.getTasks()).toHaveLength(2); + }); }); describe("static methods", () => { diff --git a/packages/test/src/test/task/Task.smartClone.test.ts b/packages/test/src/test/task/Task.smartClone.test.ts new file mode 100644 index 00000000..ac926ce7 --- /dev/null +++ b/packages/test/src/test/task/Task.smartClone.test.ts @@ -0,0 +1,203 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { IExecuteContext } from "@workglow/task-graph"; +import { Task } from "@workglow/task-graph"; +import type { DataPortSchema } from "@workglow/util"; +import { beforeEach, describe, expect, test } from "vitest"; + +// Test task class to access private smartClone method +class TestSmartCloneTask extends Task<{ data: any }, { result: any }> { + static readonly type = "TestSmartCloneTask"; + static readonly category = "Test"; + static readonly title = "Test Smart Clone Task"; + static readonly description = "A task for testing smartClone"; + declare runInputData: { data: any }; + declare runOutputData: { result: any }; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + data: {}, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + result: {}, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { data: any }, context: IExecuteContext): Promise<{ result: any }> { + return { result: input.data }; + } + + // Expose smartClone for testing + public testSmartClone(obj: any): any { + return (this as any).smartClone(obj); + } +} + +describe("Task.smartClone circular reference detection", () => { + let task: TestSmartCloneTask; + + beforeEach(() => { + task = new TestSmartCloneTask({ data: {} }, { id: "test-task" }); + }); + + test("should handle simple objects without circular references", () => { + const obj = { a: 1, b: { c: 2 } }; + const cloned = task.testSmartClone(obj); + + expect(cloned).toEqual(obj); + expect(cloned).not.toBe(obj); + expect(cloned.b).not.toBe(obj.b); + }); + + test("should handle arrays without circular references", () => { + const arr = [1, 2, [3, 4]]; + const cloned = task.testSmartClone(arr); + + expect(cloned).toEqual(arr); + expect(cloned).not.toBe(arr); + expect(cloned[2]).not.toBe(arr[2]); + }); + + test("should throw error on object with circular self-reference", () => { + const obj: any = { a: 1 }; + obj.self = obj; + + expect(() => task.testSmartClone(obj)).toThrow("Circular reference detected in input data"); + }); + + test("should throw error on nested circular reference", () => { + const obj: any = { a: 1, b: { c: 2 } }; + obj.b.parent = obj; + + expect(() => task.testSmartClone(obj)).toThrow("Circular reference detected in input data"); + }); + + test("should throw error on array with circular reference", () => { + const arr: any = [1, 2, 3]; + arr.push(arr); + + expect(() => task.testSmartClone(arr)).toThrow("Circular reference detected in input data"); + }); + + test("should throw error on complex circular reference chain", () => { + const obj1: any = { name: "obj1" }; + const obj2: any = { name: "obj2", ref: obj1 }; + const obj3: any = { name: "obj3", ref: obj2 }; + obj1.ref = obj3; // Create circular chain + + expect(() => task.testSmartClone(obj1)).toThrow("Circular reference detected in input data"); + }); + + test("should handle same object referenced multiple times (not circular)", () => { + const shared = { value: 42 }; + const obj = { a: shared, b: shared }; + + // This should work - same object referenced multiple times is not circular + // Each reference gets cloned independently + const cloned = task.testSmartClone(obj); + + expect(cloned).toEqual(obj); + expect(cloned.a).toEqual(shared); + expect(cloned.b).toEqual(shared); + // The cloned references should be different objects (deep clone) + expect(cloned.a).not.toBe(shared); + expect(cloned.b).not.toBe(shared); + expect(cloned.a).not.toBe(cloned.b); + }); + + test("should preserve class instances by reference (no circular check needed)", () => { + class CustomClass { + constructor(public value: number) {} + } + + const instance = new CustomClass(42); + const obj = { data: instance }; + + const cloned = task.testSmartClone(obj); + + expect(cloned.data).toBe(instance); // Should be same reference + expect(cloned.data.value).toBe(42); + }); + + test("should clone TypedArrays to avoid shared mutation", () => { + const typedArray = new Float32Array([1.0, 2.0, 3.0]); + const obj = { data: typedArray }; + + const cloned = task.testSmartClone(obj); + + expect(cloned.data).not.toBe(typedArray); // Should be a new instance + expect(cloned.data).toEqual(typedArray); // But with the same values + expect(cloned.data).toBeInstanceOf(Float32Array); + }); + + test("should handle null and undefined", () => { + expect(task.testSmartClone(null)).toBe(null); + expect(task.testSmartClone(undefined)).toBe(undefined); + expect(task.testSmartClone({ a: null, b: undefined })).toEqual({ a: null, b: undefined }); + }); + + test("should handle primitives", () => { + expect(task.testSmartClone(42)).toBe(42); + expect(task.testSmartClone("hello")).toBe("hello"); + expect(task.testSmartClone(true)).toBe(true); + expect(task.testSmartClone(false)).toBe(false); + }); + + test("should clone nested structures without circular references", () => { + const obj = { + level1: { + level2: { + level3: { + value: "deep", + }, + }, + array: [1, 2, { nested: true }], + }, + }; + + const cloned = task.testSmartClone(obj); + + expect(cloned).toEqual(obj); + expect(cloned).not.toBe(obj); + expect(cloned.level1).not.toBe(obj.level1); + expect(cloned.level1.level2).not.toBe(obj.level1.level2); + expect(cloned.level1.array).not.toBe(obj.level1.array); + expect(cloned.level1.array[2]).not.toBe(obj.level1.array[2]); + }); + + test("should handle mixed object and array structures", () => { + const obj = { + users: [ + { id: 1, name: "Alice" }, + { id: 2, name: "Bob" }, + ], + settings: { + theme: "dark", + features: ["feature1", "feature2"], + }, + }; + + const cloned = task.testSmartClone(obj); + + expect(cloned).toEqual(obj); + expect(cloned.users).not.toBe(obj.users); + expect(cloned.users[0]).not.toBe(obj.users[0]); + expect(cloned.settings).not.toBe(obj.settings); + expect(cloned.settings.features).not.toBe(obj.settings.features); + }); +}); diff --git a/packages/test/src/test/task/TestTasks.ts b/packages/test/src/test/task/TestTasks.ts index a803799a..297a48e4 100644 --- a/packages/test/src/test/task/TestTasks.ts +++ b/packages/test/src/test/task/TestTasks.ts @@ -97,7 +97,7 @@ export class TestIOTask extends Task { /** * Implementation of full run mode - returns complete results */ - async execute(): Promise { + async execute(_input: TestIOTaskInput, _context: IExecuteContext): Promise { return { all: true, key: "full", reactiveOnly: false }; } } @@ -680,8 +680,11 @@ export class StringTask extends Task<{ input: string }, { output: string }, Task /** * Returns the input string as output */ - async execute() { - return { output: this.runInputData.input }; + async executeReactive( + input: { input: string }, + _output: { output: string } + ): Promise<{ output: string }> { + return { output: input.input }; } } @@ -719,8 +722,8 @@ export class NumberToStringTask extends Task<{ input: number }, { output: string /** * Returns the input string as output */ - async execute() { - return { output: String(this.runInputData.input) }; + async execute(input: { input: number }, _context: IExecuteContext): Promise<{ output: string }> { + return { output: String(input.input) }; } } @@ -759,8 +762,8 @@ export class NumberTask extends Task<{ input: number }, { output: number }, Task /** * Returns the input number as output */ - async execute() { - return { output: this.runInputData.input }; + async execute(input: { input: number }, _context: IExecuteContext): Promise<{ output: number }> { + return { output: input.input }; } } @@ -825,6 +828,190 @@ export class TestAddTask extends Task { } } +/** + * Task that outputs a TypedArray with port name "vector" (singular) + * Used for testing type-only matching with different port names + */ +export class VectorOutputTask extends Task<{ text: string }, { vector: Float32Array }> { + static type = "VectorOutputTask"; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + text: { + type: "string", + description: "Input text", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + vector: { + type: "array", + format: "TypedArray", + title: "Vector", + description: "Output vector (singular name)", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { text: string }): Promise<{ vector: Float32Array }> { + return { vector: new Float32Array([0.1, 0.2, 0.3]) }; + } +} + +/** + * Task that accepts a TypedArray with port name "vectors" (plural) + * Used for testing type-only matching with different port names + */ +export class VectorsInputTask extends Task<{ vectors: Float32Array }, { count: number }> { + static type = "VectorsInputTask"; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + vectors: { + type: "array", + format: "TypedArray", + title: "Vectors", + description: "Input vectors (plural name)", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + count: { + type: "number", + description: "Length of the vector", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { vectors: Float32Array }): Promise<{ count: number }> { + return { count: input.vectors.length }; + } +} + +/** + * Task that outputs a TypedArray wrapped in oneOf (like TypeSingleOrArray) + */ +export class VectorOneOfOutputTask extends Task<{ text: string }, { embedding: Float32Array }> { + static type = "VectorOneOfOutputTask"; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + text: { + type: "string", + description: "Input text", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + embedding: { + oneOf: [ + { + type: "array", + format: "TypedArray", + title: "Single Embedding", + }, + { + type: "array", + items: { + type: "array", + format: "TypedArray", + }, + title: "Multiple Embeddings", + }, + ], + title: "Embedding", + description: "Output embedding (oneOf wrapper)", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { text: string }): Promise<{ embedding: Float32Array }> { + return { embedding: new Float32Array([0.4, 0.5, 0.6]) }; + } +} + +/** + * Task that accepts a TypedArray wrapped in anyOf + */ +export class VectorAnyOfInputTask extends Task<{ data: Float32Array }, { sum: number }> { + static type = "VectorAnyOfInputTask"; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + data: { + anyOf: [ + { + type: "array", + format: "TypedArray", + title: "Single Vector", + }, + { + type: "array", + items: { + type: "array", + format: "TypedArray", + }, + title: "Multiple Vectors", + }, + ], + title: "Data", + description: "Input data (anyOf wrapper)", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + sum: { + type: "number", + description: "Sum of vector elements", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { data: Float32Array }): Promise<{ sum: number }> { + return { sum: Array.from(input.data).reduce((a, b) => a + b, 0) }; + } +} + /** * Module augmentation to register test task types in the workflow system */ @@ -839,6 +1026,22 @@ declare module "@workglow/task-graph" { numberToString: CreateWorkflow<{ input: number }, { output: string }, TaskConfig>; number: CreateWorkflow<{ input: number }, { output: number }, TaskConfig>; testAdd: CreateWorkflow; + vectorOutput: CreateWorkflow<{ text: string }, { vector: Float32Array }, TaskConfig>; + vectorsInput: CreateWorkflow<{ vectors: Float32Array }, { count: number }, TaskConfig>; + vectorOneOfOutput: CreateWorkflow<{ text: string }, { embedding: Float32Array }, TaskConfig>; + vectorAnyOfInput: CreateWorkflow<{ data: Float32Array }, { sum: number }, TaskConfig>; + textOutput: CreateWorkflow<{ input: string }, { text: string }, TaskConfig>; + vectorOutputOnly: CreateWorkflow<{ size: number }, { vector: Float32Array }, TaskConfig>; + textVectorInput: CreateWorkflow< + { text: string; vector: Float32Array }, + { result: string }, + TaskConfig + >; + passthroughVector: CreateWorkflow< + { vector: Float32Array }, + { vector: Float32Array }, + TaskConfig + >; } } @@ -852,3 +1055,179 @@ Workflow.prototype.string = CreateWorkflow(StringTask); Workflow.prototype.numberToString = CreateWorkflow(NumberToStringTask); Workflow.prototype.number = CreateWorkflow(NumberTask); Workflow.prototype.testAdd = CreateWorkflow(TestAddTask); +Workflow.prototype.vectorOutput = CreateWorkflow(VectorOutputTask); +Workflow.prototype.vectorsInput = CreateWorkflow(VectorsInputTask); +Workflow.prototype.vectorOneOfOutput = CreateWorkflow(VectorOneOfOutputTask); +Workflow.prototype.vectorAnyOfInput = CreateWorkflow(VectorAnyOfInputTask); +/** + * Task that outputs only text - for testing multi-source matching + */ +export class TextOutputTask extends Task<{ input: string }, { text: string }> { + static type = "TextOutputTask"; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + input: { + type: "string", + description: "Input string", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + text: { + type: "string", + description: "Output text", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { input: string }): Promise<{ text: string }> { + return { text: input.input }; + } +} + +/** + * Task that outputs only a vector - for testing multi-source matching + */ +export class VectorOutputOnlyTask extends Task<{ size: number }, { vector: Float32Array }> { + static type = "VectorOutputOnlyTask"; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + size: { + type: "number", + description: "Vector size", + default: 3, + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + vector: { + type: "array", + format: "TypedArray", + title: "Vector", + description: "Output vector", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { size: number }): Promise<{ vector: Float32Array }> { + return { vector: new Float32Array(input.size || 3).fill(1.0) }; + } +} + +/** + * Task that requires both text and vector inputs - for testing multi-source matching + */ +export class TextVectorInputTask extends Task< + { text: string; vector: Float32Array }, + { result: string } +> { + static type = "TextVectorInputTask"; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + text: { + type: "string", + description: "Input text", + }, + vector: { + type: "array", + items: { type: "number" }, + format: "TypedArray", + title: "Vector", + description: "Input vector", + }, + }, + required: ["text", "vector"], + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + result: { + type: "string", + description: "Combined result", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { text: string; vector: Float32Array }): Promise<{ result: string }> { + return { result: `${input.text} with vector of length ${input.vector.length}` }; + } +} + +/** + * Task that passes through a vector - for testing multi-hop matching + */ +export class PassthroughVectorTask extends Task< + { vector: Float32Array }, + { vector: Float32Array } +> { + static type = "PassthroughVectorTask"; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + vector: { + type: "array", + format: "TypedArray", + title: "Vector", + description: "Input vector", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + vector: { + type: "array", + format: "TypedArray", + title: "Vector", + description: "Output vector (passthrough)", + }, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { vector: Float32Array }): Promise<{ vector: Float32Array }> { + return { vector: input.vector }; + } +} +Workflow.prototype.textOutput = CreateWorkflow(TextOutputTask); +Workflow.prototype.vectorOutputOnly = CreateWorkflow(VectorOutputOnlyTask); +Workflow.prototype.textVectorInput = CreateWorkflow(TextVectorInputTask); +Workflow.prototype.passthroughVector = CreateWorkflow(PassthroughVectorTask); diff --git a/packages/util/package.json b/packages/util/package.json index 07d311be..bb841966 100644 --- a/packages/util/package.json +++ b/packages/util/package.json @@ -36,7 +36,7 @@ "access": "public" }, "dependencies": { - "json-schema-library": "^10.5.1", + "@sroussey/json-schema-library": "^10.5.3", "@sroussey/json-schema-to-ts": "3.1.3" } } \ No newline at end of file diff --git a/packages/util/src/json-schema/SchemaValidation.ts b/packages/util/src/json-schema/SchemaValidation.ts index 4a6b2d5f..7e1585cc 100644 --- a/packages/util/src/json-schema/SchemaValidation.ts +++ b/packages/util/src/json-schema/SchemaValidation.ts @@ -4,5 +4,5 @@ * SPDX-License-Identifier: Apache-2.0 */ -export { compileSchema } from "json-schema-library"; -export type { SchemaNode } from "json-schema-library"; +export { compileSchema } from "@sroussey/json-schema-library"; +export type { SchemaNode } from "@sroussey/json-schema-library"; From 43347b4e7f62dd113a16900ff4e3b3c94b3483ff Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 11 Jan 2026 07:17:47 +0000 Subject: [PATCH 03/14] [feat] Implement Input Resolver System for Schema-Based Resolution - Introduced an input resolver registry to automatically resolve string identifiers to object instances based on JSON Schema format annotations. - Enhanced the TaskRunner to utilize the input resolver for resolving model names and repository IDs before task execution. - Registered custom resolvers for various formats, improving flexibility in task configuration. - Updated documentation to reflect the new input resolution capabilities and usage examples. --- docs/developers/03_extending.md | 241 ++++++- packages/ai/README.md | 199 +++++- packages/ai/src/common.ts | 6 +- packages/ai/src/model/ModelRegistry.ts | 42 +- packages/ai/src/model/ModelRepository.ts | 9 +- packages/ai/src/source/Document.ts | 173 ----- packages/ai/src/source/DocumentConverter.ts | 18 - .../src/source/DocumentConverterMarkdown.ts | 120 ---- .../ai/src/source/DocumentConverterText.ts | 20 - packages/ai/src/source/MasterDocument.ts | 50 -- packages/ai/src/source/index.ts | 10 - packages/ai/src/task/DocumentSplitterTask.ts | 98 --- packages/ai/src/task/VectorSimilarityTask.ts | 81 +-- packages/ai/src/task/base/AiTask.ts | 222 ++---- packages/ai/src/task/base/AiTaskSchemas.ts | 221 +----- packages/ai/src/task/base/AiVisionTask.ts | 6 +- .../debug/src/console/ConsoleFormatters.ts | 12 +- packages/storage/README.md | 90 +++ packages/storage/src/common.ts | 12 +- packages/storage/src/document/Document.ts | 81 +++ packages/storage/src/document/DocumentNode.ts | 134 ++++ .../src/document/DocumentRepository.ts | 222 ++++++ .../document/DocumentRepositoryRegistry.ts | 79 +++ .../storage/src/document/DocumentSchema.ts | 630 ++++++++++++++++++ .../src/document/DocumentStorageSchema.ts | 43 ++ .../storage/src/document/StructuralParser.ts | 254 +++++++ .../storage/src/tabular/ITabularRepository.ts | 26 +- packages/storage/src/tabular/README.md | 2 +- .../src/tabular/TabularRepositoryRegistry.ts | 79 +++ packages/storage/src/util/RepositorySchema.ts | 96 +++ packages/task-graph/src/common.ts | 1 + .../task-graph/src/task-graph/TaskGraph.ts | 4 +- .../src/task-graph/TaskGraphRunner.ts | 23 +- packages/task-graph/src/task/ITask.ts | 7 +- packages/task-graph/src/task/InputResolver.ts | 113 ++++ packages/task-graph/src/task/README.md | 70 ++ packages/task-graph/src/task/Task.ts | 7 +- packages/task-graph/src/task/TaskRunner.ts | 31 +- packages/test/src/test/rag/Document.test.ts | 52 ++ .../rag/DocumentNodeRetrievalTask.test.ts | 295 ++++++++ .../rag/DocumentNodeVectorSearchTask.test.ts | 254 +++++++ .../DocumentNodeVectorStoreUpsertTask.test.ts | 228 +++++++ .../src/test/rag/DocumentRepository.test.ts | 484 ++++++++++++++ packages/util/README.md | 32 + packages/util/src/common.ts | 4 + packages/util/src/di/InputResolverRegistry.ts | 83 +++ packages/util/src/di/ServiceRegistry.ts | 2 +- packages/util/src/di/index.ts | 1 + packages/util/src/vector/Tensor.ts | 62 ++ packages/util/src/vector/TypedArray.ts | 95 +++ .../util/src/vector/VectorSimilarityUtils.ts | 92 +++ packages/util/src/vector/VectorUtils.ts | 95 +++ 52 files changed, 4327 insertions(+), 984 deletions(-) delete mode 100644 packages/ai/src/source/Document.ts delete mode 100644 packages/ai/src/source/DocumentConverter.ts delete mode 100644 packages/ai/src/source/DocumentConverterMarkdown.ts delete mode 100644 packages/ai/src/source/DocumentConverterText.ts delete mode 100644 packages/ai/src/source/MasterDocument.ts delete mode 100644 packages/ai/src/source/index.ts delete mode 100644 packages/ai/src/task/DocumentSplitterTask.ts create mode 100644 packages/storage/src/document/Document.ts create mode 100644 packages/storage/src/document/DocumentNode.ts create mode 100644 packages/storage/src/document/DocumentRepository.ts create mode 100644 packages/storage/src/document/DocumentRepositoryRegistry.ts create mode 100644 packages/storage/src/document/DocumentSchema.ts create mode 100644 packages/storage/src/document/DocumentStorageSchema.ts create mode 100644 packages/storage/src/document/StructuralParser.ts create mode 100644 packages/storage/src/tabular/TabularRepositoryRegistry.ts create mode 100644 packages/storage/src/util/RepositorySchema.ts create mode 100644 packages/task-graph/src/task/InputResolver.ts create mode 100644 packages/test/src/test/rag/Document.test.ts create mode 100644 packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts create mode 100644 packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts create mode 100644 packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts create mode 100644 packages/test/src/test/rag/DocumentRepository.test.ts create mode 100644 packages/util/src/di/InputResolverRegistry.ts create mode 100644 packages/util/src/vector/Tensor.ts create mode 100644 packages/util/src/vector/TypedArray.ts create mode 100644 packages/util/src/vector/VectorSimilarityUtils.ts create mode 100644 packages/util/src/vector/VectorUtils.ts diff --git a/docs/developers/03_extending.md b/docs/developers/03_extending.md index 8f740c5d..add9693a 100644 --- a/docs/developers/03_extending.md +++ b/docs/developers/03_extending.md @@ -6,6 +6,7 @@ This document covers how to write your own tasks. For a more practical guide to - [Tasks must have a `run()` method](#tasks-must-have-a-run-method) - [Define Inputs and Outputs](#define-inputs-and-outputs) - [Register the Task](#register-the-task) +- [Schema Format Annotations](#schema-format-annotations) - [Job Queues and LLM tasks](#job-queues-and-llm-tasks) - [Write a new Compound Task](#write-a-new-compound-task) - [Reactive Task UIs](#reactive-task-uis) @@ -117,7 +118,7 @@ To use the Task in Workflow, there are a few steps: ```ts export const simpleDebug = (input: DebugLogTaskInput) => { - return new SimpleDebugTask(input).run(); + return new SimpleDebugTask({} as DebugLogTaskInput, {}).run(input); }; declare module "@workglow/task-graph" { @@ -129,6 +130,103 @@ declare module "@workglow/task-graph" { Workflow.prototype.simpleDebug = CreateWorkflow(SimpleDebugTask); ``` +## Schema Format Annotations + +When defining task input schemas, you can use `format` annotations to enable automatic resolution of string identifiers to object instances. The TaskRunner inspects input schemas and resolves annotated string values before task execution. + +### Built-in Format Annotations + +The system supports several format annotations out of the box: + +| Format | Description | Helper Function | +| --------------------------------- | ----------------------------------- | ------------------------------------ | +| `model` | Any AI model configuration | `TypeModel()` | +| `model:TaskName` | Model compatible with specific task | — | +| `repository:tabular` | Tabular data repository | `TypeTabularRepository()` | +| `repository:document-node-vector` | Vector storage repository | `TypeDocumentNodeVectorRepository()` | +| `repository:document` | Document repository | `TypeDocumentRepository()` | + +### Example: Using Format Annotations + +```typescript +import { Task, type DataPortSchema } from "@workglow/task-graph"; +import { TypeTabularRepository } from "@workglow/storage"; +import { FromSchema } from "@workglow/util"; + +const MyTaskInputSchema = { + type: "object", + properties: { + // Model input - accepts string ID or ModelConfig object + model: { + title: "AI Model", + description: "Model for text generation", + format: "model:TextGenerationTask", + oneOf: [ + { type: "string", title: "Model ID" }, + { type: "object", title: "Model Config" }, + ], + }, + // Repository input - uses helper function + dataSource: TypeTabularRepository({ + title: "Data Source", + description: "Repository containing source data", + }), + // Regular string input (no resolution) + prompt: { type: "string", title: "Prompt" }, + }, + required: ["model", "dataSource", "prompt"], +} as const satisfies DataPortSchema; + +type MyTaskInput = FromSchema; + +export class MyTask extends Task { + static readonly type = "MyTask"; + static inputSchema = () => MyTaskInputSchema; + + async executeReactive(input: MyTaskInput) { + // By the time execute runs, model is a ModelConfig object + // and dataSource is an ITabularRepository instance + const { model, dataSource, prompt } = input; + // ... + } +} +``` + +### Creating Custom Format Resolvers + +You can extend the resolution system by registering custom resolvers: + +```typescript +import { registerInputResolver } from "@workglow/util"; + +// Register a resolver for "template:*" formats +registerInputResolver("template", async (id, format, registry) => { + const templateRepo = registry.get(TEMPLATE_REPOSITORY); + const template = await templateRepo.findById(id); + if (!template) { + throw new Error(`Template "${id}" not found`); + } + return template; +}); +``` + +Then use it in your schemas: + +```typescript +const inputSchema = { + type: "object", + properties: { + emailTemplate: { + type: "string", + format: "template:email", + title: "Email Template", + }, + }, +}; +``` + +When a task runs with `{ emailTemplate: "welcome-email" }`, the resolver automatically converts it to the template object before execution. + ## Job Queues and LLM tasks We separate any long running tasks as Jobs. Jobs could potentially be run anywhere, either locally in the same thread, in separate threads, or on a remote server. A job queue will manage these for a single provider (like OpenAI, or a local Transformers.js ONNX runtime), and handle backoff, retries, etc. @@ -148,3 +246,144 @@ Compound Tasks are not cached (though any or all of their children may be). ## Reactive Task UIs Tasks can be reactive at a certain level. This means that they can be triggered by changes in the data they depend on, without "running" the expensive job based task runs. This is useful for a UI node editor. For example, you change a color in one task and it is propagated downstream without incurring costs for re-running the entire graph. It is like a spreadsheet where changing a cell can trigger a recalculation of other cells. This is implemented via a `runReactive()` method that is called when the data changes. Typically, the `run()` will call `runReactive()` on itself at the end of the method. + +## AI and RAG Tasks + +The `@workglow/ai` package provides a comprehensive set of tasks for building RAG (Retrieval-Augmented Generation) pipelines. These tasks are designed to chain together in workflows without requiring external loops. + +### Document Processing Tasks + +| Task | Description | +| ------------------------- | ----------------------------------------------------- | +| `StructuralParserTask` | Parses markdown/text into hierarchical document trees | +| `TextChunkerTask` | Splits text into chunks with configurable strategies | +| `HierarchicalChunkerTask` | Token-aware chunking that respects document structure | +| `TopicSegmenterTask` | Segments text by topic using heuristics or embeddings | +| `DocumentEnricherTask` | Adds summaries and entities to document nodes | + +### Vector and Embedding Tasks + +| Task | Description | +| ------------------------------ | ---------------------------------------------- | +| `TextEmbeddingTask` | Generates embeddings using configurable models | +| `ChunkToVectorTask` | Transforms chunks to vector store format | +| `DocumentNodeVectorUpsertTask` | Stores vectors in a repository | +| `DocumentNodeVectorSearchTask` | Searches vectors by similarity | +| `VectorQuantizeTask` | Quantizes vectors for storage efficiency | + +### Retrieval and Generation Tasks + +| Task | Description | +| ------------------------------------ | --------------------------------------------- | +| `QueryExpanderTask` | Expands queries for better retrieval coverage | +| `DocumentNodeVectorHybridSearchTask` | Combines vector and full-text search | +| `RerankerTask` | Reranks search results for relevance | +| `HierarchyJoinTask` | Enriches results with parent context | +| `ContextBuilderTask` | Builds context for LLM prompts | +| `DocumentNodeRetrievalTask` | Orchestrates end-to-end retrieval | +| `TextQuestionAnswerTask` | Generates answers from context | +| `TextGenerationTask` | General text generation | + +### Chainable RAG Pipeline Example + +Tasks chain together through compatible input/output schemas: + +```typescript +import { Workflow } from "@workglow/task-graph"; +import { InMemoryVectorRepository } from "@workglow/storage"; + +const vectorRepo = new InMemoryVectorRepository(); +await vectorRepo.setupDatabase(); + +// Document ingestion pipeline +await new Workflow() + .structuralParser({ + text: markdownContent, + title: "My Document", + format: "markdown", + }) + .documentEnricher({ + generateSummaries: true, + extractEntities: true, + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .textEmbedding({ + model: "Xenova/all-MiniLM-L6-v2", + }) + .chunkToVector() + .vectorStoreUpsert({ + repository: vectorRepo, + }) + .run(); +``` + +### Retrieval Pipeline Example + +```typescript +const answer = await new Workflow() + .textEmbedding({ + text: query, + model: "Xenova/all-MiniLM-L6-v2", + }) + .vectorStoreSearch({ + repository: vectorRepo, + topK: 10, + }) + .reranker({ + query, + topK: 5, + }) + .contextBuilder({ + format: "markdown", + maxLength: 2000, + }) + .textQuestionAnswer({ + question: query, + model: "Xenova/LaMini-Flan-T5-783M", + }) + .run(); +``` + +### Hierarchical Document Structure + +Documents are represented as trees with typed nodes: + +```typescript +type DocumentNode = + | DocumentRootNode // Root of document + | SectionNode // Headers, structural sections + | ParagraphNode // Text blocks + | SentenceNode // Fine-grained (optional) + | TopicNode; // Detected topic segments + +// Each node contains: +interface BaseNode { + nodeId: string; // Deterministic content-based ID + range: { start: number; end: number }; + text: string; + enrichment?: { + summary?: string; + entities?: Entity[]; + keywords?: string[]; + }; +} +``` + +### Task Data Flow + +Each task passes through what the next task needs: + +| Task | Passes Through | Adds | +| --------------------- | ------------------------ | ------------------------------------- | +| `structuralParser` | - | `doc_id`, `documentTree`, `nodeCount` | +| `documentEnricher` | `doc_id`, `documentTree` | `summaryCount`, `entityCount` | +| `hierarchicalChunker` | `doc_id` | `chunks`, `text[]`, `count` | +| `textEmbedding` | (implicit) | `vector[]` | +| `chunkToVector` | - | `ids[]`, `vectors[]`, `metadata[]` | +| `vectorStoreUpsert` | - | `count`, `ids` | + +This design eliminates the need for external loops - the entire pipeline chains together naturally. diff --git a/packages/ai/README.md b/packages/ai/README.md index d8118516..a1bf267b 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -216,25 +216,6 @@ const result = await task.run(); // Output: { similarity: 0.85 } ``` -### Document Processing Tasks - -#### DocumentSplitterTask - -Splits documents into smaller chunks for processing. - -```typescript -import { DocumentSplitterTask } from "@workglow/ai"; - -const task = new DocumentSplitterTask({ - document: "Very long document content...", - chunkSize: 1000, - chunkOverlap: 200, -}); - -const result = await task.run(); -// Output: { chunks: ["chunk1...", "chunk2...", "chunk3..."] } -``` - ### Model Management Tasks #### DownloadModelTask @@ -415,30 +396,140 @@ const result = await workflow console.log("Final similarity score:", result.similarity); ``` -## Document Processing +## RAG (Retrieval-Augmented Generation) Pipelines + +The AI package provides a comprehensive set of tasks for building RAG pipelines. These tasks chain together in workflows without requiring external loops. + +### Document Processing Tasks + +| Task | Description | +| ------------------------- | ----------------------------------------------------- | +| `StructuralParserTask` | Parses markdown/text into hierarchical document trees | +| `TextChunkerTask` | Splits text into chunks with configurable strategies | +| `HierarchicalChunkerTask` | Token-aware chunking that respects document structure | +| `TopicSegmenterTask` | Segments text by topic using heuristics or embeddings | +| `DocumentEnricherTask` | Adds summaries and entities to document nodes | + +### Vector and Storage Tasks -The package includes document processing capabilities: +| Task | Description | +| ------------------------------ | ---------------------------------------- | +| `ChunkToVectorTask` | Transforms chunks to vector store format | +| `DocumentNodeVectorUpsertTask` | Stores vectors in a repository | +| `DocumentNodeVectorSearchTask` | Searches vectors by similarity | +| `VectorQuantizeTask` | Quantizes vectors for storage efficiency | + +### Retrieval and Generation Tasks + +| Task | Description | +| ------------------------------------ | --------------------------------------------- | +| `QueryExpanderTask` | Expands queries for better retrieval coverage | +| `DocumentNodeVectorHybridSearchTask` | Combines vector and full-text search | +| `RerankerTask` | Reranks search results for relevance | +| `HierarchyJoinTask` | Enriches results with parent context | +| `ContextBuilderTask` | Builds context for LLM prompts | +| `DocumentNodeRetrievalTask` | Orchestrates end-to-end retrieval | + +### Complete RAG Workflow Example ```typescript -import { Document, DocumentConverterMarkdown } from "@workglow/ai"; +import { Workflow } from "@workglow/task-graph"; +import { InMemoryVectorRepository } from "@workglow/storage"; -// Create a document -const doc = new Document("# My Document\n\nThis is content...", { title: "Sample Doc" }); +const vectorRepo = new InMemoryVectorRepository(); +await vectorRepo.setupDatabase(); -// Convert markdown to structured format -const converter = new DocumentConverterMarkdown(); -const processedDoc = await converter.convert(doc); +// Document ingestion - fully chainable, no loops required +await new Workflow() + .structuralParser({ + text: markdownContent, + title: "Documentation", + format: "markdown", + }) + .documentEnricher({ + generateSummaries: true, + extractEntities: true, + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .textEmbedding({ + model: "Xenova/all-MiniLM-L6-v2", + }) + .chunkToVector() + .vectorStoreUpsert({ + repository: vectorRepo, + }) + .run(); -// Use with document splitter -const splitter = new DocumentSplitterTask({ - document: processedDoc.content, - chunkSize: 500, - chunkOverlap: 50, -}); +// Query pipeline +const answer = await new Workflow() + .queryExpander({ + query: "What is transfer learning?", + method: "multi-query", + numVariations: 3, + }) + .textEmbedding({ + model: "Xenova/all-MiniLM-L6-v2", + }) + .vectorStoreSearch({ + repository: vectorRepo, + topK: 10, + scoreThreshold: 0.5, + }) + .reranker({ + query: "What is transfer learning?", + topK: 5, + }) + .contextBuilder({ + format: "markdown", + maxLength: 2000, + }) + .textQuestionAnswer({ + question: "What is transfer learning?", + model: "Xenova/LaMini-Flan-T5-783M", + }) + .run(); +``` + +### Hierarchical Document Structure -const chunks = await splitter.run(); +Documents are represented as trees with typed nodes: + +```typescript +type DocumentNode = + | DocumentRootNode // Root of document + | SectionNode // Headers, structural sections + | ParagraphNode // Text blocks + | SentenceNode // Fine-grained (optional) + | TopicNode; // Detected topic segments ``` +Each node contains: + +- `nodeId` - Deterministic content-based ID +- `range` - Source character offsets +- `text` - Content +- `enrichment` - Summaries, entities, keywords (optional) +- `children` - Child nodes (for parent nodes) + +### Task Data Flow + +Each task passes through what the next task needs: + +| Task | Passes Through | Adds | +| --------------------- | ------------------------ | ------------------------------------- | +| `structuralParser` | - | `doc_id`, `documentTree`, `nodeCount` | +| `documentEnricher` | `doc_id`, `documentTree` | `summaryCount`, `entityCount` | +| `hierarchicalChunker` | `doc_id` | `chunks`, `text[]`, `count` | +| `textEmbedding` | (implicit) | `vector[]` | +| `chunkToVector` | - | `ids[]`, `vectors[]`, `metadata[]` | +| `vectorStoreUpsert` | - | `count`, `ids` | + +This design eliminates the need for external loops - the entire pipeline chains together naturally. + ## Error Handling AI tasks include comprehensive error handling: @@ -466,6 +557,46 @@ try { ## Advanced Configuration +### Model Input Resolution + +AI tasks accept model inputs as either string identifiers or direct `ModelConfig` objects. When a string is provided, the TaskRunner automatically resolves it to a `ModelConfig` before task execution using the `ModelRepository`. + +```typescript +import { TextGenerationTask } from "@workglow/ai"; + +// Using a model ID (resolved from ModelRepository) +const task1 = new TextGenerationTask({ + model: "onnx:Xenova/gpt2:q8", + prompt: "Generate text", +}); + +// Using a direct ModelConfig object +const task2 = new TextGenerationTask({ + model: { + model_id: "onnx:Xenova/gpt2:q8", + provider: "hf-transformers-onnx", + tasks: ["TextGenerationTask"], + title: "GPT-2", + provider_config: { pipeline: "text-generation" }, + }, + prompt: "Generate text", +}); + +// Both approaches work identically +``` + +This resolution is handled by the input resolver system, which inspects schema `format` annotations (like `"model"` or `"model:TextGenerationTask"`) to determine how string values should be resolved. + +### Supported Format Annotations + +| Format | Description | Resolver | +| --------------------------------- | ---------------------------------------- | -------------------------- | +| `model` | Any AI model configuration | ModelRepository | +| `model:TaskName` | Model compatible with specific task type | ModelRepository | +| `repository:tabular` | Tabular data repository | TabularRepositoryRegistry | +| `repository:document-node-vector` | Vector storage repository | VectorRepositoryRegistry | +| `repository:document` | Document repository | DocumentRepositoryRegistry | + ### Custom Model Validation Tasks automatically validate that specified models exist and are compatible: diff --git a/packages/ai/src/common.ts b/packages/ai/src/common.ts index 12dbd7b9..d86b50c9 100644 --- a/packages/ai/src/common.ts +++ b/packages/ai/src/common.ts @@ -5,12 +5,12 @@ */ export * from "./job/AiJob"; + export * from "./model/InMemoryModelRepository"; export * from "./model/ModelRegistry"; export * from "./model/ModelRepository"; export * from "./model/ModelSchema"; + export * from "./provider/AiProviderRegistry"; -export * from "./source/Document"; -export * from "./source/DocumentConverterMarkdown"; -export * from "./source/DocumentConverterText"; + export * from "./task"; diff --git a/packages/ai/src/model/ModelRegistry.ts b/packages/ai/src/model/ModelRegistry.ts index 0d162955..2a7a3deb 100644 --- a/packages/ai/src/model/ModelRegistry.ts +++ b/packages/ai/src/model/ModelRegistry.ts @@ -4,9 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { createServiceToken, globalServiceRegistry } from "@workglow/util"; +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; import { InMemoryModelRepository } from "./InMemoryModelRepository"; import { ModelRepository } from "./ModelRepository"; +import type { ModelConfig } from "./ModelSchema"; /** * Service token for the global model repository @@ -32,8 +38,36 @@ export function getGlobalModelRepository(): ModelRepository { /** * Sets the global model repository instance - * @param pr The model repository instance to register + * @param repository The model repository instance to register */ -export function setGlobalModelRepository(pr: ModelRepository): void { - globalServiceRegistry.registerInstance(MODEL_REPOSITORY, pr); +export function setGlobalModelRepository(repository: ModelRepository): void { + globalServiceRegistry.registerInstance(MODEL_REPOSITORY, repository); } + +/** + * Resolves a model ID to a ModelConfig from the repository. + * Used by the input resolver system. + */ +async function resolveModelFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise { + const modelRepo = registry.has(MODEL_REPOSITORY) + ? registry.get(MODEL_REPOSITORY) + : getGlobalModelRepository(); + + if (Array.isArray(id)) { + const results = await Promise.all(id.map((i) => modelRepo.findByName(i))); + return results.filter((model) => model !== undefined) as ModelConfig[]; + } + + const model = await modelRepo.findByName(id); + if (!model) { + throw new Error(`Model "${id}" not found in repository`); + } + return model; +} + +// Register the model resolver for format: "model" and "model:*" +registerInputResolver("model", resolveModelFromRegistry); diff --git a/packages/ai/src/model/ModelRepository.ts b/packages/ai/src/model/ModelRepository.ts index 9f0b7ae4..234fb168 100644 --- a/packages/ai/src/model/ModelRepository.ts +++ b/packages/ai/src/model/ModelRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type TabularRepository } from "@workglow/storage"; +import { type BaseTabularRepository } from "@workglow/storage"; import { EventEmitter, EventParameters } from "@workglow/util"; import { ModelPrimaryKeyNames, ModelRecord, ModelRecordSchema } from "./ModelSchema"; @@ -37,12 +37,15 @@ export class ModelRepository { /** * Repository for storing and managing Model instances */ - protected readonly modelTabularRepository: TabularRepository< + protected readonly modelTabularRepository: BaseTabularRepository< typeof ModelRecordSchema, typeof ModelPrimaryKeyNames >; constructor( - modelTabularRepository: TabularRepository + modelTabularRepository: BaseTabularRepository< + typeof ModelRecordSchema, + typeof ModelPrimaryKeyNames + > ) { this.modelTabularRepository = modelTabularRepository; } diff --git a/packages/ai/src/source/Document.ts b/packages/ai/src/source/Document.ts deleted file mode 100644 index a58a1c0a..00000000 --- a/packages/ai/src/source/Document.ts +++ /dev/null @@ -1,173 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -enum DocumentType { - DOCUMENT = "document", - SECTION = "section", - TEXT = "text", - IMAGE = "image", - TABLE = "table", -} - -const doc_variants = [ - "tree", - "flat", - "tree-paragraphs", - "flat-paragraphs", - "tree-sentences", - "flat-sentences", -] as const; -type DocVariant = (typeof doc_variants)[number]; -const doc_parsers = ["txt", "md"] as const; // | "html" | "pdf" | "csv"; -type DocParser = (typeof doc_parsers)[number]; - -export interface DocumentMetadata { - title: string; -} - -export interface DocumentSectionMetadata { - title: string; -} - -/** - * Represents a document with its content and metadata. - */ -export class Document { - public metadata: DocumentMetadata; - - constructor(content?: ContentType, metadata: DocumentMetadata = { title: "" }) { - this.metadata = metadata; - if (content) { - if (Array.isArray(content)) { - for (const line of content) { - this.addContent(line); - } - } else { - this.addContent(content); - } - } - } - - public addContent(content: ContentTypeItem) { - if (typeof content === "string") { - this.addText(content); - } else if (content instanceof DocumentBaseFragment || content instanceof DocumentSection) { - this.fragments.push(content); - } else { - throw new Error("Unknown content type"); - } - } - - public addSection(content?: ContentType, metadata?: DocumentSectionMetadata): DocumentSection { - const section = new DocumentSection(this, content, metadata); - this.fragments.push(section); - return section; - } - - public addText(content: string): TextFragment { - const f = new TextFragment(content); - this.fragments.push(f); - return f; - } - public addImage(content: unknown): ImageFragment { - const f = new ImageFragment(content); - this.fragments.push(f); - return f; - } - public addTable(content: unknown): TableFragment { - const f = new TableFragment(content); - this.fragments.push(f); - return f; - } - - public fragments: Array = []; - - toJSON(): unknown { - return { - type: DocumentType.DOCUMENT, - metadata: this.metadata, - fragments: this.fragments.map((f) => f.toJSON()), - }; - } -} - -export class DocumentSection extends Document { - constructor( - public parent: Document, - content?: ContentType, - metadata?: DocumentSectionMetadata - ) { - super(content, metadata); - this.parent = parent; - } - - toJSON(): unknown { - return { - type: DocumentType.SECTION, - metadata: this.metadata, - fragments: this.fragments.map((f) => f.toJSON()), - }; - } -} - -interface DocumentFragmentMetadata {} - -export class DocumentBaseFragment { - metadata?: DocumentFragmentMetadata; - constructor(metadata?: DocumentFragmentMetadata) { - this.metadata = metadata; - } -} - -export class TextFragment extends DocumentBaseFragment { - content: string; - constructor(content: string, metadata?: DocumentFragmentMetadata) { - super(metadata); - this.content = content; - } - toJSON(): unknown { - return { - type: DocumentType.TEXT, - metadata: this.metadata, - content: this.content, - }; - } -} - -export class TableFragment extends DocumentBaseFragment { - content: any; - constructor(content: any, metadata?: DocumentFragmentMetadata) { - super(metadata); - this.content = content; - } - toJSON(): unknown { - return { - type: DocumentType.TABLE, - metadata: this.metadata, - content: this.content, - }; - } -} - -export class ImageFragment extends DocumentBaseFragment { - content: any; - constructor(content: any, metadata?: DocumentFragmentMetadata) { - super(metadata); - this.content = content; - } - toJSON(): unknown { - return { - type: DocumentType.IMAGE, - metadata: this.metadata, - content: this.content, - }; - } -} - -export type DocumentFragment = TextFragment | TableFragment | ImageFragment; - -export type ContentTypeItem = string | DocumentFragment | DocumentSection; -export type ContentType = ContentTypeItem | ContentTypeItem[]; diff --git a/packages/ai/src/source/DocumentConverter.ts b/packages/ai/src/source/DocumentConverter.ts deleted file mode 100644 index b89ba6ca..00000000 --- a/packages/ai/src/source/DocumentConverter.ts +++ /dev/null @@ -1,18 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Document, DocumentMetadata } from "./Document"; - -/** - * Abstract class for converting different types of content into a Document. - */ -export abstract class DocumentConverter { - public metadata: DocumentMetadata; - constructor(metadata: DocumentMetadata) { - this.metadata = metadata; - } - public abstract convert(): Document; -} diff --git a/packages/ai/src/source/DocumentConverterMarkdown.ts b/packages/ai/src/source/DocumentConverterMarkdown.ts deleted file mode 100644 index 55e88330..00000000 --- a/packages/ai/src/source/DocumentConverterMarkdown.ts +++ /dev/null @@ -1,120 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Document, type DocumentMetadata, type DocumentSection } from "./Document"; -import { DocumentConverter } from "./DocumentConverter"; - -export class DocumentConverterMarkdown extends DocumentConverter { - constructor( - metadata: DocumentMetadata, - public markdown: string - ) { - super(metadata); - } - public convert(): Document { - const parser = new MarkdownParser(this.metadata.title); - const document = parser.parse(this.markdown); - return document; - } -} - -class MarkdownParser { - private document: Document; - private currentSection: Document | DocumentSection; - private textBuffer: string[] = []; // Buffer to accumulate text lines - - constructor(title: string) { - this.document = new Document(title); - this.currentSection = this.document; - } - - parse(markdown: string): Document { - const lines = markdown.split("\n"); - - lines.forEach((line, index) => { - if (this.isHeader(line)) { - this.flushTextBuffer(); - const { level, content } = this.parseHeader(line); - this.currentSection = - level === 1 ? this.document.addSection(content) : this.currentSection.addSection(content); - } else if (this.isTableStart(line)) { - this.flushTextBuffer(); - const tableLines = this.collectTableLines(lines, index); - this.currentSection.addTable(tableLines.join("\n")); - } else if (this.isImageInline(line)) { - this.parseLineWithPossibleImages(line); - } else { - this.textBuffer.push(line); // Accumulate text lines in the buffer - } - }); - - this.flushTextBuffer(); // Flush any remaining text in the buffer - return this.document; - } - - private flushTextBuffer() { - if (this.textBuffer.length > 0) { - const textContent = this.textBuffer.join("\n").trim(); - if (textContent) { - this.currentSection.addText(textContent); - } - this.textBuffer = []; // Clear the buffer after flushing - } - } - - private parseLineWithPossibleImages(line: string) { - // Split the line by image markdown, keeping the delimiter (image markdown) - const parts = line.split(/(!\[.*?\]\(.*?\))/).filter((part) => part !== ""); - parts.forEach((part) => { - if (this.isImage(part)) { - const { alt, src } = this.parseImage(part); - this.flushTextBuffer(); - this.currentSection.addImage({ alt, src }); - } else { - this.textBuffer.push(part); - } - }); - this.flushTextBuffer(); - } - - private isHeader(line: string): boolean { - return /^#{1,6}\s/.test(line); - } - - private parseHeader(line: string): { level: number; content: string } { - const match = line.match(/^(#{1,6})\s+(.*)$/); - return match ? { level: match[1].length, content: match[2] } : { level: 0, content: "" }; - } - - private isTableStart(line: string): boolean { - return line.trim().startsWith("|") && line.includes("|", line.indexOf("|") + 1); - } - - private collectTableLines(lines: string[], startIndex: number): string[] { - const tableLines = []; - for (let i = startIndex; i < lines.length && this.isTableLine(lines[i]); i++) { - tableLines.push(lines[i]); - } - return tableLines; - } - - private isTableLine(line: string): boolean { - return line.includes("|"); - } - - private isImageInline(line: string): boolean { - return line.includes("![") && line.includes("]("); - } - - private isImage(part: string): boolean { - return /^!\[.*\]\(.*\)$/.test(part); - } - - private parseImage(markdown: string): { alt: string; src: string } { - const match = markdown.match(/^!\[(.*)\]\((.*)\)$/); - return match ? { alt: match[1], src: match[2] } : { alt: "", src: "" }; - } -} diff --git a/packages/ai/src/source/DocumentConverterText.ts b/packages/ai/src/source/DocumentConverterText.ts deleted file mode 100644 index f9337bda..00000000 --- a/packages/ai/src/source/DocumentConverterText.ts +++ /dev/null @@ -1,20 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Document, DocumentMetadata } from "./Document"; -import { DocumentConverter } from "./DocumentConverter"; - -export class DocumentConverterText extends DocumentConverter { - constructor( - metadata: DocumentMetadata, - public text: string - ) { - super(metadata); - } - public convert(): Document { - return new Document(this.text, this.metadata); - } -} diff --git a/packages/ai/src/source/MasterDocument.ts b/packages/ai/src/source/MasterDocument.ts deleted file mode 100644 index 4d2aef33..00000000 --- a/packages/ai/src/source/MasterDocument.ts +++ /dev/null @@ -1,50 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Document, DocumentMetadata, TextFragment } from "./Document"; -import { DocumentConverter } from "./DocumentConverter"; - -/** - * MasterDocument represents a container for managing multiple versions/variants of a document. - * It maintains the original document and its transformed variants for different use cases. - * - * Key features: - * - Stores original document and metadata - * - Maintains a master version and variants - * - Automatically creates paragraph-split variant - * - * The paragraph variant splits text fragments by newlines while preserving other fragment types, - * which is useful for more granular text processing. - */ - -export class MasterDocument { - public metadata: DocumentMetadata; - public original: DocumentConverter; - public master: Document; - public variants: Document[] = []; - constructor(original: DocumentConverter, metadata: DocumentMetadata) { - this.metadata = Object.assign(original.metadata, metadata); - this.original = original; - this.master = original.convert(); - this.variants.push(paragraphVariant(this.master)); - } -} - -function paragraphVariant(doc: Document): Document { - const newdoc = new Document("", doc.metadata); - for (const node of doc.fragments) { - if (node instanceof TextFragment) { - const newnodes = node.content - .split("\n") - .filter((t) => t) - .map((paragraph) => new TextFragment(paragraph)); - newdoc.fragments.push(...newnodes); - } else { - newdoc.fragments.push(node); - } - } - return newdoc; -} diff --git a/packages/ai/src/source/index.ts b/packages/ai/src/source/index.ts deleted file mode 100644 index 8b9992fd..00000000 --- a/packages/ai/src/source/index.ts +++ /dev/null @@ -1,10 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -export * from "./Document"; -export * from "./DocumentConverterMarkdown"; -export * from "./DocumentConverterText"; -export * from "./MasterDocument"; diff --git a/packages/ai/src/task/DocumentSplitterTask.ts b/packages/ai/src/task/DocumentSplitterTask.ts deleted file mode 100644 index 864f7e1d..00000000 --- a/packages/ai/src/task/DocumentSplitterTask.ts +++ /dev/null @@ -1,98 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - CreateWorkflow, - JobQueueTaskConfig, - Task, - TaskRegistry, - Workflow, -} from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; -import { Document, DocumentFragment } from "../source/Document"; - -const inputSchema = { - type: "object", - properties: { - parser: { - type: "string", - enum: ["txt", "md"], - title: "Document Kind", - description: "The kind of document (txt or md)", - }, - // file: Type.Instance(Document), - }, - required: ["parser"], - additionalProperties: false, -} as const satisfies DataPortSchema; - -const outputSchema = { - type: "object", - properties: { - texts: { - type: "array", - items: { type: "string" }, - title: "Text Chunks", - description: "The text chunks of the document", - }, - }, - required: ["texts"], - additionalProperties: false, -} as const satisfies DataPortSchema; - -export type DocumentSplitterTaskInput = FromSchema; -export type DocumentSplitterTaskOutput = FromSchema; - -export class DocumentSplitterTask extends Task< - DocumentSplitterTaskInput, - DocumentSplitterTaskOutput, - JobQueueTaskConfig -> { - public static type = "DocumentSplitterTask"; - public static category = "Document"; - public static title = "Document Splitter"; - public static description = "Splits documents into text chunks for processing"; - public static inputSchema(): DataPortSchema { - return inputSchema as DataPortSchema; - } - public static outputSchema(): DataPortSchema { - return outputSchema as DataPortSchema; - } - - flattenFragmentsToTexts(item: DocumentFragment | Document): string[] { - if (item instanceof Document) { - const texts: string[] = []; - item.fragments.forEach((fragment) => { - texts.push(...this.flattenFragmentsToTexts(fragment)); - }); - return texts; - } else { - return [item.content]; - } - } - - async executeReactive(): Promise { - return { texts: this.flattenFragmentsToTexts(this.runInputData.file) }; - } -} - -TaskRegistry.registerTask(DocumentSplitterTask); - -export const documentSplitter = (input: DocumentSplitterTaskInput) => { - return new DocumentSplitterTask(input).run(); -}; - -declare module "@workglow/task-graph" { - interface Workflow { - documentSplitter: CreateWorkflow< - DocumentSplitterTaskInput, - DocumentSplitterTaskOutput, - JobQueueTaskConfig - >; - } -} - -Workflow.prototype.documentSplitter = CreateWorkflow(DocumentSplitterTask); diff --git a/packages/ai/src/task/VectorSimilarityTask.ts b/packages/ai/src/task/VectorSimilarityTask.ts index 48898ba7..d5714716 100644 --- a/packages/ai/src/task/VectorSimilarityTask.ts +++ b/packages/ai/src/task/VectorSimilarityTask.ts @@ -5,15 +5,21 @@ */ import { - ArrayTask, CreateWorkflow, + GraphAsTask, JobQueueTaskConfig, - TaskError, TaskRegistry, Workflow, } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; -import { TypedArray, TypedArraySchema, TypedArraySchemaOptions } from "./base/AiTaskSchemas"; +import { + cosineSimilarity, + DataPortSchema, + FromSchema, + hammingSimilarity, + jaccardSimilarity, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; export const SimilarityFn = { COSINE: "cosine", @@ -21,6 +27,12 @@ export const SimilarityFn = { HAMMING: "hamming", } as const; +const similarityFunctions = { + cosine: cosineSimilarity, + jaccard: jaccardSimilarity, + hamming: hammingSimilarity, +} as const; + export type SimilarityFn = (typeof SimilarityFn)[keyof typeof SimilarityFn]; const SimilarityInputSchema = { @@ -30,7 +42,7 @@ const SimilarityInputSchema = { title: "Query", description: "Query vector to compare against", }), - input: { + vectors: { type: "array", items: TypedArraySchema({ title: "Input", @@ -44,7 +56,7 @@ const SimilarityInputSchema = { minimum: 1, default: 10, }, - similarity: { + method: { type: "string", enum: Object.values(SimilarityFn), title: "Similarity 𝑓", @@ -52,7 +64,7 @@ const SimilarityInputSchema = { default: SimilarityFn.COSINE, }, }, - required: ["query", "input", "similarity"], + required: ["query", "vectors", "method"], additionalProperties: false, } as const satisfies DataPortSchema; @@ -88,7 +100,7 @@ export type VectorSimilarityTaskOutput = FromSchema< TypedArraySchemaOptions >; -export class VectorSimilarityTask extends ArrayTask< +export class VectorSimilarityTask extends GraphAsTask< VectorSimilarityTaskInput, VectorSimilarityTaskOutput, JobQueueTaskConfig @@ -107,17 +119,12 @@ export class VectorSimilarityTask extends ArrayTask< return SimilarityOutputSchema as DataPortSchema; } - // @ts-ignore (TODO: fix this) - async executeReactive( - { query, input, similarity, topK }: VectorSimilarityTaskInput, - oldOutput: VectorSimilarityTaskOutput - ) { + async executeReactive({ query, vectors, method, topK }: VectorSimilarityTaskInput) { let similarities = []; - const fns = { cosine }; - const fnName = similarity as keyof typeof fns; - const fn = fns[fnName]; + const fnName = method as keyof typeof similarityFunctions; + const fn = similarityFunctions[fnName]; - for (const embedding of input) { + for (const embedding of vectors) { similarities.push({ similarity: fn(embedding, query), embedding, @@ -137,7 +144,7 @@ export class VectorSimilarityTask extends ArrayTask< TaskRegistry.registerTask(VectorSimilarityTask); export const similarity = (input: VectorSimilarityTaskInput, config?: JobQueueTaskConfig) => { - return new VectorSimilarityTask(input, config).run(); + return new VectorSimilarityTask({} as VectorSimilarityTaskInput, config).run(input); }; declare module "@workglow/task-graph" { @@ -151,41 +158,3 @@ declare module "@workglow/task-graph" { } Workflow.prototype.similarity = CreateWorkflow(VectorSimilarityTask); - -// =============================================================================== - -export function inner(arr1: TypedArray, arr2: TypedArray): number { - // @ts-ignore - return 1 - arr1.reduce((acc, val, i) => acc + val * arr2[i], 0); -} - -export function magnitude(arr: TypedArray) { - // @ts-ignore - return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0)); -} - -function cosine(arr1: TypedArray, arr2: TypedArray) { - const dotProduct = inner(arr1, arr2); - const magnitude1 = magnitude(arr1); - const magnitude2 = magnitude(arr2); - return 1 - dotProduct / (magnitude1 * magnitude2); -} - -export function normalize(vector: TypedArray): TypedArray { - const mag = magnitude(vector); - - if (mag === 0) { - throw new TaskError("Cannot normalize a zero vector."); - } - - const normalized = vector.map((val) => Number(val) / mag); - - if (vector instanceof Float64Array) { - return new Float64Array(normalized); - } - if (vector instanceof Float32Array) { - return new Float32Array(normalized); - } - // For integer arrays and bigint[], use Float32Array since normalization produces floats - return new Float32Array(normalized); -} diff --git a/packages/ai/src/task/base/AiTask.ts b/packages/ai/src/task/base/AiTask.ts index 08e2c301..053466b5 100644 --- a/packages/ai/src/task/base/AiTask.ts +++ b/packages/ai/src/task/base/AiTask.ts @@ -16,11 +16,12 @@ import { TaskInput, type TaskOutput, } from "@workglow/task-graph"; -import { type JsonSchema } from "@workglow/util"; +import { type JsonSchema, type ServiceRegistry } from "@workglow/util"; import { AiJob, AiJobInput } from "../../job/AiJob"; -import { getGlobalModelRepository } from "../../model/ModelRegistry"; -import type { ModelConfig, ModelRecord } from "../../model/ModelSchema"; +import { MODEL_REPOSITORY } from "../../model/ModelRegistry"; +import type { ModelRepository } from "../../model/ModelRepository"; +import type { ModelConfig } from "../../model/ModelSchema"; function schemaFormat(schema: JsonSchema): string | undefined { return typeof schema === "object" && schema !== null && "format" in schema @@ -32,35 +33,31 @@ export interface AiSingleTaskInput extends TaskInput { model: string | ModelConfig; } -export interface AiArrayTaskInput extends TaskInput { - model: string | ModelConfig | (string | ModelConfig)[]; -} - /** * A base class for AI related tasks that run in a job queue. * Extends the JobQueueTask class to provide LLM-specific functionality. + * + * Model resolution is handled automatically by the TaskRunner before execution. + * By the time execute() is called, input.model is always a ModelConfig object. */ export class AiTask< - Input extends AiArrayTaskInput = AiArrayTaskInput, + Input extends AiSingleTaskInput = AiSingleTaskInput, Output extends TaskOutput = TaskOutput, Config extends JobQueueTaskConfig = JobQueueTaskConfig, > extends JobQueueTask { public static type: string = "AiTask"; - private modelCache?: { name: string; model: ModelRecord }; /** * Creates a new AiTask instance * @param config - Configuration object for the task */ - constructor(input: Input = {} as Input, config: Config = {} as Config) { + constructor(input: Partial = {}, config: Config = {} as Config) { const modelLabel = typeof input.model === "string" ? input.model - : Array.isArray(input.model) - ? undefined - : typeof input.model === "object" && input.model - ? input.model.model_id || input.model.title || input.model.provider - : undefined; + : typeof input.model === "object" && input.model + ? input.model.model_id || input.model.title || input.model.provider + : undefined; config.name ||= `${new.target.type || new.target.name}${ modelLabel ? " with model " + modelLabel : "" }`; @@ -74,58 +71,31 @@ export class AiTask< /** * Get the input to submit to the job queue. * Transforms the task input to AiJobInput format. - * @param input - The task input + * + * Note: By the time this is called, input.model has already been resolved + * to a ModelConfig by the TaskRunner's input resolution system. + * + * @param input - The task input (with resolved model) * @returns The AiJobInput to submit to the queue */ protected override async getJobInput(input: Input): Promise> { - if (Array.isArray(input.model)) { - console.error("AiTask: Model is an array", input); + // Model is guaranteed to be resolved by TaskRunner before this is called + const model = input.model as ModelConfig; + if (!model || typeof model !== "object") { throw new TaskConfigurationError( - "AiTask: Model is an array, only create job for single model tasks" + "AiTask: Model was not resolved to ModelConfig - this indicates a bug in the resolution system" ); } - const runtype = (this.constructor as any).runtype ?? (this.constructor as any).type; - const model = await this.getModelConfigForInput(input as AiSingleTaskInput); - // TODO: if the queue is not memory based, we need to convert to something that can structure clone to the queue - // const registeredQueue = await this.resolveQueue(input); - // const queueName = registeredQueue?.server.queueName; + const runtype = (this.constructor as any).runtype ?? (this.constructor as any).type; return { taskType: runtype, aiProvider: model.provider, - taskInput: { ...(input as any), model } as Input & { model: ModelConfig }, + taskInput: input as Input & { model: ModelConfig }, }; } - /** - * Resolves a model configuration for the given input. - * - * @remarks - * - If `input.model` is a string, it is resolved via the global model repository. - * - If `input.model` is already a config object, it is used directly. - */ - protected async getModelConfigForInput(input: AiSingleTaskInput): Promise { - const modelValue = input.model; - if (!modelValue) throw new TaskConfigurationError("AiTask: No model found"); - if (typeof modelValue === "string") { - const modelname = modelValue; - if (this.modelCache && this.modelCache.name === modelname) { - return this.modelCache.model; - } - const model = await getGlobalModelRepository().findByName(modelname); - if (!model) { - throw new TaskConfigurationError(`AiTask: No model ${modelname} found`); - } - this.modelCache = { name: modelname, model }; - return model; - } - if (typeof modelValue === "object") { - return modelValue; - } - throw new TaskConfigurationError("AiTask: Invalid model value"); - } - /** * Creates a new Job instance for direct execution (without a queue). * @param input - The task input @@ -149,42 +119,25 @@ export class AiTask< return job; } - protected async getModelForInput(input: AiSingleTaskInput): Promise { - const modelname = input.model; - if (!modelname) throw new TaskConfigurationError("AiTask: No model name found"); - if (typeof modelname !== "string") { - throw new TaskConfigurationError("AiTask: Model name is not a string"); - } - if (this.modelCache && this.modelCache.name === modelname) { - return this.modelCache.model; - } - const model = await getGlobalModelRepository().findByName(modelname); - if (!model) { - throw new TaskConfigurationError(`JobQueueTask: No model ${modelname} found`); - } - this.modelCache = { name: modelname, model }; - return model; - } - + /** + * Gets the default queue name based on the model's provider. + * After TaskRunner resolution, input.model is a ModelConfig. + */ protected override async getDefaultQueueName(input: Input): Promise { - if (typeof input.model === "string") { - const model = await this.getModelForInput(input as AiSingleTaskInput); - return model.provider; - } - if (typeof input.model === "object" && input.model !== null && !Array.isArray(input.model)) { - return (input.model as ModelConfig).provider; - } - return undefined; + const model = input.model as ModelConfig; + return model?.provider; } /** - * Validates that a model name really exists - * @param schema The schema to validate against - * @param item The item to validate - * @returns True if the item is valid, false otherwise + * Validates that model inputs are valid ModelConfig objects. + * + * Note: By the time this is called, string model IDs have already been + * resolved to ModelConfig objects by the TaskRunner's input resolution system. + * + * @param input The input to validate + * @returns True if the input is valid */ async validateInput(input: Input): Promise { - // TODO(str): this is very inefficient, we should cache the results, including intermediate results const inputSchema = this.inputSchema(); if (typeof inputSchema === "boolean") { if (inputSchema === false) { @@ -192,59 +145,41 @@ export class AiTask< } return true; } + + // Find properties with model:TaskName format - need task compatibility check const modelTaskProperties = Object.entries( (inputSchema.properties || {}) as Record ).filter(([key, schema]) => schemaFormat(schema)?.startsWith("model:")); - if (modelTaskProperties.length > 0) { - const taskModels = await getGlobalModelRepository().findModelsByTask(this.type); - for (const [key, propSchema] of modelTaskProperties) { - let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; - for (const model of requestedModels) { - if (typeof model === "string") { - const foundModel = taskModels?.find((m) => m.model_id === model); - if (!foundModel) { - throw new TaskConfigurationError( - `AiTask: Missing model for '${key}' named '${model}' for task '${this.type}'` - ); - } - } else if (typeof model === "object" && model !== null) { - // Inline configs are accepted without requiring repository access. - // If 'tasks' is provided, do a best-effort compatibility check. - const tasks = (model as ModelConfig).tasks; - if (Array.isArray(tasks) && tasks.length > 0 && !tasks.includes(this.type)) { - throw new TaskConfigurationError( - `AiTask: Inline model for '${key}' is not compatible with task '${this.type}'` - ); - } - } else { - throw new TaskConfigurationError(`AiTask: Invalid model for '${key}'`); - } + for (const [key] of modelTaskProperties) { + const model = input[key]; + if (typeof model === "object" && model !== null) { + // Check task compatibility if tasks array is specified + const tasks = (model as ModelConfig).tasks; + if (Array.isArray(tasks) && tasks.length > 0 && !tasks.includes(this.type)) { + throw new TaskConfigurationError( + `AiTask: Model for '${key}' is not compatible with task '${this.type}'` + ); } + } else if (model !== undefined && model !== null) { + // Should be a ModelConfig object after resolution + throw new TaskConfigurationError( + `AiTask: Invalid model for '${key}' - expected ModelConfig object` + ); } } + // Find properties with plain model format - just ensure they're objects const modelPlainProperties = Object.entries( (inputSchema.properties || {}) as Record ).filter(([key, schema]) => schemaFormat(schema) === "model"); - if (modelPlainProperties.length > 0) { - for (const [key, propSchema] of modelPlainProperties) { - let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; - for (const model of requestedModels) { - if (typeof model === "string") { - const foundModel = await getGlobalModelRepository().findByName(model); - if (!foundModel) { - throw new TaskConfigurationError( - `AiTask: Missing model for "${key}" named "${model}"` - ); - } - } else if (typeof model === "object" && model !== null) { - // Inline configs are accepted without requiring repository access. - } else { - throw new TaskConfigurationError(`AiTask: Invalid model for "${key}"`); - } - } + for (const [key] of modelPlainProperties) { + const model = input[key]; + if (model !== undefined && model !== null && typeof model !== "object") { + throw new TaskConfigurationError( + `AiTask: Invalid model for '${key}' - expected ModelConfig object` + ); } } @@ -253,7 +188,7 @@ export class AiTask< // dataflows can strip some models that are incompatible with the target task // if all of them are stripped, then the task will fail in validateInput - async narrowInput(input: Input): Promise { + async narrowInput(input: Input, registry: ServiceRegistry): Promise { // TODO(str): this is very inefficient, we should cache the results, including intermediate results const inputSchema = this.inputSchema(); if (typeof inputSchema === "boolean") { @@ -266,34 +201,25 @@ export class AiTask< (inputSchema.properties || {}) as Record ).filter(([key, schema]) => schemaFormat(schema)?.startsWith("model:")); if (modelTaskProperties.length > 0) { - const taskModels = await getGlobalModelRepository().findModelsByTask(this.type); + const modelRepo = registry.get(MODEL_REPOSITORY); + const taskModels = await modelRepo.findModelsByTask(this.type); for (const [key, propSchema] of modelTaskProperties) { - let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; - const requestedStrings = requestedModels.filter( - (m: unknown): m is string => typeof m === "string" - ); - const requestedInline = requestedModels.filter( - (m: unknown): m is ModelConfig => typeof m === "object" && m !== null - ); + const requestedModel = input[key]; - const usingStrings = requestedStrings.filter((model: string) => - taskModels?.find((m) => m.model_id === model) - ); - - const usingInline = requestedInline.filter((model: ModelConfig) => { + if (typeof requestedModel === "string") { + // Verify string model ID is compatible + const found = taskModels?.find((m) => m.model_id === requestedModel); + if (!found) { + (input as any)[key] = undefined; + } + } else if (typeof requestedModel === "object" && requestedModel !== null) { + // Verify inline config is compatible + const model = requestedModel as ModelConfig; const tasks = model.tasks; - // Filter out inline configs with explicit incompatible tasks arrays - // This matches the validation logic in validateInput if (Array.isArray(tasks) && tasks.length > 0 && !tasks.includes(this.type)) { - return false; + (input as any)[key] = undefined; } - return true; - }); - - const combined: (string | ModelConfig)[] = [...usingInline, ...usingStrings]; - - // we alter input to be the models that were found for this kind of input - (input as any)[key] = combined.length > 1 ? combined : combined[0]; + } } } return input; diff --git a/packages/ai/src/task/base/AiTaskSchemas.ts b/packages/ai/src/task/base/AiTaskSchemas.ts index 67d2fcaf..06a3bc43 100644 --- a/packages/ai/src/task/base/AiTaskSchemas.ts +++ b/packages/ai/src/task/base/AiTaskSchemas.ts @@ -4,177 +4,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - DataPortSchemaNonBoolean, - FromSchema, - FromSchemaDefaultOptions, - FromSchemaOptions, - JsonSchema, -} from "@workglow/util"; +import { DataPortSchemaNonBoolean, JsonSchema } from "@workglow/util"; import { ModelConfigSchema } from "../../model/ModelSchema"; -export type TypedArray = - | Float64Array - | Float32Array - | Int32Array - | Int16Array - | Int8Array - | Uint32Array - | Uint16Array - | Uint8Array - | Uint8ClampedArray; - -// Type-only value for use in deserialize patterns -const TypedArrayType = null as any as TypedArray; - -const TypedArraySchemaOptions = { - ...FromSchemaDefaultOptions, - deserialize: [ - // { - // pattern: { - // type: "number"; - // "format": "BigInt" | "Float64"; - // }; - // output: bigint; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Float64Array"; - // }; - // output: Float64Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Float32Array"; - // }; - // output: Float32Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Int32Array"; - // }; - // output: Int32Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Int16Array"; - // }; - // output: Int16Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Int8Array"; - // }; - // output: Int8Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Uint8Array"; - // }; - // output: Uint8Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Uint16Array"; - // }; - // output: Uint16Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Uint32Array"; - // }; - // output: Uint32Array; - // }, - // { - // pattern: { type: "array"; items: { type: "number" }; "format": "Uint8ClampedArray" }; - // output: Uint8ClampedArray; - // }, - { - pattern: { format: "TypedArray" }, - output: TypedArrayType, - }, - ], -} as const satisfies FromSchemaOptions; - -export type TypedArraySchemaOptions = typeof TypedArraySchemaOptions; - -export const TypedArraySchema = (annotations: Record = {}) => - ({ - oneOf: [ - { - type: "array", - items: { type: "number", format: "Float64" }, - title: "Float64Array", - description: "A 64-bit floating point array", - format: "Float64Array", - }, - { - type: "array", - items: { type: "number", format: "Float32" }, - title: "Float32Array", - description: "A 32-bit floating point array", - format: "Float32Array", - }, - { - type: "array", - items: { type: "number", format: "Int32" }, - title: "Int32Array", - description: "A 32-bit integer array", - format: "Int32Array", - }, - { - type: "array", - items: { type: "number", format: "Int16" }, - title: "Int16Array", - description: "A 16-bit integer array", - format: "Int16Array", - }, - { - type: "array", - items: { type: "number", format: "Int8" }, - title: "Int8Array", - }, - { - type: "array", - items: { type: "number", format: "Uint8" }, - title: "Uint8Array", - description: "A 8-bit unsigned integer array", - format: "Uint8Array", - }, - { - type: "array", - items: { type: "number", format: "Uint16" }, - title: "Uint16Array", - description: "A 16-bit unsigned integer array", - format: "Uint16Array", - }, - { - type: "array", - items: { type: "number", format: "Uint32" }, - title: "Uint32Array", - description: "A 32-bit unsigned integer array", - format: "Uint32Array", - }, - { - type: "array", - items: { type: "number", format: "Uint8Clamped" }, - title: "Uint8ClampedArray", - description: "A 8-bit unsigned integer array with values clamped to 0-255", - format: "Uint8ClampedArray", - }, - ], - format: "TypedArray", - ...annotations, - }) as const satisfies JsonSchema; - export const TypeLanguage = (annotations: Record = {}) => ({ type: "string", @@ -240,52 +72,11 @@ export function TypeModel< } as const satisfies JsonSchema; } -export const TypeReplicateArray = ( - type: T, - annotations: Record = {} -) => - ({ - oneOf: [type, { type: "array", items: type }], - title: type.title, - description: type.description, - ...(type.format ? { format: type.format } : {}), - ...annotations, - "x-replicate": true, - }) as const; - -export type TypedArrayFromSchema = FromSchema< - SCHEMA, - TypedArraySchemaOptions ->; - -/** - * Removes array types from a union, leaving only non-array types. - * For example, `string | string[]` becomes `string`. - * Used to extract the single-value type from schemas with x-replicate annotation. - * Uses distributive conditional types to filter out arrays from unions. - * Checks for both array types and types with numeric index signatures (FromSchema array output). - * Preserves TypedArray types like Float64Array which also have numeric indices. - */ -type UnwrapArrayUnion = T extends readonly any[] - ? T extends TypedArray - ? T - : never - : number extends keyof T - ? "push" extends keyof T - ? never - : T - : T; - -/** - * Transforms a schema by removing array variants from properties marked with x-replicate. - * Properties with x-replicate use {@link TypeReplicateArray} which creates a union of - * `T | T[]`, and this type extracts just `T`. - */ -export type DeReplicateFromSchema }> = { - [K in keyof S["properties"]]: S["properties"][K] extends { "x-replicate": true } - ? UnwrapArrayUnion> - : TypedArrayFromSchema; -}; +export function TypeSingleOrArray(type: T) { + return { + anyOf: [type, { type: "array", items: type }], + } as const satisfies JsonSchema; +} export type ImageSource = ImageBitmap | OffscreenCanvas | VideoFrame; diff --git a/packages/ai/src/task/base/AiVisionTask.ts b/packages/ai/src/task/base/AiVisionTask.ts index 16fbb7b4..52b5f2e6 100644 --- a/packages/ai/src/task/base/AiVisionTask.ts +++ b/packages/ai/src/task/base/AiVisionTask.ts @@ -19,16 +19,12 @@ export interface AiVisionTaskSingleInput extends TaskInput { model: string | ModelConfig; } -export interface AiVisionArrayTaskInput extends TaskInput { - model: string | ModelConfig | (string | ModelConfig)[]; -} - /** * A base class for AI related tasks that run in a job queue. * Extends the JobQueueTask class to provide LLM-specific functionality. */ export class AiVisionTask< - Input extends AiVisionArrayTaskInput = AiVisionArrayTaskInput, + Input extends AiVisionTaskSingleInput = AiVisionTaskSingleInput, Output extends TaskOutput = TaskOutput, Config extends JobQueueTaskConfig = JobQueueTaskConfig, > extends AiTask { diff --git a/packages/debug/src/console/ConsoleFormatters.ts b/packages/debug/src/console/ConsoleFormatters.ts index 1e55021d..86c53ba9 100644 --- a/packages/debug/src/console/ConsoleFormatters.ts +++ b/packages/debug/src/console/ConsoleFormatters.ts @@ -73,8 +73,7 @@ class WorkflowConsoleFormatter extends ConsoleFormatter { body(obj: unknown, config?: Config): JsonMLElementDef { const body = new JsonMLElement("div"); - const graph: TaskGraph = - obj instanceof TaskGraph ? obj : (obj as Workflow).graph; + const graph: TaskGraph = obj instanceof TaskGraph ? obj : (obj as Workflow).graph; const nodes = body.createStyledList(); const tasks = graph.getTasks(); if (tasks.length) { @@ -314,7 +313,7 @@ class TaskConsoleFormatter extends ConsoleFormatter { const body = new JsonMLElement("div").setStyle("padding-left: 10px;"); const inputs = body.createStyledList("Inputs:"); - const allInboundDataflows = ((config as { graph?: TaskGraph })?.graph)?.getSourceDataflows( + const allInboundDataflows = (config as { graph?: TaskGraph })?.graph?.getSourceDataflows( task.config.id ); @@ -382,7 +381,6 @@ class TaskConsoleFormatter extends ConsoleFormatter { const taskConfig = body.createStyledList("Config:"); for (const [key, value] of Object.entries(task.config)) { if (value === undefined) continue; - if (key == "provenance") continue; const li = taskConfig.createListItem("", "padding-left: 20px;"); li.inputText(`${key}: `); li.createValueObject(value); @@ -750,7 +748,10 @@ interface NodeWithConfig { function computeLayout( graph: DirectedAcyclicGraph, canvasWidth: number -): { readonly positions: { readonly [id: string]: { readonly x: number; readonly y: number } }; readonly requiredHeight: number } { +): { + readonly positions: { readonly [id: string]: { readonly x: number; readonly y: number } }; + readonly requiredHeight: number; +} { const positions: { [id: string]: { x: number; y: number } } = {}; const layers: Map = new Map(); const depths: { [id: string]: number } = {}; @@ -873,4 +874,3 @@ export function installDevToolsFormatters(): void { new DAGConsoleFormatter() ); } - diff --git a/packages/storage/README.md b/packages/storage/README.md index 0b670966..254f73fd 100644 --- a/packages/storage/README.md +++ b/packages/storage/README.md @@ -28,6 +28,7 @@ Modular storage solutions for Workglow.AI platform with multiple backend impleme - [Node.js Environment](#nodejs-environment) - [Bun Environment](#bun-environment) - [Advanced Features](#advanced-features) + - [Repository Registry](#repository-registry) - [Event-Driven Architecture](#event-driven-architecture) - [Compound Primary Keys](#compound-primary-keys) - [Custom File Layout (KV on filesystem)](#custom-file-layout-kv-on-filesystem) @@ -521,6 +522,95 @@ const cloudData = new SupabaseTabularRepository(supabase, "items", ItemSchema, [ ## Advanced Features +### Repository Registry + +Repositories can be registered globally by ID, allowing tasks to reference them by name rather than passing direct instances. This is useful for configuring repositories once at application startup and referencing them throughout your task graphs. + +#### Registering Repositories + +```typescript +import { + registerTabularRepository, + getTabularRepository, + InMemoryTabularRepository, +} from "@workglow/storage"; + +// Define your schema +const userSchema = { + type: "object", + properties: { + id: { type: "string" }, + name: { type: "string" }, + email: { type: "string" }, + }, + required: ["id", "name", "email"], + additionalProperties: false, +} as const; + +// Create and register a repository +const userRepo = new InMemoryTabularRepository(userSchema, ["id"] as const); +registerTabularRepository("users", userRepo); + +// Later, retrieve the repository by ID +const repo = getTabularRepository("users"); +``` + +#### Using Repositories in Tasks + +When using repositories with tasks, you can pass either the repository ID or a direct instance. The TaskRunner automatically resolves string IDs using the registry. + +```typescript +import { TypeTabularRepository } from "@workglow/storage"; + +// In your task's input schema, use TypeTabularRepository +static inputSchema() { + return { + type: "object", + properties: { + dataSource: TypeTabularRepository({ + title: "User Repository", + description: "Repository containing user records", + }), + }, + required: ["dataSource"], + }; +} + +// Both approaches work: +await task.run({ dataSource: "users" }); // Resolved from registry +await task.run({ dataSource: userRepoInstance }); // Direct instance +``` + +#### Schema Helper Functions + +The package provides schema helper functions for defining repository inputs with proper format annotations: + +```typescript +import { + TypeTabularRepository, + TypeVectorRepository, + TypeDocumentRepository, +} from "@workglow/storage"; + +// Tabular repository (format: "repository:tabular") +const tabularSchema = TypeTabularRepository({ + title: "Data Source", + description: "Tabular data repository", +}); + +// Vector repository (format: "repository:document-node-vector") +const vectorSchema = TypeVectorRepository({ + title: "Embeddings Store", + description: "Vector embeddings repository", +}); + +// Document repository (format: "repository:document") +const docSchema = TypeDocumentRepository({ + title: "Document Store", + description: "Document storage repository", +}); +``` + ### Event-Driven Architecture All storage implementations support event emission for monitoring and reactive programming: diff --git a/packages/storage/src/common.ts b/packages/storage/src/common.ts index 21bdb392..6ed0fc5e 100644 --- a/packages/storage/src/common.ts +++ b/packages/storage/src/common.ts @@ -7,7 +7,9 @@ export * from "./tabular/CachedTabularRepository"; export * from "./tabular/InMemoryTabularRepository"; export * from "./tabular/ITabularRepository"; -export * from "./tabular/TabularRepository"; +export * from "./tabular/TabularRepositoryRegistry"; + +export * from "./util/RepositorySchema"; export * from "./kv/IKvRepository"; export * from "./kv/InMemoryKvRepository"; @@ -22,3 +24,11 @@ export * from "./limiter/IRateLimiterStorage"; export * from "./util/HybridSubscriptionManager"; export * from "./util/PollingSubscriptionManager"; + +export * from "./document/Document"; +export * from "./document/DocumentNode"; +export * from "./document/DocumentRepository"; +export * from "./document/DocumentRepositoryRegistry"; +export * from "./document/DocumentSchema"; +export * from "./document/DocumentStorageSchema"; +export * from "./document/StructuralParser"; diff --git a/packages/storage/src/document/Document.ts b/packages/storage/src/document/Document.ts new file mode 100644 index 00000000..d0351f4d --- /dev/null +++ b/packages/storage/src/document/Document.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ChunkNode, DocumentMetadata, DocumentNode } from "./DocumentSchema"; + +/** + * Document represents a hierarchical document with chunks + * + * Key features: + * - Single source-of-truth tree structure (root node) + * - Single set of chunks + * - Separate persistence for document structure vs vectors + */ +export class Document { + public readonly doc_id: string; + public readonly metadata: DocumentMetadata; + public readonly root: DocumentNode; + private chunks: ChunkNode[]; + + constructor( + doc_id: string, + root: DocumentNode, + metadata: DocumentMetadata, + chunks: ChunkNode[] = [] + ) { + this.doc_id = doc_id; + this.root = root; + this.metadata = metadata; + this.chunks = chunks || []; + } + + /** + * Set chunks for the document + */ + setChunks(chunks: ChunkNode[]): void { + this.chunks = chunks; + } + + /** + * Get all chunks + */ + getChunks(): ChunkNode[] { + return this.chunks; + } + + /** + * Find chunks by nodeId + */ + findChunksByNodeId(nodeId: string): ChunkNode[] { + return this.chunks.filter((chunk) => chunk.nodePath.includes(nodeId)); + } + + /** + * Serialize to JSON + */ + toJSON(): { + doc_id: string; + metadata: DocumentMetadata; + root: DocumentNode; + chunks: ChunkNode[]; + } { + return { + doc_id: this.doc_id, + metadata: this.metadata, + root: this.root, + chunks: this.chunks, + }; + } + + /** + * Deserialize from JSON + */ + static fromJSON(json: string): Document { + const obj = JSON.parse(json); + const doc = new Document(obj.doc_id, obj.root, obj.metadata, obj.chunks); + return doc; + } +} diff --git a/packages/storage/src/document/DocumentNode.ts b/packages/storage/src/document/DocumentNode.ts new file mode 100644 index 00000000..a769fdbe --- /dev/null +++ b/packages/storage/src/document/DocumentNode.ts @@ -0,0 +1,134 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { sha256 } from "@workglow/util"; + +import { + NodeKind, + type DocumentNode, + type DocumentRootNode, + type NodeKind as NodeKindType, + type NodeRange, + type SectionNode, + type TopicNode, +} from "./DocumentSchema"; + +/** + * Utility functions for ID generation + */ +export class NodeIdGenerator { + /** + * Generate doc_id from source URI and content hash + */ + static async generateDocId(sourceUri: string, content: string): Promise { + const contentHash = await sha256(content); + const combined = `${sourceUri}|${contentHash}`; + const hash = await sha256(combined); + return `doc_${hash.substring(0, 16)}`; + } + + /** + * Generate nodeId for structural nodes (document, section) + */ + static async generateStructuralNodeId( + doc_id: string, + kind: NodeKindType, + range: NodeRange + ): Promise { + const combined = `${doc_id}|${kind}|${range.startOffset}:${range.endOffset}`; + const hash = await sha256(combined); + return `node_${hash.substring(0, 16)}`; + } + + /** + * Generate nodeId for child nodes (paragraph, topic) + */ + static async generateChildNodeId(parentNodeId: string, ordinal: number): Promise { + const combined = `${parentNodeId}|${ordinal}`; + const hash = await sha256(combined); + return `node_${hash.substring(0, 16)}`; + } + + /** + * Generate chunkId + */ + static async generateChunkId( + doc_id: string, + leafNodeId: string, + chunkOrdinal: number + ): Promise { + const combined = `${doc_id}|${leafNodeId}|${chunkOrdinal}`; + const hash = await sha256(combined); + return `chunk_${hash.substring(0, 16)}`; + } +} + +/** + * Approximate token counting (v1) + */ +export function estimateTokens(text: string): number { + return Math.ceil(text.length / 4); +} + +/** + * Helper to check if a node has children + */ +export function hasChildren( + node: DocumentNode +): node is DocumentRootNode | SectionNode | TopicNode { + return ( + node.kind === NodeKind.DOCUMENT || + node.kind === NodeKind.SECTION || + node.kind === NodeKind.TOPIC + ); +} + +/** + * Helper to get all children of a node + */ +export function getChildren(node: DocumentNode): DocumentNode[] { + if (hasChildren(node)) { + return node.children; + } + return []; +} + +/** + * Traverse document tree depth-first + */ +export function* traverseDepthFirst(node: DocumentNode): Generator { + yield node; + if (hasChildren(node)) { + for (const child of node.children) { + yield* traverseDepthFirst(child); + } + } +} + +/** + * Get node path from root to target node + */ +export function getNodePath(root: DocumentNode, targetNodeId: string): string[] | undefined { + const path: string[] = []; + + function search(node: DocumentNode): boolean { + path.push(node.nodeId); + if (node.nodeId === targetNodeId) { + return true; + } + if (hasChildren(node)) { + for (const child of node.children) { + if (search(child)) { + return true; + } + } + } + path.pop(); + return false; + } + + return search(root) ? path : undefined; +} diff --git a/packages/storage/src/document/DocumentRepository.ts b/packages/storage/src/document/DocumentRepository.ts new file mode 100644 index 00000000..8f0cafc7 --- /dev/null +++ b/packages/storage/src/document/DocumentRepository.ts @@ -0,0 +1,222 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TypedArray } from "@workglow/util"; +import { DocumentNodeVector } from "../document-node-vector/DocumentNodeVectorSchema"; +import type { + AnyDocumentNodeVectorRepository, + VectorSearchOptions, +} from "../document-node-vector/IDocumentNodeVectorRepository"; +import type { ITabularRepository } from "../tabular/ITabularRepository"; +import { Document } from "./Document"; +import { ChunkNode, DocumentNode } from "./DocumentSchema"; +import { + DocumentStorageEntity, + DocumentStorageKey, + DocumentStorageSchema, +} from "./DocumentStorageSchema"; +/** + * Document repository that uses TabularStorage for persistence and VectorStorage for search. + * This is a unified implementation that composes storage backends rather than using + * inheritance/interface patterns. + */ +export class DocumentRepository { + private tabularStorage: ITabularRepository< + DocumentStorageSchema, + DocumentStorageKey, + DocumentStorageEntity + >; + private vectorStorage?: AnyDocumentNodeVectorRepository; + + /** + * Creates a new DocumentRepository instance. + * + * @param tabularStorage - Pre-initialized tabular storage for document persistence + * @param vectorStorage - Pre-initialized vector storage for chunk similarity search + * + * @example + * ```typescript + * const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["doc_id"]); + * await tabularStorage.setupDatabase(); + * + * const vectorStorage = new InMemoryVectorRepository(); + * await vectorStorage.setupDatabase(); + * + * const docRepo = new DocumentRepository(tabularStorage, vectorStorage); + * ``` + */ + constructor( + tabularStorage: ITabularRepository< + typeof DocumentStorageSchema, + ["doc_id"], + DocumentStorageEntity + >, + vectorStorage?: AnyDocumentNodeVectorRepository + ) { + this.tabularStorage = tabularStorage; + this.vectorStorage = vectorStorage; + } + + /** + * Upsert a document + */ + async upsert(document: Document): Promise { + const serialized = JSON.stringify(document.toJSON ? document.toJSON() : document); + await this.tabularStorage.put({ + doc_id: document.doc_id, + data: serialized, + }); + } + + /** + * Get a document by ID + */ + async get(doc_id: string): Promise { + const entity = await this.tabularStorage.get({ doc_id: doc_id }); + if (!entity) { + return undefined; + } + return Document.fromJSON(entity.data); + } + + /** + * Delete a document + */ + async delete(doc_id: string): Promise { + await this.tabularStorage.delete({ doc_id: doc_id }); + } + + /** + * Get a specific node by ID + */ + async getNode(doc_id: string, nodeId: string): Promise { + const doc = await this.get(doc_id); + if (!doc) { + return undefined; + } + + // Traverse tree to find node + const traverse = (node: any): any => { + if (node.nodeId === nodeId) { + return node; + } + if (node.children && Array.isArray(node.children)) { + for (const child of node.children) { + const found = traverse(child); + if (found) return found; + } + } + return undefined; + }; + + return traverse(doc.root); + } + + /** + * Get ancestors of a node (from root to node) + */ + async getAncestors(doc_id: string, nodeId: string): Promise { + const doc = await this.get(doc_id); + if (!doc) { + return []; + } + + // Get path from root to target node + const path: string[] = []; + const findPath = (node: any): boolean => { + path.push(node.nodeId); + if (node.nodeId === nodeId) { + return true; + } + if (node.children && Array.isArray(node.children)) { + for (const child of node.children) { + if (findPath(child)) { + return true; + } + } + } + path.pop(); + return false; + }; + + if (!findPath(doc.root)) { + return []; + } + + // Collect nodes along the path + const ancestors: any[] = []; + let currentNode: any = doc.root; + ancestors.push(currentNode); + + for (let i = 1; i < path.length; i++) { + const targetId = path[i]; + if (currentNode.children && Array.isArray(currentNode.children)) { + const found = currentNode.children.find((child: any) => child.nodeId === targetId); + if (found) { + currentNode = found; + ancestors.push(currentNode); + } else { + break; + } + } else { + break; + } + } + + return ancestors; + } + + /** + * Get chunks for a document + */ + async getChunks(doc_id: string): Promise { + const doc = await this.get(doc_id); + if (!doc) { + return []; + } + return doc.getChunks(); + } + + /** + * Find chunks that contain a specific nodeId in their path + */ + async findChunksByNodeId(doc_id: string, nodeId: string): Promise { + const doc = await this.get(doc_id); + if (!doc) { + return []; + } + if (doc.findChunksByNodeId) { + return doc.findChunksByNodeId(nodeId); + } + // Fallback implementation + const chunks = doc.getChunks(); + return chunks.filter((chunk) => chunk.nodePath && chunk.nodePath.includes(nodeId)); + } + + /** + * List all document IDs + */ + async list(): Promise { + const entities = await this.tabularStorage.getAll(); + if (!entities) { + return []; + } + return entities.map((e) => e.doc_id); + } + + /** + * Search for similar vectors using the vector storage + * @param query - Query vector to search for + * @param options - Search options (topK, filter, scoreThreshold) + * @returns Array of search results sorted by similarity + */ + async search( + query: TypedArray, + options?: VectorSearchOptions> + ): Promise, TypedArray>>> { + return this.vectorStorage?.similaritySearch(query, options) || []; + } +} diff --git a/packages/storage/src/document/DocumentRepositoryRegistry.ts b/packages/storage/src/document/DocumentRepositoryRegistry.ts new file mode 100644 index 00000000..0f011539 --- /dev/null +++ b/packages/storage/src/document/DocumentRepositoryRegistry.ts @@ -0,0 +1,79 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import type { DocumentRepository } from "./DocumentRepository"; + +/** + * Service token for the document repository registry + * Maps repository IDs to DocumentRepository instances + */ +export const DOCUMENT_REPOSITORIES = + createServiceToken>("document.repositories"); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(DOCUMENT_REPOSITORIES)) { + globalServiceRegistry.register( + DOCUMENT_REPOSITORIES, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global document repository registry + * @returns Map of document repository ID to instance + */ +export function getGlobalDocumentRepositories(): Map { + return globalServiceRegistry.get(DOCUMENT_REPOSITORIES); +} + +/** + * Registers a document repository globally by ID + * @param id The unique identifier for this repository + * @param repository The repository instance to register + */ +export function registerDocumentRepository(id: string, repository: DocumentRepository): void { + const repos = getGlobalDocumentRepositories(); + repos.set(id, repository); +} + +/** + * Gets a document repository by ID from the global registry + * @param id The repository identifier + * @returns The repository instance or undefined if not found + */ +export function getDocumentRepository(id: string): DocumentRepository | undefined { + return getGlobalDocumentRepositories().get(id); +} + +/** + * Resolves a repository ID to a DocumentRepository from the registry. + * Used by the input resolver system. + */ +async function resolveDocumentRepositoryFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise { + const repos = registry.has(DOCUMENT_REPOSITORIES) + ? registry.get>(DOCUMENT_REPOSITORIES) + : getGlobalDocumentRepositories(); + + const repo = repos.get(id); + if (!repo) { + throw new Error(`Document repository "${id}" not found in registry`); + } + return repo; +} + +// Register the repository resolver for format: "repository:document" +registerInputResolver("repository:document", resolveDocumentRepositoryFromRegistry); diff --git a/packages/storage/src/document/DocumentSchema.ts b/packages/storage/src/document/DocumentSchema.ts new file mode 100644 index 00000000..f7b52458 --- /dev/null +++ b/packages/storage/src/document/DocumentSchema.ts @@ -0,0 +1,630 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { DataPortSchema, FromSchema, JsonSchema } from "@workglow/util"; + +/** + * Node kind discriminator for hierarchical document structure + */ +export const NodeKind = { + DOCUMENT: "document", + SECTION: "section", + PARAGRAPH: "paragraph", + SENTENCE: "sentence", + TOPIC: "topic", +} as const; + +export type NodeKind = (typeof NodeKind)[keyof typeof NodeKind]; + +// ============================================================================= +// Schema Definitions +// ============================================================================= + +/** + * Schema for source range of a node (character offsets) + */ +export const NodeRangeSchema = { + type: "object", + properties: { + startOffset: { + type: "integer", + title: "Start Offset", + description: "Starting character offset", + }, + endOffset: { + type: "integer", + title: "End Offset", + description: "Ending character offset", + }, + }, + required: ["startOffset", "endOffset"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type NodeRange = FromSchema; + +/** + * Schema for named entity extracted from text + */ +export const EntitySchema = { + type: "object", + properties: { + text: { + type: "string", + title: "Text", + description: "Entity text", + }, + type: { + type: "string", + title: "Type", + description: "Entity type (e.g., PERSON, ORG, LOC)", + }, + score: { + type: "number", + title: "Score", + description: "Confidence score", + }, + }, + required: ["text", "type", "score"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type Entity = FromSchema; + +/** + * Schema for enrichment data attached to a node + */ +export const NodeEnrichmentSchema = { + type: "object", + properties: { + summary: { + type: "string", + title: "Summary", + description: "Summary of the node content", + }, + entities: { + type: "array", + items: EntitySchema, + title: "Entities", + description: "Named entities extracted from the node", + }, + keywords: { + type: "array", + items: { type: "string" }, + title: "Keywords", + description: "Keywords associated with the node", + }, + }, + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type NodeEnrichment = FromSchema; + +/** + * Schema for base document node fields (used for runtime validation) + * Note: Individual node types and DocumentNode union are defined as interfaces + * below because FromSchema cannot properly infer recursive discriminated unions. + */ +export const DocumentNodeBaseSchema = { + type: "object", + properties: { + nodeId: { + type: "string", + title: "Node ID", + description: "Unique identifier for this node", + }, + kind: { + type: "string", + enum: Object.values(NodeKind), + title: "Kind", + description: "Node type discriminator", + }, + range: NodeRangeSchema, + text: { + type: "string", + title: "Text", + description: "Text content of the node", + }, + enrichment: NodeEnrichmentSchema, + }, + required: ["nodeId", "kind", "range", "text"], + additionalProperties: true, +} as const satisfies DataPortSchema; + +/** + * Schema for document node (generic, for runtime validation) + * This is a simplified schema for task input/output validation. + * The actual TypeScript types use a proper discriminated union. + */ +export const DocumentNodeSchema = { + type: "object", + title: "Document Node", + description: "A node in the hierarchical document tree", + properties: { + ...DocumentNodeBaseSchema.properties, + level: { + type: "integer", + title: "Level", + description: "Header level for section nodes", + }, + title: { + type: "string", + title: "Title", + description: "Section title", + }, + children: { + type: "array", + title: "Children", + description: "Child nodes", + }, + }, + required: [...DocumentNodeBaseSchema.required], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for paragraph node + */ +export const ParagraphNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.PARAGRAPH, + title: "Kind", + description: "Node type discriminator", + }, + }, + required: [...DocumentNodeBaseSchema.required], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for sentence node + */ +export const SentenceNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.SENTENCE, + title: "Kind", + description: "Node type discriminator", + }, + }, + required: [...DocumentNodeBaseSchema.required], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for section node + */ +export const SectionNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.SECTION, + title: "Kind", + description: "Node type discriminator", + }, + level: { + type: "integer", + minimum: 1, + maximum: 6, + title: "Level", + description: "Header level (1-6 for markdown)", + }, + title: { + type: "string", + title: "Title", + description: "Section title", + }, + children: { + type: "array", + items: DocumentNodeSchema, + title: "Children", + description: "Child nodes", + }, + }, + required: [...DocumentNodeBaseSchema.required, "level", "title", "children"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for topic node + */ +export const TopicNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.TOPIC, + title: "Kind", + description: "Node type discriminator", + }, + children: { + type: "array", + items: DocumentNodeSchema, + title: "Children", + description: "Child nodes", + }, + }, + required: [...DocumentNodeBaseSchema.required, "children"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for document root node + */ +export const DocumentRootNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.DOCUMENT, + title: "Kind", + description: "Node type discriminator", + }, + title: { + type: "string", + title: "Title", + description: "Document title", + }, + children: { + type: "array", + items: DocumentNodeSchema, + title: "Children", + description: "Child nodes", + }, + }, + required: [...DocumentNodeBaseSchema.required, "title", "children"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +// ============================================================================= +// Manually-defined interfaces for recursive discriminated union types +// These provide better TypeScript inference than FromSchema for recursive types +// ============================================================================= + +/** + * Base document node fields + */ +interface DocumentNodeBase { + readonly nodeId: string; + readonly kind: NodeKind; + readonly range: NodeRange; + readonly text: string; + readonly enrichment?: NodeEnrichment; +} + +/** + * Document root node + */ +export interface DocumentRootNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.DOCUMENT; + readonly title: string; + readonly children: DocumentNode[]; +} + +/** + * Section node (from markdown headers or structural divisions) + */ +export interface SectionNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.SECTION; + readonly level: number; + readonly title: string; + readonly children: DocumentNode[]; +} + +/** + * Paragraph node + */ +export interface ParagraphNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.PARAGRAPH; +} + +/** + * Sentence node (optional fine-grained segmentation) + */ +export interface SentenceNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.SENTENCE; +} + +/** + * Topic segment node (from TopicSegmenter) + */ +export interface TopicNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.TOPIC; + readonly children: DocumentNode[]; +} + +/** + * Discriminated union of all document node types + */ +export type DocumentNode = + | DocumentRootNode + | SectionNode + | ParagraphNode + | SentenceNode + | TopicNode; + +// ============================================================================= +// Token Budget and Chunk Schemas +// ============================================================================= + +/** + * Schema for token budget configuration + */ +export const TokenBudgetSchema = { + type: "object", + properties: { + maxTokensPerChunk: { + type: "integer", + title: "Max Tokens Per Chunk", + description: "Maximum tokens allowed per chunk", + }, + overlapTokens: { + type: "integer", + title: "Overlap Tokens", + description: "Number of tokens to overlap between chunks", + }, + reservedTokens: { + type: "integer", + title: "Reserved Tokens", + description: "Tokens reserved for metadata or context", + }, + }, + required: ["maxTokensPerChunk", "overlapTokens", "reservedTokens"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type TokenBudget = FromSchema; + +/** + * Schema for chunk enrichment + */ +export const ChunkEnrichmentSchema = { + type: "object", + properties: { + summary: { + type: "string", + title: "Summary", + description: "Summary of the chunk content", + }, + entities: { + type: "array", + items: EntitySchema, + title: "Entities", + description: "Named entities extracted from the chunk", + }, + }, + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type ChunkEnrichment = FromSchema; + +/** + * Schema for chunk node (output of HierarchicalChunker) + */ +export const ChunkNodeSchema = () => + ({ + type: "object", + properties: { + chunkId: { + type: "string", + title: "Chunk ID", + description: "Unique identifier for this chunk", + }, + doc_id: { + type: "string", + title: "Document ID", + description: "ID of the parent document", + }, + text: { + type: "string", + title: "Text", + description: "Text content of the chunk", + }, + nodePath: { + type: "array", + items: { type: "string" }, + title: "Node Path", + description: "Node IDs from root to leaf", + }, + depth: { + type: "integer", + title: "Depth", + description: "Depth in the document tree", + }, + enrichment: ChunkEnrichmentSchema, + }, + required: ["chunkId", "doc_id", "text", "nodePath", "depth"], + additionalProperties: false, + }) as const satisfies DataPortSchema; + +export type ChunkNode = FromSchema>; + +// ============================================================================= +// Chunk Metadata Schemas (for vector store) +// ============================================================================= + +/** + * Schema for chunk metadata stored in vector database + * This is the metadata output from ChunkToVectorTask + */ +export const ChunkMetadataSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "ID of the parent document", + }, + chunkId: { + type: "string", + title: "Chunk ID", + description: "Unique identifier for this chunk", + }, + leafNodeId: { + type: "string", + title: "Leaf Node ID", + description: "ID of the leaf node this chunk belongs to", + }, + depth: { + type: "integer", + title: "Depth", + description: "Depth in the document tree", + }, + text: { + type: "string", + title: "Text", + description: "Text content of the chunk", + }, + nodePath: { + type: "array", + items: { type: "string" }, + title: "Node Path", + description: "Node IDs from root to leaf", + }, + summary: { + type: "string", + title: "Summary", + description: "Summary of the chunk content", + }, + entities: { + type: "array", + items: EntitySchema, + title: "Entities", + description: "Named entities extracted from the chunk", + }, + }, + required: ["doc_id", "chunkId", "leafNodeId", "depth", "text", "nodePath"], + additionalProperties: true, +} as const satisfies DataPortSchema; + +export type ChunkMetadata = FromSchema; + +/** + * Schema for chunk metadata array (for use in task schemas) + */ +export const ChunkMetadataArraySchema = { + type: "array", + items: ChunkMetadataSchema, + title: "Chunk Metadata", + description: "Metadata for each chunk", +} as const satisfies JsonSchema; + +/** + * Schema for enriched chunk metadata (after HierarchyJoinTask) + * Extends ChunkMetadata with hierarchy information from document repository + */ +export const EnrichedChunkMetadataSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "ID of the parent document", + }, + chunkId: { + type: "string", + title: "Chunk ID", + description: "Unique identifier for this chunk", + }, + leafNodeId: { + type: "string", + title: "Leaf Node ID", + description: "ID of the leaf node this chunk belongs to", + }, + depth: { + type: "integer", + title: "Depth", + description: "Depth in the document tree", + }, + text: { + type: "string", + title: "Text", + description: "Text content of the chunk", + }, + nodePath: { + type: "array", + items: { type: "string" }, + title: "Node Path", + description: "Node IDs from root to leaf", + }, + summary: { + type: "string", + title: "Summary", + description: "Summary of the chunk content", + }, + entities: { + type: "array", + items: EntitySchema, + title: "Entities", + description: "Named entities (rolled up from hierarchy)", + }, + parentSummaries: { + type: "array", + items: { type: "string" }, + title: "Parent Summaries", + description: "Summaries from ancestor nodes", + }, + sectionTitles: { + type: "array", + items: { type: "string" }, + title: "Section Titles", + description: "Titles of ancestor section nodes", + }, + }, + required: ["doc_id", "chunkId", "leafNodeId", "depth", "text", "nodePath"], + additionalProperties: true, +} as const satisfies DataPortSchema; + +export type EnrichedChunkMetadata = FromSchema; + +/** + * Schema for enriched chunk metadata array (for use in task schemas) + */ +export const EnrichedChunkMetadataArraySchema = { + type: "array", + items: EnrichedChunkMetadataSchema, + title: "Enriched Metadata", + description: "Metadata enriched with hierarchy information", +} as const satisfies JsonSchema; + +/** + * Schema for document metadata + */ +export const DocumentMetadataSchema = { + type: "object", + properties: { + title: { + type: "string", + title: "Title", + description: "Document title", + }, + sourceUri: { + type: "string", + title: "Source URI", + description: "Original source URI of the document", + }, + createdAt: { + type: "string", + title: "Created At", + description: "ISO timestamp of creation", + }, + }, + required: ["title"], + additionalProperties: true, +} as const satisfies DataPortSchema; + +export type DocumentMetadata = FromSchema; diff --git a/packages/storage/src/document/DocumentStorageSchema.ts b/packages/storage/src/document/DocumentStorageSchema.ts new file mode 100644 index 00000000..40e518c6 --- /dev/null +++ b/packages/storage/src/document/DocumentStorageSchema.ts @@ -0,0 +1,43 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + TypedArraySchemaOptions, + type DataPortSchemaObject, + type FromSchema, +} from "@workglow/util"; + +/** + * Schema for storing documents in tabular storage + */ +export const DocumentStorageSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "Unique identifier for the document", + }, + data: { + type: "string", + title: "Document Data", + description: "JSON-serialized document", + }, + metadata: { + type: "object", + title: "Metadata", + description: "Metadata of the document", + }, + }, + required: ["doc_id", "data"], + additionalProperties: true, +} as const satisfies DataPortSchemaObject; +export type DocumentStorageSchema = typeof DocumentStorageSchema; + +export const DocumentStorageKey = ["doc_id"] as const; +export type DocumentStorageKey = typeof DocumentStorageKey; + +export type DocumentStorageEntity = FromSchema; diff --git a/packages/storage/src/document/StructuralParser.ts b/packages/storage/src/document/StructuralParser.ts new file mode 100644 index 00000000..3f66033b --- /dev/null +++ b/packages/storage/src/document/StructuralParser.ts @@ -0,0 +1,254 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { NodeIdGenerator } from "./DocumentNode"; +import { + type DocumentRootNode, + NodeKind, + type ParagraphNode, + type SectionNode, +} from "./DocumentSchema"; + +/** + * Parse markdown into a hierarchical DocumentNode tree + */ +export class StructuralParser { + /** + * Parse markdown text into a hierarchical document tree + */ + static async parseMarkdown( + doc_id: string, + text: string, + title: string + ): Promise { + const lines = text.split("\n"); + let currentOffset = 0; + + const root: DocumentRootNode = { + nodeId: await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.DOCUMENT, { + startOffset: 0, + endOffset: text.length, + }), + kind: NodeKind.DOCUMENT, + range: { startOffset: 0, endOffset: text.length }, + text: title, + title, + children: [], + }; + + let currentParentStack: Array = [root]; + let textBuffer: string[] = []; + let textBufferStartOffset = 0; + + const flushTextBuffer = async () => { + if (textBuffer.length > 0) { + const content = textBuffer.join("\n").trim(); + if (content) { + const paragraphStartOffset = textBufferStartOffset; + const paragraphEndOffset = currentOffset; + + const paragraph: ParagraphNode = { + nodeId: await NodeIdGenerator.generateChildNodeId( + currentParentStack[currentParentStack.length - 1].nodeId, + currentParentStack[currentParentStack.length - 1].children.length + ), + kind: NodeKind.PARAGRAPH, + range: { + startOffset: paragraphStartOffset, + endOffset: paragraphEndOffset, + }, + text: content, + }; + + currentParentStack[currentParentStack.length - 1].children.push(paragraph); + } + textBuffer = []; + } + }; + + for (const line of lines) { + const lineLength = line.length + 1; // +1 for newline + + // Check if line is a header + const headerMatch = line.match(/^(#{1,6})\s+(.*)$/); + if (headerMatch) { + await flushTextBuffer(); + + const level = headerMatch[1].length; + const headerTitle = headerMatch[2]; + + // Pop stack until we find appropriate parent + while ( + currentParentStack.length > 1 && + currentParentStack[currentParentStack.length - 1].kind === NodeKind.SECTION && + (currentParentStack[currentParentStack.length - 1] as SectionNode).level >= level + ) { + const poppedSection = currentParentStack.pop() as SectionNode; + // Update endOffset of popped section + const updatedSection: SectionNode = { + ...poppedSection, + range: { + ...poppedSection.range, + endOffset: currentOffset, + }, + }; + // Replace in parent's children + const parent = currentParentStack[currentParentStack.length - 1]; + parent.children[parent.children.length - 1] = updatedSection; + } + + const sectionStartOffset = currentOffset; + const section: SectionNode = { + nodeId: await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.SECTION, { + startOffset: sectionStartOffset, + endOffset: text.length, // Will be updated when section closes + }), + kind: NodeKind.SECTION, + level, + title: headerTitle, + range: { + startOffset: sectionStartOffset, + endOffset: text.length, + }, + text: headerTitle, + children: [], + }; + + currentParentStack[currentParentStack.length - 1].children.push(section); + currentParentStack.push(section); + } else { + // Accumulate text + if (textBuffer.length === 0) { + textBufferStartOffset = currentOffset; + } + textBuffer.push(line); + } + + currentOffset += lineLength; + } + + await flushTextBuffer(); + + // Close any remaining sections + while (currentParentStack.length > 1) { + const section = currentParentStack.pop() as SectionNode; + const updatedSection: SectionNode = { + ...section, + range: { + ...section.range, + endOffset: text.length, + }, + }; + const parent = currentParentStack[currentParentStack.length - 1]; + parent.children[parent.children.length - 1] = updatedSection; + } + + return root; + } + + /** + * Parse plain text into a hierarchical document tree + * Splits by double newlines to create paragraphs + */ + static async parsePlainText( + doc_id: string, + text: string, + title: string + ): Promise { + const root: DocumentRootNode = { + nodeId: await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.DOCUMENT, { + startOffset: 0, + endOffset: text.length, + }), + kind: NodeKind.DOCUMENT, + range: { startOffset: 0, endOffset: text.length }, + text: title, + title, + children: [], + }; + + // Split by double newlines to get paragraphs while tracking offsets + const paragraphRegex = /\n\s*\n/g; + let lastIndex = 0; + let paragraphIndex = 0; + let match: RegExpExecArray | null; + + while ((match = paragraphRegex.exec(text)) !== null) { + const rawParagraph = text.slice(lastIndex, match.index); + const paragraphText = rawParagraph.trim(); + + if (paragraphText.length > 0) { + const trimmedRelativeStart = rawParagraph.indexOf(paragraphText); + const startOffset = lastIndex + trimmedRelativeStart; + const endOffset = startOffset + paragraphText.length; + + const paragraph: ParagraphNode = { + nodeId: await NodeIdGenerator.generateChildNodeId(root.nodeId, paragraphIndex), + kind: NodeKind.PARAGRAPH, + range: { + startOffset, + endOffset, + }, + text: paragraphText, + }; + + root.children.push(paragraph); + paragraphIndex++; + } + + lastIndex = paragraphRegex.lastIndex; + } + + // Handle trailing paragraph after the last double newline, if any + if (lastIndex < text.length) { + const rawParagraph = text.slice(lastIndex); + const paragraphText = rawParagraph.trim(); + + if (paragraphText.length > 0) { + const trimmedRelativeStart = rawParagraph.indexOf(paragraphText); + const startOffset = lastIndex + trimmedRelativeStart; + const endOffset = startOffset + paragraphText.length; + + const paragraph: ParagraphNode = { + nodeId: await NodeIdGenerator.generateChildNodeId(root.nodeId, paragraphIndex), + kind: NodeKind.PARAGRAPH, + range: { + startOffset, + endOffset, + }, + text: paragraphText, + }; + + root.children.push(paragraph); + } + } + return root; + } + + /** + * Auto-detect format and parse + */ + static parse( + doc_id: string, + text: string, + title: string, + format?: "markdown" | "text" + ): Promise { + if (format === "markdown" || (!format && this.looksLikeMarkdown(text))) { + return this.parseMarkdown(doc_id, text, title); + } + return this.parsePlainText(doc_id, text, title); + } + + /** + * Check if text contains markdown header patterns + * Looks for lines starting with 1-6 hash symbols followed by whitespace + */ + private static looksLikeMarkdown(text: string): boolean { + // Check for markdown header patterns: line starting with # followed by space + return /^#{1,6}\s/m.test(text); + } +} diff --git a/packages/storage/src/tabular/ITabularRepository.ts b/packages/storage/src/tabular/ITabularRepository.ts index d7552e5c..1064ed7c 100644 --- a/packages/storage/src/tabular/ITabularRepository.ts +++ b/packages/storage/src/tabular/ITabularRepository.ts @@ -4,7 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DataPortSchemaObject, EventParameters, FromSchema } from "@workglow/util"; +import { + DataPortSchemaObject, + EventParameters, + FromSchema, + TypedArraySchemaOptions, +} from "@workglow/util"; // Generic type for possible value types in the repository export type ValueOptionType = string | number | bigint | boolean | null | Uint8Array; @@ -56,7 +61,7 @@ export interface TabularSubscribeOptions { } // Type definitions for specialized string types -export type uuid4 = string; +export type uuid4 = string & { readonly __brand: "uuid4" }; export type JSONValue = | string | number @@ -109,6 +114,16 @@ export function isSearchCondition(value: unknown): value is SearchCondition, +> = Entity extends any ? Pick> : never; + /** * Interface defining the contract for tabular storage repositories. * Provides a flexible interface for storing and retrieving data with typed @@ -121,9 +136,8 @@ export interface ITabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, > { // Core methods put(value: Entity): Promise; @@ -197,3 +211,5 @@ export interface ITabularRepository< [Symbol.dispose](): void; [Symbol.asyncDispose](): Promise; } + +export type AnyTabularRepository = ITabularRepository; diff --git a/packages/storage/src/tabular/README.md b/packages/storage/src/tabular/README.md index f4f8ec89..df169cda 100644 --- a/packages/storage/src/tabular/README.md +++ b/packages/storage/src/tabular/README.md @@ -77,7 +77,7 @@ TypeBox schemas are JSON Schema compatible and can be used directly: ```typescript import { InMemoryTabularRepository } from "@workglow/storage/tabular"; import { Type, Static } from "@sinclair/typebox"; -import { DataPortSchemaObject, FromSchema, IncludeProps, ExcludeProps } from "@workglow/util"; +import { DataPortSchemaObject, FromSchema } from "@workglow/util"; // Define schema using TypeBox const userSchema = Type.Object({ diff --git a/packages/storage/src/tabular/TabularRepositoryRegistry.ts b/packages/storage/src/tabular/TabularRepositoryRegistry.ts new file mode 100644 index 00000000..96db1972 --- /dev/null +++ b/packages/storage/src/tabular/TabularRepositoryRegistry.ts @@ -0,0 +1,79 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import { AnyTabularRepository } from "./ITabularRepository"; + +/** + * Service token for the tabular repository registry + * Maps repository IDs to ITabularRepository instances + */ +export const TABULAR_REPOSITORIES = createServiceToken>( + "storage.tabular.repositories" +); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(TABULAR_REPOSITORIES)) { + globalServiceRegistry.register( + TABULAR_REPOSITORIES, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global tabular repository registry + * @returns Map of tabular repository ID to instance + */ +export function getGlobalTabularRepositories(): Map { + return globalServiceRegistry.get(TABULAR_REPOSITORIES); +} + +/** + * Registers a tabular repository globally by ID + * @param id The unique identifier for this repository + * @param repository The repository instance to register + */ +export function registerTabularRepository(id: string, repository: AnyTabularRepository): void { + const repos = getGlobalTabularRepositories(); + repos.set(id, repository); +} + +/** + * Gets a tabular repository by ID from the global registry + * @param id The repository identifier + * @returns The repository instance or undefined if not found + */ +export function getTabularRepository(id: string): AnyTabularRepository | undefined { + return getGlobalTabularRepositories().get(id); +} + +/** + * Resolves a repository ID to an instance from the registry. + * Used by the input resolver system. + */ +function resolveRepositoryFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): AnyTabularRepository { + const repos = registry.has(TABULAR_REPOSITORIES) + ? registry.get(TABULAR_REPOSITORIES) + : getGlobalTabularRepositories(); + const repo = repos.get(id); + if (!repo) { + throw new Error(`Tabular repository "${id}" not found in registry`); + } + return repo; +} + +// Register the repository resolver for format: "repository:tabular" +registerInputResolver("repository:tabular", resolveRepositoryFromRegistry); diff --git a/packages/storage/src/util/RepositorySchema.ts b/packages/storage/src/util/RepositorySchema.ts new file mode 100644 index 00000000..9d4805b2 --- /dev/null +++ b/packages/storage/src/util/RepositorySchema.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { JsonSchema } from "@workglow/util"; + +/** + * Semantic format types for repository schema annotations. + * These are used by the InputResolver to determine how to resolve string IDs. + */ +export type RepositorySemantic = + | "repository:tabular" + | "repository:document-node-vector" + | "repository:document"; + +/** + * Creates a JSON schema for a tabular repository input. + * The schema accepts either a string ID (resolved from registry) or a direct repository instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for tabular repository input + * + * @example + * ```typescript + * const inputSchema = { + * type: "object", + * properties: { + * dataSource: TypeTabularRepository({ + * title: "User Database", + * description: "Repository containing user records", + * }), + * }, + * required: ["dataSource"], + * } as const; + * ``` + */ +export function TypeTabularRepository = {}>( + options: O = {} as O +) { + return { + title: "Tabular Repository", + description: "Repository ID or instance for tabular data storage", + ...options, + format: "repository:tabular" as const, + oneOf: [ + { type: "string" as const, title: "Repository ID" }, + { title: "Repository Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} + +/** + * Creates a JSON schema for a vector repository input. + * The schema accepts either a string ID (resolved from registry) or a direct repository instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for vector repository input + */ +export function TypeDocumentNodeVectorRepository = {}>( + options: O = {} as O +) { + return { + title: "Document Chunk Vector Repository", + description: "Repository ID or instance for document chunk vector data storage", + ...options, + format: "repository:document-node-vector" as const, + anyOf: [ + { type: "string" as const, title: "Repository ID" }, + { title: "Repository Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} + +/** + * Creates a JSON schema for a document repository input. + * The schema accepts either a string ID (resolved from registry) or a direct repository instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for document repository input + */ +export function TypeDocumentRepository = {}>( + options: O = {} as O +) { + return { + title: "Document Repository", + description: "Repository ID or instance for document data storage", + ...options, + format: "repository:document" as const, + anyOf: [ + { type: "string" as const, title: "Repository ID" }, + { title: "Repository Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} diff --git a/packages/task-graph/src/common.ts b/packages/task-graph/src/common.ts index 4184e91e..776cc7a5 100644 --- a/packages/task-graph/src/common.ts +++ b/packages/task-graph/src/common.ts @@ -8,6 +8,7 @@ export * from "./task/ArrayTask"; export * from "./task/ConditionalTask"; export * from "./task/GraphAsTask"; export * from "./task/GraphAsTaskRunner"; +export * from "./task/InputResolver"; export * from "./task/InputTask"; export * from "./task/ITask"; export * from "./task/JobQueueFactory"; diff --git a/packages/task-graph/src/task-graph/TaskGraph.ts b/packages/task-graph/src/task-graph/TaskGraph.ts index 88e71ab6..5a5cbe61 100644 --- a/packages/task-graph/src/task-graph/TaskGraph.ts +++ b/packages/task-graph/src/task-graph/TaskGraph.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DirectedAcyclicGraph, EventEmitter, uuid4 } from "@workglow/util"; +import { DirectedAcyclicGraph, EventEmitter, ServiceRegistry, uuid4 } from "@workglow/util"; import { TaskOutputRepository } from "../storage/TaskOutputRepository"; import type { ITask } from "../task/ITask"; import { JsonTaskItem, TaskGraphJson } from "../task/TaskJSON"; @@ -37,6 +37,8 @@ export interface TaskGraphRunConfig { outputCache?: TaskOutputRepository | boolean; /** Optional signal to abort the task graph */ parentSignal?: AbortSignal; + /** Optional service registry to use for this task graph (creates child from global if not provided) */ + registry?: ServiceRegistry; } class TaskGraphDAG extends DirectedAcyclicGraph< diff --git a/packages/task-graph/src/task-graph/TaskGraphRunner.ts b/packages/task-graph/src/task-graph/TaskGraphRunner.ts index 94bd630d..ddbd8845 100644 --- a/packages/task-graph/src/task-graph/TaskGraphRunner.ts +++ b/packages/task-graph/src/task-graph/TaskGraphRunner.ts @@ -8,13 +8,14 @@ import { collectPropertyValues, ConvertAllToOptionalArray, globalServiceRegistry, + ServiceRegistry, uuid4, } from "@workglow/util"; import { TASK_OUTPUT_REPOSITORY, TaskOutputRepository } from "../storage/TaskOutputRepository"; import { ConditionalTask } from "../task/ConditionalTask"; import { ITask } from "../task/ITask"; import { TaskAbortedError, TaskConfigurationError, TaskError } from "../task/TaskError"; -import { Provenance, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import { TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; import { DATAFLOW_ALL_PORTS } from "./Dataflow"; import { TaskGraph, TaskGraphRunConfig } from "./TaskGraph"; import { DependencyBasedScheduler, TopologicalScheduler } from "./TaskGraphScheduler"; @@ -68,6 +69,10 @@ export class TaskGraphRunner { * Output cache repository */ protected outputCache?: TaskOutputRepository; + /** + * Service registry for this graph run + */ + protected registry: ServiceRegistry = globalServiceRegistry; /** * AbortController for cancelling graph execution */ @@ -130,10 +135,7 @@ export class TaskGraphRunner { // Only filter input for non-root tasks; root tasks get the full input const taskInput = isRootTask ? input : this.filterInputForTask(task, input); - const taskPromise = this.runTask( - task, - taskInput - ); + const taskPromise = this.runTask(task, taskInput); this.inProgressTasks!.set(task.config.id, taskPromise); const taskResult = await taskPromise; @@ -481,6 +483,7 @@ export class TaskGraphRunner { outputCache: this.outputCache, updateProgress: async (task: ITask, progress: number, message?: string, ...args: any[]) => await this.handleProgress(task, progress, message, ...args), + registry: this.registry, }); await this.pushOutputFromNodeToEdges(task, results); @@ -535,10 +538,18 @@ export class TaskGraphRunner { * @param parentSignal Optional abort signal from parent */ protected async handleStart(config?: TaskGraphRunConfig): Promise { + // Setup registry - create child from global if not provided + if (config?.registry !== undefined) { + this.registry = config.registry; + } else { + // Create a child container that inherits from global but allows overrides + this.registry = new ServiceRegistry(globalServiceRegistry.container.createChildContainer()); + } + if (config?.outputCache !== undefined) { if (typeof config.outputCache === "boolean") { if (config.outputCache === true) { - this.outputCache = globalServiceRegistry.get(TASK_OUTPUT_REPOSITORY); + this.outputCache = this.registry.get(TASK_OUTPUT_REPOSITORY); } else { this.outputCache = undefined; } diff --git a/packages/task-graph/src/task/ITask.ts b/packages/task-graph/src/task/ITask.ts index 10056be6..b24be409 100644 --- a/packages/task-graph/src/task/ITask.ts +++ b/packages/task-graph/src/task/ITask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { DataPortSchema, EventEmitter } from "@workglow/util"; +import type { DataPortSchema, EventEmitter, ServiceRegistry } from "@workglow/util"; import { TaskOutputRepository } from "../storage/TaskOutputRepository"; import { ITaskGraph } from "../task-graph/ITaskGraph"; import { IWorkflow } from "../task-graph/IWorkflow"; @@ -28,6 +28,7 @@ export interface IExecuteContext { signal: AbortSignal; updateProgress: (progress: number, message?: string, ...args: any[]) => Promise; own: (i: T) => T; + registry: ServiceRegistry; } export type IExecuteReactiveContext = Pick; @@ -36,7 +37,6 @@ export type IExecuteReactiveContext = Pick; * Configuration for running a task */ export interface IRunConfig { - nodeProvenance?: Provenance; outputCache?: TaskOutputRepository | boolean; updateProgress?: ( task: ITask, @@ -44,6 +44,7 @@ export interface IRunConfig { message?: string, ...args: any[] ) => Promise; + registry?: ServiceRegistry; } /** @@ -114,7 +115,7 @@ export interface ITaskIO { addInput(overrides: Record | undefined): boolean; validateInput(input: Record): Promise; get cacheable(): boolean; - narrowInput(input: Record): Promise>; + narrowInput(input: Record, registry: ServiceRegistry): Promise>; } export interface ITaskInternalGraph { diff --git a/packages/task-graph/src/task/InputResolver.ts b/packages/task-graph/src/task/InputResolver.ts new file mode 100644 index 00000000..282a519a --- /dev/null +++ b/packages/task-graph/src/task/InputResolver.ts @@ -0,0 +1,113 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { DataPortSchema, ServiceRegistry } from "@workglow/util"; +import { getInputResolvers } from "@workglow/util"; + +/** + * Configuration for the input resolver + */ +export interface InputResolverConfig { + readonly registry: ServiceRegistry; +} + +/** + * Extracts the format string from a schema, handling oneOf/anyOf wrappers. + */ +function getSchemaFormat(schema: unknown): string | undefined { + if (typeof schema !== "object" || schema === null) return undefined; + + const s = schema as Record; + + // Direct format + if (typeof s.format === "string") return s.format; + + // Check oneOf/anyOf for format + const variants = (s.oneOf ?? s.anyOf) as unknown[] | undefined; + if (Array.isArray(variants)) { + for (const variant of variants) { + if (typeof variant === "object" && variant !== null) { + const v = variant as Record; + if (typeof v.format === "string") return v.format; + } + } + } + + return undefined; +} + +/** + * Gets the format prefix from a format string. + * For "model:TextEmbedding" returns "model" + * For "repository:tabular" returns "repository" + */ +function getFormatPrefix(format: string): string { + const colonIndex = format.indexOf(":"); + return colonIndex >= 0 ? format.substring(0, colonIndex) : format; +} + +/** + * Resolves schema-annotated inputs by looking up string IDs from registries. + * String values with matching format annotations are resolved to their instances. + * Non-string values (objects/instances) are passed through unchanged. + * + * @param input The task input object + * @param schema The task's input schema + * @param config Configuration including the service registry + * @returns The input with resolved values + * + * @example + * ```typescript + * // In TaskRunner.run() + * const resolvedInput = await resolveSchemaInputs( + * this.task.runInputData, + * (this.task.constructor as typeof Task).inputSchema(), + * { registry: this.registry } + * ); + * ``` + */ +export async function resolveSchemaInputs>( + input: T, + schema: DataPortSchema, + config: InputResolverConfig +): Promise { + if (typeof schema === "boolean") return input; + + const properties = schema.properties; + if (!properties || typeof properties !== "object") return input; + + const resolvers = getInputResolvers(); + const resolved: Record = { ...input }; + + for (const [key, propSchema] of Object.entries(properties)) { + const value = resolved[key]; + + const format = getSchemaFormat(propSchema); + if (!format) continue; + + // Try full format first (e.g., "repository:document-node-vector"), then fall back to prefix (e.g., "repository") + let resolver = resolvers.get(format); + if (!resolver) { + const prefix = getFormatPrefix(format); + resolver = resolvers.get(prefix); + } + + if (!resolver) continue; + + // Handle string values + if (typeof value === "string") { + resolved[key] = await resolver(value, format, config.registry); + } + // Handle arrays of strings - pass the entire array to the resolver + // (resolvers like resolveModelFromRegistry handle arrays even though typed as string) + else if (Array.isArray(value) && value.every((item) => typeof item === "string")) { + resolved[key] = await resolver(value as unknown as string, format, config.registry); + } + // Skip if not a string or array of strings (already resolved or direct instance) + } + + return resolved as T; +} diff --git a/packages/task-graph/src/task/README.md b/packages/task-graph/src/task/README.md index 18b17812..a767d9dd 100644 --- a/packages/task-graph/src/task/README.md +++ b/packages/task-graph/src/task/README.md @@ -13,6 +13,7 @@ This module provides a flexible task processing system with support for various - [Event Handling](#event-handling) - [Input/Output Schemas](#inputoutput-schemas) - [Registry \& Queues](#registry--queues) +- [Input Resolution](#input-resolution) - [Error Handling](#error-handling) - [Testing](#testing) - [Installation](#installation) @@ -241,6 +242,75 @@ const queue = getTaskQueueRegistry().getQueue("processing"); queue.add(new MyJobTask()); ``` +## Input Resolution + +The TaskRunner automatically resolves schema-annotated string inputs to their corresponding instances before task execution. This allows tasks to accept either string identifiers (like `"my-model"` or `"my-repository"`) or direct object instances, providing flexibility in how tasks are configured. + +### How It Works + +When a task's input schema includes properties with `format` annotations (such as `"model"`, `"model:TaskName"`, or `"repository:tabular"`), the TaskRunner inspects each input property: + +- **String values** are looked up in the appropriate registry and resolved to instances +- **Object values** (already instances) pass through unchanged + +This resolution happens automatically before `validateInput()` is called, so by the time `execute()` runs, all annotated inputs are guaranteed to be resolved objects. + +### Example: Task with Repository Input + +```typescript +import { Task } from "@workglow/task-graph"; +import { TypeTabularRepository } from "@workglow/storage"; + +class DataProcessingTask extends Task<{ repository: ITabularRepository; query: string }> { + static readonly type = "DataProcessingTask"; + + static inputSchema() { + return { + type: "object", + properties: { + repository: TypeTabularRepository({ + title: "Data Source", + description: "Repository to query", + }), + query: { type: "string", title: "Query" }, + }, + required: ["repository", "query"], + }; + } + + async execute(input: DataProcessingTaskInput, context: IExecuteContext) { + // repository is guaranteed to be an ITabularRepository instance + const data = await input.repository.getAll(); + return { results: data }; + } +} + +// Usage with string ID (resolved automatically) +const task = new DataProcessingTask(); +await task.run({ repository: "my-registered-repo", query: "test" }); + +// Usage with direct instance (passed through) +await task.run({ repository: myRepositoryInstance, query: "test" }); +``` + +### Registering Custom Resolvers + +Extend the input resolution system by registering custom resolvers for new format prefixes: + +```typescript +import { registerInputResolver } from "@workglow/util"; + +// Register a resolver for "config:*" formats +registerInputResolver("config", async (id, format, registry) => { + const configRepo = registry.get(CONFIG_REPOSITORY); + const config = await configRepo.findById(id); + if (!config) { + throw new Error(`Configuration "${id}" not found`); + } + return config; +}); +``` + ## Error Handling ```typescript diff --git a/packages/task-graph/src/task/Task.ts b/packages/task-graph/src/task/Task.ts index bed88b1a..20ba68fc 100644 --- a/packages/task-graph/src/task/Task.ts +++ b/packages/task-graph/src/task/Task.ts @@ -11,6 +11,7 @@ import { SchemaNode, uuid4, type DataPortSchema, + type ServiceRegistry, } from "@workglow/util"; import { DATAFLOW_ALL_PORTS } from "../task-graph/Dataflow"; import { TaskGraph } from "../task-graph/TaskGraph"; @@ -590,9 +591,13 @@ export class Task< /** * Stub for narrowing input. Override in subclasses for custom logic. * @param input The input to narrow + * @param _registry Optional service registry for lookups * @returns The (possibly narrowed) input */ - public async narrowInput(input: Record): Promise> { + public async narrowInput( + input: Record, + _registry: ServiceRegistry + ): Promise> { return input; } diff --git a/packages/task-graph/src/task/TaskRunner.ts b/packages/task-graph/src/task/TaskRunner.ts index a5621904..f49295ed 100644 --- a/packages/task-graph/src/task/TaskRunner.ts +++ b/packages/task-graph/src/task/TaskRunner.ts @@ -4,11 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { globalServiceRegistry } from "@workglow/util"; +import { globalServiceRegistry, ServiceRegistry } from "@workglow/util"; import { TASK_OUTPUT_REPOSITORY, TaskOutputRepository } from "../storage/TaskOutputRepository"; import { ensureTask, type Taskish } from "../task-graph/Conversions"; +import { resolveSchemaInputs } from "./InputResolver"; import { IRunConfig, ITask } from "./ITask"; import { ITaskRunner } from "./ITaskRunner"; +import { Task } from "./Task"; import { TaskAbortedError, TaskError, TaskFailedError, TaskInvalidInputError } from "./TaskError"; import { TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; @@ -42,6 +44,11 @@ export class TaskRunner< */ protected outputCache?: TaskOutputRepository; + /** + * The service registry for the task + */ + protected registry: ServiceRegistry = globalServiceRegistry; + /** * Constructor for TaskRunner * @param task The task to run @@ -67,6 +74,15 @@ export class TaskRunner< try { this.task.setInput(overrides); + + // Resolve schema-annotated inputs (models, repositories) before validation + const schema = (this.task.constructor as typeof Task).inputSchema(); + this.task.runInputData = (await resolveSchemaInputs( + this.task.runInputData as Record, + schema, + { registry: this.registry } + )) as Input; + const isValid = await this.task.validateInput(this.task.runInputData); if (!isValid) { throw new TaskInvalidInputError("Invalid input data"); @@ -113,6 +129,14 @@ export class TaskRunner< } this.task.setInput(overrides); + // Resolve schema-annotated inputs (models, repositories) before validation + const schema = (this.task.constructor as typeof Task).inputSchema(); + this.task.runInputData = (await resolveSchemaInputs( + this.task.runInputData as Record, + schema, + { registry: this.registry } + )) as Input; + await this.handleStartReactive(); try { @@ -164,6 +188,7 @@ export class TaskRunner< signal: this.abortController!.signal, updateProgress: this.handleProgress.bind(this), own: this.own, + registry: this.registry, }); return await this.executeTaskReactive(input, result || ({} as Output)); } @@ -211,6 +236,10 @@ export class TaskRunner< this.updateProgress = config.updateProgress; } + if (config.registry) { + this.registry = config.registry; + } + this.task.emit("start"); this.task.emit("status", this.task.status); } diff --git a/packages/test/src/test/rag/Document.test.ts b/packages/test/src/test/rag/Document.test.ts new file mode 100644 index 00000000..e89195dd --- /dev/null +++ b/packages/test/src/test/rag/Document.test.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ChunkNode, DocumentNode } from "@workglow/storage"; +import { Document, NodeKind } from "@workglow/storage"; +import { describe, expect, test } from "vitest"; + +describe("Document", () => { + const createTestDocumentNode = (): DocumentNode => ({ + nodeId: "root", + kind: NodeKind.DOCUMENT, + range: { startOffset: 0, endOffset: 100 }, + text: "Test document stuff", + title: "Test document", + children: [], + }); + + const createTestChunks = (): ChunkNode[] => [ + { + chunkId: "chunk1", + doc_id: "doc1", + text: "Test chunk", + nodePath: ["root"], + depth: 1, + }, + ]; + + test("setChunks and getChunks", () => { + const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + + doc.setChunks(createTestChunks()); + + const chunks = doc.getChunks(); + expect(chunks).toBeDefined(); + expect(chunks.length).toBe(1); + expect(chunks[0].text).toBe("Test chunk"); + }); + + test("findChunksByNodeId", () => { + const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + + doc.setChunks(createTestChunks()); + + const chunks = doc.findChunksByNodeId("root"); + expect(chunks).toBeDefined(); + expect(chunks.length).toBe(1); + expect(chunks[0].text).toBe("Test chunk"); + }); +}); diff --git a/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts b/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts new file mode 100644 index 00000000..754f119a --- /dev/null +++ b/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts @@ -0,0 +1,295 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { retrieval } from "@workglow/ai"; +import { + InMemoryDocumentNodeVectorRepository, + registerDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("DocumentNodeRetrievalTask", () => { + let repo: InMemoryDocumentNodeVectorRepository; + + beforeEach(async () => { + repo = new InMemoryDocumentNodeVectorRepository(3); + await repo.setupDatabase(); + + // Populate repository with test data + const vectors = [ + new Float32Array([1.0, 0.0, 0.0]), + new Float32Array([0.8, 0.2, 0.0]), + new Float32Array([0.0, 1.0, 0.0]), + new Float32Array([0.0, 0.0, 1.0]), + new Float32Array([0.9, 0.1, 0.0]), + ]; + + const metadata = [ + { text: "First chunk about AI" }, + { text: "Second chunk about machine learning" }, + { content: "Third chunk about cooking" }, + { chunk: "Fourth chunk about travel" }, + { text: "Fifth chunk about artificial intelligence" }, + ]; + + for (let i = 0; i < vectors.length; i++) { + const doc_id = `doc${i + 1}`; + await repo.put({ + chunk_id: `${doc_id}_0`, + doc_id, + vector: vectors[i], + metadata: metadata[i], + }); + } + }); + + afterEach(() => { + repo.destroy(); + }); + + test("should retrieve chunks with query vector", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.chunks).toHaveLength(3); + expect(result.ids).toHaveLength(3); + expect(result.metadata).toHaveLength(3); + expect(result.scores).toHaveLength(3); + + // Chunks should be extracted from metadata + expect(result.chunks[0]).toBeTruthy(); + expect(typeof result.chunks[0]).toBe("string"); + }); + + test("should extract text from metadata.text field", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 5, + }); + + // Find chunks that have text field + const textChunks = result.chunks.filter((chunk, idx) => { + const meta = result.metadata[idx]; + return meta.text !== undefined; + }); + + expect(textChunks.length).toBeGreaterThan(0); + textChunks.forEach((chunk, idx) => { + const originalIdx = result.chunks.indexOf(chunk); + expect(chunk).toBe(result.metadata[originalIdx].text); + }); + }); + + test("should extract text from metadata.content field as fallback", async () => { + const queryVector = new Float32Array([0.0, 1.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 5, + }); + + // Find the chunk with content field + const contentChunkIdx = result.metadata.findIndex((meta) => meta.content !== undefined); + if (contentChunkIdx >= 0) { + expect(result.chunks[contentChunkIdx]).toBe(result.metadata[contentChunkIdx].content); + } + }); + + test("should extract text from metadata.chunk field as fallback", async () => { + const queryVector = new Float32Array([0.0, 0.0, 1.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 5, + }); + + // Find the chunk with chunk field + const chunkIdx = result.metadata.findIndex((meta) => meta.chunk !== undefined); + if (chunkIdx >= 0) { + expect(result.chunks[chunkIdx]).toBe(result.metadata[chunkIdx].chunk); + } + }); + + test("should return vectors when returnVectors is true", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 3, + returnVectors: true, + }); + + expect(result.vectors).toBeDefined(); + expect(result.vectors).toHaveLength(3); + expect(result.vectors![0]).toBeInstanceOf(Float32Array); + }); + + test("should not return vectors when returnVectors is false", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 3, + returnVectors: false, + }); + + expect(result.vectors).toBeUndefined(); + }); + + test("should respect topK parameter", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 2, + }); + + expect(result.count).toBe(2); + expect(result.chunks).toHaveLength(2); + }); + + test("should apply metadata filter", async () => { + // Add a document with specific metadata for filtering + await repo.put({ + chunk_id: "filtered_doc_0", + doc_id: "filtered_doc", + vector: new Float32Array([1.0, 0.0, 0.0]), + metadata: { + text: "Filtered document", + category: "test", + }, + }); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 10, + filter: { category: "test" }, + }); + + expect(result.count).toBe(1); + expect(result.ids[0]).toBe("filtered_doc_0"); + }); + + test("should apply score threshold", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 10, + scoreThreshold: 0.9, + }); + + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0.9); + }); + }); + + test("should use queryEmbedding when provided", async () => { + const queryEmbedding = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryEmbedding, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.chunks).toHaveLength(3); + }); + + test("should throw error when query is string without model", async () => { + await expect( + // @ts-expect-error - query is string but no model is provided + retrieval({ + repository: repo, + query: "test query string", + topK: 3, + }) + ).rejects.toThrow("model"); + }); + + test("should handle default topK value", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + }); + + // Default topK is 5 + expect(result.count).toBe(5); + expect(result.count).toBeLessThanOrEqual(5); + }); + + test("should JSON.stringify metadata when no text/content/chunk fields", async () => { + // Add document with only non-standard metadata + await repo.put({ + chunk_id: "json_doc_0", + doc_id: "json_doc", + vector: new Float32Array([1.0, 0.0, 0.0]), + metadata: { + title: "Title only", + author: "Author name", + }, + }); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 10, + }); + + // Find the JSON stringified chunk + const jsonChunk = result.chunks.find((chunk) => chunk.includes("title")); + expect(jsonChunk).toBeDefined(); + expect(jsonChunk).toContain("Title only"); + expect(jsonChunk).toContain("Author name"); + }); + + test("should resolve repository from string ID", async () => { + // Register repository by ID + registerDocumentNodeVectorRepository("test-retrieval-repo", repo); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + // Pass repository as string ID instead of instance + const result = await retrieval({ + repository: "test-retrieval-repo" as any, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.chunks).toHaveLength(3); + expect(result.ids).toHaveLength(3); + expect(result.metadata).toHaveLength(3); + expect(result.scores).toHaveLength(3); + + // Chunks should be extracted from metadata + expect(result.chunks[0]).toBeTruthy(); + expect(typeof result.chunks[0]).toBe("string"); + }); +}); diff --git a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts b/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts new file mode 100644 index 00000000..ccf1b37a --- /dev/null +++ b/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts @@ -0,0 +1,254 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { DocumentNodeVectorSearchTask } from "@workglow/ai"; +import { + InMemoryDocumentNodeVectorRepository, + registerDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("DocumentNodeVectorSearchTask", () => { + let repo: InMemoryDocumentNodeVectorRepository; + + beforeEach(async () => { + repo = new InMemoryDocumentNodeVectorRepository(3); + await repo.setupDatabase(); + + // Populate repository with test data + const vectors = [ + new Float32Array([1.0, 0.0, 0.0]), // doc1 - similar to query + new Float32Array([0.8, 0.2, 0.0]), // doc2 - somewhat similar + new Float32Array([0.0, 1.0, 0.0]), // doc3 - different + new Float32Array([0.0, 0.0, 1.0]), // doc4 - different + new Float32Array([0.9, 0.1, 0.0]), // doc5 - very similar + ]; + + const metadata = [ + { text: "Document about AI", category: "tech" }, + { text: "Document about machine learning", category: "tech" }, + { text: "Document about cooking", category: "food" }, + { text: "Document about travel", category: "travel" }, + { text: "Document about artificial intelligence", category: "tech" }, + ]; + + for (let i = 0; i < vectors.length; i++) { + const doc_id = `doc${i + 1}`; + await repo.put({ + chunk_id: `${doc_id}_0`, + doc_id, + vector: vectors[i], + metadata: metadata[i], + }); + } + }); + + afterEach(() => { + repo.destroy(); + }); + + test("should search and return top K results", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.ids).toHaveLength(3); + expect(result.vectors).toHaveLength(3); + expect(result.metadata).toHaveLength(3); + expect(result.scores).toHaveLength(3); + + // Scores should be in descending order + for (let i = 1; i < result.scores.length; i++) { + expect(result.scores[i - 1]).toBeGreaterThanOrEqual(result.scores[i]); + } + + // Most similar should be doc1_0 (exact match) + expect(result.ids[0]).toBe("doc1_0"); + }); + + test("should respect topK limit", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 2, + }); + + expect(result.count).toBe(2); + expect(result.ids).toHaveLength(2); + }); + + test("should filter by metadata", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 10, + filter: { category: "tech" }, + }); + + expect(result.count).toBeGreaterThan(0); + // All results should have category "tech" + result.metadata.forEach((meta) => { + expect(meta).toHaveProperty("category", "tech"); + }); + }); + + test("should apply score threshold", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 10, + scoreThreshold: 0.9, + }); + + // All scores should be >= 0.9 + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0.9); + }); + }); + + test("should return empty results when no matches", async () => { + const queryVector = new Float32Array([0.0, 0.0, 1.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 10, + filter: { category: "nonexistent" }, + }); + + expect(result.count).toBe(0); + expect(result.ids).toHaveLength(0); + expect(result.vectors).toHaveLength(0); + expect(result.metadata).toHaveLength(0); + expect(result.scores).toHaveLength(0); + }); + + test("should handle default topK value", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + }); + + // Default topK is 10, but we only have 5 documents + expect(result.count).toBe(5); + expect(result.count).toBeLessThanOrEqual(10); + }); + + test("should work with quantized query vectors (Int8Array)", async () => { + const queryVector = new Int8Array([127, 0, 0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBeGreaterThan(0); + expect(result.ids).toHaveLength(result.count); + expect(result.scores).toHaveLength(result.count); + }); + + test("should return results sorted by similarity score", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 5, + }); + + // Verify descending order + for (let i = 1; i < result.scores.length; i++) { + expect(result.scores[i - 1]).toBeGreaterThanOrEqual(result.scores[i]); + } + }); + + test("should handle empty repository", async () => { + const emptyRepo = new InMemoryDocumentNodeVectorRepository(3); + await emptyRepo.setupDatabase(); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: emptyRepo, + query: queryVector, + topK: 10, + }); + + expect(result.count).toBe(0); + expect(result.ids).toHaveLength(0); + expect(result.scores).toHaveLength(0); + + emptyRepo.destroy(); + }); + + test("should combine filter and score threshold", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 10, + filter: { category: "tech" }, + scoreThreshold: 0.7, + }); + + // All results should pass both filter and threshold + result.metadata.forEach((meta) => { + expect(meta).toHaveProperty("category", "tech"); + }); + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0.7); + }); + }); + + test("should resolve repository from string ID", async () => { + // Register repository by ID + registerDocumentNodeVectorRepository("test-vector-repo", repo); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new DocumentNodeVectorSearchTask(); + // Pass repository as string ID instead of instance + const result = await task.run({ + repository: "test-vector-repo" as any, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.ids).toHaveLength(3); + expect(result.vectors).toHaveLength(3); + expect(result.metadata).toHaveLength(3); + expect(result.scores).toHaveLength(3); + + // Most similar should be doc1_0 (exact match) + expect(result.ids[0]).toBe("doc1_0"); + }); +}); diff --git a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts b/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts new file mode 100644 index 00000000..97a46304 --- /dev/null +++ b/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts @@ -0,0 +1,228 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { DocumentNodeVectorUpsertTask } from "@workglow/ai"; +import { + InMemoryDocumentNodeVectorRepository, + registerDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("DocumentNodeVectorUpsertTask", () => { + let repo: InMemoryDocumentNodeVectorRepository; + + beforeEach(async () => { + repo = new InMemoryDocumentNodeVectorRepository(3); + await repo.setupDatabase(); + }); + + afterEach(() => { + repo.destroy(); + }); + + test("should upsert a single vector", async () => { + const vector = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]); + const metadata = { text: "Test document", source: "test.txt" }; + + const task = new DocumentNodeVectorUpsertTask(); + const result = await task.run({ + repository: repo, + doc_id: "doc1", + vectors: vector, + metadata: metadata, + }); + + expect(result.count).toBe(1); + expect(result.doc_id).toBe("doc1"); + expect(result.ids).toHaveLength(1); + + // Verify vector was stored + const retrieved = await repo.get({ chunk_id: result.ids[0] }); + expect(retrieved).toBeDefined(); + expect(retrieved?.doc_id).toBe("doc1"); + expect(retrieved!.metadata).toEqual(metadata); + }); + + test("should upsert multiple vectors in bulk", async () => { + const vectors = [ + new Float32Array([0.1, 0.2, 0.3]), + new Float32Array([0.4, 0.5, 0.6]), + new Float32Array([0.7, 0.8, 0.9]), + ]; + const metadata = { text: "Document with multiple vectors", source: "doc.txt" }; + + const task = new DocumentNodeVectorUpsertTask(); + const result = await task.run({ + repository: repo, + doc_id: "doc1", + vectors: vectors, + metadata: metadata, + }); + + expect(result.count).toBe(3); + expect(result.doc_id).toBe("doc1"); + expect(result.ids).toHaveLength(3); + + // Verify all vectors were stored + for (let i = 0; i < 3; i++) { + const retrieved = await repo.get({ chunk_id: result.ids[i] }); + expect(retrieved).toBeDefined(); + expect(retrieved?.doc_id).toBe("doc1"); + expect(retrieved!.metadata).toEqual(metadata); + } + }); + + test("should handle array of single item (normalized to bulk)", async () => { + const vector = [new Float32Array([0.1, 0.2, 0.3])]; + const metadata = { text: "Single item as array" }; + + const task = new DocumentNodeVectorUpsertTask(); + const result = await task.run({ + repository: repo, + doc_id: "doc1", + vectors: vector, + metadata: metadata, + }); + + expect(result.count).toBe(1); + expect(result.doc_id).toBe("doc1"); + + const retrieved = await repo.get({ chunk_id: result.ids[0] }); + expect(retrieved).toBeDefined(); + expect(retrieved!.metadata).toEqual(metadata); + }); + + test("should update existing vector when upserting with same ID", async () => { + const vector1 = new Float32Array([0.1, 0.2, 0.3]); + const vector2 = new Float32Array([0.9, 0.8, 0.7]); + const metadata1 = { text: "Original document" }; + const metadata2 = { text: "Updated document", source: "updated.txt" }; + + // First upsert + const task1 = new DocumentNodeVectorUpsertTask(); + const result1 = await task1.run({ + repository: repo, + doc_id: "doc1", + vectors: vector1, + metadata: metadata1, + }); + + // Update with same ID + const task2 = new DocumentNodeVectorUpsertTask(); + const result2 = await task2.run({ + repository: repo, + doc_id: "doc1", + vectors: vector2, + metadata: metadata2, + }); + + const retrieved = await repo.get({ chunk_id: result2.ids[0] }); + expect(retrieved).toBeDefined(); + expect(retrieved!.metadata).toEqual(metadata2); + }); + + test("should accept multiple vectors with single metadata", async () => { + const vectors = [new Float32Array([0.1, 0.2]), new Float32Array([0.3, 0.4])]; + const metadata = { text: "Shared metadata" }; + + const task = new DocumentNodeVectorUpsertTask(); + const result = await task.run({ + repository: repo, + doc_id: "doc1", + vectors: vectors, + metadata: metadata, + }); + + expect(result.count).toBe(2); + expect(result.doc_id).toBe("doc1"); + }); + + test("should handle quantized vectors (Int8Array)", async () => { + const vector = new Int8Array([127, -128, 64, -64, 0]); + const metadata = { text: "Quantized vector" }; + + const task = new DocumentNodeVectorUpsertTask(); + const result = await task.run({ + repository: repo, + doc_id: "doc1", + vectors: vector, + metadata: metadata, + }); + + expect(result.count).toBe(1); + + const retrieved = await repo.get({ chunk_id: result.ids[0] }); + expect(retrieved).toBeDefined(); + expect(retrieved?.vector).toBeInstanceOf(Int8Array); + }); + + test("should handle metadata without optional fields", async () => { + const vector = new Float32Array([0.1, 0.2, 0.3]); + const metadata = { text: "Simple metadata" }; + + const task = new DocumentNodeVectorUpsertTask(); + const result = await task.run({ + repository: repo, + doc_id: "doc1", + vectors: vector, + metadata: metadata, + }); + + expect(result.count).toBe(1); + + const retrieved = await repo.get({ chunk_id: result.ids[0] }); + expect(retrieved!.metadata).toEqual(metadata); + }); + + test("should handle large batch upsert", async () => { + const count = 100; + const vectors = Array.from( + { length: count }, + (_, i) => new Float32Array([i * 0.01, i * 0.02, i * 0.03]) + ); + const metadata = { text: "Batch document" }; + + const task = new DocumentNodeVectorUpsertTask(); + const result = await task.run({ + repository: repo, + doc_id: "batch-doc", + vectors: vectors, + metadata: metadata, + }); + + expect(result.count).toBe(count); + expect(result.ids).toHaveLength(count); + + const size = await repo.size(); + expect(size).toBe(count); + }); + + test("should resolve repository from string ID", async () => { + // Register repository by ID + registerDocumentNodeVectorRepository("test-upsert-repo", repo); + + const vector = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]); + const metadata = { text: "Test document", source: "test.txt" }; + + const task = new DocumentNodeVectorUpsertTask(); + // Pass repository as string ID instead of instance + const result = await task.run({ + repository: "test-upsert-repo" as any, + doc_id: "doc1", + vectors: vector, + metadata: metadata, + }); + + expect(result.count).toBe(1); + expect(result.doc_id).toBe("doc1"); + + // Verify vector was stored + const retrieved = await repo.get({ chunk_id: result.ids[0] }); + expect(retrieved).toBeDefined(); + expect(retrieved?.doc_id).toBe("doc1"); + expect(retrieved!.metadata).toEqual(metadata); + }); +}); diff --git a/packages/test/src/test/rag/DocumentRepository.test.ts b/packages/test/src/test/rag/DocumentRepository.test.ts new file mode 100644 index 00000000..5f6bf291 --- /dev/null +++ b/packages/test/src/test/rag/DocumentRepository.test.ts @@ -0,0 +1,484 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Document, + DocumentRepository, + DocumentStorageKey, + DocumentStorageSchema, + InMemoryDocumentNodeVectorRepository, + InMemoryTabularRepository, + NodeIdGenerator, + NodeKind, + StructuralParser, +} from "@workglow/storage"; +import { beforeEach, describe, expect, it } from "vitest"; + +describe("DocumentRepository", () => { + let repo: DocumentRepository; + let vectorStorage: InMemoryDocumentNodeVectorRepository; + + beforeEach(async () => { + const tabularStorage = new InMemoryTabularRepository( + DocumentStorageSchema, + DocumentStorageKey + ); + await tabularStorage.setupDatabase(); + + vectorStorage = new InMemoryDocumentNodeVectorRepository(3); + await vectorStorage.setupDatabase(); + + repo = new DocumentRepository(tabularStorage, vectorStorage); + }); + + it("should store and retrieve documents", async () => { + const markdown = "# Test\n\nContent."; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test Document" }); + + await repo.upsert(doc); + const retrieved = await repo.get(doc_id); + + expect(retrieved).toBeDefined(); + expect(retrieved?.doc_id).toBe(doc_id); + expect(retrieved?.metadata.title).toBe("Test Document"); + }); + + it("should retrieve nodes by ID", async () => { + const markdown = "# Section\n\nParagraph."; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + await repo.upsert(doc); + + // Get a child node + const firstChild = root.children[0]; + const retrieved = await repo.getNode(doc_id, firstChild.nodeId); + + expect(retrieved).toBeDefined(); + expect(retrieved?.nodeId).toBe(firstChild.nodeId); + }); + + it("should get ancestors of a node", async () => { + const markdown = `# Section 1 + +## Subsection 1.1 + +Paragraph.`; + + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + await repo.upsert(doc); + + // Find a deeply nested node + const section = root.children.find((c) => c.kind === NodeKind.SECTION); + expect(section).toBeDefined(); + + const subsection = (section as any).children.find((c: any) => c.kind === NodeKind.SECTION); + expect(subsection).toBeDefined(); + + const ancestors = await repo.getAncestors(doc_id, subsection.nodeId); + + // Should include root, section, and subsection + expect(ancestors.length).toBeGreaterThanOrEqual(3); + expect(ancestors[0].nodeId).toBe(root.nodeId); + expect(ancestors[1].nodeId).toBe(section!.nodeId); + expect(ancestors[2].nodeId).toBe(subsection.nodeId); + }); + + it("should handle chunks", async () => { + const markdown = "# Test\n\nContent."; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + + // Add chunks + const chunks = [ + { + chunkId: "chunk_1", + doc_id, + text: "Test chunk", + nodePath: [root.nodeId], + depth: 1, + }, + ]; + + doc.setChunks(chunks); + + await repo.upsert(doc); + + // Retrieve chunks + const retrievedChunks = await repo.getChunks(doc_id); + expect(retrievedChunks).toBeDefined(); + expect(retrievedChunks.length).toBe(1); + }); + + it("should list all documents", async () => { + const markdown1 = "# Doc 1"; + const markdown2 = "# Doc 2"; + + const id1 = await NodeIdGenerator.generateDocId("test1", markdown1); + const id2 = await NodeIdGenerator.generateDocId("test2", markdown2); + + const root1 = await StructuralParser.parseMarkdown(id1, markdown1, "Doc 1"); + const root2 = await StructuralParser.parseMarkdown(id2, markdown2, "Doc 2"); + + const doc1 = new Document(id1, root1, { title: "Doc 1" }); + const doc2 = new Document(id2, root2, { title: "Doc 2" }); + + await repo.upsert(doc1); + await repo.upsert(doc2); + + const list = await repo.list(); + expect(list.length).toBe(2); + expect(list).toContain(id1); + expect(list).toContain(id2); + }); + + it("should delete documents", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + await repo.upsert(doc); + + expect(await repo.get(doc_id)).toBeDefined(); + + await repo.delete(doc_id); + + expect(await repo.get(doc_id)).toBeUndefined(); + }); + + it("should return undefined for non-existent document", async () => { + const result = await repo.get("non-existent-doc-id"); + expect(result).toBeUndefined(); + }); + + it("should return undefined for node in non-existent document", async () => { + const result = await repo.getNode("non-existent-doc-id", "some-node-id"); + expect(result).toBeUndefined(); + }); + + it("should return undefined for non-existent node", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + await repo.upsert(doc); + + const result = await repo.getNode(doc_id, "non-existent-node-id"); + expect(result).toBeUndefined(); + }); + + it("should return empty array for ancestors of non-existent document", async () => { + const result = await repo.getAncestors("non-existent-doc-id", "some-node-id"); + expect(result).toEqual([]); + }); + + it("should return empty array for ancestors of non-existent node", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + await repo.upsert(doc); + + const result = await repo.getAncestors(doc_id, "non-existent-node-id"); + expect(result).toEqual([]); + }); + + it("should return empty array for chunks of non-existent document", async () => { + const result = await repo.getChunks("non-existent-doc-id"); + expect(result).toEqual([]); + }); + + it("should return empty list for empty repository", async () => { + // Create fresh empty repo + const tabularStorage = new InMemoryTabularRepository( + DocumentStorageSchema, + DocumentStorageKey + ); + await tabularStorage.setupDatabase(); + const emptyRepo = new DocumentRepository(tabularStorage); + + const result = await emptyRepo.list(); + expect(result).toEqual([]); + }); + + it("should not throw when deleting non-existent document", async () => { + // Just verify delete completes without error + await repo.delete("non-existent-doc-id"); + // If we get here, it didn't throw + expect(true).toBe(true); + }); + + it("should update existing document on upsert", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc1 = new Document(doc_id, root, { title: "Original Title" }); + await repo.upsert(doc1); + + const doc2 = new Document(doc_id, root, { title: "Updated Title" }); + await repo.upsert(doc2); + + const retrieved = await repo.get(doc_id); + expect(retrieved?.metadata.title).toBe("Updated Title"); + + const list = await repo.list(); + expect(list.length).toBe(1); + }); + + it("should find chunks by node ID", async () => { + const markdown = "# Test\n\nContent."; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + doc_id, + text: "First chunk", + nodePath: [root.nodeId, "child-1"], + depth: 2, + }, + { + chunkId: "chunk_2", + doc_id, + text: "Second chunk", + nodePath: [root.nodeId, "child-2"], + depth: 2, + }, + ]; + doc.setChunks(chunks); + await repo.upsert(doc); + + const result = await repo.findChunksByNodeId(doc_id, root.nodeId); + expect(result.length).toBe(2); + }); + + it("should return empty array for findChunksByNodeId with no matches", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + doc.setChunks([]); + await repo.upsert(doc); + + const result = await repo.findChunksByNodeId(doc_id, "non-matching-node"); + expect(result).toEqual([]); + }); + + it("should return empty array for findChunksByNodeId with non-existent document", async () => { + const result = await repo.findChunksByNodeId("non-existent-doc", "some-node"); + expect(result).toEqual([]); + }); + + it("should search with vector storage", async () => { + // Add vectors to vector storage + await vectorStorage.put({ + chunk_id: "chunk_1", + doc_id: "doc1", + vector: new Float32Array([1.0, 0.0, 0.0]), + metadata: { text: "First chunk" }, + }); + await vectorStorage.put({ + chunk_id: "chunk_2", + doc_id: "doc1", + vector: new Float32Array([0.8, 0.2, 0.0]), + metadata: { text: "Second chunk" }, + }); + await vectorStorage.put({ + chunk_id: "chunk_3", + doc_id: "doc2", + vector: new Float32Array([0.0, 1.0, 0.0]), + metadata: { text: "Third chunk" }, + }); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const results = await repo.search(queryVector, { topK: 2 }); + + expect(results.length).toBe(2); + expect(results[0].chunk_id).toBe("chunk_1"); + }); + + it("should search with score threshold", async () => { + await vectorStorage.put({ + chunk_id: "chunk_1", + doc_id: "doc1", + vector: new Float32Array([1.0, 0.0, 0.0]), + metadata: { text: "Matching chunk" }, + }); + await vectorStorage.put({ + chunk_id: "chunk_2", + doc_id: "doc1", + vector: new Float32Array([0.0, 1.0, 0.0]), + metadata: { text: "Non-matching chunk" }, + }); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const results = await repo.search(queryVector, { topK: 10, scoreThreshold: 0.9 }); + + expect(results.length).toBeGreaterThanOrEqual(1); + results.forEach((r: any) => { + expect(r.score).toBeGreaterThanOrEqual(0.9); + }); + }); + + it("should return empty array for search when no vector storage configured", async () => { + const tabularStorage = new InMemoryTabularRepository( + DocumentStorageSchema, + DocumentStorageKey + ); + await tabularStorage.setupDatabase(); + + const repoWithoutVector = new DocumentRepository(tabularStorage); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const results = await repoWithoutVector.search(queryVector); + + expect(results).toEqual([]); + }); +}); + +describe("Document", () => { + it("should manage chunks", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + doc_id, + text: "Chunk 1", + nodePath: [root.nodeId], + depth: 1, + }, + ]; + doc.setChunks(chunks); + + const retrievedChunks = doc.getChunks(); + expect(retrievedChunks.length).toBe(1); + expect(retrievedChunks[0].text).toBe("Chunk 1"); + }); + + it("should serialize and deserialize", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + doc_id, + text: "Chunk", + nodePath: [root.nodeId], + depth: 1, + }, + ]; + doc.setChunks(chunks); + + // Serialize + const json = doc.toJSON(); + + // Deserialize + const restored = Document.fromJSON(JSON.stringify(json)); + + expect(restored.doc_id).toBe(doc.doc_id); + expect(restored.metadata.title).toBe(doc.metadata.title); + expect(restored.getChunks().length).toBe(1); + }); + + it("should find chunks by nodeId", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + doc_id, + text: "First", + nodePath: ["root", "section-a"], + depth: 2, + }, + { + chunkId: "chunk_2", + doc_id, + text: "Second", + nodePath: ["root", "section-b"], + depth: 2, + }, + { + chunkId: "chunk_3", + doc_id, + text: "Third", + nodePath: ["root", "section-a", "subsection"], + depth: 3, + }, + ]; + doc.setChunks(chunks); + + // Find chunks containing "section-a" + const result = doc.findChunksByNodeId("section-a"); + expect(result.length).toBe(2); + expect(result.map((c) => c.chunkId)).toContain("chunk_1"); + expect(result.map((c) => c.chunkId)).toContain("chunk_3"); + }); + + it("should return empty array when no chunks match nodeId", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + doc_id, + text: "First", + nodePath: ["root", "section-a"], + depth: 2, + }, + ]; + doc.setChunks(chunks); + + const result = doc.findChunksByNodeId("non-existent-node"); + expect(result).toEqual([]); + }); + + it("should handle empty chunks in findChunksByNodeId", async () => { + const markdown = "# Test"; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + doc.setChunks([]); + + const result = doc.findChunksByNodeId("any-node"); + expect(result).toEqual([]); + }); +}); diff --git a/packages/util/README.md b/packages/util/README.md index d0ce9501..7acc1bc8 100644 --- a/packages/util/README.md +++ b/packages/util/README.md @@ -108,6 +108,37 @@ container.register("UserService", UserService); const userService = container.resolve("UserService"); ``` +### Input Resolver Registry + +The input resolver registry enables automatic resolution of string identifiers to object instances based on JSON Schema format annotations. This is used by the TaskRunner to resolve inputs like model names or repository IDs before task execution. + +```typescript +import { + registerInputResolver, + getInputResolvers, + INPUT_RESOLVERS, +} from "@workglow/util"; + +// Register a custom resolver for a format prefix +registerInputResolver("myformat", async (id, format, registry) => { + // id: the string value to resolve (e.g., "my-item-id") + // format: the full format string (e.g., "myformat:subtype") + // registry: ServiceRegistry for accessing other services + + const myRepo = registry.get(MY_REPOSITORY_TOKEN); + const item = await myRepo.findById(id); + if (!item) { + throw new Error(`Item "${id}" not found`); + } + return item; +}); + +// Get all registered resolvers +const resolvers = getInputResolvers(); +``` + +When a task input schema includes a property with `format: "myformat:subtype"`, and the input value is a string, the resolver is called automatically to convert it to the resolved instance. + ### Event System ```typescript @@ -260,6 +291,7 @@ type User = z.infer; - Decorator-based injection - Singleton and transient lifetimes - Circular dependency detection +- Input resolver registry for schema-based resolution ### Event System (`/events`) diff --git a/packages/util/src/common.ts b/packages/util/src/common.ts index fb650c38..e686798a 100644 --- a/packages/util/src/common.ts +++ b/packages/util/src/common.ts @@ -16,5 +16,9 @@ export * from "./utilities/BaseError"; export * from "./utilities/Misc"; export * from "./utilities/objectOfArraysAsArrayOfObjects"; export * from "./utilities/TypeUtilities"; +export * from "./vector/Tensor"; +export * from "./vector/TypedArray"; +export * from "./vector/VectorSimilarityUtils"; +export * from "./vector/VectorUtils"; export * from "./worker/WorkerManager"; export * from "./worker/WorkerServer"; diff --git a/packages/util/src/di/InputResolverRegistry.ts b/packages/util/src/di/InputResolverRegistry.ts new file mode 100644 index 00000000..064fb9f9 --- /dev/null +++ b/packages/util/src/di/InputResolverRegistry.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { createServiceToken, globalServiceRegistry } from "./ServiceRegistry"; +import type { ServiceRegistry } from "./ServiceRegistry"; + +/** + * A resolver function that converts a string ID to an instance. + * Returns undefined if the resolver cannot handle this format. + * Throws an error if the ID is not found. + * + * @param id The string ID to resolve + * @param format The full format string (e.g., "model:TextEmbedding", "repository:tabular") + * @param registry The service registry to use for lookups + */ +export type InputResolverFn = ( + id: string, + format: string, + registry: ServiceRegistry +) => unknown | Promise; + +/** + * Service token for the input resolver registry. + * Maps format prefixes to resolver functions. + */ +export const INPUT_RESOLVERS = createServiceToken>( + "task.input.resolvers" +); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(INPUT_RESOLVERS)) { + globalServiceRegistry.register( + INPUT_RESOLVERS, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global input resolver registry + * @returns Map of format prefix to resolver function + */ +export function getInputResolvers(): Map { + return globalServiceRegistry.get(INPUT_RESOLVERS); +} + +/** + * Registers an input resolver for a format prefix. + * The resolver will be called for any format that starts with this prefix. + * + * @param formatPrefix The format prefix to match (e.g., "model", "repository") + * @param resolver The resolver function + * + * @example + * ```typescript + * // Register model resolver + * registerInputResolver("model", async (id, format, registry) => { + * const modelRepo = registry.get(MODEL_REPOSITORY); + * const model = await modelRepo.findByName(id); + * if (!model) throw new Error(`Model "${id}" not found`); + * return model; + * }); + * + * // Register repository resolver + * registerInputResolver("repository", (id, format, registry) => { + * const repoType = format.split(":")[1]; // "tabular", "vector", etc. + * if (repoType === "tabular") { + * const repos = registry.get(TABULAR_REPOSITORIES); + * const repo = repos.get(id); + * if (!repo) throw new Error(`Repository "${id}" not found`); + * return repo; + * } + * throw new Error(`Unknown repository type: ${repoType}`); + * }); + * ``` + */ +export function registerInputResolver(formatPrefix: string, resolver: InputResolverFn): void { + const resolvers = getInputResolvers(); + resolvers.set(formatPrefix, resolver); +} diff --git a/packages/util/src/di/ServiceRegistry.ts b/packages/util/src/di/ServiceRegistry.ts index 66aa94b6..eeafa954 100644 --- a/packages/util/src/di/ServiceRegistry.ts +++ b/packages/util/src/di/ServiceRegistry.ts @@ -27,7 +27,7 @@ export function createServiceToken(id: string): ServiceToken { * Service registry for managing and accessing services */ export class ServiceRegistry { - private container: Container; + public container: Container; /** * Create a new service registry diff --git a/packages/util/src/di/index.ts b/packages/util/src/di/index.ts index 4163b35f..a221c727 100644 --- a/packages/util/src/di/index.ts +++ b/packages/util/src/di/index.ts @@ -5,4 +5,5 @@ */ export * from "./Container"; +export * from "./InputResolverRegistry"; export * from "./ServiceRegistry"; diff --git a/packages/util/src/vector/Tensor.ts b/packages/util/src/vector/Tensor.ts new file mode 100644 index 00000000..34897fa7 --- /dev/null +++ b/packages/util/src/vector/Tensor.ts @@ -0,0 +1,62 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { FromSchema } from "../json-schema/FromSchema"; +import { JsonSchema } from "../json-schema/JsonSchema"; +import { TypedArraySchema, TypedArraySchemaOptions } from "./TypedArray"; + +export const TensorType = { + FLOAT16: "float16", + FLOAT32: "float32", + FLOAT64: "float64", + INT8: "int8", + UINT8: "uint8", + INT16: "int16", + UINT16: "uint16", +} as const; + +export type TensorType = (typeof TensorType)[keyof typeof TensorType]; + +/** + * Tensor schema for representing tensors as arrays of numbers + * @param annotations - Additional annotations for the schema + * @returns The tensor schema + */ +export const TensorSchema = (annotations: Record = {}) => + ({ + type: "object", + properties: { + type: { + type: "string", + enum: Object.values(TensorType), + title: "Type", + description: "The type of the tensor", + }, + data: TypedArraySchema({ + title: "Data", + description: "The data of the tensor", + }), + shape: { + type: "array", + items: { type: "number" }, + title: "Shape", + description: "The shape of the tensor (dimensions)", + minItems: 1, + default: [1], + }, + normalized: { + type: "boolean", + title: "Normalized", + description: "Whether the tensor data is normalized", + default: false, + }, + }, + required: ["data"], + additionalProperties: false, + ...annotations, + }) as const satisfies JsonSchema; + +export type Tensor = FromSchema, TypedArraySchemaOptions>; diff --git a/packages/util/src/vector/TypedArray.ts b/packages/util/src/vector/TypedArray.ts new file mode 100644 index 00000000..1ec512b6 --- /dev/null +++ b/packages/util/src/vector/TypedArray.ts @@ -0,0 +1,95 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { FromSchema, FromSchemaDefaultOptions, FromSchemaOptions } from "../json-schema/FromSchema"; +import { JsonSchema } from "../json-schema/JsonSchema"; + +/** + * Supported typed array types + * - Float16Array: 16-bit floating point (medium precision) + * - Float32Array: Standard 32-bit floating point (most common) + * - Float64Array: 64-bit floating point (high precision) + * - Int8Array: 8-bit signed integer (binary quantization) + * - Uint8Array: 8-bit unsigned integer (quantization) + * - Int16Array: 16-bit signed integer (quantization) + * - Uint16Array: 16-bit unsigned integer (quantization) + */ +export type TypedArray = + | Float32Array + | Float16Array + | Float64Array + | Int8Array + | Uint8Array + | Int16Array + | Uint16Array; + +export type TypedArrayString = + | "TypedArray" + | "TypedArray:Float16Array" + | "TypedArray:Float32Array" + | "TypedArray:Float64Array" + | "TypedArray:Int8Array" + | "TypedArray:Uint8Array" + | "TypedArray:Int16Array" + | "TypedArray:Uint16Array"; + +// Type-only value for use in deserialize patterns +const TypedArrayType = null as any as TypedArray; + +const TypedArraySchemaOptions = { + ...FromSchemaDefaultOptions, + deserialize: [ + { + pattern: { type: "array", format: "TypedArray:Float64Array" }, + output: Float64Array, + }, + { + pattern: { type: "array", format: "TypedArray:Float32Array" }, + output: Float32Array, + }, + { + pattern: { type: "array", format: "TypedArray:Float16Array" }, + output: Float16Array, + }, + { + pattern: { type: "array", format: "TypedArray:Int16Array" }, + output: Int16Array, + }, + { + pattern: { type: "array", format: "TypedArray:Int8Array" }, + output: Int8Array, + }, + { + pattern: { type: "array", format: "TypedArray:Uint8Array" }, + output: Uint8Array, + }, + { + pattern: { type: "array", format: "TypedArray:Uint16Array" }, + output: Uint16Array, + }, + { + pattern: { type: "array", format: "TypedArray" }, + output: TypedArrayType, + }, + ], +} as const satisfies FromSchemaOptions; + +export type TypedArraySchemaOptions = typeof TypedArraySchemaOptions; + +export type VectorFromSchema = FromSchema< + SCHEMA, + TypedArraySchemaOptions +>; + +export const TypedArraySchema = (annotations: Record = {}) => { + return { + type: "array", + format: "TypedArray", + title: "Typed Array", + description: "A typed array (Float32Array, Int8Array, etc.)", + ...annotations, + } as const satisfies JsonSchema; +}; diff --git a/packages/util/src/vector/VectorSimilarityUtils.ts b/packages/util/src/vector/VectorSimilarityUtils.ts new file mode 100644 index 00000000..74151950 --- /dev/null +++ b/packages/util/src/vector/VectorSimilarityUtils.ts @@ -0,0 +1,92 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TypedArray } from "./TypedArray"; + +/** + * Calculates cosine similarity between two vectors + * Returns a value between -1 and 1, where 1 means identical direction + */ +export function cosineSimilarity(a: TypedArray, b: TypedArray): number { + if (a.length !== b.length) { + throw new Error("Vectors must have the same length"); + } + let dotProduct = 0; + let normA = 0; + let normB = 0; + for (let i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + const denominator = Math.sqrt(normA) * Math.sqrt(normB); + if (denominator === 0) { + return 0; + } + return dotProduct / denominator; +} + +/** + * Calculates Jaccard similarity between two vectors + * Uses the formula: sum(min(a[i], b[i])) / sum(max(a[i], b[i])) + * Returns a value between 0 and 1 + * For negative values, normalizes by finding the global min and shifting to non-negative range + */ +export function jaccardSimilarity(a: TypedArray, b: TypedArray): number { + if (a.length !== b.length) { + throw new Error("Vectors must have the same length"); + } + + // Find global min across both vectors to handle negative values + let globalMin = a[0]; + for (let i = 0; i < a.length; i++) { + globalMin = Math.min(globalMin, a[i], b[i]); + } + + // Shift values to non-negative range if needed + const shift = globalMin < 0 ? -globalMin : 0; + + let minSum = 0; + let maxSum = 0; + + for (let i = 0; i < a.length; i++) { + const shiftedA = a[i] + shift; + const shiftedB = b[i] + shift; + minSum += Math.min(shiftedA, shiftedB); + maxSum += Math.max(shiftedA, shiftedB); + } + + return maxSum === 0 ? 0 : minSum / maxSum; +} + +/** + * Calculates Hamming distance between two vectors (normalized) + * Counts the number of positions where vectors differ + * Returns a value between 0 and 1 (0 = identical, 1 = completely different) + */ +export function hammingDistance(a: TypedArray, b: TypedArray): number { + if (a.length !== b.length) { + throw new Error("Vectors must have the same length"); + } + + let differences = 0; + + for (let i = 0; i < a.length; i++) { + if (a[i] !== b[i]) { + differences++; + } + } + + return differences / a.length; +} + +/** + * Calculates Hamming similarity (inverse of distance) + * Returns a value between 0 and 1 (1 = identical, 0 = completely different) + */ +export function hammingSimilarity(a: TypedArray, b: TypedArray): number { + return 1 - hammingDistance(a, b); +} diff --git a/packages/util/src/vector/VectorUtils.ts b/packages/util/src/vector/VectorUtils.ts new file mode 100644 index 00000000..e7044415 --- /dev/null +++ b/packages/util/src/vector/VectorUtils.ts @@ -0,0 +1,95 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TypedArray } from "./TypedArray"; + +/** + * Calculates the magnitude (L2 norm) of a vector + */ +export function magnitude(arr: TypedArray | number[]): number { + // @ts-ignore - Vector reduce works but TS doesn't recognize it + return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0)); +} + +/** + * Calculates the inner (dot) product of two vectors + */ +export function inner(arr1: TypedArray, arr2: TypedArray): number { + if (arr1.length !== arr2.length) { + throw new Error("Vectors must have the same length to compute inner product."); + } + // @ts-ignore - Vector reduce works but TS doesn't recognize it + return arr1.reduce((acc, val, i) => acc + val * arr2[i], 0); +} + +/** + * Normalizes a vector to unit length (L2 normalization) + * + * @param vector - The vector to normalize + * @param throwOnZero - If true, throws an error for zero vectors. If false, returns the original vector. + * @returns Normalized vector with the same type as input + */ +export function normalize(vector: TypedArray, throwOnZero = true, float32 = false): TypedArray { + const mag = magnitude(vector); + + if (mag === 0) { + if (throwOnZero) { + throw new Error("Cannot normalize a zero vector."); + } + return vector; + } + + const normalized = Array.from(vector).map((val) => Number(val) / mag); + + if (float32) { + return new Float32Array(normalized); + } + + // Preserve the original Vector type + if (vector instanceof Float64Array) { + return new Float64Array(normalized); + } + if (vector instanceof Float16Array) { + return new Float16Array(normalized); + } + if (vector instanceof Float32Array) { + return new Float32Array(normalized); + } + if (vector instanceof Int8Array) { + return new Int8Array(normalized); + } + if (vector instanceof Uint8Array) { + return new Uint8Array(normalized); + } + if (vector instanceof Int16Array) { + return new Int16Array(normalized); + } + if (vector instanceof Uint16Array) { + return new Uint16Array(normalized); + } + // For other integer arrays, use Float32Array since normalization produces floats + return new Float32Array(normalized); +} + +/** + * Normalizes an array of numbers to unit length (L2 normalization) + * + * @param values - The array of numbers to normalize + * @param throwOnZero - If true, throws an error for zero vectors. If false, returns the original array. + * @returns Normalized array of numbers + */ +export function normalizeNumberArray(values: number[], throwOnZero = false): number[] { + const norm = magnitude(values); + + if (norm === 0) { + if (throwOnZero) { + throw new Error("Cannot normalize a zero vector."); + } + return values; + } + + return values.map((v) => v / norm); +} From 41a03a63911eb85b90690bcdb4baf6db6f3c5c83 Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 11 Jan 2026 07:24:07 +0000 Subject: [PATCH 04/14] [feat] Refactor storage structure to rename queue-limiter and add document-node-vector repositories - Reorganized storage exports to include new queue-limiter implementations for Postgres, Sqlite, and IndexedDb. - Added document-node-vector repositories for Postgres and Sqlite, enhancing document storage capabilities. - Updated existing references in common-server and common files to reflect the new structure. --- packages/storage/src/browser.ts | 4 +- packages/storage/src/common-server.ts | 11 +- packages/storage/src/common.ts | 10 +- .../DocumentNodeVectorRepositoryRegistry.ts | 91 ++++ .../DocumentNodeVectorSchema.ts | 35 ++ .../IDocumentNodeVectorRepository.ts | 105 +++++ .../InMemoryDocumentNodeVectorRepository.ts | 189 ++++++++ .../PostgresDocumentNodeVectorRepository.ts | 297 ++++++++++++ .../src/document-node-vector/README.md | 446 ++++++++++++++++++ .../SqliteDocumentNodeVectorRepository.ts | 196 ++++++++ .../storage/src/kv/KvViaTabularRepository.ts | 4 +- .../IRateLimiterStorage.ts | 0 .../InMemoryRateLimiterStorage.ts | 0 .../IndexedDbRateLimiterStorage.ts | 0 .../PostgresRateLimiterStorage.ts | 0 .../SqliteRateLimiterStorage.ts | 0 .../SupabaseRateLimiterStorage.ts | 0 .../src/tabular/BaseSqlTabularRepository.ts | 17 +- ...Repository.ts => BaseTabularRepository.ts} | 15 +- .../src/tabular/CachedTabularRepository.ts | 44 +- .../src/tabular/FsFolderTabularRepository.ts | 19 +- .../src/tabular/InMemoryTabularRepository.ts | 19 +- .../src/tabular/IndexedDbTabularRepository.ts | 19 +- .../src/tabular/PostgresTabularRepository.ts | 150 +++++- .../SharedInMemoryTabularRepository.ts | 45 +- .../src/tabular/SqliteTabularRepository.ts | 24 +- .../src/tabular/SupabaseTabularRepository.ts | 25 +- .../src/storage/TaskGraphTabularRepository.ts | 4 +- .../storage/TaskOutputTabularRepository.ts | 4 +- packages/test/src/samples/ONNXModelSamples.ts | 12 + .../storage-kv/SupabaseKvRepository.test.ts | 25 +- .../IndexedDbTabularRepository.test.ts | 7 +- .../SupabaseTabularRepository.test.ts | 25 +- .../src/test/task-graph/InputResolver.test.ts | 265 +++++++++++ .../TaskGraphFormatSemantic.test.ts | 115 ++++- packages/test/src/test/util/Document.test.ts | 52 ++ 36 files changed, 2081 insertions(+), 193 deletions(-) create mode 100644 packages/storage/src/document-node-vector/DocumentNodeVectorRepositoryRegistry.ts create mode 100644 packages/storage/src/document-node-vector/DocumentNodeVectorSchema.ts create mode 100644 packages/storage/src/document-node-vector/IDocumentNodeVectorRepository.ts create mode 100644 packages/storage/src/document-node-vector/InMemoryDocumentNodeVectorRepository.ts create mode 100644 packages/storage/src/document-node-vector/PostgresDocumentNodeVectorRepository.ts create mode 100644 packages/storage/src/document-node-vector/README.md create mode 100644 packages/storage/src/document-node-vector/SqliteDocumentNodeVectorRepository.ts rename packages/storage/src/{limiter => queue-limiter}/IRateLimiterStorage.ts (100%) rename packages/storage/src/{limiter => queue-limiter}/InMemoryRateLimiterStorage.ts (100%) rename packages/storage/src/{limiter => queue-limiter}/IndexedDbRateLimiterStorage.ts (100%) rename packages/storage/src/{limiter => queue-limiter}/PostgresRateLimiterStorage.ts (100%) rename packages/storage/src/{limiter => queue-limiter}/SqliteRateLimiterStorage.ts (100%) rename packages/storage/src/{limiter => queue-limiter}/SupabaseRateLimiterStorage.ts (100%) rename packages/storage/src/tabular/{TabularRepository.ts => BaseTabularRepository.ts} (97%) create mode 100644 packages/test/src/test/task-graph/InputResolver.test.ts create mode 100644 packages/test/src/test/util/Document.test.ts diff --git a/packages/storage/src/browser.ts b/packages/storage/src/browser.ts index 4a0d756c..f960c5ed 100644 --- a/packages/storage/src/browser.ts +++ b/packages/storage/src/browser.ts @@ -16,7 +16,7 @@ export * from "./kv/SupabaseKvRepository"; export * from "./queue/IndexedDbQueueStorage"; export * from "./queue/SupabaseQueueStorage"; -export * from "./limiter/IndexedDbRateLimiterStorage"; -export * from "./limiter/SupabaseRateLimiterStorage"; +export * from "./queue-limiter/IndexedDbRateLimiterStorage"; +export * from "./queue-limiter/SupabaseRateLimiterStorage"; export * from "./util/IndexedDbTable"; diff --git a/packages/storage/src/common-server.ts b/packages/storage/src/common-server.ts index ab1bca3f..61b76cf2 100644 --- a/packages/storage/src/common-server.ts +++ b/packages/storage/src/common-server.ts @@ -21,13 +21,16 @@ export * from "./queue/PostgresQueueStorage"; export * from "./queue/SqliteQueueStorage"; export * from "./queue/SupabaseQueueStorage"; -export * from "./limiter/PostgresRateLimiterStorage"; -export * from "./limiter/SqliteRateLimiterStorage"; -export * from "./limiter/SupabaseRateLimiterStorage"; +export * from "./queue-limiter/PostgresRateLimiterStorage"; +export * from "./queue-limiter/SqliteRateLimiterStorage"; +export * from "./queue-limiter/SupabaseRateLimiterStorage"; + +export * from "./document-node-vector/PostgresDocumentNodeVectorRepository"; +export * from "./document-node-vector/SqliteDocumentNodeVectorRepository"; // testing export * from "./kv/IndexedDbKvRepository"; -export * from "./limiter/IndexedDbRateLimiterStorage"; +export * from "./queue-limiter/IndexedDbRateLimiterStorage"; export * from "./queue/IndexedDbQueueStorage"; export * from "./tabular/IndexedDbTabularRepository"; export * from "./util/IndexedDbTable"; diff --git a/packages/storage/src/common.ts b/packages/storage/src/common.ts index 6ed0fc5e..3049d312 100644 --- a/packages/storage/src/common.ts +++ b/packages/storage/src/common.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +export * from "./tabular/BaseTabularRepository"; export * from "./tabular/CachedTabularRepository"; export * from "./tabular/InMemoryTabularRepository"; export * from "./tabular/ITabularRepository"; @@ -19,8 +20,8 @@ export * from "./kv/KvViaTabularRepository"; export * from "./queue/InMemoryQueueStorage"; export * from "./queue/IQueueStorage"; -export * from "./limiter/InMemoryRateLimiterStorage"; -export * from "./limiter/IRateLimiterStorage"; +export * from "./queue-limiter/InMemoryRateLimiterStorage"; +export * from "./queue-limiter/IRateLimiterStorage"; export * from "./util/HybridSubscriptionManager"; export * from "./util/PollingSubscriptionManager"; @@ -32,3 +33,8 @@ export * from "./document/DocumentRepositoryRegistry"; export * from "./document/DocumentSchema"; export * from "./document/DocumentStorageSchema"; export * from "./document/StructuralParser"; + +export * from "./document-node-vector/DocumentNodeVectorRepositoryRegistry"; +export * from "./document-node-vector/DocumentNodeVectorSchema"; +export * from "./document-node-vector/IDocumentNodeVectorRepository"; +export * from "./document-node-vector/InMemoryDocumentNodeVectorRepository"; diff --git a/packages/storage/src/document-node-vector/DocumentNodeVectorRepositoryRegistry.ts b/packages/storage/src/document-node-vector/DocumentNodeVectorRepositoryRegistry.ts new file mode 100644 index 00000000..e2f69781 --- /dev/null +++ b/packages/storage/src/document-node-vector/DocumentNodeVectorRepositoryRegistry.ts @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import { AnyDocumentNodeVectorRepository } from "./IDocumentNodeVectorRepository"; + +/** + * Service token for the documenbt chunk vector repository registry + * Maps repository IDs to IVectorChunkRepository instances + */ +export const DOCUMENT_CHUNK_VECTOR_REPOSITORIES = createServiceToken< + Map +>("storage.document-node-vector.repositories"); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES)) { + globalServiceRegistry.register( + DOCUMENT_CHUNK_VECTOR_REPOSITORIES, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global document chunk vector repository registry + * @returns Map of document chunk vector repository ID to instance + */ +export function getGlobalDocumentNodeVectorRepositories(): Map< + string, + AnyDocumentNodeVectorRepository +> { + return globalServiceRegistry.get(DOCUMENT_CHUNK_VECTOR_REPOSITORIES); +} + +/** + * Registers a vector repository globally by ID + * @param id The unique identifier for this repository + * @param repository The repository instance to register + */ +export function registerDocumentNodeVectorRepository( + id: string, + repository: AnyDocumentNodeVectorRepository +): void { + const repos = getGlobalDocumentNodeVectorRepositories(); + repos.set(id, repository); +} + +/** + * Gets a document chunk vector repository by ID from the global registry + * @param id The repository identifier + * @returns The repository instance or undefined if not found + */ +export function getDocumentNodeVectorRepository( + id: string +): AnyDocumentNodeVectorRepository | undefined { + return getGlobalDocumentNodeVectorRepositories().get(id); +} + +/** + * Resolves a repository ID to an IVectorChunkRepository from the registry. + * Used by the input resolver system. + */ +async function resolveDocumentNodeVectorRepositoryFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise { + const repos = registry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) + ? registry.get>(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) + : getGlobalDocumentNodeVectorRepositories(); + + const repo = repos.get(id); + if (!repo) { + throw new Error(`Document chunk vector repository "${id}" not found in registry`); + } + return repo; +} + +// Register the repository resolver for format: "repository:document-node-vector" +registerInputResolver( + "repository:document-node-vector", + resolveDocumentNodeVectorRepositoryFromRegistry +); diff --git a/packages/storage/src/document-node-vector/DocumentNodeVectorSchema.ts b/packages/storage/src/document-node-vector/DocumentNodeVectorSchema.ts new file mode 100644 index 00000000..c3bd30b1 --- /dev/null +++ b/packages/storage/src/document-node-vector/DocumentNodeVectorSchema.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { TypedArraySchema, type DataPortSchemaObject, type TypedArray } from "@workglow/util"; + +/** + * Default schema for document chunk storage with vector embeddings + */ +export const DocumentNodeVectorSchema = { + type: "object", + properties: { + chunk_id: { type: "string" }, + doc_id: { type: "string" }, + vector: TypedArraySchema(), + metadata: { type: "object", additionalProperties: true }, + }, + additionalProperties: false, +} as const satisfies DataPortSchemaObject; +export type DocumentNodeVectorSchema = typeof DocumentNodeVectorSchema; + +export const DocumentNodeVectorKey = ["chunk_id"] as const; +export type DocumentNodeVectorKey = typeof DocumentNodeVectorKey; + +export interface DocumentNodeVector< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> { + chunk_id: string; + doc_id: string; + vector: Vector; + metadata: Metadata; +} diff --git a/packages/storage/src/document-node-vector/IDocumentNodeVectorRepository.ts b/packages/storage/src/document-node-vector/IDocumentNodeVectorRepository.ts new file mode 100644 index 00000000..548c2695 --- /dev/null +++ b/packages/storage/src/document-node-vector/IDocumentNodeVectorRepository.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + DataPortSchemaObject, + EventParameters, + FromSchema, + TypedArray, + TypedArraySchemaOptions, +} from "@workglow/util"; +import type { ITabularRepository, TabularEventListeners } from "../tabular/ITabularRepository"; + +export type AnyDocumentNodeVectorRepository = IDocumentNodeVectorRepository; + +/** + * Options for vector search operations + */ +export interface VectorSearchOptions> { + readonly topK?: number; + readonly filter?: Partial; + readonly scoreThreshold?: number; +} + +/** + * Options for hybrid search (vector + full-text) + */ +export interface HybridSearchOptions< + Metadata = Record, +> extends VectorSearchOptions { + readonly textQuery: string; + readonly vectorWeight?: number; +} + +/** + * Type definitions for document chunk vector repository events + */ +export interface VectorChunkEventListeners extends TabularEventListeners< + PrimaryKey, + Entity +> { + similaritySearch: (query: TypedArray, results: (Entity & { score: number })[]) => void; + hybridSearch: (query: TypedArray, results: (Entity & { score: number })[]) => void; +} + +export type VectorChunkEventName = keyof VectorChunkEventListeners; +export type VectorChunkEventListener< + Event extends VectorChunkEventName, + PrimaryKey, + Entity, +> = VectorChunkEventListeners[Event]; + +export type VectorChunkEventParameters< + Event extends VectorChunkEventName, + PrimaryKey, + Entity, +> = EventParameters, Event>; + +/** + * Interface defining the contract for document chunk vector storage repositories. + * These repositories store vector embeddings with metadata for decument chunks. + * Extends ITabularRepository to provide standard storage operations, + * plus vector-specific similarity search capabilities. + * Supports various vector types including quantized formats. + * + * @typeParam Schema - The schema definition for the entity using JSON Schema + * @typeParam PrimaryKeyNames - Array of property names that form the primary key + * @typeParam Entity - The entity type + */ +export interface IDocumentNodeVectorRepository< + Schema extends DataPortSchemaObject, + PrimaryKeyNames extends ReadonlyArray, + Entity = FromSchema, +> extends ITabularRepository { + /** + * Get the vector dimension + * @returns The vector dimension + */ + getVectorDimensions(): number; + + /** + * Search for similar vectors using similarity scoring + * @param query - Query vector to compare against + * @param options - Search options (topK, filter, scoreThreshold) + * @returns Array of search results sorted by similarity (highest first) + */ + similaritySearch( + query: TypedArray, + options?: VectorSearchOptions> + ): Promise<(Entity & { score: number })[]>; + + /** + * Hybrid search combining vector similarity with full-text search + * This is optional and may not be supported by all implementations + * @param query - Query vector to compare against + * @param options - Hybrid search options including text query + * @returns Array of search results sorted by combined relevance + */ + hybridSearch?( + query: TypedArray, + options: HybridSearchOptions> + ): Promise<(Entity & { score: number })[]>; +} diff --git a/packages/storage/src/document-node-vector/InMemoryDocumentNodeVectorRepository.ts b/packages/storage/src/document-node-vector/InMemoryDocumentNodeVectorRepository.ts new file mode 100644 index 00000000..61fb1d8b --- /dev/null +++ b/packages/storage/src/document-node-vector/InMemoryDocumentNodeVectorRepository.ts @@ -0,0 +1,189 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TypedArray } from "@workglow/util"; +import { cosineSimilarity } from "@workglow/util"; +import { InMemoryTabularRepository } from "../tabular/InMemoryTabularRepository"; +import { + DocumentNodeVector, + DocumentNodeVectorKey, + DocumentNodeVectorSchema, +} from "./DocumentNodeVectorSchema"; +import type { + HybridSearchOptions, + IDocumentNodeVectorRepository, + VectorSearchOptions, +} from "./IDocumentNodeVectorRepository"; + +/** + * Check if metadata matches filter + */ +function matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; +} + +/** + * Simple full-text search scoring (keyword matching) + */ +function textRelevance(text: string, query: string): number { + const textLower = text.toLowerCase(); + const queryLower = query.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + if (queryWords.length === 0) { + return 0; + } + let matches = 0; + for (const word of queryWords) { + if (textLower.includes(word)) { + matches++; + } + } + return matches / queryWords.length; +} + +/** + * In-memory document chunk vector repository implementation. + * Extends InMemoryTabularRepository for storage. + * Suitable for testing and small-scale browser applications. + * Supports all vector types including quantized formats. + * + * @template Metadata - The metadata type for the document chunk + * @template Vector - The vector type for the document chunk + */ +export class InMemoryDocumentNodeVectorRepository< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> + extends InMemoryTabularRepository< + typeof DocumentNodeVectorSchema, + typeof DocumentNodeVectorKey, + DocumentNodeVector + > + implements + IDocumentNodeVectorRepository< + typeof DocumentNodeVectorSchema, + typeof DocumentNodeVectorKey, + DocumentNodeVector + > +{ + private vectorDimensions: number; + private VectorType: new (array: number[]) => TypedArray; + + /** + * Creates a new in-memory document chunk vector repository + * @param dimensions - The number of dimensions of the vector + * @param VectorType - The type of vector to use (defaults to Float32Array) + */ + constructor(dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array) { + super(DocumentNodeVectorSchema, DocumentNodeVectorKey); + + this.vectorDimensions = dimensions; + this.VectorType = VectorType; + } + + /** + * Get the vector dimensions + * @returns The vector dimensions + */ + getVectorDimensions(): number { + return this.vectorDimensions; + } + + async similaritySearch( + query: TypedArray, + options: VectorSearchOptions> = {} + ) { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const results: Array & { score: number }> = []; + + const allEntities = (await this.getAll()) || []; + + for (const entity of allEntities) { + const vector = entity.vector; + const metadata = entity.metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate similarity + const score = cosineSimilarity(query, vector); + + // Apply threshold + if (score < scoreThreshold) { + continue; + } + + results.push({ + ...entity, + vector, + score, + }); + } + + // Sort by score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } + + async hybridSearch(query: TypedArray, options: HybridSearchOptions>) { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + // Fall back to regular vector search if no text query + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + const results: Array & { score: number }> = []; + const allEntities = (await this.getAll()) || []; + + for (const entity of allEntities) { + // In memory, vectors are stored as TypedArrays directly (not serialized) + const vector = entity.vector; + const metadata = entity.metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate vector similarity + const vectorScore = cosineSimilarity(query, vector); + + // Calculate text relevance (simple keyword matching) + const metadataText = Object.values(metadata).join(" ").toLowerCase(); + const textScore = textRelevance(metadataText, textQuery); + + // Combine scores + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + // Apply threshold + if (combinedScore < scoreThreshold) { + continue; + } + + results.push({ + ...entity, + vector, + score: combinedScore, + }); + } + + // Sort by combined score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } +} diff --git a/packages/storage/src/document-node-vector/PostgresDocumentNodeVectorRepository.ts b/packages/storage/src/document-node-vector/PostgresDocumentNodeVectorRepository.ts new file mode 100644 index 00000000..5d067c7c --- /dev/null +++ b/packages/storage/src/document-node-vector/PostgresDocumentNodeVectorRepository.ts @@ -0,0 +1,297 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { cosineSimilarity, type TypedArray } from "@workglow/util"; +import type { Pool } from "pg"; +import { PostgresTabularRepository } from "../tabular/PostgresTabularRepository"; +import { + DocumentNodeVector, + DocumentNodeVectorKey, + DocumentNodeVectorSchema, +} from "./DocumentNodeVectorSchema"; +import type { + HybridSearchOptions, + IDocumentNodeVectorRepository, + VectorSearchOptions, +} from "./IDocumentNodeVectorRepository"; + +/** + * PostgreSQL document chunk vector repository implementation using pgvector extension. + * Extends PostgresTabularRepository for storage. + * Provides efficient vector similarity search with native database support. + * + * Requirements: + * - PostgreSQL database with pgvector extension installed + * - CREATE EXTENSION vector; + * + * @template Metadata - The metadata type for the document chunk + * @template Vector - The vector type for the document chunk + */ +export class PostgresDocumentNodeVectorRepository< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> + extends PostgresTabularRepository< + typeof DocumentNodeVectorSchema, + typeof DocumentNodeVectorKey, + DocumentNodeVector + > + implements + IDocumentNodeVectorRepository< + typeof DocumentNodeVectorSchema, + typeof DocumentNodeVectorKey, + DocumentNodeVector + > +{ + private vectorDimensions: number; + private VectorType: new (array: number[]) => TypedArray; + /** + * Creates a new PostgreSQL document chunk vector repository + * @param db - PostgreSQL connection pool + * @param table - The name of the table to use for storage + * @param dimensions - The number of dimensions of the vector + * @param VectorType - The type of vector to use (defaults to Float32Array) + */ + constructor( + db: Pool, + table: string, + dimensions: number, + VectorType: new (array: number[]) => TypedArray = Float32Array + ) { + super(db, table, DocumentNodeVectorSchema, DocumentNodeVectorKey); + + this.vectorDimensions = dimensions; + this.VectorType = VectorType; + } + + getVectorDimensions(): number { + return this.vectorDimensions; + } + + async similaritySearch( + query: TypedArray, + options: VectorSearchOptions = {} + ): Promise & { score: number }>> { + const { topK = 10, filter, scoreThreshold = 0 } = options; + + try { + // Try native pgvector search first + const queryVector = `[${Array.from(query).join(",")}]`; + let sql = ` + SELECT + *, + 1 - (vector <=> $1::vector) as score + FROM "${this.table}" + `; + + const params: any[] = [queryVector]; + let paramIndex = 2; + + if (filter && Object.keys(filter).length > 0) { + const conditions: string[] = []; + for (const [key, value] of Object.entries(filter)) { + conditions.push(`metadata->>'${key}' = $${paramIndex}`); + params.push(String(value)); + paramIndex++; + } + sql += ` WHERE ${conditions.join(" AND ")}`; + } + + if (scoreThreshold > 0) { + sql += filter ? " AND" : " WHERE"; + sql += ` (1 - (vector <=> $1::vector)) >= $${paramIndex}`; + params.push(scoreThreshold); + paramIndex++; + } + + sql += ` ORDER BY vector <=> $1::vector LIMIT $${paramIndex}`; + params.push(topK); + + const result = await this.db.query(sql, params); + + // Fetch vectors separately for each result + const results: Array & { score: number }> = []; + for (const row of result.rows) { + const vectorResult = await this.db.query( + `SELECT vector::text FROM "${this.table}" WHERE id = $1`, + [row.id] + ); + const vectorStr = vectorResult.rows[0]?.vector || "[]"; + const vectorArray = JSON.parse(vectorStr); + + results.push({ + ...row, + vector: new this.VectorType(vectorArray), + score: parseFloat(row.score), + } as any); + } + + return results; + } catch (error) { + // Fall back to in-memory similarity calculation if pgvector is not available + console.warn("pgvector query failed, falling back to in-memory search:", error); + return this.searchFallback(query, options); + } + } + + async hybridSearch(query: TypedArray, options: HybridSearchOptions) { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + try { + // Try native hybrid search with pgvector + full-text + const queryVector = `[${Array.from(query).join(",")}]`; + const tsQuery = textQuery.split(/\s+/).join(" & "); + + let sql = ` + SELECT + *, + ( + $2 * (1 - (vector <=> $1::vector)) + + $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) + ) as score + FROM "${this.table}" + `; + + const params: any[] = [queryVector, vectorWeight, 1 - vectorWeight, tsQuery]; + let paramIndex = 5; + + if (filter && Object.keys(filter).length > 0) { + const conditions: string[] = []; + for (const [key, value] of Object.entries(filter)) { + conditions.push(`metadata->>'${key}' = $${paramIndex}`); + params.push(String(value)); + paramIndex++; + } + sql += ` WHERE ${conditions.join(" AND ")}`; + } + + if (scoreThreshold > 0) { + sql += filter ? " AND" : " WHERE"; + sql += ` ( + $2 * (1 - (vector <=> $1::vector)) + + $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) + ) >= $${paramIndex}`; + params.push(scoreThreshold); + paramIndex++; + } + + sql += ` ORDER BY score DESC LIMIT $${paramIndex}`; + params.push(topK); + + const result = await this.db.query(sql, params); + + // Fetch vectors separately for each result + const results: Array & { score: number }> = []; + for (const row of result.rows) { + const vectorResult = await this.db.query( + `SELECT vector::text FROM "${this.table}" WHERE id = $1`, + [row.id] + ); + const vectorStr = vectorResult.rows[0]?.vector || "[]"; + const vectorArray = JSON.parse(vectorStr); + + results.push({ + ...row, + vector: new this.VectorType(vectorArray), + score: parseFloat(row.score), + } as any); + } + + return results; + } catch (error) { + // Fall back to in-memory hybrid search + console.warn("pgvector hybrid query failed, falling back to in-memory search:", error); + return this.hybridSearchFallback(query, options); + } + } + + /** + * Fallback search using in-memory cosine similarity + */ + private async searchFallback(query: TypedArray, options: VectorSearchOptions) { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const allRows = (await this.getAll()) || []; + const results: Array & { score: number }> = []; + + for (const row of allRows) { + const vector = row.vector; + const metadata = row.metadata; + + if (filter && !this.matchesFilter(metadata, filter)) { + continue; + } + + const score = cosineSimilarity(query, vector); + + if (score >= scoreThreshold) { + results.push({ ...row, vector, score }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } + + /** + * Fallback hybrid search + */ + private async hybridSearchFallback(query: TypedArray, options: HybridSearchOptions) { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + const allRows = (await this.getAll()) || []; + const results: Array & { score: number }> = []; + const queryLower = textQuery.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + + for (const row of allRows) { + const vector = row.vector; + const metadata = row.metadata; + + if (filter && !this.matchesFilter(metadata, filter)) { + continue; + } + + const vectorScore = cosineSimilarity(query, vector); + const metadataText = JSON.stringify(metadata).toLowerCase(); + let textScore = 0; + if (queryWords.length > 0) { + let matches = 0; + for (const word of queryWords) { + if (metadataText.includes(word)) { + matches++; + } + } + textScore = matches / queryWords.length; + } + + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + if (combinedScore >= scoreThreshold) { + results.push({ ...row, vector, score: combinedScore }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } + + private matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; + } +} diff --git a/packages/storage/src/document-node-vector/README.md b/packages/storage/src/document-node-vector/README.md new file mode 100644 index 00000000..3e848f59 --- /dev/null +++ b/packages/storage/src/document-node-vector/README.md @@ -0,0 +1,446 @@ +# Vector Storage Module + +A flexible vector storage solution with multiple backend implementations for RAG (Retrieval-Augmented Generation) pipelines. Provides a consistent interface for vector CRUD operations with similarity search and hybrid search capabilities. + +## Features + +- **Multiple Storage Backends:** + - 🧠 `InMemoryVectorRepository` - Fast in-memory storage for testing and small datasets + - 📁 `SqliteVectorRepository` - Persistent SQLite storage for local applications + - 🐘 `PostgresVectorRepository` - PostgreSQL with pgvector extension for production + - 🔍 `SeekDbVectorRepository` - SeekDB/OceanBase with native hybrid search + - 📱 `EdgeVecRepository` - Edge/browser deployment with IndexedDB and WebGPU support + +- **Quantized Vector Support:** + - Float32Array (standard 32-bit floating point) + - Float16Array (16-bit floating point) + - Float64Array (64-bit high precision) + - Int8Array (8-bit signed - binary quantization) + - Uint8Array (8-bit unsigned - quantization) + - Int16Array (16-bit signed - quantization) + - Uint16Array (16-bit unsigned - quantization) + +- **Advanced Search Capabilities:** + - Vector similarity search (cosine similarity) + - Hybrid search (vector + full-text) + - Metadata filtering + - Top-K retrieval with score thresholds + +- **Production Ready:** + - Type-safe interfaces + - Event emitters for monitoring + - Bulk operations support + - Efficient indexing strategies + +## Installation + +```bash +bun install @workglow/storage +``` + +## Usage + +### In-Memory Repository (Testing/Browser) + +```typescript +import { InMemoryVectorRepository } from "@workglow/storage"; + +// Standard Float32 vectors +const repo = new InMemoryVectorRepository<{ text: string; source: string }>(); +await repo.setupDatabase(); + +// Upsert vectors +await repo.upsert( + "doc1", + new Float32Array([0.1, 0.2, 0.3, ...]), + { text: "Hello world", source: "example.txt" } +); + +// Search for similar vectors +const results = await repo.similaritySearch( + new Float32Array([0.15, 0.25, 0.35, ...]), + { topK: 5, scoreThreshold: 0.7 } +); +``` + +### Quantized Vectors (Reduced Storage) + +```typescript +import { InMemoryVectorRepository } from "@workglow/storage"; + +// Use Int8Array for 4x smaller storage (binary quantization) +const repo = new InMemoryVectorRepository< + { text: string }, + Int8Array +>(); +await repo.setupDatabase(); + +// Store quantized vectors +await repo.upsert( + "doc1", + new Int8Array([127, -128, 64, ...]), + { text: "Quantized embedding" } +); + +// Search with quantized query +const results = await repo.similaritySearch( + new Int8Array([100, -50, 75, ...]), + { topK: 5 } +); +``` + +### SQLite Repository (Local Persistence) + +```typescript +import { SqliteVectorRepository } from "@workglow/storage"; + +const repo = new SqliteVectorRepository<{ text: string }>( + "./vectors.db", // database path + "embeddings" // table name +); +await repo.setupDatabase(); + +// Bulk upsert +await repo.upsertBulk([ + { id: "1", vector: new Float32Array([...]), metadata: { text: "..." } }, + { id: "2", vector: new Float32Array([...]), metadata: { text: "..." } }, +]); +``` + +### PostgreSQL with pgvector + +```typescript +import { Pool } from "pg"; +import { PostgresVectorRepository } from "@workglow/storage"; + +const pool = new Pool({ connectionString: "postgresql://..." }); +const repo = new PostgresVectorRepository<{ text: string; category: string }>( + pool, + "vectors", + 384 // vector dimension +); +await repo.setupDatabase(); + +// Hybrid search (vector + full-text) +const results = await repo.hybridSearch(queryVector, { + textQuery: "machine learning", + topK: 10, + vectorWeight: 0.7, + filter: { category: "ai" }, +}); +``` + +### SeekDB (Hybrid Search Database) + +```typescript +import mysql from "mysql2/promise"; +import { SeekDbVectorRepository } from "@workglow/storage"; + +const pool = mysql.createPool({ host: "...", database: "..." }); +const repo = new SeekDbVectorRepository<{ text: string }>( + pool, + "vectors", + 768 // vector dimension +); +await repo.setupDatabase(); + +// Native hybrid search +const results = await repo.hybridSearch(queryVector, { + textQuery: "neural networks", + topK: 5, + vectorWeight: 0.6, +}); +``` + +### EdgeVec (Browser/Edge Deployment) + +```typescript +import { EdgeVecRepository } from "@workglow/storage"; + +const repo = new EdgeVecRepository<{ text: string }>({ + dbName: "my-vectors", // IndexedDB name + enableWebGPU: true, // Enable GPU acceleration +}); +await repo.setupDatabase(); + +// Works entirely in the browser +await repo.upsert("1", vector, { text: "..." }); +const results = await repo.similaritySearch(queryVector, { topK: 3 }); +``` + +## API Documentation + +### Core Methods + +All repositories implement the `IVectorRepository` interface: + +```typescript +interface IVectorRepository { + // Setup + setupDatabase(): Promise; + + // CRUD Operations + upsert(id: string, vector: Float32Array, metadata: Metadata): Promise; + upsertBulk(items: VectorEntry[]): Promise; + get(id: string): Promise | undefined>; + delete(id: string): Promise; + deleteBulk(ids: string[]): Promise; + deleteByFilter(filter: Partial): Promise; + + // Search + search( + query: Float32Array, + options?: VectorSearchOptions + ): Promise[]>; + hybridSearch?( + query: Float32Array, + options: HybridSearchOptions + ): Promise[]>; + + // Utility + size(): Promise; + clear(): Promise; + destroy(): void; + + // Events + on(event: "upsert" | "delete" | "search", callback: Function): void; +} +``` + +### Search Options + +```typescript +interface VectorSearchOptions { + topK?: number; // Number of results (default: 10) + filter?: Partial; // Filter by metadata + scoreThreshold?: number; // Minimum score (0-1) +} + +interface HybridSearchOptions extends VectorSearchOptions { + textQuery: string; // Full-text query + vectorWeight?: number; // Vector weight 0-1 (default: 0.7) +} +``` + +## Quantization Benefits + +Quantized vectors can significantly reduce storage and improve performance: + +| Vector Type | Bytes/Dim | Storage vs Float32 | Use Case | +| ------------ | --------- | ------------------ | ------------------------------------ | +| Float32Array | 4 | 100% (baseline) | Standard embeddings | +| Float64Array | 8 | 200% | High precision needed | +| Int16Array | 2 | 50% | Good precision/size tradeoff | +| Int8Array | 1 | 25% | Binary quantization, max compression | +| Uint8Array | 1 | 25% | Quantized embeddings [0-255] | + +**Example:** A 768-dimensional embedding: + +- Float32: 3,072 bytes +- Int8: 768 bytes (75% reduction!) + +## Performance Considerations + +### InMemory + +- **Best for:** Testing, small datasets (<10K vectors), browser apps +- **Pros:** Fastest, no dependencies, supports all vector types +- **Cons:** No persistence, memory limited + +### SQLite + +- **Best for:** Local apps, medium datasets (<100K vectors) +- **Pros:** Persistent, single file, no server +- **Cons:** No native vector indexing, slower for large datasets + +### PostgreSQL + pgvector + +- **Best for:** Production, large datasets (>100K vectors) +- **Pros:** HNSW indexing, efficient, scalable +- **Cons:** Requires PostgreSQL server and pgvector extension + +### SeekDB + +- **Best for:** Hybrid search workloads, production +- **Pros:** Native hybrid search, MySQL-compatible +- **Cons:** Requires SeekDB/OceanBase instance + +### EdgeVec + +- **Best for:** Privacy-sensitive apps, offline-first, edge computing +- **Pros:** No server, IndexedDB persistence, WebGPU acceleration +- **Cons:** Limited by browser storage, smaller datasets + +## Integration with RAG Tasks + +The vector repositories integrate seamlessly with RAG tasks: + +```typescript +import { InMemoryVectorRepository } from "@workglow/storage"; +import { Workflow } from "@workglow/task-graph"; + +const repo = new InMemoryVectorRepository(); +await repo.setupDatabase(); + +const workflow = new Workflow() + // Load and chunk document + .fileLoader({ path: "./doc.md" }) + .textChunker({ chunkSize: 512, chunkOverlap: 50 }) + + // Generate embeddings + .textEmbedding({ model: "Xenova/all-MiniLM-L6-v2" }) + + // Store in vector repository + .vectorStoreUpsert({ repository: repo }); + +await workflow.run(); + +// Later: Search +const searchWorkflow = new Workflow() + .textEmbedding({ text: "What is RAG?", model: "..." }) + .vectorStoreSearch({ repository: repo, topK: 5 }) + .contextBuilder({ format: "markdown" }) + .textQuestionAnswer({ question: "What is RAG?" }); + +const result = await searchWorkflow.run(); +``` + +## Hierarchical Document Integration + +For document-level storage and hierarchical context enrichment, use vector repositories alongside document repositories: + +```typescript +import { InMemoryVectorRepository, InMemoryDocumentRepository } from "@workglow/storage"; +import { Workflow } from "@workglow/task-graph"; + +const vectorRepo = new InMemoryVectorRepository(); +const docRepo = new InMemoryDocumentRepository(); +await vectorRepo.setupDatabase(); + +// Ingestion with hierarchical structure +await new Workflow() + .structuralParser({ + text: markdownContent, + title: "Documentation", + format: "markdown", + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .textEmbedding({ model: "Xenova/all-MiniLM-L6-v2" }) + .chunkToVector() + .vectorStoreUpsert({ repository: vectorRepo }) + .run(); + +// Retrieval with parent context +const result = await new Workflow() + .textEmbedding({ text: query, model: "Xenova/all-MiniLM-L6-v2" }) + .vectorStoreSearch({ repository: vectorRepo, topK: 10 }) + .hierarchyJoin({ + documentRepository: docRepo, + includeParentSummaries: true, + includeEntities: true, + }) + .reranker({ query, topK: 5 }) + .contextBuilder({ format: "markdown" }) + .run(); +``` + +### Vector Metadata for Hierarchical Documents + +When using hierarchical chunking, base vector metadata (stored in vector database) includes: + +```typescript +metadata: { + doc_id: string, // Document identifier + chunkId: string, // Chunk identifier + leafNodeId: string, // Reference to document tree node + depth: number, // Hierarchy depth + text: string, // Chunk text content + nodePath: string[], // Node IDs from root to leaf + // From enrichment (optional): + summary?: string, // Summary of the chunk content + entities?: Entity[], // Named entities extracted from the chunk +} +``` + +After `HierarchyJoinTask`, enriched metadata includes additional fields: + +```typescript +enrichedMetadata: { + // ... all base metadata fields above ... + parentSummaries?: string[], // Summaries from ancestor nodes (looked up on-demand) + sectionTitles?: string[], // Titles of ancestor section nodes +} +``` + +Note: `parentSummaries` is not stored in the vector database. It is computed on-demand by `HierarchyJoinTask` using `doc_id` and `leafNodeId` to look up ancestors from the document repository. + +## Document Repository + +The `IDocumentRepository` interface provides storage for hierarchical document structures: + +```typescript +class DocumentRepository { + constructor(tabularStorage: ITabularRepository, vectorStorage: IVectorRepository); + + upsert(document: Document): Promise; + get(doc_id: string): Promise; + getNode(doc_id: string, nodeId: string): Promise; + getAncestors(doc_id: string, nodeId: string): Promise; + getChunks(doc_id: string): Promise; + findChunksByNodeId(doc_id: string, nodeId: string): Promise; + delete(doc_id: string): Promise; + list(): Promise; + search(query: TypedArray, options?: VectorSearchOptions): Promise; +} +``` + +### Document Repository + +The `DocumentRepository` class provides a unified interface for storing hierarchical documents and searching chunks. It uses composition of storage backends: + +| Component | Purpose | +| -------------------- | -------------------------------------------- | +| `ITabularRepository` | Stores document structure and metadata | +| `IVectorRepository` | Enables similarity search on document chunks | + +**Example Usage:** + +```typescript +import { + DocumentRepository, + InMemoryTabularRepository, + InMemoryVectorRepository, +} from "@workglow/storage"; + +// Define schema for document storage +const DocumentStorageSchema = { + type: "object", + properties: { + doc_id: { type: "string" }, + data: { type: "string" }, + }, + required: ["doc_id", "data"], +} as const; + +// Initialize storage backends +const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["doc_id"]); +await tabularStorage.setupDatabase(); + +const vectorStorage = new InMemoryVectorRepository(); +await vectorStorage.setupDatabase(); + +// Create document repository +const docRepo = new DocumentRepository(tabularStorage, vectorStorage); + +// Use the repository +await docRepo.upsert(document); +const results = await docRepo.search(queryVector, { topK: 5 }); +``` + +## License + +Apache 2.0 diff --git a/packages/storage/src/document-node-vector/SqliteDocumentNodeVectorRepository.ts b/packages/storage/src/document-node-vector/SqliteDocumentNodeVectorRepository.ts new file mode 100644 index 00000000..0dd02653 --- /dev/null +++ b/packages/storage/src/document-node-vector/SqliteDocumentNodeVectorRepository.ts @@ -0,0 +1,196 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Sqlite } from "@workglow/sqlite"; +import type { TypedArray } from "@workglow/util"; +import { cosineSimilarity } from "@workglow/util"; +import { SqliteTabularRepository } from "../tabular/SqliteTabularRepository"; +import { + DocumentNodeVector, + DocumentNodeVectorKey, + DocumentNodeVectorSchema, +} from "./DocumentNodeVectorSchema"; +import type { + HybridSearchOptions, + IDocumentNodeVectorRepository, + VectorSearchOptions, +} from "./IDocumentNodeVectorRepository"; + +/** + * Check if metadata matches filter + */ +function matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; +} + +/** + * SQLite document chunk vector repository implementation using tabular storage underneath. + * Stores vectors as JSON-encoded arrays with metadata. + * + * @template Metadata - The metadata type for the document chunk + * @template Vector - The vector type for the document chunk + */ +export class SqliteDocumentNodeVectorRepository< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> + extends SqliteTabularRepository< + typeof DocumentNodeVectorSchema, + typeof DocumentNodeVectorKey, + DocumentNodeVector + > + implements + IDocumentNodeVectorRepository< + typeof DocumentNodeVectorSchema, + typeof DocumentNodeVectorKey, + DocumentNodeVector + > +{ + private vectorDimensions: number; + private VectorType: new (array: number[]) => TypedArray; + + /** + * Creates a new SQLite document chunk vector repository + * @param dbOrPath - Either a Database instance or a path to the SQLite database file + * @param table - The name of the table to use for storage (defaults to 'vectors') + * @param dimensions - The number of dimensions of the vector + * @param VectorType - The type of vector to use (defaults to Float32Array) + */ + constructor( + dbOrPath: string | Sqlite.Database, + table: string = "vectors", + dimensions: number, + VectorType: new (array: number[]) => TypedArray = Float32Array + ) { + super(dbOrPath, table, DocumentNodeVectorSchema, DocumentNodeVectorKey); + + this.vectorDimensions = dimensions; + this.VectorType = VectorType; + } + + getVectorDimensions(): number { + return this.vectorDimensions; + } + + /** + * Deserialize vector from JSON string + * Defaults to Float32Array for compatibility with typical embedding vectors + */ + private deserializeVector(vectorJson: string): TypedArray { + const array = JSON.parse(vectorJson); + // Default to Float32Array for typical use case (embeddings) + return new this.VectorType(array); + } + + async similaritySearch(query: TypedArray, options: VectorSearchOptions = {}) { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const results: Array & { score: number }> = []; + + const allEntities = (await this.getAll()) || []; + + for (const entity of allEntities) { + // SQLite stores vectors as JSON strings, need to deserialize + const vectorRaw = entity.vector as unknown as string; + const vector = this.deserializeVector(vectorRaw); + const metadata = entity.metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate similarity + const score = cosineSimilarity(query, vector); + + // Apply threshold + if (score < scoreThreshold) { + continue; + } + + results.push({ + ...entity, + vector, + score, + } as any); + } + + // Sort by score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } + + async hybridSearch(query: TypedArray, options: HybridSearchOptions) { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + // Fall back to regular vector search if no text query + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + const results: Array & { score: number }> = []; + const allEntities = (await this.getAll()) || []; + const queryLower = textQuery.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + + for (const entity of allEntities) { + // SQLite stores vectors as JSON strings, need to deserialize + const vectorRaw = entity.vector as unknown as string; + const vector = + typeof vectorRaw === "string" + ? this.deserializeVector(vectorRaw) + : (vectorRaw as TypedArray); + const metadata = entity.metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate vector similarity + const vectorScore = cosineSimilarity(query, vector); + + // Calculate text relevance (simple keyword matching) + const metadataText = JSON.stringify(metadata).toLowerCase(); + let textScore = 0; + if (queryWords.length > 0) { + let matches = 0; + for (const word of queryWords) { + if (metadataText.includes(word)) { + matches++; + } + } + textScore = matches / queryWords.length; + } + + // Combine scores + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + // Apply threshold + if (combinedScore < scoreThreshold) { + continue; + } + + results.push({ + ...entity, + vector, + score: combinedScore, + } as any); + } + + // Sort by combined score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } +} diff --git a/packages/storage/src/kv/KvViaTabularRepository.ts b/packages/storage/src/kv/KvViaTabularRepository.ts index b0a780ac..bc854394 100644 --- a/packages/storage/src/kv/KvViaTabularRepository.ts +++ b/packages/storage/src/kv/KvViaTabularRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { TabularRepository } from "../tabular/TabularRepository"; +import type { BaseTabularRepository } from "../tabular/BaseTabularRepository"; import { DefaultKeyValueKey, DefaultKeyValueSchema } from "./IKvRepository"; import { KvRepository } from "./KvRepository"; @@ -21,7 +21,7 @@ export abstract class KvViaTabularRepository< Value extends any = any, Combined = { key: Key; value: Value }, > extends KvRepository { - public abstract tabularRepository: TabularRepository< + public abstract tabularRepository: BaseTabularRepository< typeof DefaultKeyValueSchema, typeof DefaultKeyValueKey >; diff --git a/packages/storage/src/limiter/IRateLimiterStorage.ts b/packages/storage/src/queue-limiter/IRateLimiterStorage.ts similarity index 100% rename from packages/storage/src/limiter/IRateLimiterStorage.ts rename to packages/storage/src/queue-limiter/IRateLimiterStorage.ts diff --git a/packages/storage/src/limiter/InMemoryRateLimiterStorage.ts b/packages/storage/src/queue-limiter/InMemoryRateLimiterStorage.ts similarity index 100% rename from packages/storage/src/limiter/InMemoryRateLimiterStorage.ts rename to packages/storage/src/queue-limiter/InMemoryRateLimiterStorage.ts diff --git a/packages/storage/src/limiter/IndexedDbRateLimiterStorage.ts b/packages/storage/src/queue-limiter/IndexedDbRateLimiterStorage.ts similarity index 100% rename from packages/storage/src/limiter/IndexedDbRateLimiterStorage.ts rename to packages/storage/src/queue-limiter/IndexedDbRateLimiterStorage.ts diff --git a/packages/storage/src/limiter/PostgresRateLimiterStorage.ts b/packages/storage/src/queue-limiter/PostgresRateLimiterStorage.ts similarity index 100% rename from packages/storage/src/limiter/PostgresRateLimiterStorage.ts rename to packages/storage/src/queue-limiter/PostgresRateLimiterStorage.ts diff --git a/packages/storage/src/limiter/SqliteRateLimiterStorage.ts b/packages/storage/src/queue-limiter/SqliteRateLimiterStorage.ts similarity index 100% rename from packages/storage/src/limiter/SqliteRateLimiterStorage.ts rename to packages/storage/src/queue-limiter/SqliteRateLimiterStorage.ts diff --git a/packages/storage/src/limiter/SupabaseRateLimiterStorage.ts b/packages/storage/src/queue-limiter/SupabaseRateLimiterStorage.ts similarity index 100% rename from packages/storage/src/limiter/SupabaseRateLimiterStorage.ts rename to packages/storage/src/queue-limiter/SupabaseRateLimiterStorage.ts diff --git a/packages/storage/src/tabular/BaseSqlTabularRepository.ts b/packages/storage/src/tabular/BaseSqlTabularRepository.ts index dc965406..4d2ccd8c 100644 --- a/packages/storage/src/tabular/BaseSqlTabularRepository.ts +++ b/packages/storage/src/tabular/BaseSqlTabularRepository.ts @@ -4,9 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DataPortSchemaObject, FromSchema, JsonSchema } from "@workglow/util"; -import { ValueOptionType } from "./ITabularRepository"; -import { TabularRepository } from "./TabularRepository"; +import { + DataPortSchemaObject, + FromSchema, + JsonSchema, + TypedArraySchemaOptions, +} from "@workglow/util"; +import { BaseTabularRepository } from "./BaseTabularRepository"; +import { SimplifyPrimaryKey, ValueOptionType } from "./ITabularRepository"; // BaseTabularRepository is a tabular store that uses SQLite and Postgres use as common code @@ -21,10 +26,10 @@ export abstract class BaseSqlTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, Value = Omit, -> extends TabularRepository { +> extends BaseTabularRepository { /** * Creates a new instance of BaseSqlTabularRepository * @param table - The name of the database table to use for storage diff --git a/packages/storage/src/tabular/TabularRepository.ts b/packages/storage/src/tabular/BaseTabularRepository.ts similarity index 97% rename from packages/storage/src/tabular/TabularRepository.ts rename to packages/storage/src/tabular/BaseTabularRepository.ts index 94384736..aea5dfb9 100644 --- a/packages/storage/src/tabular/TabularRepository.ts +++ b/packages/storage/src/tabular/BaseTabularRepository.ts @@ -10,10 +10,13 @@ import { EventEmitter, FromSchema, makeFingerprint, + TypedArraySchemaOptions, } from "@workglow/util"; import { + AnyTabularRepository, DeleteSearchCriteria, ITabularRepository, + SimplifyPrimaryKey, TabularChangePayload, TabularEventListener, TabularEventListeners, @@ -23,7 +26,7 @@ import { ValueOptionType, } from "./ITabularRepository"; -export const TABULAR_REPOSITORY = createServiceToken>( +export const TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository" ); @@ -36,14 +39,14 @@ export const TABULAR_REPOSITORY = createServiceToken, // computed types - Entity = FromSchema, - PrimaryKey = Pick, + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, Value = Omit, -> implements ITabularRepository { +> implements ITabularRepository { /** Event emitter for repository events */ protected events = new EventEmitter>(); @@ -52,7 +55,7 @@ export abstract class TabularRepository< protected valueSchema: DataPortSchemaObject; /** - * Creates a new TabularRepository instance + * Creates a new BaseTabularRepository instance * @param schema - Schema defining the structure of the entity * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, diff --git a/packages/storage/src/tabular/CachedTabularRepository.ts b/packages/storage/src/tabular/CachedTabularRepository.ts index 9ed259eb..db0d8e23 100644 --- a/packages/storage/src/tabular/CachedTabularRepository.ts +++ b/packages/storage/src/tabular/CachedTabularRepository.ts @@ -4,18 +4,25 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { createServiceToken, DataPortSchemaObject, FromSchema } from "@workglow/util"; +import { + createServiceToken, + DataPortSchemaObject, + FromSchema, + TypedArraySchemaOptions, +} from "@workglow/util"; +import { BaseTabularRepository } from "./BaseTabularRepository"; import { InMemoryTabularRepository } from "./InMemoryTabularRepository"; import { + AnyTabularRepository, DeleteSearchCriteria, ITabularRepository, + SimplifyPrimaryKey, TabularSubscribeOptions, } from "./ITabularRepository"; -import { TabularRepository } from "./TabularRepository"; -export const CACHED_TABULAR_REPOSITORY = createServiceToken< - ITabularRepository ->("storage.tabularRepository.cached"); +export const CACHED_TABULAR_REPOSITORY = createServiceToken( + "storage.tabularRepository.cached" +); /** * A tabular repository wrapper that adds caching layer to a durable repository. @@ -29,12 +36,11 @@ export class CachedTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, -> extends TabularRepository { - public readonly cache: ITabularRepository; - private durable: ITabularRepository; + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, +> extends BaseTabularRepository { + public readonly cache: ITabularRepository; + private durable: ITabularRepository; private cacheInitialized = false; /** @@ -48,8 +54,8 @@ export class CachedTabularRepository< * while each array creates a compound index with columns in the specified order. */ constructor( - durable: ITabularRepository, - cache?: ITabularRepository, + durable: ITabularRepository, + cache?: ITabularRepository, schema?: Schema, primaryKeyNames?: PrimaryKeyNames, indexes?: readonly (keyof Entity | readonly (keyof Entity)[])[] @@ -70,13 +76,11 @@ export class CachedTabularRepository< if (cache) { this.cache = cache; } else { - this.cache = new InMemoryTabularRepository< - Schema, - PrimaryKeyNames, - Entity, - PrimaryKey, - Value - >(schema, primaryKeyNames, indexes || []); + this.cache = new InMemoryTabularRepository( + schema, + primaryKeyNames, + indexes || [] + ); } // Forward events from both cache and durable diff --git a/packages/storage/src/tabular/FsFolderTabularRepository.ts b/packages/storage/src/tabular/FsFolderTabularRepository.ts index 6f919ca5..9a4fe9b3 100644 --- a/packages/storage/src/tabular/FsFolderTabularRepository.ts +++ b/packages/storage/src/tabular/FsFolderTabularRepository.ts @@ -10,21 +10,23 @@ import { FromSchema, makeFingerprint, sleep, + TypedArraySchemaOptions, } from "@workglow/util"; import { mkdir, readdir, readFile, rm, writeFile } from "node:fs/promises"; import path from "node:path"; import { PollingSubscriptionManager } from "../util/PollingSubscriptionManager"; +import { BaseTabularRepository } from "./BaseTabularRepository"; import { + AnyTabularRepository, DeleteSearchCriteria, - ITabularRepository, + SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, } from "./ITabularRepository"; -import { TabularRepository } from "./TabularRepository"; -export const FS_FOLDER_TABULAR_REPOSITORY = createServiceToken< - ITabularRepository ->("storage.tabularRepository.fsFolder"); +export const FS_FOLDER_TABULAR_REPOSITORY = createServiceToken( + "storage.tabularRepository.fsFolder" +); /** * A tabular repository implementation that uses the filesystem for storage. @@ -37,10 +39,9 @@ export class FsFolderTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, -> extends TabularRepository { + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, +> extends BaseTabularRepository { private folderPath: string; /** Shared polling subscription manager */ private pollingManager: PollingSubscriptionManager< diff --git a/packages/storage/src/tabular/InMemoryTabularRepository.ts b/packages/storage/src/tabular/InMemoryTabularRepository.ts index b27b6b34..dab87deb 100644 --- a/packages/storage/src/tabular/InMemoryTabularRepository.ts +++ b/packages/storage/src/tabular/InMemoryTabularRepository.ts @@ -9,19 +9,21 @@ import { DataPortSchemaObject, FromSchema, makeFingerprint, + TypedArraySchemaOptions, } from "@workglow/util"; +import { BaseTabularRepository } from "./BaseTabularRepository"; import { + AnyTabularRepository, DeleteSearchCriteria, isSearchCondition, - ITabularRepository, + SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, } from "./ITabularRepository"; -import { TabularRepository } from "./TabularRepository"; -export const MEMORY_TABULAR_REPOSITORY = createServiceToken< - ITabularRepository ->("storage.tabularRepository.inMemory"); +export const MEMORY_TABULAR_REPOSITORY = createServiceToken( + "storage.tabularRepository.inMemory" +); /** * A generic in-memory key-value repository implementation. @@ -34,10 +36,9 @@ export class InMemoryTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, -> extends TabularRepository { + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, +> extends BaseTabularRepository { /** Internal storage using a Map with fingerprint strings as keys */ values = new Map(); diff --git a/packages/storage/src/tabular/IndexedDbTabularRepository.ts b/packages/storage/src/tabular/IndexedDbTabularRepository.ts index 6c3603d4..50bc483a 100644 --- a/packages/storage/src/tabular/IndexedDbTabularRepository.ts +++ b/packages/storage/src/tabular/IndexedDbTabularRepository.ts @@ -9,6 +9,7 @@ import { DataPortSchemaObject, FromSchema, makeFingerprint, + TypedArraySchemaOptions, } from "@workglow/util"; import { HybridSubscriptionManager } from "../util/HybridSubscriptionManager"; import { @@ -16,19 +17,20 @@ import { ExpectedIndexDefinition, MigrationOptions, } from "../util/IndexedDbTable"; +import { BaseTabularRepository } from "./BaseTabularRepository"; import { + AnyTabularRepository, DeleteSearchCriteria, isSearchCondition, - ITabularRepository, SearchOperator, + SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, } from "./ITabularRepository"; -import { TabularRepository } from "./TabularRepository"; -export const IDB_TABULAR_REPOSITORY = createServiceToken< - ITabularRepository ->("storage.tabularRepository.indexedDb"); +export const IDB_TABULAR_REPOSITORY = createServiceToken( + "storage.tabularRepository.indexedDb" +); /** * A tabular repository implementation using IndexedDB for browser-based storage. @@ -40,10 +42,9 @@ export class IndexedDbTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, -> extends TabularRepository { + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, +> extends BaseTabularRepository { /** Promise that resolves to the IndexedDB database instance */ private db: IDBDatabase | undefined; /** Promise to track ongoing database setup to prevent concurrent setup calls */ diff --git a/packages/storage/src/tabular/PostgresTabularRepository.ts b/packages/storage/src/tabular/PostgresTabularRepository.ts index 6534ba60..1054229e 100644 --- a/packages/storage/src/tabular/PostgresTabularRepository.ts +++ b/packages/storage/src/tabular/PostgresTabularRepository.ts @@ -4,22 +4,30 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { createServiceToken, DataPortSchemaObject, FromSchema, JsonSchema } from "@workglow/util"; +import { + createServiceToken, + DataPortSchemaObject, + FromSchema, + JsonSchema, + type TypedArray, + TypedArraySchemaOptions, +} from "@workglow/util"; import type { Pool } from "pg"; import { BaseSqlTabularRepository } from "./BaseSqlTabularRepository"; import { + AnyTabularRepository, DeleteSearchCriteria, isSearchCondition, - ITabularRepository, SearchOperator, + SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, ValueOptionType, } from "./ITabularRepository"; -export const POSTGRES_TABULAR_REPOSITORY = createServiceToken< - ITabularRepository ->("storage.tabularRepository.postgres"); +export const POSTGRES_TABULAR_REPOSITORY = createServiceToken( + "storage.tabularRepository.postgres" +); /** * A PostgreSQL-based tabular repository implementation that extends BaseSqlTabularRepository. @@ -33,11 +41,10 @@ export class PostgresTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, -> extends BaseSqlTabularRepository { - private db: Pool; + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, +> extends BaseSqlTabularRepository { + protected db: Pool; /** * Creates a new PostgresTabularRepository instance. @@ -74,6 +81,9 @@ export class PostgresTabularRepository< `; await this.db.query(sql); + // Create vector indexes if there are vector columns + await this.createVectorIndexes(); + // Get primary key columns to avoid creating redundant indexes const pkColumns = this.primaryKeyColumns(); @@ -114,6 +124,15 @@ export class PostgresTabularRepository< } } + protected isVectorFormat(format?: string): boolean { + if (!format) return false; + return format.startsWith("TypedArray:") || format === "TypedArray"; + } + + protected getVectorDimensions(typeDef: JsonSchema): number | undefined { + return undefined; + } + /** * Maps TypeScript/JavaScript types to corresponding PostgreSQL data types. * Uses additional schema information like minimum/maximum values, nullable status, @@ -141,6 +160,14 @@ export class PostgresTabularRepository< if (actualType.format === "uri") return "VARCHAR(2048)"; if (actualType.format === "uuid") return "UUID"; + // Handle vector format (pgvector extension) + if (this.isVectorFormat(actualType.format)) { + const dimension = this.getVectorDimensions(actualType); + if (typeof dimension === "number") { + return `vector(${dimension})`; + } + } + // Use a VARCHAR with maxLength if specified if (typeof actualType.maxLength === "number") { return `VARCHAR(${actualType.maxLength})`; @@ -283,10 +310,34 @@ export class PostgresTabularRepository< } } + /** + * Convert JavaScript values to PostgreSQL values, including TypedArray to vector string + */ + protected override jsToSqlValue(column: string, value: Entity[keyof Entity]): ValueOptionType { + const typeDef = this.schema.properties[column]; + if (typeDef) { + const actualType = this.getNonNullType(typeDef); + + // Handle vector format - convert TypedArray to pgvector string format [1.0, 2.0, ...] + if (typeof actualType !== "boolean" && this.isVectorFormat(actualType.format)) { + if (value && ArrayBuffer.isView(value) && !(value instanceof DataView)) { + // It's a TypedArray + const array = Array.from(value as unknown as TypedArray); + return `[${array.join(",")}]` as any; + } + // If it's already a string (serialized), return as-is + if (typeof value === "string") { + return value; + } + } + } + return super.jsToSqlValue(column, value); + } + /** * Convert PostgreSQL values to JS values. Ensures numeric strings become numbers where schema says number. */ - protected sqlToJsValue(column: string, value: ValueOptionType): Entity[keyof Entity] { + protected override sqlToJsValue(column: string, value: ValueOptionType): Entity[keyof Entity] { const typeDef = this.schema.properties[column as keyof typeof this.schema.properties] as | JsonSchema | undefined; @@ -296,6 +347,23 @@ export class PostgresTabularRepository< } const actualType = this.getNonNullType(typeDef); + // Handle vector format - convert pgvector string to TypedArray + if (typeof actualType !== "boolean" && this.isVectorFormat(actualType.format)) { + if (typeof value === "string") { + try { + // Parse the vector string format [1.0, 2.0, ...] to TypedArray + const array = JSON.parse(value); + return new Float32Array(array) as any; + } catch (e) { + console.warn(`Failed to parse vector for column ${column}:`, e); + } + } + // If it's already an object/TypedArray, return as-is + if (value && typeof value === "object") { + return value as any; + } + } + // Handle numeric types - PostgreSQL can return them as strings if ( typeof actualType !== "boolean" && @@ -336,6 +404,66 @@ export class PostgresTabularRepository< return false; } + /** + * Gets information about vector columns in the schema + * @returns Array of objects with column name and dimension + */ + protected getVectorColumns(): Array<{ column: string; dimension: number }> { + const vectorColumns: Array<{ column: string; dimension: number }> = []; + + // Check all properties in the schema + for (const [key, typeDef] of Object.entries(this.schema.properties)) { + const actualType = this.getNonNullType(typeDef); + if (typeof actualType !== "boolean" && this.isVectorFormat(actualType.format)) { + const dimension = this.getVectorDimensions(actualType); + if (typeof dimension === "number") { + vectorColumns.push({ column: key, dimension }); + } else { + console.warn(`Invalid vector format for column ${key}: ${actualType.format}, skipping`); + } + } + } + + return vectorColumns; + } + + /** + * Creates vector-specific indexes (HNSW for pgvector) + * Called after table creation if vector columns exist + */ + protected async createVectorIndexes(): Promise { + const vectorColumns = this.getVectorColumns(); + + if (vectorColumns.length === 0) { + return; // No vector columns, nothing to do + } + + // Try to enable pgvector extension + try { + await this.db.query("CREATE EXTENSION IF NOT EXISTS vector"); + } catch (error) { + console.warn( + "pgvector extension not available, vector columns will use TEXT fallback:", + error + ); + return; + } + + // Create HNSW index for each vector column + for (const { column } of vectorColumns) { + const indexName = `${this.table}_${column}_hnsw_idx`; + try { + await this.db.query(` + CREATE INDEX IF NOT EXISTS "${indexName}" + ON "${this.table}" + USING hnsw ("${column}" vector_cosine_ops) + `); + } catch (error) { + console.warn(`Failed to create HNSW index on ${column}:`, error); + } + } + } + /** * Stores or updates a row in the database. * Uses UPSERT (INSERT ... ON CONFLICT DO UPDATE) for atomic operations. diff --git a/packages/storage/src/tabular/SharedInMemoryTabularRepository.ts b/packages/storage/src/tabular/SharedInMemoryTabularRepository.ts index e8579617..41058df6 100644 --- a/packages/storage/src/tabular/SharedInMemoryTabularRepository.ts +++ b/packages/storage/src/tabular/SharedInMemoryTabularRepository.ts @@ -4,18 +4,24 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { createServiceToken, DataPortSchemaObject, FromSchema } from "@workglow/util"; import { + createServiceToken, + DataPortSchemaObject, + FromSchema, + TypedArraySchemaOptions, +} from "@workglow/util"; +import { BaseTabularRepository } from "./BaseTabularRepository"; +import { + AnyTabularRepository, DeleteSearchCriteria, - ITabularRepository, + SimplifyPrimaryKey, TabularSubscribeOptions, } from "./ITabularRepository"; import { InMemoryTabularRepository } from "./InMemoryTabularRepository"; -import { TabularRepository } from "./TabularRepository"; -export const SHARED_IN_MEMORY_TABULAR_REPOSITORY = createServiceToken< - ITabularRepository ->("storage.tabularRepository.sharedInMemory"); +export const SHARED_IN_MEMORY_TABULAR_REPOSITORY = createServiceToken( + "storage.tabularRepository.sharedInMemory" +); /** * Message types for BroadcastChannel communication @@ -41,19 +47,12 @@ export class SharedInMemoryTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, -> extends TabularRepository { + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, +> extends BaseTabularRepository { private channel: BroadcastChannel | null = null; private channelName: string; - private inMemoryRepo: InMemoryTabularRepository< - Schema, - PrimaryKeyNames, - Entity, - PrimaryKey, - Value - >; + private inMemoryRepo: InMemoryTabularRepository; private isInitialized = false; private syncInProgress = false; @@ -73,13 +72,11 @@ export class SharedInMemoryTabularRepository< ) { super(schema, primaryKeyNames, indexes); this.channelName = channelName; - this.inMemoryRepo = new InMemoryTabularRepository< - Schema, - PrimaryKeyNames, - Entity, - PrimaryKey, - Value - >(schema, primaryKeyNames, indexes); + this.inMemoryRepo = new InMemoryTabularRepository( + schema, + primaryKeyNames, + indexes + ); // Forward events from the in-memory repository this.setupEventForwarding(); diff --git a/packages/storage/src/tabular/SqliteTabularRepository.ts b/packages/storage/src/tabular/SqliteTabularRepository.ts index 8f1992aa..331c71e8 100644 --- a/packages/storage/src/tabular/SqliteTabularRepository.ts +++ b/packages/storage/src/tabular/SqliteTabularRepository.ts @@ -5,13 +5,20 @@ */ import { Sqlite } from "@workglow/sqlite"; -import { createServiceToken, DataPortSchemaObject, FromSchema, JsonSchema } from "@workglow/util"; +import { + createServiceToken, + DataPortSchemaObject, + FromSchema, + JsonSchema, + TypedArraySchemaOptions, +} from "@workglow/util"; import { BaseSqlTabularRepository } from "./BaseSqlTabularRepository"; import { + AnyTabularRepository, DeleteSearchCriteria, isSearchCondition, - ITabularRepository, SearchOperator, + SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, ValueOptionType, @@ -20,9 +27,9 @@ import { // Define local type for SQL operations type ExcludeDateKeyOptionType = Exclude; -export const SQLITE_TABULAR_REPOSITORY = createServiceToken< - ITabularRepository ->("storage.tabularRepository.sqlite"); +export const SQLITE_TABULAR_REPOSITORY = createServiceToken( + "storage.tabularRepository.sqlite" +); const Database = Sqlite.Database; @@ -38,10 +45,9 @@ export class SqliteTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, -> extends BaseSqlTabularRepository { + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, +> extends BaseSqlTabularRepository { /** The SQLite database instance */ private db: Sqlite.Database; diff --git a/packages/storage/src/tabular/SupabaseTabularRepository.ts b/packages/storage/src/tabular/SupabaseTabularRepository.ts index df17ffd4..f4d433dd 100644 --- a/packages/storage/src/tabular/SupabaseTabularRepository.ts +++ b/packages/storage/src/tabular/SupabaseTabularRepository.ts @@ -5,22 +5,29 @@ */ import type { RealtimeChannel, SupabaseClient } from "@supabase/supabase-js"; -import { createServiceToken, DataPortSchemaObject, FromSchema, JsonSchema } from "@workglow/util"; +import { + createServiceToken, + DataPortSchemaObject, + FromSchema, + JsonSchema, + TypedArraySchemaOptions, +} from "@workglow/util"; import { BaseSqlTabularRepository } from "./BaseSqlTabularRepository"; import { + AnyTabularRepository, DeleteSearchCriteria, isSearchCondition, - ITabularRepository, SearchOperator, + SimplifyPrimaryKey, TabularChangePayload, TabularChangeType, TabularSubscribeOptions, ValueOptionType, } from "./ITabularRepository"; -export const SUPABASE_TABULAR_REPOSITORY = createServiceToken< - ITabularRepository ->("storage.tabularRepository.supabase"); +export const SUPABASE_TABULAR_REPOSITORY = createServiceToken( + "storage.tabularRepository.supabase" +); /** * A Supabase-based tabular repository implementation that extends BaseSqlTabularRepository. @@ -34,10 +41,9 @@ export class SupabaseTabularRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types - Entity = FromSchema, - PrimaryKey = Pick, - Value = Omit, -> extends BaseSqlTabularRepository { + Entity = FromSchema, + PrimaryKey = SimplifyPrimaryKey, +> extends BaseSqlTabularRepository { private client: SupabaseClient; private realtimeChannel: RealtimeChannel | null = null; @@ -66,7 +72,6 @@ export class SupabaseTabularRepository< * Initializes the database table with the required schema. * Creates the table if it doesn't exist with primary key and value columns. * Must be called before using any other methods. - * Note: By default, assumes the table already exists (set isSetup in tests). */ public async setupDatabase(): Promise { const sql = ` diff --git a/packages/task-graph/src/storage/TaskGraphTabularRepository.ts b/packages/task-graph/src/storage/TaskGraphTabularRepository.ts index 18f43956..ab56900d 100644 --- a/packages/task-graph/src/storage/TaskGraphTabularRepository.ts +++ b/packages/task-graph/src/storage/TaskGraphTabularRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { TabularRepository } from "@workglow/storage"; +import type { BaseTabularRepository } from "@workglow/storage"; import { DataPortSchemaObject } from "@workglow/util"; import { TaskGraph } from "../task-graph/TaskGraph"; import { createGraphFromGraphJSON } from "../task/TaskJSON"; @@ -24,7 +24,7 @@ export const TaskGraphPrimaryKeyNames = ["key"] as const; /** * Options for the TaskGraphRepository */ -export type TaskGraphRepositoryStorage = TabularRepository< +export type TaskGraphRepositoryStorage = BaseTabularRepository< typeof TaskGraphSchema, typeof TaskGraphPrimaryKeyNames >; diff --git a/packages/task-graph/src/storage/TaskOutputTabularRepository.ts b/packages/task-graph/src/storage/TaskOutputTabularRepository.ts index 2c9c67e2..cf7a5789 100644 --- a/packages/task-graph/src/storage/TaskOutputTabularRepository.ts +++ b/packages/task-graph/src/storage/TaskOutputTabularRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type TabularRepository } from "@workglow/storage"; +import { type BaseTabularRepository } from "@workglow/storage"; import { compress, DataPortSchemaObject, decompress, makeFingerprint } from "@workglow/util"; import { TaskInput, TaskOutput } from "../task/TaskTypes"; import { TaskOutputRepository } from "./TaskOutputRepository"; @@ -27,7 +27,7 @@ export const TaskOutputSchema = { export const TaskOutputPrimaryKeyNames = ["key", "taskType"] as const; -export type TaskOutputRepositoryStorage = TabularRepository< +export type TaskOutputRepositoryStorage = BaseTabularRepository< typeof TaskOutputSchema, typeof TaskOutputPrimaryKeyNames >; diff --git a/packages/test/src/samples/ONNXModelSamples.ts b/packages/test/src/samples/ONNXModelSamples.ts index aae837f7..373f5831 100644 --- a/packages/test/src/samples/ONNXModelSamples.ts +++ b/packages/test/src/samples/ONNXModelSamples.ts @@ -68,6 +68,18 @@ export async function registerHuggingfaceLocalModels(): Promise { }, metadata: {}, }, + { + model_id: "onnx:onnx-community/NeuroBERT-NER-ONNX:q8", + title: "NeuroBERT NER", + description: "onnx-community/NeuroBERT-NER-ONNX", + tasks: ["TextNamedEntityRecognitionTask"], + provider: HF_TRANSFORMERS_ONNX, + provider_config: { + pipeline: "token-classification", + model_path: "onnx-community/NeuroBERT-NER-ONNX", + }, + metadata: {}, + }, { model_id: "onnx:Xenova/distilbert-base-uncased-distilled-squad:q8", title: "distilbert-base-uncased-distilled-squad", diff --git a/packages/test/src/test/storage-kv/SupabaseKvRepository.test.ts b/packages/test/src/test/storage-kv/SupabaseKvRepository.test.ts index a9dd8587..1f750a58 100644 --- a/packages/test/src/test/storage-kv/SupabaseKvRepository.test.ts +++ b/packages/test/src/test/storage-kv/SupabaseKvRepository.test.ts @@ -10,27 +10,11 @@ import { SupabaseKvRepository, SupabaseTabularRepository, } from "@workglow/storage"; -import { - DataPortSchemaObject, - ExcludeProps, - FromSchema, - IncludeProps, - uuid4, -} from "@workglow/util"; +import { uuid4 } from "@workglow/util"; import { describe } from "vitest"; import { createSupabaseMockClient } from "../helpers/SupabaseMockClient"; import { runGenericKvRepositoryTests } from "./genericKvRepositoryTests"; -class SupabaseTabularTestRepository< - Schema extends DataPortSchemaObject, - PrimaryKeyNames extends ReadonlyArray, - PrimaryKey = FromSchema>, - Entity = FromSchema, - Value = FromSchema>, -> extends SupabaseTabularRepository { - protected isSetup = false; // force setup to run, which is not the default -} - describe("SupabaseKvRepository", () => { const client = createSupabaseMockClient(); runGenericKvRepositoryTests(async (keyType, valueType) => { @@ -40,12 +24,7 @@ describe("SupabaseKvRepository", () => { tableName, keyType, valueType, - new SupabaseTabularTestRepository( - client, - tableName, - DefaultKeyValueSchema, - DefaultKeyValueKey - ) + new SupabaseTabularRepository(client, tableName, DefaultKeyValueSchema, DefaultKeyValueKey) ); }); }); diff --git a/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts b/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts index b9a48213..6e07ec9d 100644 --- a/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts @@ -102,9 +102,7 @@ describe("IndexedDbTabularRepository", () => { let repo: IndexedDbTabularRepository< typeof RequiredColumnsSchema, typeof RequiredColumnsPK, - RequiredEntity, - Pick, - Omit + RequiredEntity >; beforeEach(async () => { @@ -198,8 +196,7 @@ describe("IndexedDbTabularRepository", () => { typeof OptionalColumnsSchema, typeof OptionalColumnsPK, OptionalEntity, - Pick, - Omit + Pick >; beforeEach(async () => { diff --git a/packages/test/src/test/storage-tabular/SupabaseTabularRepository.test.ts b/packages/test/src/test/storage-tabular/SupabaseTabularRepository.test.ts index c3cc98e0..3a4f507e 100644 --- a/packages/test/src/test/storage-tabular/SupabaseTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/SupabaseTabularRepository.test.ts @@ -5,13 +5,7 @@ */ import { SupabaseTabularRepository } from "@workglow/storage"; -import { - DataPortSchemaObject, - ExcludeProps, - FromSchema, - IncludeProps, - uuid4, -} from "@workglow/util"; +import { uuid4 } from "@workglow/util"; import { describe } from "vitest"; import { createSupabaseMockClient } from "../helpers/SupabaseMockClient"; import { @@ -26,28 +20,17 @@ import { const client = createSupabaseMockClient(); -class SupabaseTabularTestRepository< - Schema extends DataPortSchemaObject, - PrimaryKeyNames extends ReadonlyArray, - // computed types - PrimaryKey = FromSchema>, - Entity = FromSchema, - Value = FromSchema>, -> extends SupabaseTabularRepository { - protected isSetup = false; // force setup to run, which is not the default -} - describe("SupabaseTabularRepository", () => { runGenericTabularRepositoryTests( async () => - new SupabaseTabularTestRepository( + new SupabaseTabularRepository( client, `supabase_test_${uuid4().replace(/-/g, "_")}`, CompoundSchema, CompoundPrimaryKeyNames ), async () => - new SupabaseTabularTestRepository( + new SupabaseTabularRepository( client, `supabase_test_${uuid4().replace(/-/g, "_")}`, SearchSchema, @@ -55,7 +38,7 @@ describe("SupabaseTabularRepository", () => { ["category", ["category", "subcategory"], ["subcategory", "category"], "value"] ), async () => { - const repo = new SupabaseTabularTestRepository< + const repo = new SupabaseTabularRepository< typeof AllTypesSchema, typeof AllTypesPrimaryKeyNames >( diff --git a/packages/test/src/test/task-graph/InputResolver.test.ts b/packages/test/src/test/task-graph/InputResolver.test.ts new file mode 100644 index 00000000..2d69920a --- /dev/null +++ b/packages/test/src/test/task-graph/InputResolver.test.ts @@ -0,0 +1,265 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AnyTabularRepository, + getGlobalTabularRepositories, + InMemoryTabularRepository, + registerTabularRepository, + TypeTabularRepository, +} from "@workglow/storage"; +import { IExecuteContext, resolveSchemaInputs, Task, TaskRegistry } from "@workglow/task-graph"; +import { + getInputResolvers, + globalServiceRegistry, + registerInputResolver, + type DataPortSchema, +} from "@workglow/util"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("InputResolver", () => { + // Test schema for tabular repository + const testEntitySchema = { + type: "object", + properties: { + id: { type: "string" }, + name: { type: "string" }, + }, + required: ["id", "name"], + additionalProperties: false, + } as const; + + let testRepo: InMemoryTabularRepository; + + beforeEach(async () => { + // Create and register a test repository + testRepo = new InMemoryTabularRepository(testEntitySchema, ["id"] as const); + await testRepo.setupDatabase(); + registerTabularRepository("test-repo", testRepo); + }); + + afterEach(() => { + // Clean up the registry + getGlobalTabularRepositories().delete("test-repo"); + testRepo.destroy(); + }); + + describe("resolveSchemaInputs", () => { + test("should pass through non-string values unchanged", async () => { + const schema: DataPortSchema = { + type: "object", + properties: { + repository: TypeTabularRepository(), + }, + }; + + const input = { repository: testRepo }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.repository).toBe(testRepo); + }); + + test("should resolve string repository ID to instance", async () => { + const schema: DataPortSchema = { + type: "object", + properties: { + repository: TypeTabularRepository(), + }, + }; + + const input = { repository: "test-repo" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.repository).toBe(testRepo); + }); + + test("should throw error for unknown repository ID", async () => { + const schema: DataPortSchema = { + type: "object", + properties: { + repository: TypeTabularRepository(), + }, + }; + + const input = { repository: "non-existent-repo" }; + + await expect( + resolveSchemaInputs(input, schema, { registry: globalServiceRegistry }) + ).rejects.toThrow('Tabular repository "non-existent-repo" not found'); + }); + + test("should not resolve properties without format annotation", async () => { + const schema: DataPortSchema = { + type: "object", + properties: { + name: { type: "string" }, + }, + }; + + const input = { name: "test-name" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.name).toBe("test-name"); + }); + + test("should handle boolean schema", async () => { + const input = { foo: "bar" }; + const resolved = await resolveSchemaInputs(input, true as DataPortSchema, { + registry: globalServiceRegistry, + }); + + expect(resolved).toEqual(input); + }); + + test("should handle schema without properties", async () => { + // @ts-expect-error - schema is not a DataPortSchemaObject + const schema: DataPortSchema = { type: "object" }; + const input = { foo: "bar" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved).toEqual(input); + }); + }); + + describe("registerInputResolver", () => { + test("should register custom resolver", async () => { + // Register a custom resolver for a test format + registerInputResolver("custom", (id, format, registry) => { + return { resolved: true, id, format }; + }); + + const schema: DataPortSchema = { + type: "object", + properties: { + data: { type: "string", format: "custom:test" }, + }, + }; + + const input = { data: "my-id" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.data).toEqual({ resolved: true, id: "my-id", format: "custom:test" }); + + // Clean up + getInputResolvers().delete("custom"); + }); + + test("should support async resolvers", async () => { + registerInputResolver("async", async (id, format, registry) => { + await new Promise((resolve) => setTimeout(resolve, 10)); + return { asyncResolved: true, id }; + }); + + const schema: DataPortSchema = { + type: "object", + properties: { + data: { type: "string", format: "async" }, + }, + }; + + const input = { data: "async-id" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.data).toEqual({ asyncResolved: true, id: "async-id" }); + + // Clean up + getInputResolvers().delete("async"); + }); + }); + + describe("Integration with Task", () => { + // Define a test task that uses a repository + class RepositoryConsumerTask extends Task< + { repository: any; query: string }, + { results: any[] } + > { + public static type = "RepositoryConsumerTask"; + + public static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + repository: TypeTabularRepository({ + title: "Data Repository", + description: "Repository to query", + }), + query: { type: "string", title: "Query" }, + }, + required: ["repository", "query"], + additionalProperties: false, + }; + } + + public static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + results: { type: "array", items: { type: "object" } }, + }, + required: ["results"], + additionalProperties: false, + }; + } + + async execute( + input: { repository: AnyTabularRepository; query: string }, + _context: IExecuteContext + ): Promise<{ results: any[] }> { + const { repository } = input; + // In a real task, we'd search the repository + const results = await repository.getAll(); + return { results: results ?? [] }; + } + } + + beforeEach(() => { + TaskRegistry.registerTask(RepositoryConsumerTask); + }); + + afterEach(() => { + TaskRegistry.all.delete(RepositoryConsumerTask.type); + }); + + test("should resolve repository when running task with string ID", async () => { + // Add some test data + await testRepo.put({ id: "1", name: "Test Item" }); + + const task = new RepositoryConsumerTask(); + const result = await task.run({ + repository: "test-repo", + query: "test", + }); + + expect(result.results).toHaveLength(1); + expect(result.results[0]).toEqual({ id: "1", name: "Test Item" }); + }); + + test("should work with direct repository instance", async () => { + await testRepo.put({ id: "2", name: "Direct Item" }); + + const task = new RepositoryConsumerTask(); + const result = await task.run({ + repository: testRepo, + query: "test", + }); + + expect(result.results).toHaveLength(1); + expect(result.results[0]).toEqual({ id: "2", name: "Direct Item" }); + }); + }); +}); diff --git a/packages/test/src/test/task-graph/TaskGraphFormatSemantic.test.ts b/packages/test/src/test/task-graph/TaskGraphFormatSemantic.test.ts index cfd17fcb..554c9cf0 100644 --- a/packages/test/src/test/task-graph/TaskGraphFormatSemantic.test.ts +++ b/packages/test/src/test/task-graph/TaskGraphFormatSemantic.test.ts @@ -5,8 +5,78 @@ */ import { Dataflow, Task, TaskGraph, type TaskInput } from "@workglow/task-graph"; -import type { DataPortSchema } from "@workglow/util"; +import type { DataPortSchema, ServiceRegistry } from "@workglow/util"; import { beforeEach, describe, expect, it } from "vitest"; +import { + MODEL_REPOSITORY, + InMemoryModelRepository, + type ModelRecord, + type ModelRepository, +} from "@workglow/ai"; + +/** + * Test model fixtures for embedding models + */ +const EMBEDDING_MODELS: ModelRecord[] = [ + { + model_id: "text-embedding-ada-002", + tasks: ["EmbeddingTask"], + provider: "openai", + title: "OpenAI Ada Embedding", + description: "OpenAI text embedding model", + provider_config: {}, + metadata: {}, + }, + { + model_id: "all-MiniLM-L6-v2", + tasks: ["EmbeddingTask"], + provider: "local", + title: "MiniLM Embedding", + description: "Local embedding model", + provider_config: {}, + metadata: {}, + }, +]; + +/** + * Test model fixtures for text generation models + */ +const TEXT_GEN_MODELS: ModelRecord[] = [ + { + model_id: "gpt-4", + tasks: ["TextGenerationTask"], + provider: "openai", + title: "GPT-4", + description: "OpenAI GPT-4 text generation model", + provider_config: {}, + metadata: {}, + }, + { + model_id: "claude-3", + tasks: ["TextGenerationTask"], + provider: "anthropic", + title: "Claude 3", + description: "Anthropic Claude 3 model", + provider_config: {}, + metadata: {}, + }, +]; + +/** + * Helper function to create a test-local service registry with a model repository + * @param models - Array of model records to populate the repository with + * @returns Promise resolving to a configured ServiceRegistry + */ +async function createTestRegistry(models: ModelRecord[]): Promise { + const { ServiceRegistry } = await import("@workglow/util"); + const registry = new ServiceRegistry(); + const modelRepo = new InMemoryModelRepository(); + for (const model of models) { + await modelRepo.addModel(model); + } + registry.registerInstance(MODEL_REPOSITORY, modelRepo); + return registry; +} /** * Test task with generic model output (format: "model") @@ -459,16 +529,17 @@ describe("TaskGraph with format annotations", () => { } as const satisfies DataPortSchema; } - // Simulate runtime narrowing of models - async narrowInput(input: { - model: string | string[]; - }): Promise<{ model: string | string[] }> { - // In real implementation, this would check ModelRepository for compatible models - // For testing, we simulate filtering - const validEmbeddingModels = ["text-embedding-ada-002", "all-MiniLM-L6-v2"]; + // Runtime narrowing using ModelRepository from the registry + async narrowInput( + input: { model: string | string[] }, + registry: ServiceRegistry + ): Promise<{ model: string | string[] }> { + const modelRepo = registry.get(MODEL_REPOSITORY); + const validModels = await modelRepo.findModelsByTask(this.type); + const validIds = new Set(validModels?.map((m) => m.model_id) ?? []); const models = Array.isArray(input.model) ? input.model : [input.model]; - const narrowedModels = models.filter((m) => validEmbeddingModels.includes(m)); + const narrowedModels = models.filter((m) => validIds.has(m)); return { model: narrowedModels.length === 1 ? narrowedModels[0] : narrowedModels, @@ -482,12 +553,15 @@ describe("TaskGraph with format annotations", () => { const task = new NarrowableModelConsumerTask({}, { id: "consumer" }); + // Create test registry with embedding and text generation models + const registry = await createTestRegistry([...EMBEDDING_MODELS, ...TEXT_GEN_MODELS]); + // Test narrowing with array of models (some compatible, some not) const inputWithMixed = { model: ["text-embedding-ada-002", "gpt-4", "all-MiniLM-L6-v2", "claude-3"], }; - const narrowedResult = await task.narrowInput(inputWithMixed); + const narrowedResult = await task.narrowInput(inputWithMixed, registry); // Should only keep the embedding models expect(narrowedResult.model).toEqual(["text-embedding-ada-002", "all-MiniLM-L6-v2"]); @@ -522,12 +596,16 @@ describe("TaskGraph with format annotations", () => { } as const satisfies DataPortSchema; } - async narrowInput(input: { - model: string | string[]; - }): Promise<{ model: string | string[] }> { - const validModels = ["text-embedding-ada-002"]; + async narrowInput( + input: { model: string | string[] }, + registry: ServiceRegistry + ): Promise<{ model: string | string[] }> { + const modelRepo = registry.get(MODEL_REPOSITORY); + const validModels = await modelRepo.findModelsByTask(this.type); + const validIds = new Set(validModels?.map((m) => m.model_id) ?? []); + const models = Array.isArray(input.model) ? input.model : [input.model]; - const narrowed = models.filter((m) => validModels.includes(m)); + const narrowed = models.filter((m) => validIds.has(m)); return { model: narrowed.length === 1 ? narrowed[0] : narrowed }; } @@ -538,12 +616,15 @@ describe("TaskGraph with format annotations", () => { const task = new NarrowableModelTask({}, { id: "task" }); + // Create test registry with only embedding models + const registry = await createTestRegistry(EMBEDDING_MODELS); + // Test with single valid model - const result1 = await task.narrowInput({ model: "text-embedding-ada-002" }); + const result1 = await task.narrowInput({ model: "text-embedding-ada-002" }, registry); expect(result1.model).toBe("text-embedding-ada-002"); // Test with single invalid model (gets filtered out) - const result2 = await task.narrowInput({ model: "gpt-4" }); + const result2 = await task.narrowInput({ model: "gpt-4" }, registry); expect(result2.model).toEqual([]); }); diff --git a/packages/test/src/test/util/Document.test.ts b/packages/test/src/test/util/Document.test.ts new file mode 100644 index 00000000..85cc0635 --- /dev/null +++ b/packages/test/src/test/util/Document.test.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ChunkNode, DocumentNode } from "@workglow/storage"; +import { Document, NodeKind } from "@workglow/storage"; +import { describe, expect, test } from "vitest"; + +describe("Document", () => { + const createTestDocumentNode = (): DocumentNode => ({ + nodeId: "root", + kind: NodeKind.DOCUMENT, + range: { startOffset: 0, endOffset: 100 }, + text: "Test document", + title: "Test document", + children: [], + }); + + const createTestChunks = (): ChunkNode[] => [ + { + chunkId: "chunk1", + doc_id: "doc1", + text: "Test chunk", + nodePath: ["root"], + depth: 1, + }, + ]; + + test("setChunks and getChunks", () => { + const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + + doc.setChunks(createTestChunks()); + + const chunks = doc.getChunks(); + expect(chunks).toBeDefined(); + expect(chunks.length).toBe(1); + expect(chunks[0].text).toBe("Test chunk"); + }); + + test("findChunksByNodeId", () => { + const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + + doc.setChunks(createTestChunks()); + + const chunks = doc.findChunksByNodeId("root"); + expect(chunks).toBeDefined(); + expect(chunks.length).toBe(1); + expect(chunks[0].text).toBe("Test chunk"); + }); +}); From 344d5b01342eae86bbfec1915fa969d6cc75280b Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 11 Jan 2026 07:24:25 +0000 Subject: [PATCH 05/14] [feat] Add new AI tasks for document processing and vector management - Introduced multiple new tasks including ChunkToVectorTask, ContextBuilderTask, DocumentEnricherTask, DocumentNodeRetrievalTask, DocumentNodeVectorHybridSearchTask, DocumentNodeVectorSearchTask, DocumentNodeVectorUpsertTask, HierarchicalChunkerTask, HierarchyJoinTask, QueryExpanderTask, RerankerTask, StructuralParserTask, TextChunkerTask, TopicSegmenterTask, and VectorQuantizeTask. - Enhanced the task registry to support these new tasks, allowing for improved document processing workflows and vector management capabilities. - Updated the index file to export the new tasks for easier access and integration. - Added comprehensive tests for each new task to ensure functionality and reliability in various scenarios. --- packages/ai/src/task/ChunkToVectorTask.ts | 179 +++++++ packages/ai/src/task/ContextBuilderTask.ts | 339 ++++++++++++++ packages/ai/src/task/DocumentEnricherTask.ts | 417 +++++++++++++++++ .../ai/src/task/DocumentNodeRetrievalTask.ts | 246 ++++++++++ .../DocumentNodeVectorHybridSearchTask.ts | 235 ++++++++++ .../src/task/DocumentNodeVectorSearchTask.ts | 175 +++++++ .../src/task/DocumentNodeVectorUpsertTask.ts | 175 +++++++ .../ai/src/task/HierarchicalChunkerTask.ts | 305 ++++++++++++ packages/ai/src/task/HierarchyJoinTask.ts | 247 ++++++++++ packages/ai/src/task/QueryExpanderTask.ts | 318 +++++++++++++ packages/ai/src/task/RerankerTask.ts | 341 ++++++++++++++ packages/ai/src/task/StructuralParserTask.ts | 159 +++++++ packages/ai/src/task/TextChunkerTask.ts | 358 ++++++++++++++ packages/ai/src/task/TopicSegmenterTask.ts | 439 ++++++++++++++++++ packages/ai/src/task/VectorQuantizeTask.ts | 257 ++++++++++ packages/ai/src/task/index.ts | 16 +- .../test/src/test/rag/ChunkToVector.test.ts | 124 +++++ .../src/test/rag/ContextBuilderTask.test.ts | 247 ++++++++++ packages/test/src/test/rag/EndToEnd.test.ts | 143 ++++++ packages/test/src/test/rag/FullChain.test.ts | 144 ++++++ .../src/test/rag/HierarchicalChunker.test.ts | 191 ++++++++ .../src/test/rag/HybridSearchTask.test.ts | 278 +++++++++++ .../test/src/test/rag/RagWorkflow.test.ts | 277 +++++++++++ .../src/test/rag/StructuralParser.test.ts | 204 ++++++++ .../test/src/test/rag/TextChunkerTask.test.ts | 226 +++++++++ .../src/test/rag/VectorQuantizeTask.test.ts | 228 +++++++++ .../test/util/VectorSimilarityUtils.test.ts | 390 ++++++++++++++++ .../test/src/test/util/VectorUtils.test.ts | 382 +++++++++++++++ 28 files changed, 7039 insertions(+), 1 deletion(-) create mode 100644 packages/ai/src/task/ChunkToVectorTask.ts create mode 100644 packages/ai/src/task/ContextBuilderTask.ts create mode 100644 packages/ai/src/task/DocumentEnricherTask.ts create mode 100644 packages/ai/src/task/DocumentNodeRetrievalTask.ts create mode 100644 packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts create mode 100644 packages/ai/src/task/DocumentNodeVectorSearchTask.ts create mode 100644 packages/ai/src/task/DocumentNodeVectorUpsertTask.ts create mode 100644 packages/ai/src/task/HierarchicalChunkerTask.ts create mode 100644 packages/ai/src/task/HierarchyJoinTask.ts create mode 100644 packages/ai/src/task/QueryExpanderTask.ts create mode 100644 packages/ai/src/task/RerankerTask.ts create mode 100644 packages/ai/src/task/StructuralParserTask.ts create mode 100644 packages/ai/src/task/TextChunkerTask.ts create mode 100644 packages/ai/src/task/TopicSegmenterTask.ts create mode 100644 packages/ai/src/task/VectorQuantizeTask.ts create mode 100644 packages/test/src/test/rag/ChunkToVector.test.ts create mode 100644 packages/test/src/test/rag/ContextBuilderTask.test.ts create mode 100644 packages/test/src/test/rag/EndToEnd.test.ts create mode 100644 packages/test/src/test/rag/FullChain.test.ts create mode 100644 packages/test/src/test/rag/HierarchicalChunker.test.ts create mode 100644 packages/test/src/test/rag/HybridSearchTask.test.ts create mode 100644 packages/test/src/test/rag/RagWorkflow.test.ts create mode 100644 packages/test/src/test/rag/StructuralParser.test.ts create mode 100644 packages/test/src/test/rag/TextChunkerTask.test.ts create mode 100644 packages/test/src/test/rag/VectorQuantizeTask.test.ts create mode 100644 packages/test/src/test/util/VectorSimilarityUtils.test.ts create mode 100644 packages/test/src/test/util/VectorUtils.test.ts diff --git a/packages/ai/src/task/ChunkToVectorTask.ts b/packages/ai/src/task/ChunkToVectorTask.ts new file mode 100644 index 00000000..e5394f4d --- /dev/null +++ b/packages/ai/src/task/ChunkToVectorTask.ts @@ -0,0 +1,179 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ChunkNodeSchema, type ChunkNode } from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "The document ID", + }, + chunks: { + type: "array", + items: ChunkNodeSchema(), + title: "Chunks", + description: "Array of chunk nodes", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector embedding", + }), + title: "Vectors", + description: "Embeddings from TextEmbeddingTask", + }, + }, + required: [], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "Chunk IDs for vector store", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector embedding", + }), + title: "Vectors", + description: "Vector embeddings", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata for vector store", + }, + title: "Metadata", + description: "Metadata for each vector", + }, + texts: { + type: "array", + items: { type: "string" }, + title: "Texts", + description: "Chunk texts (for reference)", + }, + }, + required: ["ids", "vectors", "metadata", "texts"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type ChunkToVectorTaskInput = FromSchema; +export type ChunkToVectorTaskOutput = FromSchema; + +/** + * Task to transform chunk nodes and embeddings into vector store format + * Bridges HierarchicalChunker + TextEmbedding → VectorStoreUpsert + */ +export class ChunkToVectorTask extends Task< + ChunkToVectorTaskInput, + ChunkToVectorTaskOutput, + JobQueueTaskConfig +> { + public static type = "ChunkToVectorTask"; + public static category = "Document"; + public static title = "Chunk to Vector Transform"; + public static description = "Transform chunks and embeddings to vector store format"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: ChunkToVectorTaskInput, + context: IExecuteContext + ): Promise { + const { chunks, vectors } = input; + + const chunkArray = chunks as ChunkNode[]; + + if (!chunkArray || !vectors) { + throw new Error("Both chunks and vector are required"); + } + + if (chunkArray.length !== vectors.length) { + throw new Error(`Mismatch: ${chunkArray.length} chunks but ${vectors.length} vectors`); + } + + const ids: string[] = []; + const metadata: any[] = []; + const texts: string[] = []; + + for (let i = 0; i < chunkArray.length; i++) { + const chunk = chunkArray[i]; + + ids.push(chunk.chunkId); + texts.push(chunk.text); + + metadata.push({ + doc_id: chunk.doc_id, + chunkId: chunk.chunkId, + leafNodeId: chunk.nodePath[chunk.nodePath.length - 1], + depth: chunk.depth, + text: chunk.text, + nodePath: chunk.nodePath, + // Include enrichment if present + ...(chunk.enrichment || {}), + }); + } + + return { + ids, + vectors, + metadata, + texts, + }; + } +} + +TaskRegistry.registerTask(ChunkToVectorTask); + +export const chunkToVector = (input: ChunkToVectorTaskInput, config?: JobQueueTaskConfig) => { + return new ChunkToVectorTask({} as ChunkToVectorTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + chunkToVector: CreateWorkflow< + ChunkToVectorTaskInput, + ChunkToVectorTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.chunkToVector = CreateWorkflow(ChunkToVectorTask); diff --git a/packages/ai/src/task/ContextBuilderTask.ts b/packages/ai/src/task/ContextBuilderTask.ts new file mode 100644 index 00000000..19dee6dc --- /dev/null +++ b/packages/ai/src/task/ContextBuilderTask.ts @@ -0,0 +1,339 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +export const ContextFormat = { + SIMPLE: "simple", + NUMBERED: "numbered", + XML: "xml", + MARKDOWN: "markdown", + JSON: "json", +} as const; + +export type ContextFormat = (typeof ContextFormat)[keyof typeof ContextFormat]; + +const inputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "Retrieved text chunks to format", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata for each chunk", + }, + title: "Metadata", + description: "Metadata for each chunk (optional)", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Relevance scores for each chunk (optional)", + }, + format: { + type: "string", + enum: Object.values(ContextFormat), + title: "Format", + description: "Format for the context output", + default: ContextFormat.SIMPLE, + }, + maxLength: { + type: "number", + title: "Max Length", + description: "Maximum length of context in characters (0 = unlimited)", + minimum: 0, + default: 0, + }, + includeMetadata: { + type: "boolean", + title: "Include Metadata", + description: "Whether to include metadata in the context", + default: false, + }, + separator: { + type: "string", + title: "Separator", + description: "Separator between chunks", + default: "\n\n", + }, + }, + required: ["chunks"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + context: { + type: "string", + title: "Context", + description: "Formatted context string for LLM", + }, + chunksUsed: { + type: "number", + title: "Chunks Used", + description: "Number of chunks included in context", + }, + totalLength: { + type: "number", + title: "Total Length", + description: "Total length of context in characters", + }, + }, + required: ["context", "chunksUsed", "totalLength"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type ContextBuilderTaskInput = FromSchema; +export type ContextBuilderTaskOutput = FromSchema; + +/** + * Task for formatting retrieved chunks into context for LLM prompts. + * Supports various formatting styles and length constraints. + */ +export class ContextBuilderTask extends Task< + ContextBuilderTaskInput, + ContextBuilderTaskOutput, + JobQueueTaskConfig +> { + public static type = "ContextBuilderTask"; + public static category = "RAG"; + public static title = "Context Builder"; + public static description = "Format retrieved chunks into context for LLM prompts"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async executeReactive( + input: ContextBuilderTaskInput, + output: ContextBuilderTaskOutput + ): Promise { + const { + chunks, + metadata = [], + scores = [], + format = ContextFormat.SIMPLE, + maxLength = 0, + includeMetadata = false, + separator = "\n\n", + } = input; + + let context = ""; + let chunksUsed = 0; + + for (let i = 0; i < chunks.length; i++) { + const chunk = chunks[i]; + const meta = metadata[i]; + const score = scores[i]; + + let formattedChunk = this.formatChunk(chunk, meta, score, i, format, includeMetadata); + + // Check length constraint + if (maxLength > 0) { + const potentialLength = context.length + formattedChunk.length + separator.length; + if (potentialLength > maxLength) { + // Try to fit partial chunk if it's the first one + if (chunksUsed === 0) { + const available = maxLength - context.length; + if (available > 100) { + // Only include partial if we have reasonable space + formattedChunk = formattedChunk.substring(0, available - 3) + "..."; + context += formattedChunk; + chunksUsed++; + } + } + break; + } + } + + if (chunksUsed > 0) { + context += separator; + } + context += formattedChunk; + chunksUsed++; + } + + return { + context, + chunksUsed, + totalLength: context.length, + }; + } + + private formatChunk( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + format: ContextFormat, + includeMetadata: boolean + ): string { + switch (format) { + case ContextFormat.NUMBERED: + return this.formatNumbered(chunk, metadata, score, index, includeMetadata); + case ContextFormat.XML: + return this.formatXML(chunk, metadata, score, index, includeMetadata); + case ContextFormat.MARKDOWN: + return this.formatMarkdown(chunk, metadata, score, index, includeMetadata); + case ContextFormat.JSON: + return this.formatJSON(chunk, metadata, score, index, includeMetadata); + case ContextFormat.SIMPLE: + default: + return chunk; + } + } + + private formatNumbered( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + includeMetadata: boolean + ): string { + let result = `[${index + 1}] ${chunk}`; + if (includeMetadata && metadata) { + const metaStr = this.formatMetadataInline(metadata, score); + if (metaStr) { + result += ` ${metaStr}`; + } + } + return result; + } + + private formatXML( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + includeMetadata: boolean + ): string { + let result = ``; + if (includeMetadata && (metadata || score !== undefined)) { + result += "\n "; + if (score !== undefined) { + result += `\n ${score.toFixed(4)}`; + } + if (metadata) { + for (const [key, value] of Object.entries(metadata)) { + result += `\n <${key}>${this.escapeXML(String(value))}`; + } + } + result += "\n "; + result += `\n ${this.escapeXML(chunk)}`; + result += "\n"; + } else { + result += `${this.escapeXML(chunk)}`; + } + return result; + } + + private formatMarkdown( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + includeMetadata: boolean + ): string { + let result = `### Chunk ${index + 1}\n\n`; + if (includeMetadata && (metadata || score !== undefined)) { + result += "**Metadata:**\n"; + if (score !== undefined) { + result += `- Score: ${score.toFixed(4)}\n`; + } + if (metadata) { + for (const [key, value] of Object.entries(metadata)) { + result += `- ${key}: ${value}\n`; + } + } + result += "\n"; + } + result += chunk; + return result; + } + + private formatJSON( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + includeMetadata: boolean + ): string { + const obj: any = { + index: index + 1, + content: chunk, + }; + if (includeMetadata) { + if (score !== undefined) { + obj.score = score; + } + if (metadata) { + obj.metadata = metadata; + } + } + return JSON.stringify(obj); + } + + private formatMetadataInline(metadata: any, score: number | undefined): string { + const parts: string[] = []; + if (score !== undefined) { + parts.push(`score=${score.toFixed(4)}`); + } + if (metadata) { + for (const [key, value] of Object.entries(metadata)) { + parts.push(`${key}=${value}`); + } + } + return parts.length > 0 ? `(${parts.join(", ")})` : ""; + } + + private escapeXML(str: string): string { + return str + .replace(/&/g, "&") + .replace(//g, ">") + .replace(/"/g, """) + .replace(/'/g, "'"); + } +} + +TaskRegistry.registerTask(ContextBuilderTask); + +export const contextBuilder = (input: ContextBuilderTaskInput, config?: JobQueueTaskConfig) => { + return new ContextBuilderTask({} as ContextBuilderTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + contextBuilder: CreateWorkflow< + ContextBuilderTaskInput, + ContextBuilderTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.contextBuilder = CreateWorkflow(ContextBuilderTask); diff --git a/packages/ai/src/task/DocumentEnricherTask.ts b/packages/ai/src/task/DocumentEnricherTask.ts new file mode 100644 index 00000000..4fce5aec --- /dev/null +++ b/packages/ai/src/task/DocumentEnricherTask.ts @@ -0,0 +1,417 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + getChildren, + hasChildren, + type DocumentNode, + type Entity, + type NodeEnrichment, +} from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; +import { ModelConfig } from "../model/ModelSchema"; +import { TextNamedEntityRecognitionTask } from "./TextNamedEntityRecognitionTask"; +import { TextSummaryTask } from "./TextSummaryTask"; +import { TypeModel } from "./base/AiTaskSchemas"; + +const inputSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "The document ID", + }, + documentTree: { + title: "Document Tree", + description: "The hierarchical document tree to enrich", + }, + generateSummaries: { + type: "boolean", + title: "Generate Summaries", + description: "Whether to generate summaries for sections", + default: true, + }, + extractEntities: { + type: "boolean", + title: "Extract Entities", + description: "Whether to extract named entities", + default: true, + }, + summaryModel: TypeModel("model:TextSummaryTask", { + title: "Summary Model", + description: "Model to use for summary generation (optional)", + }), + summaryThreshold: { + type: "number", + title: "Summary Threshold", + description: "Minimum combined text length (node + children) to warrant generating a summary", + default: 500, + }, + nerModel: TypeModel("model:TextNamedEntityRecognitionTask", { + title: "NER Model", + description: "Model to use for named entity recognition (optional)", + }), + }, + required: [], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "The document ID (passed through)", + }, + documentTree: { + title: "Document Tree", + description: "The enriched document tree", + }, + summaryCount: { + type: "number", + title: "Summary Count", + description: "Number of summaries generated", + }, + entityCount: { + type: "number", + title: "Entity Count", + description: "Number of entities extracted", + }, + }, + required: ["doc_id", "documentTree", "summaryCount", "entityCount"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type DocumentEnricherTaskInput = FromSchema; +export type DocumentEnricherTaskOutput = FromSchema; + +/** + * Task for enriching document nodes with summaries and entities + * Uses bottom-up propagation to roll up child information to parents + */ +export class DocumentEnricherTask extends Task< + DocumentEnricherTaskInput, + DocumentEnricherTaskOutput, + JobQueueTaskConfig +> { + public static type = "DocumentEnricherTask"; + public static category = "Document"; + public static title = "Document Enricher"; + public static description = "Enrich document nodes with summaries and entities"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: DocumentEnricherTaskInput, + context: IExecuteContext + ): Promise { + const { + doc_id, + documentTree, + generateSummaries = true, + extractEntities = true, + summaryModel: summaryModelConfig, + summaryThreshold = 500, + nerModel: nerModelConfig, + } = input; + + const root = documentTree as DocumentNode; + const summaryModel = summaryModelConfig ? (summaryModelConfig as ModelConfig) : undefined; + const nerModel = nerModelConfig ? (nerModelConfig as ModelConfig) : undefined; + let summaryCount = 0; + let entityCount = 0; + + const extract = + extractEntities && nerModel + ? async (text: string) => { + const result = await context + .own(new TextNamedEntityRecognitionTask({ text, model: nerModel })) + .run(); + return result.entities.map((e) => ({ + type: e.entity, + text: e.word, + score: e.score, + })); + } + : undefined; + + // Bottom-up enrichment + const enrichedRoot = await this.enrichNode( + root, + context, + generateSummaries && summaryModel ? summaryModel : undefined, + summaryThreshold, + extract, + (count) => (summaryCount += count), + (count) => (entityCount += count) + ); + + return { + doc_id: doc_id as string, + documentTree: enrichedRoot, + summaryCount, + entityCount, + }; + } + + /** + * Enrich a node recursively (bottom-up) + */ + private async enrichNode( + node: DocumentNode, + context: IExecuteContext, + summaryModel: ModelConfig | undefined, + summaryThreshold: number, + extract: ((text: string) => Promise) | undefined, + onSummary: (count: number) => void, + onEntity: (count: number) => void + ): Promise { + // If node has children, enrich them first + let enrichedChildren: DocumentNode[] | undefined; + if (hasChildren(node)) { + const children = getChildren(node); + enrichedChildren = await Promise.all( + children.map((child) => + this.enrichNode( + child, + context, + summaryModel, + summaryThreshold, + extract, + onSummary, + onEntity + ) + ) + ); + } + + // Generate enrichment for this node + const enrichment: NodeEnrichment = {}; + + // Generate summary (for sections and documents) + if (summaryModel && (node.kind === "section" || node.kind === "document")) { + if (enrichedChildren && enrichedChildren.length > 0) { + // Summary of children + enrichment.summary = await this.generateSummary( + node, + enrichedChildren, + context, + summaryModel, + summaryThreshold + ); + } else { + // Leaf section summary + enrichment.summary = await this.generateLeafSummary( + node.text, + context, + summaryModel, + summaryThreshold + ); + } + if (enrichment.summary) { + onSummary(1); + } + } + + // Extract entities + if (extract) { + enrichment.entities = await this.extractEntities(node, enrichedChildren, extract); + if (enrichment.entities) { + onEntity(enrichment.entities.length); + } + } + + // Create enriched node + const enrichedNode: DocumentNode = { + ...node, + enrichment: Object.keys(enrichment).length > 0 ? enrichment : undefined, + }; + + if (enrichedChildren) { + (enrichedNode as any).children = enrichedChildren; + } + + return enrichedNode; + } + + /** + * Private method to summarize text using the TextSummaryTask + */ + private async summarize( + text: string, + context: IExecuteContext, + model: ModelConfig + ): Promise { + // TODO: Handle truncation of text if needed, based on model configuration + return (await context.own(new TextSummaryTask()).run({ text, model })).text; + } + + /** + * Generate summary for a node with children + */ + private async generateSummary( + node: DocumentNode, + children: DocumentNode[], + context: IExecuteContext, + model: ModelConfig, + threshold: number + ): Promise { + const textParts: string[] = []; + + // Include the node's own text + const nodeText = node.text?.trim(); + if (nodeText) { + textParts.push(nodeText); + } + + // Include children summaries/texts + const childTexts = children + .map((child) => { + if (child.enrichment?.summary) { + return child.enrichment.summary; + } + return child.text; + }) + .join(" ") + .trim(); + + if (childTexts) { + textParts.push(childTexts); + } + + const combinedText = textParts.join(" ").trim(); + if (!combinedText) { + return undefined; + } + + // Check if summary is warranted based on threshold + if (combinedText.length < threshold) { + return undefined; + } + + const summaryParts: string[] = []; + + // Summarize the node's own text first + if (nodeText) { + const nodeSummary = await this.summarize(nodeText, context, model); + if (nodeSummary) { + summaryParts.push(nodeSummary); + } + } + + // Include children summaries/texts + if (childTexts) { + summaryParts.push(childTexts); + } + + const combinedSummaries = summaryParts.join(" ").trim(); + if (!combinedSummaries) { + return undefined; + } + + const result = await this.summarize(combinedSummaries, context, model); + return result; + } + + /** + * Generate summary for a leaf node + */ + private async generateLeafSummary( + text: string, + context: IExecuteContext, + model: ModelConfig, + threshold: number + ): Promise { + const trimmedText = text.trim(); + if (!trimmedText) { + return undefined; + } + + // Check if summary is warranted based on threshold + if (trimmedText.length < threshold) { + return undefined; + } + + const result = await this.summarize(trimmedText, context, model); + return result; + } + + /** + * Extract and roll up entities from node and children + */ + private async extractEntities( + node: DocumentNode, + children: DocumentNode[] | undefined, + extract: ((text: string) => Promise) | undefined + ): Promise { + const entities: Entity[] = []; + + // Collect from children first + if (children) { + for (const child of children) { + if (child.enrichment?.entities) { + entities.push(...child.enrichment.entities); + } + } + } + + const text = node.text.trim(); + if (text && extract) { + const nodeEntities = await extract(text); + if (nodeEntities?.length) { + entities.push(...nodeEntities); + } + } + + // Deduplicate by text + const unique = new Map(); + for (const entity of entities) { + const key = `${entity.text}::${entity.type}`; + const existing = unique.get(key); + if (!existing || entity.score > existing.score) { + unique.set(key, entity); + } + } + + const result = Array.from(unique.values()); + return result.length > 0 ? result : undefined; + } +} + +TaskRegistry.registerTask(DocumentEnricherTask); + +export const documentEnricher = (input: DocumentEnricherTaskInput, config?: JobQueueTaskConfig) => { + return new DocumentEnricherTask({} as DocumentEnricherTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + documentEnricher: CreateWorkflow< + DocumentEnricherTaskInput, + DocumentEnricherTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.documentEnricher = CreateWorkflow(DocumentEnricherTask); diff --git a/packages/ai/src/task/DocumentNodeRetrievalTask.ts b/packages/ai/src/task/DocumentNodeRetrievalTask.ts new file mode 100644 index 00000000..b54565a5 --- /dev/null +++ b/packages/ai/src/task/DocumentNodeRetrievalTask.ts @@ -0,0 +1,246 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AnyDocumentNodeVectorRepository, + TypeDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArray, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; +import { TypeModel } from "./base/AiTaskSchemas"; +import { TextEmbeddingTask } from "./TextEmbeddingTask"; + +const inputSchema = { + type: "object", + properties: { + repository: TypeDocumentNodeVectorRepository({ + title: "Document Chunk Vector Repository", + description: "The document chunk vector repository instance to search in", + }), + query: { + oneOf: [ + { type: "string" }, + TypedArraySchema({ + title: "Query Vector", + description: "Pre-computed query vector", + }), + ], + title: "Query", + description: "Query string or pre-computed query vector", + }, + model: TypeModel("model:TextEmbeddingTask", { + title: "Model", + description: + "Text embedding model to use for query embedding (required when query is a string)", + }), + topK: { + type: "number", + title: "Top K", + description: "Number of top results to return", + minimum: 1, + default: 5, + }, + filter: { + type: "object", + title: "Metadata Filter", + description: "Filter results by metadata fields", + }, + scoreThreshold: { + type: "number", + title: "Score Threshold", + description: "Minimum similarity score threshold (0-1)", + minimum: 0, + maximum: 1, + default: 0, + }, + returnVectors: { + type: "boolean", + title: "Return Vectors", + description: "Whether to return vector embeddings in results", + default: false, + }, + }, + required: ["repository", "query"], + if: { + properties: { + query: { type: "string" }, + }, + }, + then: { + required: ["repository", "query", "model"], + }, + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "Retrieved text chunks", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "IDs of retrieved chunks", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata of retrieved chunk", + }, + title: "Metadata", + description: "Metadata of retrieved chunks", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Similarity scores for each result", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector embedding", + }), + title: "Vectors", + description: "Vector embeddings (if returnVectors is true)", + }, + count: { + type: "number", + title: "Count", + description: "Number of results returned", + }, + }, + required: ["chunks", "ids", "metadata", "scores", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type RetrievalTaskInput = FromSchema; +export type RetrievalTaskOutput = FromSchema; + +/** + * End-to-end retrieval task that combines embedding generation (if needed) and vector search. + * Simplifies the RAG pipeline by handling the full retrieval process. + */ +export class DocumentNodeRetrievalTask extends Task< + RetrievalTaskInput, + RetrievalTaskOutput, + JobQueueTaskConfig +> { + public static type = "DocumentNodeRetrievalTask"; + public static category = "RAG"; + public static title = "Retrieval"; + public static description = "End-to-end retrieval: embed query and search for similar chunks"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute(input: RetrievalTaskInput, context: IExecuteContext): Promise { + const { + repository, + query, + topK = 5, + filter, + model, + scoreThreshold = 0, + returnVectors = false, + } = input; + + // Repository is resolved by input resolver system before execution + const repo = repository as AnyDocumentNodeVectorRepository; + + // Determine query vector + let queryVector: TypedArray; + if (typeof query === "string") { + // If query is a string, model must be provided (enforced by schema) + if (!model) { + throw new Error( + "Model is required when query is a string. Please provide a model with format 'model:TextEmbeddingTask'." + ); + } + const embeddingTask = context.own(new TextEmbeddingTask({ text: query, model })); + const embeddingResult = await embeddingTask.run(); + queryVector = Array.isArray(embeddingResult.vector) + ? embeddingResult.vector[0] + : embeddingResult.vector; + } else { + // Query is already a vector + queryVector = query as TypedArray; + } + + // Convert to Float32Array for repository search (repo expects Float32Array by default) + const searchVector = + queryVector instanceof Float32Array ? queryVector : new Float32Array(queryVector); + + // Search vector repository + const results = await repo.similaritySearch(searchVector, { + topK, + filter, + scoreThreshold, + }); + + // Extract text chunks from metadata + // Assumes metadata has a 'text' or 'content' field + const chunks = results.map((r) => { + const meta = r.metadata as any; + return meta.text || meta.content || meta.chunk || JSON.stringify(meta); + }); + + const output: RetrievalTaskOutput = { + chunks, + ids: results.map((r) => r.chunk_id), + metadata: results.map((r) => r.metadata), + scores: results.map((r) => r.score), + count: results.length, + }; + + if (returnVectors) { + output.vectors = results.map((r) => r.vector); + } + + return output; + } +} + +TaskRegistry.registerTask(DocumentNodeRetrievalTask); + +export const retrieval = (input: RetrievalTaskInput, config?: JobQueueTaskConfig) => { + return new DocumentNodeRetrievalTask({} as RetrievalTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + retrieval: CreateWorkflow; + } +} + +Workflow.prototype.retrieval = CreateWorkflow(DocumentNodeRetrievalTask); diff --git a/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts b/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts new file mode 100644 index 00000000..d4f4c608 --- /dev/null +++ b/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts @@ -0,0 +1,235 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AnyDocumentNodeVectorRepository, + TypeDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + repository: TypeDocumentNodeVectorRepository({ + title: "Document Chunk Vector Repository", + description: + "The document chunk vector repository instance to search in (must support hybridSearch)", + }), + queryVector: TypedArraySchema({ + title: "Query Vector", + description: "The query vector for semantic search", + }), + queryText: { + type: "string", + title: "Query Text", + description: "The query text for full-text search", + }, + topK: { + type: "number", + title: "Top K", + description: "Number of top results to return", + minimum: 1, + default: 10, + }, + filter: { + type: "object", + title: "Metadata Filter", + description: "Filter results by metadata fields", + }, + scoreThreshold: { + type: "number", + title: "Score Threshold", + description: "Minimum combined score threshold (0-1)", + minimum: 0, + maximum: 1, + default: 0, + }, + vectorWeight: { + type: "number", + title: "Vector Weight", + description: "Weight for vector similarity (0-1), remainder goes to text relevance", + minimum: 0, + maximum: 1, + default: 0.7, + }, + returnVectors: { + type: "boolean", + title: "Return Vectors", + description: "Whether to return vector embeddings in results", + default: false, + }, + }, + required: ["repository", "queryVector", "queryText"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "Retrieved text chunks", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "IDs of retrieved chunks", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata of retrieved chunk", + }, + title: "Metadata", + description: "Metadata of retrieved chunks", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Combined relevance scores for each result", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector embedding", + }), + title: "Vectors", + description: "Vector embeddings (if returnVectors is true)", + }, + count: { + type: "number", + title: "Count", + description: "Number of results returned", + }, + }, + required: ["chunks", "ids", "metadata", "scores", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type HybridSearchTaskInput = FromSchema; +export type HybridSearchTaskOutput = FromSchema; + +/** + * Task for hybrid search combining vector similarity and full-text search. + * Requires a document chunk vector repository that supports hybridSearch (e.g., Postgres with pgvector). + * + * Hybrid search improves retrieval by combining: + * - Semantic similarity (vector search) - understands meaning + * - Keyword matching (full-text search) - finds exact terms + */ +export class DocumentNodeVectorHybridSearchTask extends Task< + HybridSearchTaskInput, + HybridSearchTaskOutput, + JobQueueTaskConfig +> { + public static type = "DocumentNodeVectorHybridSearchTask"; + public static category = "RAG"; + public static title = "Hybrid Search"; + public static description = "Combined vector + full-text search for improved retrieval"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: HybridSearchTaskInput, + context: IExecuteContext + ): Promise { + const { + repository, + queryVector, + queryText, + topK = 10, + filter, + scoreThreshold = 0, + vectorWeight = 0.7, + returnVectors = false, + } = input; + + // Repository is resolved by input resolver system before execution + const repo = repository as AnyDocumentNodeVectorRepository; + + // Check if repository supports hybrid search + if (!repo.hybridSearch) { + throw new Error("Repository does not support hybrid search."); + } + + // Convert to Float32Array for repository search (repo expects Float32Array by default) + const searchVector = + queryVector instanceof Float32Array ? queryVector : new Float32Array(queryVector); + + // Perform hybrid search + const results = await repo.hybridSearch(searchVector, { + textQuery: queryText, + topK, + filter, + scoreThreshold, + vectorWeight, + }); + + // Extract text chunks from metadata + const chunks = results.map((r) => { + const meta = r.metadata as Record; + return meta.text || meta.content || meta.chunk || JSON.stringify(meta); + }); + + const output: HybridSearchTaskOutput = { + chunks, + ids: results.map((r) => r.chunk_id), + metadata: results.map((r) => r.metadata), + scores: results.map((r) => r.score), + count: results.length, + }; + + if (returnVectors) { + output.vectors = results.map((r) => r.vector); + } + + return output; + } +} + +TaskRegistry.registerTask(DocumentNodeVectorHybridSearchTask); + +export const hybridSearch = async ( + input: HybridSearchTaskInput, + config?: JobQueueTaskConfig +): Promise => { + return new DocumentNodeVectorHybridSearchTask({} as HybridSearchTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + hybridSearch: CreateWorkflow; + } +} + +Workflow.prototype.hybridSearch = CreateWorkflow(DocumentNodeVectorHybridSearchTask); diff --git a/packages/ai/src/task/DocumentNodeVectorSearchTask.ts b/packages/ai/src/task/DocumentNodeVectorSearchTask.ts new file mode 100644 index 00000000..63d736f8 --- /dev/null +++ b/packages/ai/src/task/DocumentNodeVectorSearchTask.ts @@ -0,0 +1,175 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AnyDocumentNodeVectorRepository, + TypeDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + repository: TypeDocumentNodeVectorRepository({ + title: "Vector Repository", + description: "The vector repository instance to search in", + }), + query: TypedArraySchema({ + title: "Query Vector", + description: "The query vector to search for similar vectors", + }), + topK: { + type: "number", + title: "Top K", + description: "Number of top results to return", + minimum: 1, + default: 10, + }, + filter: { + type: "object", + title: "Metadata Filter", + description: "Filter results by metadata fields", + }, + scoreThreshold: { + type: "number", + title: "Score Threshold", + description: "Minimum similarity score threshold (0-1)", + minimum: 0, + maximum: 1, + default: 0, + }, + }, + required: ["repository", "query"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "IDs of matching vectors", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Matching vector embedding", + }), + title: "Vectors", + description: "Matching vector embeddings", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata of matching vector", + }, + title: "Metadata", + description: "Metadata of matching vectors", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Similarity scores for each result", + }, + count: { + type: "number", + title: "Count", + description: "Number of results returned", + }, + }, + required: ["ids", "vectors", "metadata", "scores", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type VectorStoreSearchTaskInput = FromSchema; +export type VectorStoreSearchTaskOutput = FromSchema; + +/** + * Task for searching similar vectors in a vector repository. + * Returns top-K most similar vectors with their metadata and scores. + */ +export class DocumentNodeVectorSearchTask extends Task< + VectorStoreSearchTaskInput, + VectorStoreSearchTaskOutput, + JobQueueTaskConfig +> { + public static type = "DocumentNodeVectorSearchTask"; + public static category = "Vector Store"; + public static title = "Vector Store Search"; + public static description = "Search for similar vectors in a vector repository"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: VectorStoreSearchTaskInput, + context: IExecuteContext + ): Promise { + const { repository, query, topK = 10, filter, scoreThreshold = 0 } = input; + + const repo = repository as AnyDocumentNodeVectorRepository; + + const results = await repo.similaritySearch(query, { + topK, + filter, + scoreThreshold, + }); + + return { + ids: results.map((r) => r.chunk_id), + vectors: results.map((r) => r.vector), + metadata: results.map((r) => r.metadata), + scores: results.map((r) => r.score), + count: results.length, + }; + } +} + +TaskRegistry.registerTask(DocumentNodeVectorSearchTask); + +export const vectorStoreSearch = ( + input: VectorStoreSearchTaskInput, + config?: JobQueueTaskConfig +) => { + return new DocumentNodeVectorSearchTask({} as VectorStoreSearchTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + vectorStoreSearch: CreateWorkflow< + VectorStoreSearchTaskInput, + VectorStoreSearchTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.vectorStoreSearch = CreateWorkflow(DocumentNodeVectorSearchTask); diff --git a/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts b/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts new file mode 100644 index 00000000..e5800c30 --- /dev/null +++ b/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts @@ -0,0 +1,175 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AnyDocumentNodeVectorRepository, + TypeDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; +import { TypeSingleOrArray } from "./base/AiTaskSchemas"; + +const inputSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "The document ID", + }, + repository: TypeDocumentNodeVectorRepository({ + title: "Document Chunk Vector Repository", + description: "The document chunk vector repository instance to store vectors in", + }), + vectors: TypeSingleOrArray( + TypedArraySchema({ + title: "Vector", + description: "The vector embedding", + }) + ), + metadata: { + type: "object", + title: "Metadata", + description: "Metadata associated with the vector", + }, + }, + required: ["repository", "doc_id", "vectors", "metadata"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + count: { + type: "number", + title: "Count", + description: "Number of vectors upserted", + }, + doc_id: { + type: "string", + title: "Document ID", + description: "The document ID", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "IDs of upserted vectors", + }, + }, + required: ["count", "ids"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type VectorStoreUpsertTaskInput = FromSchema< + typeof inputSchema, + TypedArraySchemaOptions // & TypeVectorRepositoryOptions +>; +export type VectorStoreUpsertTaskOutput = FromSchema; + +/** + * Task for upserting (insert or update) vectors into a vector repository. + * Supports both single and bulk operations. + */ +export class DocumentNodeVectorUpsertTask extends Task< + VectorStoreUpsertTaskInput, + VectorStoreUpsertTaskOutput, + JobQueueTaskConfig +> { + public static type = "DocumentNodeVectorUpsertTask"; + public static category = "Vector Store"; + public static title = "Vector Store Upsert"; + public static description = "Store vector embeddings with metadata in a vector repository"; + public static cacheable = false; // Has side effects + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: VectorStoreUpsertTaskInput, + context: IExecuteContext + ): Promise { + const { repository, doc_id, vectors, metadata } = input; + + // Normalize inputs to arrays + const vectorArray = Array.isArray(vectors) ? vectors : [vectors]; + + const repo = repository as AnyDocumentNodeVectorRepository; + + await context.updateProgress(1, "Upserting vectors"); + + const idArray: string[] = []; + + // Bulk upsert if multiple items + if (vectorArray.length > 1) { + const entities = vectorArray.map((vector, i) => { + const chunk_id = `${doc_id}_${i}`; + idArray.push(chunk_id); + return { + chunk_id, + doc_id, + vector: vector as any, // Store TypedArray directly (memory) or as string (SQL) + metadata, + }; + }); + await repo.putBulk(entities as any); + } else if (vectorArray.length === 1) { + // Single upsert + const chunk_id = `${doc_id}_0`; + idArray.push(chunk_id); + await repo.put({ + chunk_id, + doc_id, + vector: vectorArray[0] as any, // Store TypedArray directly (memory) or as string (SQL) + metadata, + } as any); + } + + return { + doc_id, + ids: idArray, + count: vectorArray.length, + }; + } +} + +TaskRegistry.registerTask(DocumentNodeVectorUpsertTask); + +export const vectorStoreUpsert = ( + input: VectorStoreUpsertTaskInput, + config?: JobQueueTaskConfig +) => { + return new DocumentNodeVectorUpsertTask({} as VectorStoreUpsertTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + vectorStoreUpsert: CreateWorkflow< + VectorStoreUpsertTaskInput, + VectorStoreUpsertTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.vectorStoreUpsert = CreateWorkflow(DocumentNodeVectorUpsertTask); diff --git a/packages/ai/src/task/HierarchicalChunkerTask.ts b/packages/ai/src/task/HierarchicalChunkerTask.ts new file mode 100644 index 00000000..e35a4f67 --- /dev/null +++ b/packages/ai/src/task/HierarchicalChunkerTask.ts @@ -0,0 +1,305 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + ChunkNodeSchema, + estimateTokens, + getChildren, + hasChildren, + NodeIdGenerator, + type ChunkNode, + type DocumentNode, + type TokenBudget, +} from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "The ID of the document", + }, + documentTree: { + title: "Document Tree", + description: "The hierarchical document tree to chunk", + }, + maxTokens: { + type: "number", + title: "Max Tokens", + description: "Maximum tokens per chunk", + minimum: 50, + default: 512, + }, + overlap: { + type: "number", + title: "Overlap", + description: "Overlap in tokens between chunks", + minimum: 0, + default: 50, + }, + reservedTokens: { + type: "number", + title: "Reserved Tokens", + description: "Reserved tokens for metadata/wrappers", + minimum: 0, + default: 10, + }, + strategy: { + type: "string", + enum: ["hierarchical", "flat", "sentence"], + title: "Chunking Strategy", + description: "Strategy for chunking", + default: "hierarchical", + }, + }, + required: [], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "The document ID (passed through)", + }, + chunks: { + type: "array", + items: ChunkNodeSchema(), + title: "Chunks", + description: "Array of chunk nodes", + }, + text: { + type: "array", + items: { type: "string" }, + title: "Texts", + description: "Chunk texts (for TextEmbeddingTask)", + }, + count: { + type: "number", + title: "Count", + description: "Number of chunks generated", + }, + }, + required: ["doc_id", "chunks", "text", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type HierarchicalChunkerTaskInput = FromSchema; +export type HierarchicalChunkerTaskOutput = FromSchema; + +/** + * Task for hierarchical chunking that respects token budgets and document structure + */ +export class HierarchicalChunkerTask extends Task< + HierarchicalChunkerTaskInput, + HierarchicalChunkerTaskOutput, + JobQueueTaskConfig +> { + public static type = "HierarchicalChunkerTask"; + public static category = "Document"; + public static title = "Hierarchical Chunker"; + public static description = "Chunk documents hierarchically respecting token budgets"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: HierarchicalChunkerTaskInput, + context: IExecuteContext + ): Promise { + const { + doc_id, + documentTree, + maxTokens = 512, + overlap = 50, + reservedTokens = 10, + strategy = "hierarchical", + } = input; + + if (!doc_id) { + throw new Error("doc_id is required"); + } + if (!documentTree) { + throw new Error("documentTree is required"); + } + + const root = documentTree as DocumentNode; + const tokenBudget: TokenBudget = { + maxTokensPerChunk: maxTokens, + overlapTokens: overlap, + reservedTokens, + }; + + const chunks: ChunkNode[] = []; + + if (strategy === "hierarchical") { + await this.chunkHierarchically(root, [], doc_id, tokenBudget, chunks); + } else { + // Flat chunking: treat entire document as flat text + await this.chunkFlat(root, doc_id, tokenBudget, chunks); + } + + return { + doc_id, + chunks, + text: chunks.map((c) => c.text), + count: chunks.length, + }; + } + + /** + * Hierarchical chunking that respects document structure + */ + private async chunkHierarchically( + node: DocumentNode, + nodePath: string[], + doc_id: string, + tokenBudget: TokenBudget, + chunks: ChunkNode[] + ): Promise { + const currentPath = [...nodePath, node.nodeId]; + + // If node has no children, it's a leaf - chunk its text + if (!hasChildren(node)) { + await this.chunkText(node.text, currentPath, doc_id, tokenBudget, chunks, node.nodeId); + return; + } + + // For nodes with children, recursively chunk children + const children = getChildren(node); + for (const child of children) { + await this.chunkHierarchically(child, currentPath, doc_id, tokenBudget, chunks); + } + } + + /** + * Chunk a single text string + */ + private async chunkText( + text: string, + nodePath: string[], + doc_id: string, + tokenBudget: TokenBudget, + chunks: ChunkNode[], + leafNodeId: string + ): Promise { + const maxChars = (tokenBudget.maxTokensPerChunk - tokenBudget.reservedTokens) * 4; + const overlapChars = tokenBudget.overlapTokens * 4; + + if (estimateTokens(text) <= tokenBudget.maxTokensPerChunk - tokenBudget.reservedTokens) { + // Text fits in one chunk + const chunkId = await NodeIdGenerator.generateChunkId(doc_id, leafNodeId, 0); + chunks.push({ + chunkId, + doc_id, + text, + nodePath, + depth: nodePath.length, + }); + return; + } + + // Split into multiple chunks with overlap + let chunkOrdinal = 0; + let startOffset = 0; + + while (startOffset < text.length) { + const endOffset = Math.min(startOffset + maxChars, text.length); + const chunkText = text.substring(startOffset, endOffset); + + const chunkId = await NodeIdGenerator.generateChunkId(doc_id, leafNodeId, chunkOrdinal); + + chunks.push({ + chunkId, + doc_id, + text: chunkText, + nodePath, + depth: nodePath.length, + }); + + chunkOrdinal++; + startOffset += maxChars - overlapChars; + + // Prevent infinite loop + if (overlapChars >= maxChars) { + startOffset = endOffset; + } + } + } + + /** + * Flat chunking (ignores hierarchy) + */ + private async chunkFlat( + root: DocumentNode, + doc_id: string, + tokenBudget: TokenBudget, + chunks: ChunkNode[] + ): Promise { + // Collect all text from the tree + const allText = this.collectAllText(root); + await this.chunkText(allText, [root.nodeId], doc_id, tokenBudget, chunks, root.nodeId); + } + + /** + * Collect all text from a node and its descendants + */ + private collectAllText(node: DocumentNode): string { + const texts: string[] = []; + + const traverse = (n: DocumentNode) => { + if (!hasChildren(n)) { + texts.push(n.text); + } else { + for (const child of getChildren(n)) { + traverse(child); + } + } + }; + + traverse(node); + return texts.join("\n\n"); + } +} + +TaskRegistry.registerTask(HierarchicalChunkerTask); + +export const hierarchicalChunker = ( + input: HierarchicalChunkerTaskInput, + config?: JobQueueTaskConfig +) => { + return new HierarchicalChunkerTask({} as HierarchicalChunkerTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + hierarchicalChunker: CreateWorkflow< + HierarchicalChunkerTaskInput, + HierarchicalChunkerTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.hierarchicalChunker = CreateWorkflow(HierarchicalChunkerTask); diff --git a/packages/ai/src/task/HierarchyJoinTask.ts b/packages/ai/src/task/HierarchyJoinTask.ts new file mode 100644 index 00000000..3f55eca7 --- /dev/null +++ b/packages/ai/src/task/HierarchyJoinTask.ts @@ -0,0 +1,247 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { DocumentRepository } from "@workglow/storage"; +import { + type ChunkMetadata, + ChunkMetadataArraySchema, + EnrichedChunkMetadataArraySchema, +} from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + documentRepository: { + title: "Document Repository", + description: "The document repository to query for hierarchy", + }, + chunks: { + type: "array", + items: { type: "string" }, + title: "Chunks", + description: "Retrieved text chunks", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "Chunk IDs", + description: "IDs of retrieved chunks", + }, + metadata: ChunkMetadataArraySchema, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Similarity scores for each result", + }, + includeParentSummaries: { + type: "boolean", + title: "Include Parent Summaries", + description: "Whether to include summaries from parent nodes", + default: true, + }, + includeEntities: { + type: "boolean", + title: "Include Entities", + description: "Whether to include entities from the node hierarchy", + default: true, + }, + }, + required: ["documentRepository", "chunks", "ids", "metadata", "scores"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Chunks", + description: "Retrieved text chunks", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "Chunk IDs", + description: "IDs of retrieved chunks", + }, + metadata: EnrichedChunkMetadataArraySchema, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Similarity scores", + }, + count: { + type: "number", + title: "Count", + description: "Number of results", + }, + }, + required: ["chunks", "ids", "metadata", "scores", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type HierarchyJoinTaskInput = FromSchema; +export type HierarchyJoinTaskOutput = FromSchema; + +/** + * Task for enriching search results with hierarchy information + * Joins chunk IDs back to document repository to get parent summaries and entities + */ +export class HierarchyJoinTask extends Task< + HierarchyJoinTaskInput, + HierarchyJoinTaskOutput, + JobQueueTaskConfig +> { + public static type = "HierarchyJoinTask"; + public static category = "RAG"; + public static title = "Hierarchy Join"; + public static description = "Enrich search results with document hierarchy context"; + public static cacheable = false; // Has external dependency + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: HierarchyJoinTaskInput, + context: IExecuteContext + ): Promise { + const { + documentRepository, + chunks, + ids, + metadata, + scores, + includeParentSummaries = true, + includeEntities = true, + } = input; + + const repo = documentRepository as DocumentRepository; + const enrichedMetadata: any[] = []; + + for (let i = 0; i < ids.length; i++) { + const chunkId = ids[i]; + const originalMetadata: ChunkMetadata | undefined = metadata[i]; + + if (!originalMetadata) { + // Skip if metadata is missing + enrichedMetadata.push({} as ChunkMetadata); + continue; + } + + // Extract doc_id and nodeId from metadata + const doc_id = originalMetadata.doc_id; + const leafNodeId = originalMetadata.leafNodeId; + + if (!doc_id || !leafNodeId) { + // Can't enrich without IDs + enrichedMetadata.push(originalMetadata); + continue; + } + + try { + // Get ancestors from document repository + const ancestors = await repo.getAncestors(doc_id, leafNodeId); + + const enriched: any = { ...originalMetadata }; + + // Add parent summaries + if (includeParentSummaries && ancestors.length > 0) { + const parentSummaries: string[] = []; + const sectionTitles: string[] = []; + + for (const ancestor of ancestors) { + if (ancestor.enrichment?.summary) { + parentSummaries.push(ancestor.enrichment.summary); + } + if (ancestor.kind === "section" && (ancestor as any).title) { + sectionTitles.push((ancestor as any).title); + } + } + + if (parentSummaries.length > 0) { + enriched.parentSummaries = parentSummaries; + } + if (sectionTitles.length > 0) { + enriched.sectionTitles = sectionTitles; + } + } + + // Add entities (rolled up from ancestors) + if (includeEntities && ancestors.length > 0) { + const allEntities: any[] = []; + + for (const ancestor of ancestors) { + if (ancestor.enrichment?.entities) { + allEntities.push(...ancestor.enrichment.entities); + } + } + + // Deduplicate entities + const uniqueEntities = new Map(); + for (const entity of allEntities) { + const existing = uniqueEntities.get(entity.text); + if (!existing || entity.score > existing.score) { + uniqueEntities.set(entity.text, entity); + } + } + + if (uniqueEntities.size > 0) { + enriched.entities = Array.from(uniqueEntities.values()); + } + } + + enrichedMetadata.push(enriched); + } catch (error) { + // If join fails, keep original metadata + console.error(`Failed to join hierarchy for chunk ${chunkId}:`, error); + enrichedMetadata.push(originalMetadata); + } + } + + return { + chunks, + ids, + metadata: enrichedMetadata, + scores, + count: chunks.length, + }; + } +} + +TaskRegistry.registerTask(HierarchyJoinTask); + +export const hierarchyJoin = (input: HierarchyJoinTaskInput, config?: JobQueueTaskConfig) => { + return new HierarchyJoinTask({} as HierarchyJoinTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + hierarchyJoin: CreateWorkflow< + HierarchyJoinTaskInput, + HierarchyJoinTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.hierarchyJoin = CreateWorkflow(HierarchyJoinTask); diff --git a/packages/ai/src/task/QueryExpanderTask.ts b/packages/ai/src/task/QueryExpanderTask.ts new file mode 100644 index 00000000..b3804b19 --- /dev/null +++ b/packages/ai/src/task/QueryExpanderTask.ts @@ -0,0 +1,318 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +export const QueryExpansionMethod = { + MULTI_QUERY: "multi-query", + HYDE: "hyde", + SYNONYMS: "synonyms", + PARAPHRASE: "paraphrase", +} as const; + +export type QueryExpansionMethod = (typeof QueryExpansionMethod)[keyof typeof QueryExpansionMethod]; + +const inputSchema = { + type: "object", + properties: { + query: { + type: "string", + title: "Query", + description: "The original query to expand", + }, + method: { + type: "string", + enum: Object.values(QueryExpansionMethod), + title: "Expansion Method", + description: "Method to use for query expansion", + default: QueryExpansionMethod.MULTI_QUERY, + }, + numVariations: { + type: "number", + title: "Number of Variations", + description: "Number of query variations to generate", + minimum: 1, + maximum: 10, + default: 3, + }, + model: { + type: "string", + title: "Model", + description: "LLM model to use for expansion (for HyDE and paraphrase methods)", + }, + }, + required: ["query"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + queries: { + type: "array", + items: { type: "string" }, + title: "Expanded Queries", + description: "Generated query variations", + }, + originalQuery: { + type: "string", + title: "Original Query", + description: "The original input query", + }, + method: { + type: "string", + title: "Method Used", + description: "The expansion method that was used", + }, + count: { + type: "number", + title: "Count", + description: "Number of queries generated", + }, + }, + required: ["queries", "originalQuery", "method", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type QueryExpanderTaskInput = FromSchema; +export type QueryExpanderTaskOutput = FromSchema; + +/** + * Task for expanding queries to improve retrieval coverage. + * Supports multiple expansion methods including multi-query, HyDE, and paraphrasing. + * + * Note: HyDE and paraphrase methods require an LLM model. + * For now, this implements simple rule-based expansion. + */ +export class QueryExpanderTask extends Task< + QueryExpanderTaskInput, + QueryExpanderTaskOutput, + JobQueueTaskConfig +> { + public static type = "QueryExpanderTask"; + public static category = "RAG"; + public static title = "Query Expander"; + public static description = "Expand queries to improve retrieval coverage"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: QueryExpanderTaskInput, + context: IExecuteContext + ): Promise { + const { query, method = QueryExpansionMethod.MULTI_QUERY, numVariations = 3 } = input; + + let queries: string[]; + + switch (method) { + case QueryExpansionMethod.HYDE: + queries = this.hydeExpansion(query, numVariations); + break; + case QueryExpansionMethod.SYNONYMS: + queries = this.synonymExpansion(query, numVariations); + break; + case QueryExpansionMethod.PARAPHRASE: + queries = this.paraphraseExpansion(query, numVariations); + break; + case QueryExpansionMethod.MULTI_QUERY: + default: + queries = this.multiQueryExpansion(query, numVariations); + break; + } + + // Always include original query + if (!queries.includes(query)) { + queries.unshift(query); + } + + return { + queries, + originalQuery: query, + method, + count: queries.length, + }; + } + + /** + * Multi-query expansion: Generate variations by rephrasing the question + */ + private multiQueryExpansion(query: string, numVariations: number): string[] { + const queries: string[] = [query]; + + // Simple rule-based variations + const variations: string[] = []; + + // Question word variations + if (query.toLowerCase().startsWith("what")) { + variations.push(query.replace(/^what/i, "Which")); + variations.push(query.replace(/^what/i, "Can you explain")); + } else if (query.toLowerCase().startsWith("how")) { + variations.push(query.replace(/^how/i, "What is the method to")); + variations.push(query.replace(/^how/i, "In what way")); + } else if (query.toLowerCase().startsWith("why")) { + variations.push(query.replace(/^why/i, "What is the reason")); + variations.push(query.replace(/^why/i, "For what purpose")); + } else if (query.toLowerCase().startsWith("where")) { + variations.push(query.replace(/^where/i, "In which location")); + variations.push(query.replace(/^where/i, "At what place")); + } + + // Add "Tell me about" variation + if (!query.toLowerCase().startsWith("tell me")) { + variations.push(`Tell me about ${query.toLowerCase()}`); + } + + // Add "Explain" variation + if (!query.toLowerCase().startsWith("explain")) { + variations.push(`Explain ${query.toLowerCase()}`); + } + + // Take up to numVariations + for (let i = 0; i < Math.min(numVariations - 1, variations.length); i++) { + if (variations[i] && !queries.includes(variations[i])) { + queries.push(variations[i]); + } + } + + return queries; + } + + /** + * HyDE (Hypothetical Document Embeddings): Generate hypothetical answers + */ + private hydeExpansion(query: string, numVariations: number): string[] { + // TODO: in a real implementation, this would call a model to generate hypothetical answer templates + const queries: string[] = [query]; + + const templates = [ + `The answer to "${query}" is that`, + `Regarding ${query}, it is important to note that`, + `${query} can be explained by the fact that`, + `In response to ${query}, one should consider that`, + ]; + + for (let i = 0; i < Math.min(numVariations - 1, templates.length); i++) { + queries.push(templates[i]); + } + + return queries; + } + + /** + * Synonym expansion: Replace keywords with synonyms + */ + private synonymExpansion(query: string, numVariations: number): string[] { + const queries: string[] = [query]; + + // Simple synonym dictionary (in production, use a proper thesaurus) + const synonyms: Record = { + find: ["locate", "discover", "search for"], + create: ["make", "build", "generate"], + delete: ["remove", "erase", "eliminate"], + update: ["modify", "change", "edit"], + show: ["display", "present", "reveal"], + explain: ["describe", "clarify", "elaborate"], + help: ["assist", "aid", "support"], + problem: ["issue", "challenge", "difficulty"], + solution: ["answer", "resolution", "fix"], + method: ["approach", "technique", "way"], + }; + + const words = query.toLowerCase().split(/\s+/); + let variationsGenerated = 0; + + for (const [word, syns] of Object.entries(synonyms)) { + if (variationsGenerated >= numVariations - 1) break; + + const wordIndex = words.indexOf(word); + if (wordIndex !== -1) { + for (const syn of syns) { + if (variationsGenerated >= numVariations - 1) break; + + const newWords = [...words]; + newWords[wordIndex] = syn; + const newQuery = newWords.join(" "); + + // Preserve original capitalization pattern + const capitalizedQuery = this.preserveCapitalization(query, newQuery); + if (!queries.includes(capitalizedQuery)) { + queries.push(capitalizedQuery); + variationsGenerated++; + } + } + } + } + + return queries; + } + + /** + * Paraphrase expansion: Rephrase the query + * TODO: This should use an LLM for better paraphrasing + */ + private paraphraseExpansion(query: string, numVariations: number): string[] { + const queries: string[] = [query]; + + // Simple paraphrase templates + const paraphrases: string[] = []; + + // Add context + paraphrases.push(`I need information about ${query.toLowerCase()}`); + paraphrases.push(`Can you help me understand ${query.toLowerCase()}`); + paraphrases.push(`I'm looking for details on ${query.toLowerCase()}`); + + for (let i = 0; i < Math.min(numVariations - 1, paraphrases.length); i++) { + if (!queries.includes(paraphrases[i])) { + queries.push(paraphrases[i]); + } + } + + return queries; + } + + /** + * Preserve capitalization pattern from original to new query + */ + private preserveCapitalization(original: string, modified: string): string { + if (original[0] === original[0].toUpperCase()) { + return modified.charAt(0).toUpperCase() + modified.slice(1); + } + return modified; + } +} + +TaskRegistry.registerTask(QueryExpanderTask); + +export const queryExpander = (input: QueryExpanderTaskInput, config?: JobQueueTaskConfig) => { + return new QueryExpanderTask({} as QueryExpanderTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + queryExpander: CreateWorkflow< + QueryExpanderTaskInput, + QueryExpanderTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.queryExpander = CreateWorkflow(QueryExpanderTask); diff --git a/packages/ai/src/task/RerankerTask.ts b/packages/ai/src/task/RerankerTask.ts new file mode 100644 index 00000000..bcd5d8c0 --- /dev/null +++ b/packages/ai/src/task/RerankerTask.ts @@ -0,0 +1,341 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; +import { TextClassificationTask } from "./TextClassificationTask"; + +const inputSchema = { + type: "object", + properties: { + query: { + type: "string", + title: "Query", + description: "The query to rerank results against", + }, + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "Retrieved text chunks to rerank", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Initial Scores", + description: "Initial retrieval scores (optional)", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata for each chunk", + }, + title: "Metadata", + description: "Metadata for each chunk (optional)", + }, + topK: { + type: "number", + title: "Top K", + description: "Number of top results to return after reranking", + minimum: 1, + }, + method: { + type: "string", + enum: ["cross-encoder", "reciprocal-rank-fusion", "simple"], + title: "Reranking Method", + description: "Method to use for reranking", + default: "simple", + }, + model: { + type: "string", + title: "Reranker Model", + description: "Cross-encoder model to use for reranking", + }, + }, + required: ["query", "chunks"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Reranked Chunks", + description: "Chunks reordered by relevance", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Reranked Scores", + description: "New relevance scores", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata for each chunk", + }, + title: "Metadata", + description: "Metadata for reranked chunks", + }, + originalIndices: { + type: "array", + items: { type: "number" }, + title: "Original Indices", + description: "Original indices of reranked chunks", + }, + count: { + type: "number", + title: "Count", + description: "Number of results returned", + }, + }, + required: ["chunks", "scores", "originalIndices", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type RerankerTaskInput = FromSchema; +export type RerankerTaskOutput = FromSchema; + +interface RankedItem { + chunk: string; + score: number; + metadata?: any; + originalIndex: number; +} + +/** + * Task for reranking retrieved chunks to improve relevance. + * Supports multiple reranking methods including cross-encoder models. + * + * Note: Cross-encoder reranking requires a model to be loaded. + * For now, this implements simple heuristic-based reranking. + */ +export class RerankerTask extends Task { + public static type = "RerankerTask"; + public static category = "RAG"; + public static title = "Reranker"; + public static description = "Rerank retrieved chunks to improve relevance"; + public static cacheable = true; + private resolvedCrossEncoderModel?: string | null; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute(input: RerankerTaskInput, context: IExecuteContext): Promise { + const { query, chunks, scores = [], metadata = [], topK, method = "simple", model } = input; + + let rankedItems: RankedItem[]; + + switch (method) { + case "cross-encoder": + rankedItems = await this.crossEncoderRerank( + query, + chunks, + scores, + metadata, + model, + context + ); + break; + case "reciprocal-rank-fusion": + rankedItems = this.reciprocalRankFusion(chunks, scores, metadata); + break; + case "simple": + default: + rankedItems = this.simpleRerank(query, chunks, scores, metadata); + break; + } + + // Apply topK if specified + if (topK && topK < rankedItems.length) { + rankedItems = rankedItems.slice(0, topK); + } + + return { + chunks: rankedItems.map((item) => item.chunk), + scores: rankedItems.map((item) => item.score), + metadata: rankedItems.map((item) => item.metadata), + originalIndices: rankedItems.map((item) => item.originalIndex), + count: rankedItems.length, + }; + } + + private async crossEncoderRerank( + query: string, + chunks: string[], + scores: number[], + metadata: any[], + model: string | undefined, + context: IExecuteContext + ): Promise { + if (chunks.length === 0) { + return []; + } + + if (!model) { + throw new Error( + "No cross-encoder model found. Please provide a model or register a TextClassificationTask model." + ); + } + + const items = await Promise.all( + chunks.map(async (chunk, index) => { + const pairText = `${query} [SEP] ${chunk}`; + const task = context.own( + new TextClassificationTask({ text: pairText, model: model, maxCategories: 2 }) + ); + const result = await task.run(); + const crossScore = this.extractCrossEncoderScore(result.categories); + return { + chunk, + score: Number.isFinite(crossScore) ? crossScore : scores[index] || 0, + metadata: metadata[index], + originalIndex: index, + }; + }) + ); + + items.sort((a, b) => b.score - a.score); + return items; + } + + private extractCrossEncoderScore( + categories: Array<{ label: string; score: number }> | undefined + ): number { + if (!categories || categories.length === 0) { + return 0; + } + const preferred = categories.find((category) => + /^(label_1|positive|relevant|yes|true)$/i.test(category.label) + ); + if (preferred) { + return preferred.score; + } + let best = categories[0].score; + for (let i = 1; i < categories.length; i++) { + if (categories[i].score > best) { + best = categories[i].score; + } + } + return best; + } + + /** + * Simple heuristic-based reranking using keyword matching and position + */ + private simpleRerank( + query: string, + chunks: string[], + scores: number[], + metadata: any[] + ): RankedItem[] { + const queryLower = query.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + + const items: RankedItem[] = chunks.map((chunk, index) => { + const chunkLower = chunk.toLowerCase(); + const initialScore = scores[index] || 0; + + // Calculate keyword match score + let keywordScore = 0; + let exactMatchBonus = 0; + + for (const word of queryWords) { + // Count occurrences + const regex = new RegExp(word, "gi"); + const matches = chunkLower.match(regex); + if (matches) { + keywordScore += matches.length; + } + } + + // Bonus for exact query match + if (chunkLower.includes(queryLower)) { + exactMatchBonus = 0.5; + } + + // Normalize keyword score + const normalizedKeywordScore = Math.min(keywordScore / (queryWords.length * 3), 1); + + // Position penalty (prefer earlier results, but not too heavily) + const positionPenalty = Math.log(index + 1) / 10; + + // Combined score + const combinedScore = + initialScore * 0.4 + normalizedKeywordScore * 0.4 + exactMatchBonus * 0.2 - positionPenalty; + + return { + chunk, + score: combinedScore, + metadata: metadata[index], + originalIndex: index, + }; + }); + + // Sort by score descending + items.sort((a, b) => b.score - a.score); + + return items; + } + + /** + * Reciprocal Rank Fusion for combining multiple rankings + * Useful when you have multiple retrieval methods + */ + private reciprocalRankFusion(chunks: string[], scores: number[], metadata: any[]): RankedItem[] { + const k = 60; // RRF constant + + const items: RankedItem[] = chunks.map((chunk, index) => { + // RRF score = 1 / (k + rank) + // Here we use the initial ranking (index) as the rank + const rrfScore = 1 / (k + index + 1); + + return { + chunk, + score: rrfScore, + metadata: metadata[index], + originalIndex: index, + }; + }); + + // Sort by RRF score descending + items.sort((a, b) => b.score - a.score); + + return items; + } +} + +TaskRegistry.registerTask(RerankerTask); + +export const reranker = (input: RerankerTaskInput, config?: JobQueueTaskConfig) => { + return new RerankerTask({} as RerankerTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + reranker: CreateWorkflow; + } +} + +Workflow.prototype.reranker = CreateWorkflow(RerankerTask); diff --git a/packages/ai/src/task/StructuralParserTask.ts b/packages/ai/src/task/StructuralParserTask.ts new file mode 100644 index 00000000..bcfc9617 --- /dev/null +++ b/packages/ai/src/task/StructuralParserTask.ts @@ -0,0 +1,159 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { DocumentNode, NodeIdGenerator, StructuralParser } from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + text: { + type: "string", + title: "Text", + description: "The text content to parse", + }, + title: { + type: "string", + title: "Title", + description: "Document title", + }, + format: { + type: "string", + enum: ["markdown", "text", "auto"], + title: "Format", + description: "Document format (auto-detects if not specified)", + default: "auto", + }, + sourceUri: { + type: "string", + title: "Source URI", + description: "Source identifier for document ID generation", + }, + doc_id: { + type: "string", + title: "Document ID", + description: "Pre-generated document ID (optional)", + }, + }, + required: ["text", "title"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + doc_id: { + type: "string", + title: "Document ID", + description: "Generated or provided document ID", + }, + documentTree: { + title: "Document Tree", + description: "Parsed hierarchical document tree", + }, + nodeCount: { + type: "number", + title: "Node Count", + description: "Total number of nodes in the tree", + }, + }, + required: ["doc_id", "documentTree", "nodeCount"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type StructuralParserTaskInput = FromSchema; +export type StructuralParserTaskOutput = FromSchema; + +/** + * Task for parsing documents into hierarchical tree structure + * Supports markdown and plain text with automatic format detection + */ +export class StructuralParserTask extends Task< + StructuralParserTaskInput, + StructuralParserTaskOutput, + JobQueueTaskConfig +> { + public static type = "StructuralParserTask"; + public static category = "Document"; + public static title = "Structural Parser"; + public static description = "Parse documents into hierarchical tree structure"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: StructuralParserTaskInput, + context: IExecuteContext + ): Promise { + const { text, title, format = "auto", sourceUri, doc_id: providedDocId } = input; + + // Generate or use provided doc_id + const doc_id = + providedDocId || (await NodeIdGenerator.generateDocId(sourceUri || "document", text)); + + // Parse based on format + let documentTree: DocumentNode; + if (format === "markdown") { + documentTree = await StructuralParser.parseMarkdown(doc_id, text, title); + } else if (format === "text") { + documentTree = await StructuralParser.parsePlainText(doc_id, text, title); + } else { + // Auto-detect + documentTree = await StructuralParser.parse(doc_id, text, title); + } + + // Count nodes + const nodeCount = this.countNodes(documentTree); + + return { + doc_id, + documentTree, + nodeCount, + }; + } + + private countNodes(node: any): number { + let count = 1; + if (node.children && Array.isArray(node.children)) { + for (const child of node.children) { + count += this.countNodes(child); + } + } + return count; + } +} + +TaskRegistry.registerTask(StructuralParserTask); + +export const structuralParser = (input: StructuralParserTaskInput, config?: JobQueueTaskConfig) => { + return new StructuralParserTask({} as StructuralParserTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + structuralParser: CreateWorkflow< + StructuralParserTaskInput, + StructuralParserTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.structuralParser = CreateWorkflow(StructuralParserTask); diff --git a/packages/ai/src/task/TextChunkerTask.ts b/packages/ai/src/task/TextChunkerTask.ts new file mode 100644 index 00000000..99cce8f9 --- /dev/null +++ b/packages/ai/src/task/TextChunkerTask.ts @@ -0,0 +1,358 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +export const ChunkingStrategy = { + FIXED: "fixed", + SENTENCE: "sentence", + PARAGRAPH: "paragraph", + SEMANTIC: "semantic", +} as const; + +export type ChunkingStrategy = (typeof ChunkingStrategy)[keyof typeof ChunkingStrategy]; + +const inputSchema = { + type: "object", + properties: { + text: { + type: "string", + title: "Text", + description: "The text to chunk", + }, + chunkSize: { + type: "number", + title: "Chunk Size", + description: "Maximum size of each chunk in characters", + minimum: 1, + default: 512, + }, + chunkOverlap: { + type: "number", + title: "Chunk Overlap", + description: "Number of characters to overlap between chunks", + minimum: 0, + default: 50, + }, + strategy: { + type: "string", + enum: Object.values(ChunkingStrategy), + title: "Chunking Strategy", + description: "Strategy to use for chunking text", + default: ChunkingStrategy.FIXED, + }, + }, + required: ["text"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "The chunked text segments", + }, + metadata: { + type: "array", + items: { + type: "object", + properties: { + index: { type: "number" }, + startChar: { type: "number" }, + endChar: { type: "number" }, + length: { type: "number" }, + }, + additionalProperties: false, + }, + title: "Chunk Metadata", + description: "Metadata for each chunk", + }, + }, + required: ["chunks", "metadata"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type TextChunkerTaskInput = FromSchema; +export type TextChunkerTaskOutput = FromSchema; + +interface ChunkMetadata { + index: number; + startChar: number; + endChar: number; + length: number; +} + +/** + * Task for chunking text into smaller segments with configurable strategies. + * Supports fixed-size, sentence-based, paragraph-based, and semantic chunking. + */ +export class TextChunkerTask extends Task< + TextChunkerTaskInput, + TextChunkerTaskOutput, + JobQueueTaskConfig +> { + public static type = "TextChunkerTask"; + public static category = "Document"; + public static title = "Text Chunker"; + public static description = + "Splits text into chunks using various strategies (fixed, sentence, paragraph)"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: TextChunkerTaskInput, + context: IExecuteContext + ): Promise { + const { text, chunkSize = 512, chunkOverlap = 50, strategy = ChunkingStrategy.FIXED } = input; + + let chunks: string[]; + let metadata: ChunkMetadata[]; + + switch (strategy) { + case ChunkingStrategy.SENTENCE: + ({ chunks, metadata } = this.chunkBySentence(text, chunkSize, chunkOverlap)); + break; + case ChunkingStrategy.PARAGRAPH: + ({ chunks, metadata } = this.chunkByParagraph(text, chunkSize, chunkOverlap)); + break; + case ChunkingStrategy.SEMANTIC: + // For now, semantic is the same as sentence-based + // TODO: Implement true semantic chunking with embeddings + ({ chunks, metadata } = this.chunkBySentence(text, chunkSize, chunkOverlap)); + break; + case ChunkingStrategy.FIXED: + default: + ({ chunks, metadata } = this.chunkFixed(text, chunkSize, chunkOverlap)); + break; + } + + return { chunks, metadata }; + } + + /** + * Fixed-size chunking with overlap + */ + private chunkFixed( + text: string, + chunkSize: number, + chunkOverlap: number + ): { chunks: string[]; metadata: ChunkMetadata[] } { + const chunks: string[] = []; + const metadata: ChunkMetadata[] = []; + let startChar = 0; + let index = 0; + + while (startChar < text.length) { + const endChar = Math.min(startChar + chunkSize, text.length); + const chunk = text.substring(startChar, endChar); + chunks.push(chunk); + metadata.push({ + index, + startChar, + endChar, + length: chunk.length, + }); + + // Move forward by chunkSize - chunkOverlap, but at least 1 character to prevent infinite loop + const step = Math.max(1, chunkSize - chunkOverlap); + startChar += step; + index++; + } + + return { chunks, metadata }; + } + + /** + * Sentence-based chunking that respects sentence boundaries + */ + private chunkBySentence( + text: string, + chunkSize: number, + chunkOverlap: number + ): { chunks: string[]; metadata: ChunkMetadata[] } { + // Split by sentence boundaries (., !, ?, followed by space or newline) + const sentenceRegex = /[.!?]+[\s\n]+/g; + const sentences: string[] = []; + const sentenceStarts: number[] = []; + let lastIndex = 0; + let match: RegExpExecArray | null; + + while ((match = sentenceRegex.exec(text)) !== null) { + const sentence = text.substring(lastIndex, match.index + match[0].length); + sentences.push(sentence); + sentenceStarts.push(lastIndex); + lastIndex = match.index + match[0].length; + } + + // Add remaining text as last sentence + if (lastIndex < text.length) { + sentences.push(text.substring(lastIndex)); + sentenceStarts.push(lastIndex); + } + + // Group sentences into chunks + const chunks: string[] = []; + const metadata: ChunkMetadata[] = []; + let currentChunk = ""; + let currentStartChar = 0; + let index = 0; + + for (let i = 0; i < sentences.length; i++) { + const sentence = sentences[i]; + const sentenceStart = sentenceStarts[i]; + + // If adding this sentence would exceed chunkSize, save current chunk + if (currentChunk.length > 0 && currentChunk.length + sentence.length > chunkSize) { + chunks.push(currentChunk.trim()); + metadata.push({ + index, + startChar: currentStartChar, + endChar: currentStartChar + currentChunk.length, + length: currentChunk.trim().length, + }); + index++; + + // Start new chunk with overlap + if (chunkOverlap > 0) { + // Find sentences to include in overlap + let overlapText = ""; + let j = i - 1; + while (j >= 0 && overlapText.length < chunkOverlap) { + overlapText = sentences[j] + overlapText; + j--; + } + currentChunk = overlapText + sentence; + currentStartChar = sentenceStarts[Math.max(0, j + 1)]; + } else { + currentChunk = sentence; + currentStartChar = sentenceStart; + } + } else { + if (currentChunk.length === 0) { + currentStartChar = sentenceStart; + } + currentChunk += sentence; + } + } + + // Add final chunk + if (currentChunk.length > 0) { + chunks.push(currentChunk.trim()); + metadata.push({ + index, + startChar: currentStartChar, + endChar: currentStartChar + currentChunk.length, + length: currentChunk.trim().length, + }); + } + + return { chunks, metadata }; + } + + /** + * Paragraph-based chunking that respects paragraph boundaries + */ + private chunkByParagraph( + text: string, + chunkSize: number, + chunkOverlap: number + ): { chunks: string[]; metadata: ChunkMetadata[] } { + // Split by paragraph boundaries (double newline or more) + const paragraphs = text.split(/\n\s*\n/).filter((p) => p.trim().length > 0); + const chunks: string[] = []; + const metadata: ChunkMetadata[] = []; + let currentChunk = ""; + let currentStartChar = 0; + let index = 0; + let charPosition = 0; + + for (let i = 0; i < paragraphs.length; i++) { + const paragraph = paragraphs[i].trim(); + const paragraphStart = text.indexOf(paragraph, charPosition); + charPosition = paragraphStart + paragraph.length; + + // If adding this paragraph would exceed chunkSize, save current chunk + if (currentChunk.length > 0 && currentChunk.length + paragraph.length + 2 > chunkSize) { + chunks.push(currentChunk.trim()); + metadata.push({ + index, + startChar: currentStartChar, + endChar: currentStartChar + currentChunk.length, + length: currentChunk.trim().length, + }); + index++; + + // Start new chunk with overlap + if (chunkOverlap > 0 && i > 0) { + // Include previous paragraph(s) for overlap + let overlapText = ""; + let j = i - 1; + while (j >= 0 && overlapText.length < chunkOverlap) { + overlapText = paragraphs[j].trim() + "\n\n" + overlapText; + j--; + } + currentChunk = overlapText + paragraph; + currentStartChar = paragraphStart - overlapText.length; + } else { + currentChunk = paragraph; + currentStartChar = paragraphStart; + } + } else { + if (currentChunk.length === 0) { + currentStartChar = paragraphStart; + currentChunk = paragraph; + } else { + currentChunk += "\n\n" + paragraph; + } + } + } + + // Add final chunk + if (currentChunk.length > 0) { + chunks.push(currentChunk.trim()); + metadata.push({ + index, + startChar: currentStartChar, + endChar: currentStartChar + currentChunk.length, + length: currentChunk.trim().length, + }); + } + + return { chunks, metadata }; + } +} + +TaskRegistry.registerTask(TextChunkerTask); + +export const textChunker = (input: TextChunkerTaskInput, config?: JobQueueTaskConfig) => { + return new TextChunkerTask({} as TextChunkerTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + textChunker: CreateWorkflow; + } +} + +Workflow.prototype.textChunker = CreateWorkflow(TextChunkerTask); diff --git a/packages/ai/src/task/TopicSegmenterTask.ts b/packages/ai/src/task/TopicSegmenterTask.ts new file mode 100644 index 00000000..415fc55e --- /dev/null +++ b/packages/ai/src/task/TopicSegmenterTask.ts @@ -0,0 +1,439 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +export const SegmentationMethod = { + HEURISTIC: "heuristic", + EMBEDDING_SIMILARITY: "embedding-similarity", + HYBRID: "hybrid", +} as const; + +export type SegmentationMethod = (typeof SegmentationMethod)[keyof typeof SegmentationMethod]; + +const inputSchema = { + type: "object", + properties: { + text: { + type: "string", + title: "Text", + description: "The text to segment into topics", + }, + method: { + type: "string", + enum: Object.values(SegmentationMethod), + title: "Segmentation Method", + description: "Method to use for topic segmentation", + default: SegmentationMethod.HEURISTIC, + }, + minSegmentSize: { + type: "number", + title: "Min Segment Size", + description: "Minimum segment size in characters", + minimum: 50, + default: 100, + }, + maxSegmentSize: { + type: "number", + title: "Max Segment Size", + description: "Maximum segment size in characters", + minimum: 100, + default: 2000, + }, + similarityThreshold: { + type: "number", + title: "Similarity Threshold", + description: "Threshold for embedding similarity (0-1, lower = more splits)", + minimum: 0, + maximum: 1, + default: 0.5, + }, + }, + required: ["text"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + segments: { + type: "array", + items: { + type: "object", + properties: { + text: { type: "string" }, + startOffset: { type: "number" }, + endOffset: { type: "number" }, + }, + required: ["text", "startOffset", "endOffset"], + additionalProperties: false, + }, + title: "Segments", + description: "Detected topic segments", + }, + count: { + type: "number", + title: "Count", + description: "Number of segments detected", + }, + }, + required: ["segments", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type TopicSegmenterTaskInput = FromSchema; +export type TopicSegmenterTaskOutput = FromSchema; + +/** + * Task for segmenting text into topic-based sections + * Uses hybrid approach: heuristics + optional embedding similarity + */ +export class TopicSegmenterTask extends Task< + TopicSegmenterTaskInput, + TopicSegmenterTaskOutput, + JobQueueTaskConfig +> { + public static type = "TopicSegmenterTask"; + public static category = "Document"; + public static title = "Topic Segmenter"; + public static description = "Segment text into topic-based sections using hybrid approach"; + public static cacheable = true; + private static readonly EMBEDDING_DIMENSIONS = 256; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: TopicSegmenterTaskInput, + context: IExecuteContext + ): Promise { + const { + text, + method = SegmentationMethod.HEURISTIC, + minSegmentSize = 100, + maxSegmentSize = 2000, + similarityThreshold = 0.5, + } = input; + + let segments: Array<{ text: string; startOffset: number; endOffset: number }>; + + switch (method) { + case SegmentationMethod.EMBEDDING_SIMILARITY: + segments = this.embeddingSegmentation( + text, + minSegmentSize, + maxSegmentSize, + similarityThreshold + ); + break; + case SegmentationMethod.HYBRID: + // Start with heuristic, optionally refine with embeddings + segments = this.heuristicSegmentation(text, minSegmentSize, maxSegmentSize); + // TODO: Add embedding refinement step + break; + case SegmentationMethod.HEURISTIC: + default: + segments = this.heuristicSegmentation(text, minSegmentSize, maxSegmentSize); + break; + } + + return { + segments, + count: segments.length, + }; + } + + /** + * Embedding-based segmentation using hashed token vectors and cosine similarity + */ + private embeddingSegmentation( + text: string, + minSegmentSize: number, + maxSegmentSize: number, + similarityThreshold: number + ): Array<{ text: string; startOffset: number; endOffset: number }> { + const paragraphs = this.splitIntoParagraphs(text); + if (paragraphs.length === 0) { + return []; + } + + const embeddings = paragraphs.map((p) => + this.embedParagraph(p.text, TopicSegmenterTask.EMBEDDING_DIMENSIONS) + ); + + const segments: Array<{ text: string; startOffset: number; endOffset: number }> = []; + let currentSegmentParagraphs: Array<{ text: string; offset: number }> = []; + let currentSegmentSize = 0; + + for (let i = 0; i < paragraphs.length; i++) { + const paragraph = paragraphs[i]; + const paragraphSize = paragraph.text.length; + const exceedsMax = + currentSegmentSize + paragraphSize > maxSegmentSize && currentSegmentSize >= minSegmentSize; + + let shouldSplit = false; + if (i > 0 && currentSegmentSize >= minSegmentSize) { + const prev = embeddings[i - 1]; + const curr = embeddings[i]; + const similarity = this.cosineSimilarityWithNorms( + prev.vector, + prev.norm, + curr.vector, + curr.norm + ); + shouldSplit = similarity < similarityThreshold; + } + + if ((exceedsMax || shouldSplit) && currentSegmentParagraphs.length > 0) { + segments.push(this.createSegment(currentSegmentParagraphs)); + currentSegmentParagraphs = []; + currentSegmentSize = 0; + } + + currentSegmentParagraphs.push(paragraph); + currentSegmentSize += paragraphSize; + } + + if (currentSegmentParagraphs.length > 0) { + segments.push(this.createSegment(currentSegmentParagraphs)); + } + + return this.mergeSmallSegments(segments, minSegmentSize); + } + + /** + * Heuristic segmentation based on paragraph breaks and transition markers + */ + private heuristicSegmentation( + text: string, + minSegmentSize: number, + maxSegmentSize: number + ): Array<{ text: string; startOffset: number; endOffset: number }> { + const segments: Array<{ text: string; startOffset: number; endOffset: number }> = []; + + // Split by double newlines (paragraph breaks) + const paragraphs = this.splitIntoParagraphs(text); + + let currentSegmentParagraphs: Array<{ text: string; offset: number }> = []; + let currentSegmentSize = 0; + + for (const paragraph of paragraphs) { + const paragraphSize = paragraph.text.length; + + // Check if adding this paragraph would exceed max size + if ( + currentSegmentSize + paragraphSize > maxSegmentSize && + currentSegmentSize >= minSegmentSize + ) { + // Flush current segment + if (currentSegmentParagraphs.length > 0) { + const segment = this.createSegment(currentSegmentParagraphs); + segments.push(segment); + currentSegmentParagraphs = []; + currentSegmentSize = 0; + } + } + + // Check for transition markers + const hasTransition = this.hasTransitionMarker(paragraph.text); + if ( + hasTransition && + currentSegmentSize >= minSegmentSize && + currentSegmentParagraphs.length > 0 + ) { + // Flush current segment before transition + const segment = this.createSegment(currentSegmentParagraphs); + segments.push(segment); + currentSegmentParagraphs = []; + currentSegmentSize = 0; + } + + currentSegmentParagraphs.push(paragraph); + currentSegmentSize += paragraphSize; + } + + // Flush remaining segment + if (currentSegmentParagraphs.length > 0) { + const segment = this.createSegment(currentSegmentParagraphs); + segments.push(segment); + } + + // Merge small segments + return this.mergeSmallSegments(segments, minSegmentSize); + } + + /** + * Create a hashed token embedding for fast similarity checks + */ + private embedParagraph(text: string, dimensions: number): { vector: Float32Array; norm: number } { + const vector = new Float32Array(dimensions); + const tokens = text.toLowerCase().match(/[a-z0-9]+/g); + if (!tokens) { + return { vector, norm: 0 }; + } + + for (const token of tokens) { + let hash = 2166136261; + for (let i = 0; i < token.length; i++) { + hash ^= token.charCodeAt(i); + hash = Math.imul(hash, 16777619); + } + const index = (hash >>> 0) % dimensions; + vector[index] += 1; + } + + let sumSquares = 0; + for (let i = 0; i < vector.length; i++) { + const value = vector[i]; + sumSquares += value * value; + } + + return { vector, norm: sumSquares > 0 ? Math.sqrt(sumSquares) : 0 }; + } + + private cosineSimilarityWithNorms( + a: Float32Array, + aNorm: number, + b: Float32Array, + bNorm: number + ): number { + if (aNorm === 0 || bNorm === 0) { + return 0; + } + + let dot = 0; + for (let i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + } + + return dot / (aNorm * bNorm); + } + + /** + * Split text into paragraphs with offsets + */ + private splitIntoParagraphs(text: string): Array<{ text: string; offset: number }> { + const paragraphs: Array<{ text: string; offset: number }> = []; + const splits = text.split(/\n\s*\n/); + + let currentOffset = 0; + for (const split of splits) { + const trimmed = split.trim(); + if (trimmed.length > 0) { + const offset = text.indexOf(trimmed, currentOffset); + paragraphs.push({ text: trimmed, offset }); + currentOffset = offset + trimmed.length; + } + } + + return paragraphs; + } + + /** + * Check if paragraph contains transition markers + */ + private hasTransitionMarker(text: string): boolean { + const transitionMarkers = [ + /^(however|therefore|thus|consequently|in conclusion|in summary|furthermore|moreover|additionally|meanwhile|nevertheless|on the other hand)/i, + /^(first|second|third|finally|lastly)/i, + /^\d+\./, // Numbered list + ]; + + return transitionMarkers.some((pattern) => pattern.test(text)); + } + + /** + * Create a segment from paragraphs + */ + private createSegment(paragraphs: Array<{ text: string; offset: number }>): { + text: string; + startOffset: number; + endOffset: number; + } { + const text = paragraphs.map((p) => p.text).join("\n\n"); + const startOffset = paragraphs[0].offset; + const endOffset = + paragraphs[paragraphs.length - 1].offset + paragraphs[paragraphs.length - 1].text.length; + + return { text, startOffset, endOffset }; + } + + /** + * Merge segments that are too small + */ + private mergeSmallSegments( + segments: Array<{ text: string; startOffset: number; endOffset: number }>, + minSegmentSize: number + ): Array<{ text: string; startOffset: number; endOffset: number }> { + if (segments.length <= 1) { + return segments; + } + + const merged: Array<{ text: string; startOffset: number; endOffset: number }> = []; + let i = 0; + + while (i < segments.length) { + const current = segments[i]; + + if (current.text.length < minSegmentSize && i + 1 < segments.length) { + // Merge with next + const next = segments[i + 1]; + const mergedSegment = { + text: current.text + "\n\n" + next.text, + startOffset: current.startOffset, + endOffset: next.endOffset, + }; + merged.push(mergedSegment); + i += 2; + } else if (current.text.length < minSegmentSize && merged.length > 0) { + // Merge with previous + const previous = merged[merged.length - 1]; + merged[merged.length - 1] = { + text: previous.text + "\n\n" + current.text, + startOffset: previous.startOffset, + endOffset: current.endOffset, + }; + i++; + } else { + merged.push(current); + i++; + } + } + + return merged; + } +} + +TaskRegistry.registerTask(TopicSegmenterTask); + +export const topicSegmenter = (input: TopicSegmenterTaskInput, config?: JobQueueTaskConfig) => { + return new TopicSegmenterTask({} as TopicSegmenterTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + topicSegmenter: CreateWorkflow< + TopicSegmenterTaskInput, + TopicSegmenterTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.topicSegmenter = CreateWorkflow(TopicSegmenterTask); diff --git a/packages/ai/src/task/VectorQuantizeTask.ts b/packages/ai/src/task/VectorQuantizeTask.ts new file mode 100644 index 00000000..9ed102dc --- /dev/null +++ b/packages/ai/src/task/VectorQuantizeTask.ts @@ -0,0 +1,257 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + normalizeNumberArray, + TensorType, + TypedArray, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + vector: { + anyOf: [ + TypedArraySchema({ + title: "Vector", + description: "The vector to quantize", + }), + { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector to quantize", + }), + }, + ], + title: "Input Vector(s)", + description: "Vector or array of vectors to quantize", + }, + targetType: { + type: "string", + enum: Object.values(TensorType), + title: "Target Type", + description: "Target quantization type", + default: TensorType.INT8, + }, + normalize: { + type: "boolean", + title: "Normalize", + description: "Normalize vector before quantization", + default: true, + }, + }, + required: ["vector", "targetType"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + vector: { + anyOf: [ + TypedArraySchema({ + title: "Quantized Vector", + description: "The quantized vector", + }), + { + type: "array", + items: TypedArraySchema({ + title: "Quantized Vector", + description: "Quantized vector", + }), + }, + ], + title: "Output Vector(s)", + description: "Quantized vector or array of vectors", + }, + originalType: { + type: "string", + enum: Object.values(TensorType), + title: "Original Type", + description: "Original vector type", + }, + targetType: { + type: "string", + enum: Object.values(TensorType), + title: "Target Type", + description: "Target quantization type", + }, + }, + required: ["vector", "originalType", "targetType"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type VectorQuantizeTaskInput = FromSchema; +export type VectorQuantizeTaskOutput = FromSchema; + +/** + * Task for quantizing vectors to reduce storage and improve performance. + * Supports various quantization types including binary, int8, uint8, int16, uint16. + */ +export class VectorQuantizeTask extends Task< + VectorQuantizeTaskInput, + VectorQuantizeTaskOutput, + JobQueueTaskConfig +> { + public static type = "VectorQuantizeTask"; + public static category = "Vector Processing"; + public static title = "Quantize Vector"; + public static description = "Quantize vectors to reduce storage and improve performance"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async executeReactive(input: VectorQuantizeTaskInput): Promise { + const { vector, targetType, normalize = true } = input; + const isArray = Array.isArray(vector); + const vectors = isArray ? vector : [vector]; + const originalType = this.getVectorType(vectors[0]); + + const quantized = vectors.map((v) => this.vectorQuantize(v, targetType, normalize)); + + return { + vector: isArray ? quantized : quantized[0], + originalType, + targetType, + }; + } + + private getVectorType(vector: TypedArray): TensorType { + if (vector instanceof Float16Array) return TensorType.FLOAT16; + if (vector instanceof Float32Array) return TensorType.FLOAT32; + if (vector instanceof Float64Array) return TensorType.FLOAT64; + if (vector instanceof Int8Array) return TensorType.INT8; + if (vector instanceof Uint8Array) return TensorType.UINT8; + if (vector instanceof Int16Array) return TensorType.INT16; + if (vector instanceof Uint16Array) return TensorType.UINT16; + throw new Error(`Unknown vector type: ${typeof vector}`); + } + + private vectorQuantize( + vector: TypedArray, + targetType: TensorType, + normalize: boolean + ): TypedArray { + let values = Array.from(vector) as number[]; + + // Normalize if requested + if (normalize) { + values = normalizeNumberArray(values, false); + } + + switch (targetType) { + case TensorType.FLOAT16: + return new Float16Array(values); + + case TensorType.FLOAT32: + return new Float32Array(values); + + case TensorType.FLOAT64: + return new Float64Array(values); + + case TensorType.INT8: + return this.quantizeToInt8(values); + + case TensorType.UINT8: + return this.quantizeToUint8(values); + + case TensorType.INT16: + return this.quantizeToInt16(values); + + case TensorType.UINT16: + return this.quantizeToUint16(values); + + default: + return new Float32Array(values); + } + } + + /** + * Find min and max values in a single pass for better performance + */ + private findMinMax(values: number[]): { min: number; max: number } { + if (values.length === 0) { + return { min: 0, max: 1 }; + } + + let min = values[0]; + let max = values[0]; + + for (let i = 1; i < values.length; i++) { + const val = values[i]; + if (val < min) min = val; + if (val > max) max = val; + } + + return { min, max }; + } + + private quantizeToInt8(values: number[]): Int8Array { + // Assume values are in [-1, 1] range after normalization + // Scale to [-127, 127] to avoid overflow at -128 + return new Int8Array(values.map((v) => Math.round(Math.max(-1, Math.min(1, v)) * 127))); + } + + private quantizeToUint8(values: number[]): Uint8Array { + // Find min/max for scaling in a single pass + const { min, max } = this.findMinMax(values); + const range = max - min || 1; + + // Scale to [0, 255] + return new Uint8Array(values.map((v) => Math.round(((v - min) / range) * 255))); + } + + private quantizeToInt16(values: number[]): Int16Array { + // Assume values are in [-1, 1] range after normalization + // Scale to [-32767, 32767] + return new Int16Array(values.map((v) => Math.round(Math.max(-1, Math.min(1, v)) * 32767))); + } + + private quantizeToUint16(values: number[]): Uint16Array { + // Find min/max for scaling in a single pass + const { min, max } = this.findMinMax(values); + const range = max - min || 1; + + // Scale to [0, 65535] + return new Uint16Array(values.map((v) => Math.round(((v - min) / range) * 65535))); + } +} + +TaskRegistry.registerTask(VectorQuantizeTask); + +export const vectorQuantize = (input: VectorQuantizeTaskInput, config?: JobQueueTaskConfig) => { + return new VectorQuantizeTask({} as VectorQuantizeTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + vectorQuantize: CreateWorkflow< + VectorQuantizeTaskInput, + VectorQuantizeTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.vectorQuantize = CreateWorkflow(VectorQuantizeTask); diff --git a/packages/ai/src/task/index.ts b/packages/ai/src/task/index.ts index 91dbf459..a4cb0f38 100644 --- a/packages/ai/src/task/index.ts +++ b/packages/ai/src/task/index.ts @@ -7,18 +7,30 @@ export * from "./BackgroundRemovalTask"; export * from "./base/AiTask"; export * from "./base/AiTaskSchemas"; -export * from "./DocumentSplitterTask"; +export * from "./ChunkToVectorTask"; +export * from "./ContextBuilderTask"; +export * from "./DocumentEnricherTask"; +export * from "./DocumentNodeRetrievalTask"; +export * from "./DocumentNodeVectorHybridSearchTask"; +export * from "./DocumentNodeVectorSearchTask"; +export * from "./DocumentNodeVectorUpsertTask"; export * from "./DownloadModelTask"; export * from "./FaceDetectorTask"; export * from "./FaceLandmarkerTask"; export * from "./GestureRecognizerTask"; export * from "./HandLandmarkerTask"; +export * from "./HierarchicalChunkerTask"; +export * from "./HierarchyJoinTask"; export * from "./ImageClassificationTask"; export * from "./ImageEmbeddingTask"; export * from "./ImageSegmentationTask"; export * from "./ImageToTextTask"; export * from "./ObjectDetectionTask"; export * from "./PoseLandmarkerTask"; +export * from "./QueryExpanderTask"; +export * from "./RerankerTask"; +export * from "./StructuralParserTask"; +export * from "./TextChunkerTask"; export * from "./TextClassificationTask"; export * from "./TextEmbeddingTask"; export * from "./TextFillMaskTask"; @@ -29,5 +41,7 @@ export * from "./TextQuestionAnswerTask"; export * from "./TextRewriterTask"; export * from "./TextSummaryTask"; export * from "./TextTranslationTask"; +export * from "./TopicSegmenterTask"; export * from "./UnloadModelTask"; +export * from "./VectorQuantizeTask"; export * from "./VectorSimilarityTask"; diff --git a/packages/test/src/test/rag/ChunkToVector.test.ts b/packages/test/src/test/rag/ChunkToVector.test.ts new file mode 100644 index 00000000..7a236125 --- /dev/null +++ b/packages/test/src/test/rag/ChunkToVector.test.ts @@ -0,0 +1,124 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ChunkToVectorTaskOutput, HierarchicalChunkerTaskOutput } from "@workglow/ai"; +import { type ChunkNode, NodeIdGenerator, StructuralParser } from "@workglow/storage"; +import { Workflow } from "@workglow/task-graph"; +import { describe, expect, it } from "vitest"; + +describe("ChunkToVectorTask", () => { + it("should transform chunks and vectors to vector store format", async () => { + const markdown = "# Test\n\nContent."; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + // Generate chunks using workflow + const chunkResult = (await new Workflow() + .hierarchicalChunker({ + doc_id, + documentTree: root, + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .run()) as HierarchicalChunkerTaskOutput; + + // Mock vectors (would normally come from TextEmbeddingTask) + const mockVectors = chunkResult.chunks.map(() => new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5])); + + // Transform to vector store format using workflow + const result = (await new Workflow() + .chunkToVector({ + chunks: chunkResult.chunks as ChunkNode[], + vectors: mockVectors, + }) + .run()) as ChunkToVectorTaskOutput; + + // Verify output format + expect(result.ids).toBeDefined(); + expect(result.vectors).toBeDefined(); + expect(result.metadata).toBeDefined(); + expect(result.texts).toBeDefined(); + + expect(result.ids.length).toBe(chunkResult.count); + expect(result.vectors.length).toBe(chunkResult.count); + expect(result.metadata.length).toBe(chunkResult.count); + expect(result.texts.length).toBe(chunkResult.count); + + // Check metadata structure + for (let i = 0; i < result.metadata.length; i++) { + const meta = result.metadata[i]; + expect(meta.doc_id).toBe(doc_id); + expect(meta.chunkId).toBeDefined(); + expect(meta.leafNodeId).toBeDefined(); + expect(meta.depth).toBeDefined(); + expect(meta.text).toBeDefined(); + expect(meta.nodePath).toBeDefined(); + } + + // Verify IDs match chunks + for (let i = 0; i < result.ids.length; i++) { + expect(result.ids[i]).toBe(chunkResult.chunks[i].chunkId); + } + }); + + it("should throw error on length mismatch", async () => { + const chunks = [ + { + chunkId: "chunk_1", + doc_id: "doc_1", + text: "Test", + nodePath: ["node_1"], + depth: 1, + }, + { + chunkId: "chunk_2", + doc_id: "doc_1", + text: "Test 2", + nodePath: ["node_1"], + depth: 1, + }, + ]; + + const vectors = [new Float32Array([1, 2, 3])]; // Only 1 vector for 2 chunks + + // Using workflow + await expect(new Workflow().chunkToVector({ chunks, vectors }).run()).rejects.toThrow( + "Mismatch" + ); + }); + + it("should include enrichment in metadata if present", async () => { + const chunks = [ + { + chunkId: "chunk_1", + doc_id: "doc_1", + text: "Test", + nodePath: ["node_1"], + depth: 1, + enrichment: { + summary: "Test summary", + entities: [{ text: "Entity", type: "TEST", score: 0.9 }], + }, + }, + ]; + + const vectors = [new Float32Array([1, 2, 3])]; + + const result = (await new Workflow() + .chunkToVector({ chunks, vectors }) + .run()) as ChunkToVectorTaskOutput; + + const metadata = result.metadata as Array<{ + summary?: string; + entities?: Array<{ text: string; type: string; score: number }>; + [key: string]: unknown; + }>; + expect(metadata[0].summary).toBe("Test summary"); + expect(metadata[0].entities).toBeDefined(); + expect(metadata[0].entities!.length).toBe(1); + }); +}); diff --git a/packages/test/src/test/rag/ContextBuilderTask.test.ts b/packages/test/src/test/rag/ContextBuilderTask.test.ts new file mode 100644 index 00000000..b7282c31 --- /dev/null +++ b/packages/test/src/test/rag/ContextBuilderTask.test.ts @@ -0,0 +1,247 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ContextFormat, contextBuilder } from "@workglow/ai"; +import { describe, expect, test } from "vitest"; + +describe("ContextBuilderTask", () => { + const testChunks = [ + "First chunk of text about artificial intelligence.", + "Second chunk discussing machine learning algorithms.", + "Third chunk covering neural networks and deep learning.", + ]; + + const testMetadata = [ + { source: "doc1.txt", page: 1 }, + { source: "doc2.txt", page: 2 }, + { source: "doc3.txt", page: 3 }, + ]; + + const testScores = [0.95, 0.87, 0.82]; + + test("should format chunks with SIMPLE format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + }); + + expect(result.context).toBeDefined(); + expect(result.chunksUsed).toBe(3); + expect(result.totalLength).toBeGreaterThan(0); + expect(result.context).toContain(testChunks[0]); + expect(result.context).toContain(testChunks[1]); + expect(result.context).toContain(testChunks[2]); + }); + + test("should format chunks with NUMBERED format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + format: ContextFormat.NUMBERED, + }); + + expect(result.context).toContain("[1]"); + expect(result.context).toContain("[2]"); + expect(result.context).toContain("[3]"); + expect(result.context).toContain(testChunks[0]); + }); + + test("should format chunks with XML format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + format: ContextFormat.XML, + }); + + expect(result.context).toContain(""); + expect(result.context).toContain('id="1"'); + expect(result.context).toContain(testChunks[0]); + }); + + test("should format chunks with MARKDOWN format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + format: ContextFormat.MARKDOWN, + }); + + expect(result.context).toContain("### Chunk"); + expect(result.context).toContain("### Chunk 1"); + expect(result.context).toContain("### Chunk 2"); + expect(result.context).toContain(testChunks[0]); + }); + + test("should format chunks with JSON format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + format: ContextFormat.JSON, + }); + + // Should contain JSON objects + expect(result.context).toContain('"index"'); + expect(result.context).toContain('"content"'); + expect(result.context).toContain(testChunks[0]); + }); + + test("should include metadata when includeMetadata is true", async () => { + const result = await contextBuilder({ + chunks: testChunks, + metadata: testMetadata, + includeMetadata: true, + format: ContextFormat.NUMBERED, + }); + + expect(result.context).toContain("doc1.txt"); + expect(result.context).toContain("page"); + }); + + test("should include scores when provided and includeMetadata is true", async () => { + const result = await contextBuilder({ + chunks: testChunks, + metadata: testMetadata, + scores: testScores, + includeMetadata: true, + format: ContextFormat.NUMBERED, + }); + + // NUMBERED format includes scores in the formatNumbered method when includeMetadata is true + // The formatNumbered method uses formatMetadataInline which includes scores + expect(result.context).toContain("score="); + expect(result.context).toContain("0.95"); + }); + + test("should respect maxLength constraint", async () => { + const result = await contextBuilder({ + chunks: testChunks, + maxLength: 100, + }); + + expect(result.totalLength).toBeLessThanOrEqual(100); + expect(result.chunksUsed).toBeLessThanOrEqual(testChunks.length); + }); + + test("should use custom separator", async () => { + const separator = "---"; + const result = await contextBuilder({ + chunks: testChunks, + separator: separator, + }); + + // Should contain separator between chunks + const separatorCount = (result.context.match(new RegExp(separator, "g")) || []).length; + expect(separatorCount).toBeGreaterThan(0); + }); + + test("should handle empty chunks array", async () => { + const result = await contextBuilder({ + chunks: [], + }); + + expect(result.context).toBe(""); + expect(result.chunksUsed).toBe(0); + expect(result.totalLength).toBe(0); + }); + + test("should handle single chunk", async () => { + const singleChunk = ["Only one chunk"]; + const result = await contextBuilder({ + chunks: singleChunk, + }); + + expect(result.context).toBe(singleChunk[0]); + expect(result.chunksUsed).toBe(1); + expect(result.totalLength).toBe(singleChunk[0].length); + }); + + test("should handle chunks with mismatched metadata length", async () => { + const result = await contextBuilder({ + chunks: testChunks, + metadata: [testMetadata[0]], // Only one metadata entry + includeMetadata: true, + }); + + // Should handle gracefully, only include metadata where available + expect(result.chunksUsed).toBe(3); + expect(result.context).toBeDefined(); + }); + + test("should handle chunks with mismatched scores length", async () => { + const result = await contextBuilder({ + chunks: testChunks, + scores: [testScores[0]], // Only one score + includeMetadata: true, + }); + + expect(result.chunksUsed).toBe(3); + expect(result.context).toBeDefined(); + }); + + test("should truncate first chunk if maxLength is very small", async () => { + const result = await contextBuilder({ + chunks: testChunks, + maxLength: 50, + }); + + expect(result.totalLength).toBeLessThanOrEqual(50); + expect(result.context.length).toBeLessThanOrEqual(50); + if (result.chunksUsed > 0) { + expect(result.context).toContain("..."); + } + }); + + test("should use default separator when not specified", async () => { + const result = await contextBuilder({ + chunks: testChunks, + }); + + // Default separator is "\n\n" + expect(result.context).toContain("\n\n"); + }); + + test("should escape XML special characters in XML format", async () => { + const chunksWithSpecialChars = ['Text with & "quotes"']; + const result = await contextBuilder({ + chunks: chunksWithSpecialChars, + format: ContextFormat.XML, + }); + + // Should escape XML characters + expect(result.context).not.toContain(""); + expect(result.context).toContain("<tag>"); + expect(result.context).toContain("&"); + expect(result.context).toContain(""quotes""); + }); + + test("should format metadata correctly in different formats", async () => { + // Test MARKDOWN format with metadata + const markdownResult = await contextBuilder({ + chunks: testChunks, + metadata: testMetadata, + includeMetadata: true, + format: ContextFormat.MARKDOWN, + }); + + expect(markdownResult.context).toContain("**Metadata:**"); + expect(markdownResult.context).toContain("- source:"); + + // Test JSON format with metadata + const jsonResult = await contextBuilder({ + chunks: testChunks, + metadata: testMetadata, + includeMetadata: true, + format: ContextFormat.JSON, + }); + + expect(jsonResult.context).toContain('"metadata"'); + }); + + test("should handle very long chunks", async () => { + const longChunk = "A".repeat(10000); + const result = await contextBuilder({ + chunks: [longChunk], + maxLength: 5000, + }); + + expect(result.totalLength).toBeLessThanOrEqual(5000); + }); +}); diff --git a/packages/test/src/test/rag/EndToEnd.test.ts b/packages/test/src/test/rag/EndToEnd.test.ts new file mode 100644 index 00000000..1f8f277d --- /dev/null +++ b/packages/test/src/test/rag/EndToEnd.test.ts @@ -0,0 +1,143 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { hierarchicalChunker } from "@workglow/ai"; +import { + Document, + DocumentRepository, + DocumentStorageKey, + DocumentStorageSchema, + InMemoryDocumentNodeVectorRepository, + InMemoryTabularRepository, + NodeIdGenerator, + StructuralParser, +} from "@workglow/storage"; +import { describe, expect, it } from "vitest"; + +describe("End-to-end hierarchical RAG", () => { + it("should demonstrate chainable design (chunks → text array)", async () => { + // Sample markdown document + const markdown = `# Machine Learning + +Machine learning is AI. + +## Supervised Learning + +Uses labeled data. + +## Unsupervised Learning + +Finds patterns in data.`; + + // Parse into hierarchical tree + const doc_id = await NodeIdGenerator.generateDocId("ml-guide", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "ML Guide"); + + // CHAINABLE DESIGN TEST - Use workflow to verify chaining + const chunkResult = await hierarchicalChunker({ + doc_id, + documentTree: root, + maxTokens: 256, + overlap: 25, + strategy: "hierarchical", + }); + + // Verify outputs are ready for next task in chain + expect(chunkResult.chunks).toBeDefined(); + expect(chunkResult.text).toBeDefined(); + expect(chunkResult.count).toBe(chunkResult.text.length); + expect(chunkResult.count).toBe(chunkResult.chunks.length); + + // The text array can be directly consumed by TextEmbeddingTask + expect(Array.isArray(chunkResult.text)).toBe(true); + expect(chunkResult.text.every((t) => typeof t === "string")).toBe(true); + }); + + it("should manage document chunks", async () => { + const markdown = "# Test Document\n\nThis is test content."; + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const doc = new Document(doc_id, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + doc_id: doc_id, + text: "Test chunk 1", + nodePath: [root.nodeId], + depth: 1, + }, + ]; + + doc.setChunks(chunks); + + // Verify chunks are stored + const retrievedChunks = doc.getChunks(); + expect(retrievedChunks.length).toBe(1); + expect(retrievedChunks[0].text).toBe("Test chunk 1"); + }); + + it("should demonstrate document repository integration", async () => { + const tabularStorage = new InMemoryTabularRepository( + DocumentStorageSchema, + DocumentStorageKey + ); + await tabularStorage.setupDatabase(); + + const vectorStorage = new InMemoryDocumentNodeVectorRepository(3); + await vectorStorage.setupDatabase(); + + const docRepo = new DocumentRepository(tabularStorage, vectorStorage); + + // Create document with enriched hierarchy + const markdown = `# Guide + +## Section 1 + +Content about topic A. + +## Section 2 + +Content about topic B.`; + + const doc_id = await NodeIdGenerator.generateDocId("guide", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Guide"); + + const doc = new Document(doc_id, root, { title: "Guide" }); + + // Enrich (in real workflow this would use DocumentEnricherTask) + // For test, manually add enrichment + const enrichedRoot = { + ...root, + enrichment: { + summary: "A guide covering two sections", + }, + }; + + const enrichedDoc = new Document(doc_id, enrichedRoot as any, doc.metadata); + await docRepo.upsert(enrichedDoc); + + // Generate chunks using workflow (without embedding to avoid model requirement) + const chunkResult = await hierarchicalChunker({ + doc_id, + documentTree: enrichedRoot, + maxTokens: 256, + overlap: 25, + strategy: "hierarchical", + }); + expect(chunkResult.count).toBeGreaterThan(0); + + // Add chunks to document + enrichedDoc.setChunks(chunkResult.chunks); + await docRepo.upsert(enrichedDoc); + + // Verify chunks were stored + const retrieved = await docRepo.getChunks(doc_id); + expect(retrieved).toBeDefined(); + expect(retrieved.length).toBe(chunkResult.count); + }); +}); diff --git a/packages/test/src/test/rag/FullChain.test.ts b/packages/test/src/test/rag/FullChain.test.ts new file mode 100644 index 00000000..5460e71b --- /dev/null +++ b/packages/test/src/test/rag/FullChain.test.ts @@ -0,0 +1,144 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { HierarchicalChunkerTaskOutput } from "@workglow/ai"; +import { + ChunkNode, + InMemoryDocumentNodeVectorRepository, + NodeIdGenerator, +} from "@workglow/storage"; +import { Workflow } from "@workglow/task-graph"; +import { describe, expect, it } from "vitest"; + +describe("Complete chainable workflow", () => { + it("should chain from parsing to storage without loops", async () => { + const vectorRepo = new InMemoryDocumentNodeVectorRepository(3); + await vectorRepo.setupDatabase(); + + const markdown = `# Test Document + +## Section 1 + +This is the first section with some content. + +## Section 2 + +This is the second section with more content.`; + + // Parse → Enrich → Chunk + const result = await new Workflow() + .structuralParser({ + text: markdown, + title: "Test Doc", + format: "markdown", + sourceUri: "test.md", + }) + .documentEnricher({ + generateSummaries: true, + extractEntities: true, + }) + .hierarchicalChunker({ + maxTokens: 256, + overlap: 25, + strategy: "hierarchical", + }) + .run(); + + // Verify the chain worked - final output from hierarchicalChunker + expect(result.doc_id).toBeDefined(); + expect(result.doc_id).toMatch(/^doc_[0-9a-f]{16}$/); + expect(result.chunks).toBeDefined(); + expect(result.text).toBeDefined(); + expect(result.count).toBeGreaterThan(0); + + // Verify output structure matches expectations + expect(result.chunks.length).toBe(result.count); + expect(result.text.length).toBe(result.count); + }); + + it("should demonstrate data flow through chain", async () => { + const markdown = "# Title\n\nParagraph content."; + + const result = await new Workflow() + .structuralParser({ + text: markdown, + title: "Test", + format: "markdown", + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .run(); + + // Verify data flows correctly (final output from hierarchicalChunker) + expect(result.doc_id).toBeDefined(); + expect(result.chunks).toBeDefined(); + expect(result.text).toBeDefined(); + + // doc_id should flow through the chain to all chunks + // PropertyArrayGraphResult makes chunks potentially an array of arrays + const chunks = ( + Array.isArray(result.chunks) && result.chunks.length > 0 + ? Array.isArray(result.chunks[0]) + ? result.chunks.flat() + : result.chunks + : [] + ) as ChunkNode[]; + for (const chunk of chunks) { + expect(chunk.doc_id).toBe(result.doc_id); + } + }); + + it("should generate consistent doc_id across chains", async () => { + const markdown = "# Test\n\nContent."; + + // Run twice with same content + const result1 = await new Workflow() + .structuralParser({ + text: markdown, + title: "Test", + sourceUri: "test.md", + }) + .run(); + + const result2 = await new Workflow() + .structuralParser({ + text: markdown, + title: "Test", + sourceUri: "test.md", + }) + .run(); + + // Should generate same doc_id (deterministic) + expect(result1.doc_id).toBe(result2.doc_id); + }); + + it("should allow doc_id override for variant creation", async () => { + const markdown = "# Test\n\nContent."; + const customId = await NodeIdGenerator.generateDocId("custom", markdown); + + const result = (await new Workflow() + .structuralParser({ + text: markdown, + title: "Test", + doc_id: customId, // Override with custom ID + }) + .hierarchicalChunker({ + maxTokens: 512, + }) + .run()) as HierarchicalChunkerTaskOutput; + + // Should use the provided ID + expect(result.doc_id).toBe(customId); + + // All chunks should reference it + for (const chunk of result.chunks) { + expect(chunk.doc_id).toBe(customId); + } + }); +}); diff --git a/packages/test/src/test/rag/HierarchicalChunker.test.ts b/packages/test/src/test/rag/HierarchicalChunker.test.ts new file mode 100644 index 00000000..c7f943dc --- /dev/null +++ b/packages/test/src/test/rag/HierarchicalChunker.test.ts @@ -0,0 +1,191 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { hierarchicalChunker } from "@workglow/ai"; +import { estimateTokens, NodeIdGenerator, StructuralParser } from "@workglow/storage"; +import { Workflow } from "@workglow/task-graph"; +import { describe, expect, it } from "vitest"; + +describe("HierarchicalChunkerTask", () => { + it("should chunk a simple document hierarchically", async () => { + const markdown = `# Section 1 + +This is a paragraph that should fit in one chunk. + +# Section 2 + +This is another paragraph.`; + + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + const result = await hierarchicalChunker({ + doc_id: doc_id, + documentTree: root, + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }); + + expect(result.chunks).toBeDefined(); + expect(result.text).toBeDefined(); + expect(result.count).toBeGreaterThan(0); + expect(result.chunks.length).toBe(result.count); + expect(result.text.length).toBe(result.count); + + // Each chunk should have required fields + for (const chunk of result.chunks) { + expect(chunk.chunkId).toBeDefined(); + expect(chunk.doc_id).toBe(doc_id); + expect(chunk.text).toBeDefined(); + expect(chunk.nodePath).toBeDefined(); + expect(chunk.nodePath.length).toBeGreaterThan(0); + expect(chunk.depth).toBeGreaterThanOrEqual(0); + } + }); + + it("should respect token budgets", async () => { + // Create a long text that requires splitting + const longText = "Lorem ipsum dolor sit amet. ".repeat(100); + const markdown = `# Section\n\n${longText}`; + + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Long"); + + const maxTokens = 100; + const result = await hierarchicalChunker({ + doc_id, + documentTree: root, + maxTokens, + overlap: 10, + strategy: "hierarchical", + }); + + // Should create multiple chunks + expect(result.count).toBeGreaterThan(1); + + // Each chunk should respect token budget + for (const chunk of result.chunks) { + const tokens = estimateTokens(chunk.text); + expect(tokens).toBeLessThanOrEqual(maxTokens); + } + }); + + it("should create overlapping chunks", async () => { + const text = "Word ".repeat(200); + const markdown = `# Section\n\n${text}`; + + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Overlap"); + + const maxTokens = 50; + const overlap = 10; + const result = await hierarchicalChunker({ + doc_id, + documentTree: root, + maxTokens, + overlap, + strategy: "hierarchical", + }); + + // Should have multiple chunks + expect(result.count).toBeGreaterThan(1); + + // Check for overlap in text content + if (result.chunks.length > 1) { + const chunk0 = result.chunks[0].text; + const chunk1 = result.chunks[1].text; + + // Extract end of first chunk + const chunk0End = chunk0.substring(Math.max(0, chunk0.length - 50)); + // Check if beginning of second chunk overlaps + const hasOverlap = chunk1.includes(chunk0End.substring(0, 20)); + + expect(hasOverlap).toBe(true); + } + }); + + it("should handle flat strategy", async () => { + const markdown = `# Section 1 + +Paragraph 1. + +# Section 2 + +Paragraph 2.`; + + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Flat"); + + const result = await new Workflow() + .hierarchicalChunker({ + doc_id, + documentTree: root, + maxTokens: 512, + overlap: 50, + strategy: "flat", + }) + .run(); + + // Flat strategy should still produce chunks + expect(result.count).toBeGreaterThan(0); + }); + + it("should maintain node paths in chunks", async () => { + const markdown = `# Section 1 + +## Subsection 1.1 + +Paragraph content.`; + + const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Paths"); + + const result = await hierarchicalChunker({ + doc_id, + documentTree: root, + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }); + + // Check that chunks have node paths + for (const chunk of result.chunks) { + expect(chunk.nodePath).toBeDefined(); + expect(Array.isArray(chunk.nodePath)).toBe(true); + expect(chunk.nodePath.length).toBeGreaterThan(0); + + // First element should be root node ID + expect(chunk.nodePath[0]).toBe(root.nodeId); + } + }); +}); + +describe("Token estimation", () => { + it("should estimate tokens approximately", () => { + const text = "This is a test string"; + const tokens = estimateTokens(text); + + // Rough approximation: 1 token ~= 4 characters + const expected = Math.ceil(text.length / 4); + expect(tokens).toBe(expected); + }); + + it("should handle empty strings", () => { + const tokens = estimateTokens(""); + expect(tokens).toBe(0); + }); + + it("should increase token count with text length", () => { + const shortText = "Hello"; + const longText = "Hello world this is a much longer text"; + + const shortTokens = estimateTokens(shortText); + const longTokens = estimateTokens(longText); + + expect(longTokens).toBeGreaterThan(shortTokens); + }); +}); diff --git a/packages/test/src/test/rag/HybridSearchTask.test.ts b/packages/test/src/test/rag/HybridSearchTask.test.ts new file mode 100644 index 00000000..03ad874b --- /dev/null +++ b/packages/test/src/test/rag/HybridSearchTask.test.ts @@ -0,0 +1,278 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { hybridSearch } from "@workglow/ai"; +import { + InMemoryDocumentNodeVectorRepository, + registerDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("DocumentNodeVectorHybridSearchTask", () => { + let repo: InMemoryDocumentNodeVectorRepository; + + beforeEach(async () => { + repo = new InMemoryDocumentNodeVectorRepository(3); + await repo.setupDatabase(); + + // Populate repository with test data + const vectors = [ + new Float32Array([1.0, 0.0, 0.0]), // Similar vector, contains "machine" + new Float32Array([0.8, 0.2, 0.0]), // Somewhat similar, contains "learning" + new Float32Array([0.0, 1.0, 0.0]), // Different vector, contains "cooking" + new Float32Array([0.0, 0.0, 1.0]), // Different vector, contains "travel" + new Float32Array([0.9, 0.1, 0.0]), // Very similar, contains "artificial" + ]; + + const metadata = [ + { text: "Document about machine learning", category: "tech" }, + { text: "Document about deep learning algorithms", category: "tech" }, + { text: "Document about cooking recipes", category: "food" }, + { text: "Document about travel destinations", category: "travel" }, + { text: "Document about artificial intelligence", category: "tech" }, + ]; + + for (let i = 0; i < vectors.length; i++) { + const doc_id = `doc${i + 1}`; + await repo.put({ + id: `${doc_id}_0`, + doc_id, + vector: vectors[i] as any, + metadata: metadata[i], + } as any); + } + }); + + afterEach(() => { + repo.destroy(); + }); + + test("should perform hybrid search with vector and text query", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine learning"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 3, + }); + + expect(result.count).toBeGreaterThan(0); + expect(result.chunks).toHaveLength(result.count); + expect(result.ids).toHaveLength(result.count); + expect(result.metadata).toHaveLength(result.count); + expect(result.scores).toHaveLength(result.count); + + // Scores should be in descending order + for (let i = 1; i < result.scores.length; i++) { + expect(result.scores[i - 1]).toBeGreaterThanOrEqual(result.scores[i]); + } + }); + + test("should combine vector and text scores", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 5, + }); + + // Results should be ranked by combined score + expect(result.scores.length).toBeGreaterThan(0); + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0); + expect(score).toBeLessThanOrEqual(1); + }); + }); + + test("should respect vectorWeight parameter", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "learning"; + + // Test with high vector weight + const resultHighVector = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 5, + vectorWeight: 0.9, + }); + + // Test with low vector weight (high text weight) + const resultHighText = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 5, + vectorWeight: 0.1, + }); + + // Results might differ based on weight + expect(resultHighVector.count).toBeGreaterThan(0); + expect(resultHighText.count).toBeGreaterThan(0); + }); + + test("should return vectors when returnVectors is true", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 3, + returnVectors: true, + }); + + expect(result.vectors).toBeDefined(); + expect(result.vectors).toHaveLength(result.count); + expect(result.vectors![0]).toBeInstanceOf(Float32Array); + }); + + test("should not return vectors when returnVectors is false", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 3, + returnVectors: false, + }); + + expect(result.vectors).toBeUndefined(); + }); + + test("should apply metadata filter", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "learning"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 10, + filter: { category: "tech" }, + }); + + // All results should have category "tech" + result.metadata.forEach((meta) => { + expect(meta).toHaveProperty("category", "tech"); + }); + }); + + test("should apply score threshold", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 10, + scoreThreshold: 0.5, + }); + + // All scores should be >= threshold + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0.5); + }); + }); + + test("should respect topK limit", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "document"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 2, + }); + + expect(result.count).toBeLessThanOrEqual(2); + expect(result.chunks).toHaveLength(result.count); + }); + + test("should handle default parameters", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "learning"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + }); + + // Default topK is 10, vectorWeight is 0.7 + expect(result.count).toBeGreaterThan(0); + expect(result.count).toBeLessThanOrEqual(10); + }); + + test("should extract chunks from metadata", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 5, + }); + + // Chunks should match metadata text + result.chunks.forEach((chunk, idx) => { + expect(chunk).toBe(result.metadata[idx].text); + }); + }); + + test("should work with quantized query vectors", async () => { + const queryVector = new Int8Array([127, 0, 0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 3, + }); + + expect(result.count).toBeGreaterThan(0); + expect(result.chunks).toHaveLength(result.count); + }); + + test("should resolve repository from string ID", async () => { + // Register repository by ID + registerDocumentNodeVectorRepository("test-hybrid-repo", repo); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine learning"; + + // Pass repository as string ID instead of instance + const result = await hybridSearch({ + repository: "test-hybrid-repo" as any, + queryVector: queryVector, + queryText: queryText, + topK: 3, + }); + + expect(result.count).toBeGreaterThan(0); + expect(result.chunks).toHaveLength(result.count); + expect(result.ids).toHaveLength(result.count); + expect(result.metadata).toHaveLength(result.count); + expect(result.scores).toHaveLength(result.count); + + // Scores should be in descending order + for (let i = 1; i < result.scores.length; i++) { + expect(result.scores[i - 1]).toBeGreaterThanOrEqual(result.scores[i]); + } + }); +}); diff --git a/packages/test/src/test/rag/RagWorkflow.test.ts b/packages/test/src/test/rag/RagWorkflow.test.ts new file mode 100644 index 00000000..327ac19a --- /dev/null +++ b/packages/test/src/test/rag/RagWorkflow.test.ts @@ -0,0 +1,277 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * RAG (Retrieval Augmented Generation) Workflow End-to-End Test + * + * This test demonstrates a complete RAG pipeline using the Workflow API + * in a way that's compatible with visual node editors. + * + * Node Editor Mapping: + * ==================== + * Each workflow step below represents a node in a visual editor with + * dataflow connections between them: + * + * 1. Document Ingestion Pipeline (per file): + * FileLoader → StructuralParser → DocumentEnricher → HierarchicalChunker + * → [Array Processing] → TextEmbedding (multiple) → ChunkToVector → VectorStoreUpsert + * + * Note: The array processing step (embedding multiple chunks) would use: + * - A "ForEach" or "Map" control node in the visual editor + * - Or an ArrayTask wrapper that replicates TextEmbedding nodes + * - Or a batch TextEmbedding node that accepts arrays + * + * 2. Semantic Search Pipeline: + * Query (input) → DocumentNodeRetrievalTask → Results (output) + * + * 3. Question Answering Pipeline: + * Question → DocumentNodeRetrievalTask → ContextBuilder → TextQuestionAnswerTask → Answer + * + * Models Used: + * - Xenova/all-MiniLM-L6-v2 (Text Embedding - 384D) + * - onnx-community/NeuroBERT-NER-ONNX (Named Entity Recognition) + * - Xenova/distilbert-base-uncased-distilled-squad (Question Answering) + */ + +import { + InMemoryModelRepository, + RetrievalTaskOutput, + setGlobalModelRepository, + TextQuestionAnswerTaskOutput, + VectorStoreUpsertTaskOutput, +} from "@workglow/ai"; +import { register_HFT_InlineJobFns } from "@workglow/ai-provider"; +import { + DocumentRepository, + DocumentStorageKey, + DocumentStorageSchema, + InMemoryDocumentNodeVectorRepository, + InMemoryTabularRepository, + registerDocumentNodeVectorRepository, +} from "@workglow/storage"; +import { getTaskQueueRegistry, setTaskQueueRegistry, Workflow } from "@workglow/task-graph"; +import { readdirSync } from "fs"; +import { join } from "path"; +import { afterAll, beforeAll, describe, expect, it } from "vitest"; +import { registerHuggingfaceLocalModels } from "../../samples"; +export { FileLoaderTask } from "@workglow/tasks"; + +describe("RAG Workflow End-to-End", () => { + let vectorRepo: InMemoryDocumentNodeVectorRepository; + let docRepo: DocumentRepository; + const vectorRepoName = "rag-test-vector-repo"; + const embeddingModel = "onnx:Xenova/all-MiniLM-L6-v2:q8"; + const summaryModel = "onnx:Falconsai/text_summarization:fp32"; + const nerModel = "onnx:onnx-community/NeuroBERT-NER-ONNX:q8"; + const qaModel = "onnx:Xenova/distilbert-base-uncased-distilled-squad:q8"; + + beforeAll(async () => { + // Setup task queue and model repository + setTaskQueueRegistry(null); + setGlobalModelRepository(new InMemoryModelRepository()); + await register_HFT_InlineJobFns(); + + await registerHuggingfaceLocalModels(); + + // Setup repositories + vectorRepo = new InMemoryDocumentNodeVectorRepository(3); + await vectorRepo.setupDatabase(); + + // Register vector repository for use in workflows + registerDocumentNodeVectorRepository(vectorRepoName, vectorRepo); + + const tabularRepo = new InMemoryTabularRepository(DocumentStorageSchema, DocumentStorageKey); + await tabularRepo.setupDatabase(); + + docRepo = new DocumentRepository(tabularRepo, vectorRepo); + }); + + afterAll(async () => { + getTaskQueueRegistry().stopQueues().clearQueues(); + setTaskQueueRegistry(null); + }); + + it.only("should ingest markdown documents with NER enrichment", async () => { + // Find markdown files in docs folder + const docsPath = join(process.cwd(), "docs", "background"); + const files = readdirSync(docsPath).filter((f) => f.endsWith(".md")); + + console.log(`Found ${files.length} markdown files to process`); + + let totalVectors = 0; + + for (const file of files) { + const filePath = join(docsPath, file); + console.log(`Processing: ${file}`); + + const ingestionWorkflow = new Workflow(); + + ingestionWorkflow + .fileLoader({ url: `file://${filePath}`, format: "markdown" }) + .structuralParser({ + title: filePath.split("/").pop()?.split(".")[0] || "", + format: "markdown", + sourceUri: filePath, + }) + .documentEnricher({ + generateSummaries: true, + extractEntities: true, + summaryModel, + nerModel, + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .textEmbedding({ + model: embeddingModel, + }) + .vectorStoreUpsert({ + repository: vectorRepoName, + }); + + const result = (await ingestionWorkflow.run()) as VectorStoreUpsertTaskOutput; + + console.log(` → Stored ${result.count} vectors`); + totalVectors += result.count; + } + + // Verify vectors were stored + expect(totalVectors).toBeGreaterThan(0); + console.log(`Total vectors in repository: ${totalVectors}`); + }, 360000); // 3 minute timeout for model downloads + + it("should search for relevant content", async () => { + const query = "What is retrieval augmented generation?"; + + console.log(`\nSearching for: "${query}"`); + + // Create search workflow + const searchWorkflow = new Workflow(); + + searchWorkflow.retrieval({ + repository: vectorRepoName, + query, + model: embeddingModel, + topK: 5, + scoreThreshold: 0.3, + }); + + const searchResult = (await searchWorkflow.run()) as RetrievalTaskOutput; + + // Verify search results + expect(searchResult.chunks).toBeDefined(); + expect(Array.isArray(searchResult.chunks)).toBe(true); + expect(searchResult.chunks.length).toBeGreaterThan(0); + expect(searchResult.chunks.length).toBeLessThanOrEqual(5); + expect(searchResult.scores).toBeDefined(); + expect(searchResult.scores!.length).toBe(searchResult.chunks.length); + + console.log(`Found ${searchResult.chunks.length} relevant chunks:`); + for (let i = 0; i < searchResult.chunks.length; i++) { + const chunk = searchResult.chunks[i]; + const score = searchResult.scores![i]; + console.log(` ${i + 1}. Score: ${score.toFixed(3)} - ${chunk.substring(0, 80)}...`); + } + + // Verify scores are in descending order + for (let i = 1; i < searchResult.scores!.length; i++) { + expect(searchResult.scores![i]).toBeLessThanOrEqual(searchResult.scores![i - 1]); + } + }, 60000); // 1 minute timeout + + it("should answer questions using retrieved context", async () => { + const question = "What is RAG?"; + + console.log(`\nAnswering question: "${question}"`); + + // Step 1: Retrieve relevant context + const retrievalWorkflow = new Workflow(); + + retrievalWorkflow.retrieval({ + repository: vectorRepoName, + query: question, + model: embeddingModel, + topK: 3, + scoreThreshold: 0.2, // Lower threshold to find results + }); + + const retrievalResult = (await retrievalWorkflow.run()) as RetrievalTaskOutput; + + expect(retrievalResult.chunks).toBeDefined(); + + if (retrievalResult.chunks.length === 0) { + console.log("No relevant chunks found, skipping QA"); + return; // Skip QA if no relevant context found + } + + console.log(`Retrieved ${retrievalResult.chunks.length} context chunks`); + + // Step 2: Build context from retrieved chunks + const context = retrievalResult.chunks.join("\n\n"); + + console.log(`Context length: ${context.length} characters`); + + // Step 3: Answer question using context + const qaWorkflow = new Workflow(); + + qaWorkflow.textQuestionAnswer({ + context, + question, + model: qaModel, + }); + + const answer = (await qaWorkflow.run()) as TextQuestionAnswerTaskOutput; + + // Verify answer + expect(answer.text).toBeDefined(); + expect(typeof answer.text).toBe("string"); + expect(answer.text.length).toBeGreaterThan(0); + + console.log(`\nAnswer: ${answer.text}`); + }, 60000); // 1 minute timeout + + it("should handle complex multi-step RAG pipeline", async () => { + const question = "How does vector search work?"; + + console.log(`\nComplex RAG pipeline for: "${question}"`); + + // Step 1: Retrieve context + const retrievalWorkflow = new Workflow(); + retrievalWorkflow.retrieval({ + repository: vectorRepoName, + query: question, + model: embeddingModel, + topK: 3, + scoreThreshold: 0.2, + }); + + const retrievalResult = (await retrievalWorkflow.run()) as RetrievalTaskOutput; + + if (retrievalResult.chunks.length === 0) { + console.log("No chunks found, skipping QA step"); + return; + } + + // Step 2: Answer question with retrieved context + const context = retrievalResult.chunks.join("\n\n"); + const qaWorkflow = new Workflow(); + qaWorkflow.textQuestionAnswer({ + context, + question, + model: qaModel, + }); + + const result = (await qaWorkflow.run()) as TextQuestionAnswerTaskOutput; + + expect(result.text).toBeDefined(); + expect(typeof result.text).toBe("string"); + expect(result.text.length).toBeGreaterThan(0); + + console.log(`Answer: ${result.text}`); + }, 60000); // 1 minute timeout +}); diff --git a/packages/test/src/test/rag/StructuralParser.test.ts b/packages/test/src/test/rag/StructuralParser.test.ts new file mode 100644 index 00000000..51d3b8e9 --- /dev/null +++ b/packages/test/src/test/rag/StructuralParser.test.ts @@ -0,0 +1,204 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { NodeIdGenerator, NodeKind, StructuralParser } from "@workglow/storage"; +import { describe, expect, it } from "vitest"; + +describe("StructuralParser", () => { + describe("Markdown parsing", () => { + it("should parse markdown with headers into hierarchical tree", async () => { + const markdown = `# Main Title + +This is the intro. + +## Section 1 + +Content for section 1. + +## Section 2 + +Content for section 2. + +### Subsection 2.1 + +Nested content.`; + + const doc_id = "doc_test123"; + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test Document"); + + expect(root.kind).toBe(NodeKind.DOCUMENT); + expect(root.children.length).toBeGreaterThan(0); + + // Find sections - parser should create sections for headers + const sections = root.children.filter((child) => child.kind === NodeKind.SECTION); + expect(sections.length).toBeGreaterThan(0); + + // Should have some children (sections or paragraphs) + expect(root.children.length).toBeGreaterThanOrEqual(1); + }); + + it("should preserve source offsets", async () => { + const markdown = `# Title + +Paragraph one. + +Paragraph two.`; + + const doc_id = "doc_test456"; + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); + + expect(root.range.startOffset).toBe(0); + expect(root.range.endOffset).toBe(markdown.length); + + // Check children have valid offsets + for (const child of root.children) { + expect(child.range.startOffset).toBeGreaterThanOrEqual(0); + expect(child.range.endOffset).toBeLessThanOrEqual(markdown.length); + expect(child.range.endOffset).toBeGreaterThan(child.range.startOffset); + } + }); + + it("should handle nested sections correctly", async () => { + const markdown = `# Level 1 + +Content. + +## Level 2 + +More content. + +### Level 3 + +Deep content.`; + + const doc_id = "doc_test789"; + const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Nested Test"); + + // Find first section (Level 1) + const level1 = root.children.find( + (c) => c.kind === NodeKind.SECTION && (c as any).level === 1 + ); + expect(level1).toBeDefined(); + + // It should have children including level 2 + const level2 = (level1 as any).children.find( + (c: any) => c.kind === NodeKind.SECTION && c.level === 2 + ); + expect(level2).toBeDefined(); + + // Level 2 should have level 3 + const level3 = (level2 as any).children.find( + (c: any) => c.kind === NodeKind.SECTION && c.level === 3 + ); + expect(level3).toBeDefined(); + }); + }); + + describe("Plain text parsing", () => { + it("should parse plain text into paragraphs", async () => { + const text = `First paragraph here. + +Second paragraph here. + +Third paragraph here.`; + + const doc_id = "doc_plain123"; + const root = await StructuralParser.parsePlainText(doc_id, text, "Plain Text"); + + expect(root.kind).toBe(NodeKind.DOCUMENT); + expect(root.children.length).toBe(3); + + for (const child of root.children) { + expect(child.kind).toBe(NodeKind.PARAGRAPH); + } + }); + + it("should handle single paragraph", async () => { + const text = "Just one paragraph."; + + const doc_id = "doc_plain456"; + const root = await StructuralParser.parsePlainText(doc_id, text, "Single"); + + expect(root.children.length).toBe(1); + expect(root.children[0].kind).toBe(NodeKind.PARAGRAPH); + expect(root.children[0].text).toBe(text); + }); + }); + + describe("Auto-detect", () => { + it("should auto-detect markdown", async () => { + const markdown = "# Header\n\nParagraph."; + const doc_id = "doc_auto123"; + + const root = await StructuralParser.parse(doc_id, markdown, "Auto"); + + // Should have detected markdown and created sections + const hasSection = root.children.some((c) => c.kind === NodeKind.SECTION); + expect(hasSection).toBe(true); + }); + + it("should default to plain text when no markdown markers", async () => { + const text = "Just plain text here."; + const doc_id = "doc_auto456"; + + const root = await StructuralParser.parse(doc_id, text, "Plain"); + + // Should be plain paragraph + expect(root.children[0].kind).toBe(NodeKind.PARAGRAPH); + }); + }); + + describe("NodeIdGenerator", () => { + it("should generate consistent docIds", async () => { + const id1 = await NodeIdGenerator.generateDocId("source1", "content"); + const id2 = await NodeIdGenerator.generateDocId("source1", "content"); + + expect(id1).toBe(id2); + expect(id1).toMatch(/^doc_[0-9a-f]{16}$/); + }); + + it("should generate different IDs for different content", () => { + const id1 = NodeIdGenerator.generateDocId("source", "content1"); + const id2 = NodeIdGenerator.generateDocId("source", "content2"); + + expect(id1).not.toBe(id2); + }); + + it("should generate consistent structural node IDs", async () => { + const doc_id = "doc_test"; + const range = { startOffset: 0, endOffset: 100 }; + + const id1 = await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.SECTION, range); + const id2 = await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.SECTION, range); + + expect(id1).toBe(id2); + expect(id1).toMatch(/^node_[0-9a-f]{16}$/); + }); + + it("should generate consistent child node IDs", async () => { + const parentId = "node_parent"; + const ordinal = 2; + + const id1 = await NodeIdGenerator.generateChildNodeId(parentId, ordinal); + const id2 = await NodeIdGenerator.generateChildNodeId(parentId, ordinal); + + expect(id1).toBe(id2); + expect(id1).toMatch(/^node_[0-9a-f]{16}$/); + }); + + it("should generate consistent chunk IDs", async () => { + const doc_id = "doc_test"; + const leafNodeId = "node_leaf"; + const ordinal = 0; + + const id1 = await NodeIdGenerator.generateChunkId(doc_id, leafNodeId, ordinal); + const id2 = await NodeIdGenerator.generateChunkId(doc_id, leafNodeId, ordinal); + + expect(id1).toBe(id2); + expect(id1).toMatch(/^chunk_[0-9a-f]{16}$/); + }); + }); +}); diff --git a/packages/test/src/test/rag/TextChunkerTask.test.ts b/packages/test/src/test/rag/TextChunkerTask.test.ts new file mode 100644 index 00000000..425677b1 --- /dev/null +++ b/packages/test/src/test/rag/TextChunkerTask.test.ts @@ -0,0 +1,226 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ChunkingStrategy, textChunker } from "@workglow/ai"; +import { describe, expect, test } from "vitest"; + +describe("TextChunkerTask", () => { + const testText = + "This is the first sentence. This is the second sentence! This is the third sentence? " + + "This is the fourth sentence. This is the fifth sentence."; + + test("should chunk text with FIXED strategy", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 10, + strategy: ChunkingStrategy.FIXED, + }); + + expect(result.chunks).toBeDefined(); + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + + // Verify metadata structure + result.metadata.forEach((meta, idx) => { + expect(meta).toHaveProperty("index"); + expect(meta).toHaveProperty("startChar"); + expect(meta).toHaveProperty("endChar"); + expect(meta).toHaveProperty("length"); + expect(meta.index).toBe(idx); + }); + }); + + test("should chunk with SENTENCE strategy", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 80, + chunkOverlap: 20, + strategy: ChunkingStrategy.SENTENCE, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + + // Chunks should respect sentence boundaries + result.chunks.forEach((chunk) => { + expect(chunk.length).toBeGreaterThan(0); + }); + }); + + test("should chunk with PARAGRAPH strategy", async () => { + const paragraphText = + "First paragraph with multiple sentences. It has more content.\n\n" + + "Second paragraph with different content. It also has sentences.\n\n" + + "Third paragraph is here. With more text."; + + const result = await textChunker({ + text: paragraphText, + chunkSize: 100, + chunkOverlap: 20, + strategy: ChunkingStrategy.PARAGRAPH, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + }); + + test("should handle default parameters", async () => { + const result = await textChunker({ + text: testText, + }); + + // Default: chunkSize=512, chunkOverlap=50, strategy=FIXED + expect(result.chunks).toBeDefined(); + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + }); + + test("should handle chunkOverlap correctly", async () => { + const shortText = "A".repeat(100); // 100 characters + const result = await textChunker({ + text: shortText, + chunkSize: 30, + chunkOverlap: 10, + strategy: ChunkingStrategy.FIXED, + }); + + // With chunkSize=30 and overlap=10, we move forward by 20 each time + // Should have multiple chunks + expect(result.chunks.length).toBeGreaterThan(1); + + // Verify overlap by checking that chunks share content + if (result.chunks.length > 1) { + const firstChunkEnd = result.chunks[0].slice(-10); + const secondChunkStart = result.chunks[1].slice(0, 10); + // There should be some overlap + expect(firstChunkEnd).toBe(secondChunkStart); + } + }); + + test("should handle zero overlap", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 0, + strategy: ChunkingStrategy.FIXED, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + // With zero overlap, chunks should be adjacent + result.metadata.forEach((meta, idx) => { + if (idx > 0) { + const prevMeta = result.metadata[idx - 1]; + expect(meta.startChar).toBe(prevMeta.endChar); + } + }); + }); + + test("should handle text shorter than chunkSize", async () => { + const shortText = "Short text"; + const result = await textChunker({ + text: shortText, + chunkSize: 100, + chunkOverlap: 10, + }); + + expect(result.chunks.length).toBe(1); + expect(result.chunks[0]).toBe(shortText); + expect(result.metadata[0].length).toBe(shortText.length); + }); + + test("should handle empty text", async () => { + const result = await textChunker({ + text: "", + chunkSize: 50, + }); + + // Empty text should produce empty chunks or handle gracefully + expect(result.chunks).toBeDefined(); + expect(result.metadata).toBeDefined(); + }); + + test("should include all text in chunks (no loss)", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 10, + strategy: ChunkingStrategy.FIXED, + }); + + // Reconstruct text from chunks (accounting for overlap) + const totalChars = result.chunks.reduce((sum, chunk) => sum + chunk.length, 0); + // With overlap, total should be >= original length + expect(totalChars).toBeGreaterThanOrEqual(testText.length); + }); + + test("should handle SEMANTIC strategy (currently same as sentence)", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 80, + chunkOverlap: 20, + strategy: ChunkingStrategy.SEMANTIC, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + }); + + test("should preserve chunk order", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 10, + }); + + // Metadata indices should be sequential + result.metadata.forEach((meta, idx) => { + expect(meta.index).toBe(idx); + }); + + // Start positions should be in order + for (let i = 1; i < result.metadata.length; i++) { + expect(result.metadata[i].startChar).toBeGreaterThanOrEqual( + result.metadata[i - 1].startChar! + ); + } + }); + + test("should handle very large chunkSize", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 10000, + chunkOverlap: 0, + }); + + // Should produce single chunk + expect(result.chunks.length).toBe(1); + expect(result.chunks[0]).toBe(testText); + }); + + test("should handle overlap equal to chunkSize (edge case)", async () => { + // This should be handled to prevent infinite loops + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 50, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata.length).toBe(result.chunks.length); + }); + + test("should handle overlap greater than chunkSize (edge case)", async () => { + // Should handle gracefully + const result = await textChunker({ + text: testText, + chunkSize: 30, + chunkOverlap: 50, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + }); +}); diff --git a/packages/test/src/test/rag/VectorQuantizeTask.test.ts b/packages/test/src/test/rag/VectorQuantizeTask.test.ts new file mode 100644 index 00000000..f88dab04 --- /dev/null +++ b/packages/test/src/test/rag/VectorQuantizeTask.test.ts @@ -0,0 +1,228 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vectorQuantize } from "@workglow/ai"; +import { TensorType } from "@workglow/util"; +import { describe, expect, test } from "vitest"; + +describe("VectorQuantizeTask", () => { + const testVector = new Float32Array([0.5, -0.5, 0.8, -0.3, 0.0, 1.0, -1.0]); + + test("should quantize to INT8", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.INT8, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Int8Array); + expect(result.originalType).toBe(TensorType.FLOAT32); + expect(result.targetType).toBe(TensorType.INT8); + + const quantized = result.vector as Int8Array; + expect(quantized.length).toBe(testVector.length); + // Values should be scaled to [-127, 127] + expect(quantized[0]).toBe(64); // 0.5 * 127 ≈ 64 + expect(quantized[1]).toBe(-63); // -0.5 * 127 ≈ -63 (rounded) + }); + + test("should quantize to UINT8", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.UINT8, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Uint8Array); + expect(result.targetType).toBe(TensorType.UINT8); + + const quantized = result.vector as Uint8Array; + expect(quantized.length).toBe(testVector.length); + // Values should be scaled to [0, 255] + expect(quantized.every((v) => v >= 0 && v <= 255)).toBe(true); + }); + + test("should quantize to INT16", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.INT16, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Int16Array); + expect(result.targetType).toBe(TensorType.INT16); + + const quantized = result.vector as Int16Array; + expect(quantized.length).toBe(testVector.length); + // Values should be scaled to [-32767, 32767] + expect(quantized[0]).toBeCloseTo(16384, -2); // 0.5 * 32767 + }); + + test("should quantize to UINT16", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.UINT16, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Uint16Array); + expect(result.targetType).toBe(TensorType.UINT16); + + const quantized = result.vector as Uint16Array; + expect(quantized.length).toBe(testVector.length); + // Values should be scaled to [0, 65535] + expect(quantized.every((v) => v >= 0 && v <= 65535)).toBe(true); + }); + + test("should quantize to FLOAT16", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.FLOAT16, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Float16Array); + expect(result.targetType).toBe(TensorType.FLOAT16); + + const quantized = result.vector as Float16Array; + expect(quantized.length).toBe(testVector.length); + }); + + test("should quantize to FLOAT64", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.FLOAT64, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Float64Array); + expect(result.targetType).toBe(TensorType.FLOAT64); + + const quantized = result.vector as Float64Array; + expect(quantized.length).toBe(testVector.length); + }); + + test("should handle normalization", async () => { + const unnormalizedVector = new Float32Array([1, 2, 3, 4, 5]); + + const result = await vectorQuantize({ + vector: unnormalizedVector, + targetType: TensorType.INT8, + normalize: true, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Int8Array); + + // With normalization, values should be normalized before quantization + const quantized = result.vector as Int8Array; + expect(quantized.length).toBe(unnormalizedVector.length); + }); + + test("should handle array of vectors", async () => { + const vectors = [ + new Float32Array([0.5, -0.5, 0.8]), + new Float32Array([0.1, 0.2, 0.3]), + new Float32Array([-0.4, -0.5, -0.6]), + ]; + + const result = await vectorQuantize({ + vector: vectors, + targetType: TensorType.INT8, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(Array.isArray(result.vector)).toBe(true); + + const quantizedVectors = result.vector as Int8Array[]; + expect(quantizedVectors.length).toBe(3); + quantizedVectors.forEach((v, idx) => { + expect(v).toBeInstanceOf(Int8Array); + expect(v.length).toBe(vectors[idx].length); + }); + }); + + test("should preserve dimensions when quantizing", async () => { + const largeVector = new Float32Array(384).map(() => Math.random() * 2 - 1); + + const result = await vectorQuantize({ + vector: largeVector, + targetType: TensorType.INT8, + normalize: true, + }); + + expect(result).toBeDefined(); + const quantized = result.vector as Int8Array; + expect(quantized.length).toBe(largeVector.length); + }); + + test("should handle edge cases in INT8 quantization", async () => { + const edgeVector = new Float32Array([1.0, -1.0, 1.5, -1.5, 0.0]); + + const result = await vectorQuantize({ + vector: edgeVector, + targetType: TensorType.INT8, + normalize: false, + }); + + const quantized = result.vector as Int8Array; + // Values clamped to [-1, 1] before scaling + expect(quantized[0]).toBe(127); // 1.0 * 127 + expect(quantized[1]).toBe(-127); // -1.0 * 127 + expect(quantized[2]).toBe(127); // 1.5 clamped to 1.0 + expect(quantized[3]).toBe(-127); // -1.5 clamped to -1.0 + expect(quantized[4]).toBe(0); // 0.0 + }); + + test("should detect original vector type", async () => { + const int8Vector = new Int8Array([10, 20, 30, 40]); + + const result = await vectorQuantize({ + vector: int8Vector, + targetType: TensorType.FLOAT32, + normalize: false, + }); + + expect(result.originalType).toBe(TensorType.INT8); + expect(result.targetType).toBe(TensorType.FLOAT32); + expect(result.vector).toBeInstanceOf(Float32Array); + }); + + test("should handle different typed arrays as input", async () => { + const testCases = [ + { input: new Float16Array([0.5, -0.5]), expected: TensorType.FLOAT16 }, + { input: new Float32Array([0.5, -0.5]), expected: TensorType.FLOAT32 }, + { input: new Float64Array([0.5, -0.5]), expected: TensorType.FLOAT64 }, + { input: new Int8Array([10, -10]), expected: TensorType.INT8 }, + { input: new Uint8Array([10, 20]), expected: TensorType.UINT8 }, + { input: new Int16Array([100, -100]), expected: TensorType.INT16 }, + { input: new Uint16Array([100, 200]), expected: TensorType.UINT16 }, + ]; + + for (const testCase of testCases) { + const result = await vectorQuantize({ + vector: testCase.input, + targetType: TensorType.FLOAT32, + normalize: false, + }); + expect(result.originalType).toBe(testCase.expected); + } + }); + + test("should use default normalize value of true", async () => { + const result = await vectorQuantize({ vector: testVector, targetType: TensorType.INT8 }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Int8Array); + }); +}); diff --git a/packages/test/src/test/util/VectorSimilarityUtils.test.ts b/packages/test/src/test/util/VectorSimilarityUtils.test.ts new file mode 100644 index 00000000..01408040 --- /dev/null +++ b/packages/test/src/test/util/VectorSimilarityUtils.test.ts @@ -0,0 +1,390 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + cosineSimilarity, + jaccardSimilarity, + hammingDistance, + hammingSimilarity, +} from "@workglow/util"; +import { describe, expect, test } from "vitest"; + +describe("VectorSimilarityUtils", () => { + describe("cosineSimilarity", () => { + test("should calculate cosine similarity for identical vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 3, 4]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should calculate cosine similarity for orthogonal vectors", () => { + const a = new Float32Array([1, 0, 0]); + const b = new Float32Array([0, 1, 0]); + expect(cosineSimilarity(a, b)).toBeCloseTo(0.0, 5); + }); + + test("should calculate cosine similarity for opposite vectors", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([-1, -2, -3]); + expect(cosineSimilarity(a, b)).toBeCloseTo(-1.0, 5); + }); + + test("should handle zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([1, 2, 3]); + expect(cosineSimilarity(a, b)).toBe(0); + }); + + test("should handle both zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([0, 0, 0]); + expect(cosineSimilarity(a, b)).toBe(0); + }); + + test("should work with Int8Array", () => { + const a = new Int8Array([10, 20, 30]); + const b = new Int8Array([10, 20, 30]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Uint8Array", () => { + const a = new Uint8Array([10, 20, 30]); + const b = new Uint8Array([10, 20, 30]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Int16Array", () => { + const a = new Int16Array([100, 200, 300]); + const b = new Int16Array([100, 200, 300]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Uint16Array", () => { + const a = new Uint16Array([100, 200, 300]); + const b = new Uint16Array([100, 200, 300]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Float64Array", () => { + const a = new Float64Array([1.5, 2.5, 3.5]); + const b = new Float64Array([1.5, 2.5, 3.5]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should calculate cosine similarity for partially similar vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([2, 3, 4, 5]); + const result = cosineSimilarity(a, b); + expect(result).toBeGreaterThan(0.9); + expect(result).toBeLessThan(1.0); + }); + + test("should throw error for mismatched vector lengths", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + expect(() => cosineSimilarity(a, b)).toThrow("Vectors must have the same length"); + }); + + test("should handle negative values correctly", () => { + const a = new Float32Array([-1, -2, -3]); + const b = new Float32Array([-1, -2, -3]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should handle mixed positive and negative values", () => { + const a = new Float32Array([1, -2, 3, -4]); + const b = new Float32Array([1, -2, 3, -4]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should handle large vectors", () => { + const size = 1000; + const a = new Float32Array(size).fill(1); + const b = new Float32Array(size).fill(1); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + }); + + describe("jaccardSimilarity", () => { + test("should calculate Jaccard similarity for identical vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 3, 4]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should calculate Jaccard similarity for completely different vectors", () => { + const a = new Float32Array([5, 5, 5]); + const b = new Float32Array([1, 1, 1]); + const result = jaccardSimilarity(a, b); + expect(result).toBeGreaterThan(0); + expect(result).toBeLessThan(1); + }); + + test("should handle zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([1, 2, 3]); + expect(jaccardSimilarity(a, b)).toBe(0); + }); + + test("should handle both zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([0, 0, 0]); + expect(jaccardSimilarity(a, b)).toBe(0); + }); + + test("should work with Int8Array", () => { + const a = new Int8Array([10, 20, 30]); + const b = new Int8Array([10, 20, 30]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Uint8Array", () => { + const a = new Uint8Array([10, 20, 30]); + const b = new Uint8Array([10, 20, 30]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Int16Array", () => { + const a = new Int16Array([100, 200, 300]); + const b = new Int16Array([100, 200, 300]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Uint16Array", () => { + const a = new Uint16Array([100, 200, 300]); + const b = new Uint16Array([100, 200, 300]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Float64Array", () => { + const a = new Float64Array([1.5, 2.5, 3.5]); + const b = new Float64Array([1.5, 2.5, 3.5]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should calculate correct similarity for partially overlapping vectors", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([2, 3, 4]); + const result = jaccardSimilarity(a, b); + expect(result).toBeGreaterThan(0); + expect(result).toBeLessThan(1); + }); + + test("should throw error for mismatched vector lengths", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + expect(() => jaccardSimilarity(a, b)).toThrow("Vectors must have the same length"); + }); + + test("should handle all positive values", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2, 3]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should handle negative values by using min/max", () => { + const a = new Float32Array([-1, -2, -3]); + const b = new Float32Array([-2, -3, -4]); + const result = jaccardSimilarity(a, b); + expect(result).toBeGreaterThan(0); + expect(result).toBeLessThan(1); + }); + }); + + describe("hammingDistance", () => { + test("should calculate Hamming distance for identical vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 3, 4]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should calculate Hamming distance for completely different vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([5, 6, 7, 8]); + expect(hammingDistance(a, b)).toBe(1.0); + }); + + test("should calculate Hamming distance for partially different vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 5, 6]); + expect(hammingDistance(a, b)).toBe(0.5); + }); + + test("should handle zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([0, 0, 0]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should work with Int8Array", () => { + const a = new Int8Array([10, 20, 30]); + const b = new Int8Array([10, 20, 30]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should work with Uint8Array", () => { + const a = new Uint8Array([10, 20, 30]); + const b = new Uint8Array([10, 20, 40]); + expect(hammingDistance(a, b)).toBeCloseTo(1 / 3, 5); + }); + + test("should work with Int16Array", () => { + const a = new Int16Array([100, 200, 300]); + const b = new Int16Array([100, 200, 300]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should work with Uint16Array", () => { + const a = new Uint16Array([100, 200, 300]); + const b = new Uint16Array([100, 200, 300]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should work with Float64Array", () => { + const a = new Float64Array([1.5, 2.5, 3.5]); + const b = new Float64Array([1.5, 2.5, 3.5]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should throw error for mismatched vector lengths", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + expect(() => hammingDistance(a, b)).toThrow("Vectors must have the same length"); + }); + + test("should handle negative values", () => { + const a = new Float32Array([-1, -2, -3]); + const b = new Float32Array([-1, -2, -3]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should distinguish between close but not equal values", () => { + const a = new Float32Array([1.0, 2.0, 3.0]); + const b = new Float32Array([1.0001, 2.0, 3.0]); + expect(hammingDistance(a, b)).toBeCloseTo(1 / 3, 5); + }); + + test("should normalize distance by vector length", () => { + const a = new Float32Array([1, 2, 3, 4, 5, 6, 7, 8]); + const b = new Float32Array([1, 2, 3, 4, 9, 10, 11, 12]); + expect(hammingDistance(a, b)).toBe(0.5); + }); + }); + + describe("hammingSimilarity", () => { + test("should calculate Hamming similarity for identical vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 3, 4]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should calculate Hamming similarity for completely different vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([5, 6, 7, 8]); + expect(hammingSimilarity(a, b)).toBe(0); + }); + + test("should calculate Hamming similarity for partially different vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 5, 6]); + expect(hammingSimilarity(a, b)).toBe(0.5); + }); + + test("should be inverse of Hamming distance", () => { + const a = new Float32Array([1, 2, 3, 4, 5]); + const b = new Float32Array([1, 6, 3, 8, 5]); + const distance = hammingDistance(a, b); + const similarity = hammingSimilarity(a, b); + expect(similarity).toBeCloseTo(1 - distance, 5); + }); + + test("should work with Int8Array", () => { + const a = new Int8Array([10, 20, 30]); + const b = new Int8Array([10, 20, 30]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should work with Uint8Array", () => { + const a = new Uint8Array([10, 20, 30]); + const b = new Uint8Array([10, 20, 40]); + expect(hammingSimilarity(a, b)).toBeCloseTo(2 / 3, 5); + }); + + test("should work with Int16Array", () => { + const a = new Int16Array([100, 200, 300]); + const b = new Int16Array([100, 200, 300]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should work with Uint16Array", () => { + const a = new Uint16Array([100, 200, 300]); + const b = new Uint16Array([100, 200, 300]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should work with Float64Array", () => { + const a = new Float64Array([1.5, 2.5, 3.5]); + const b = new Float64Array([1.5, 2.5, 3.5]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should throw error for mismatched vector lengths", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + expect(() => hammingSimilarity(a, b)).toThrow("Vectors must have the same length"); + }); + + test("should handle zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([0, 0, 0]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + }); + + describe("Edge cases and cross-function consistency", () => { + test("should handle single element vectors", () => { + const a = new Float32Array([5]); + const b = new Float32Array([5]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + expect(hammingDistance(a, b)).toBe(0); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should handle empty vectors", () => { + const a = new Float32Array([]); + const b = new Float32Array([]); + // For empty vectors, the functions should handle them gracefully + expect(hammingDistance(a, b)).toBeNaN(); // 0/0 + expect(hammingSimilarity(a, b)).toBeNaN(); + }); + + test("should handle very small values", () => { + const a = new Float32Array([0.0001, 0.0002, 0.0003]); + const b = new Float32Array([0.0001, 0.0002, 0.0003]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should handle very large values", () => { + const a = new Float32Array([10000, 20000, 30000]); + const b = new Float32Array([10000, 20000, 30000]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("all functions should throw same error for length mismatch", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + const errorMessage = "Vectors must have the same length"; + + expect(() => cosineSimilarity(a, b)).toThrow(errorMessage); + expect(() => jaccardSimilarity(a, b)).toThrow(errorMessage); + expect(() => hammingDistance(a, b)).toThrow(errorMessage); + expect(() => hammingSimilarity(a, b)).toThrow(errorMessage); + }); + }); +}); diff --git a/packages/test/src/test/util/VectorUtils.test.ts b/packages/test/src/test/util/VectorUtils.test.ts new file mode 100644 index 00000000..135ef740 --- /dev/null +++ b/packages/test/src/test/util/VectorUtils.test.ts @@ -0,0 +1,382 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { inner, magnitude, normalize, normalizeNumberArray } from "@workglow/util"; +import { describe, expect, test } from "vitest"; + +describe("VectorUtils", () => { + describe("magnitude", () => { + test("should calculate magnitude for Float32Array", () => { + const vector = new Float32Array([3, 4]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should calculate magnitude for Float64Array", () => { + const vector = new Float64Array([1, 2, 2]); + const result = magnitude(vector); + expect(result).toBe(3); + }); + + test("should calculate magnitude for Int8Array", () => { + const vector = new Int8Array([6, 8]); + const result = magnitude(vector); + expect(result).toBe(10); + }); + + test("should calculate magnitude for Uint8Array", () => { + const vector = new Uint8Array([5, 12]); + const result = magnitude(vector); + expect(result).toBe(13); + }); + + test("should calculate magnitude for Int16Array", () => { + const vector = new Int16Array([3, 4]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should calculate magnitude for Uint16Array", () => { + const vector = new Uint16Array([8, 15]); + const result = magnitude(vector); + expect(result).toBe(17); + }); + + test("should calculate magnitude for Float16Array", () => { + const vector = new Float16Array([3, 4]); + const result = magnitude(vector); + expect(result).toBeCloseTo(5, 1); + }); + + test("should calculate magnitude for number array", () => { + const vector = [3, 4]; + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should return 0 for zero vector", () => { + const vector = new Float32Array([0, 0, 0]); + const result = magnitude(vector); + expect(result).toBe(0); + }); + + test("should handle single element vector", () => { + const vector = new Float32Array([5]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should handle negative values", () => { + const vector = new Float32Array([-3, -4]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should handle mixed positive and negative values", () => { + const vector = new Float32Array([3, -4]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should handle large vectors", () => { + const vector = new Float32Array(1000).fill(1); + const result = magnitude(vector); + expect(result).toBeCloseTo(Math.sqrt(1000), 5); + }); + }); + + describe("inner", () => { + test("should calculate dot product for Float32Array", () => { + const arr1 = new Float32Array([1, 2, 3]); + const arr2 = new Float32Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBe(32); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32 + }); + + test("should calculate dot product for Float64Array", () => { + const arr1 = new Float64Array([2, 3]); + const arr2 = new Float64Array([4, 5]); + const result = inner(arr1, arr2); + expect(result).toBe(23); // 2*4 + 3*5 = 8 + 15 = 23 + }); + + test("should calculate dot product for Int8Array", () => { + const arr1 = new Int8Array([1, 2, 3]); + const arr2 = new Int8Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBe(32); + }); + + test("should calculate dot product for Uint8Array", () => { + const arr1 = new Uint8Array([1, 2, 3]); + const arr2 = new Uint8Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBe(32); + }); + + test("should calculate dot product for Int16Array", () => { + const arr1 = new Int16Array([10, 20]); + const arr2 = new Int16Array([5, 3]); + const result = inner(arr1, arr2); + expect(result).toBe(110); // 10*5 + 20*3 = 50 + 60 = 110 + }); + + test("should calculate dot product for Uint16Array", () => { + const arr1 = new Uint16Array([10, 20]); + const arr2 = new Uint16Array([5, 3]); + const result = inner(arr1, arr2); + expect(result).toBe(110); + }); + + test("should calculate dot product for Float16Array", () => { + const arr1 = new Float16Array([1, 2, 3]); + const arr2 = new Float16Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBeCloseTo(32, 0); + }); + + test("should return 0 for zero vectors", () => { + const arr1 = new Float32Array([0, 0, 0]); + const arr2 = new Float32Array([1, 2, 3]); + const result = inner(arr1, arr2); + expect(result).toBe(0); + }); + + test("should handle orthogonal vectors", () => { + const arr1 = new Float32Array([1, 0, 0]); + const arr2 = new Float32Array([0, 1, 0]); + const result = inner(arr1, arr2); + expect(result).toBe(0); + }); + + test("should handle negative values", () => { + const arr1 = new Float32Array([-1, -2, -3]); + const arr2 = new Float32Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBe(-32); // -1*4 + -2*5 + -3*6 = -4 - 10 - 18 = -32 + }); + + test("should handle single element vectors", () => { + const arr1 = new Float32Array([5]); + const arr2 = new Float32Array([3]); + const result = inner(arr1, arr2); + expect(result).toBe(15); + }); + + test("should handle large vectors", () => { + const size = 1000; + const arr1 = new Float32Array(size).fill(1); + const arr2 = new Float32Array(size).fill(2); + const result = inner(arr1, arr2); + expect(result).toBe(2000); + }); + }); + + describe("normalize", () => { + test("should normalize Float32Array to unit length", () => { + const vector = new Float32Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(result.length).toBe(2); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should normalize Float64Array to unit length", () => { + const vector = new Float64Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float64Array); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should normalize Int8Array to unit length", () => { + const vector = new Int8Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Int8Array); + expect(result.length).toBe(2); + // Int8Array will truncate the decimal values (0.6, 0.8 -> 0, 0) + // So magnitude will be 0, which is expected behavior for integer arrays + expect(magnitude(result)).toBe(0); + }); + + test("should normalize Uint8Array to unit length", () => { + const vector = new Uint8Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Uint8Array); + expect(result.length).toBe(2); + // Uint8Array will truncate the decimal values (0.6, 0.8 -> 0, 0) + // So magnitude will be 0, which is expected behavior for integer arrays + expect(magnitude(result)).toBe(0); + }); + + test("should normalize Int16Array to unit length", () => { + const vector = new Int16Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Int16Array); + expect(result.length).toBe(2); + // Int16Array will truncate the decimal values (0.6, 0.8 -> 0, 0) + // So magnitude will be 0, which is expected behavior for integer arrays + expect(magnitude(result)).toBe(0); + }); + + test("should normalize Uint16Array to unit length", () => { + const vector = new Uint16Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Uint16Array); + expect(result.length).toBe(2); + // Uint16Array will truncate the decimal values (0.6, 0.8 -> 0, 0) + // So magnitude will be 0, which is expected behavior for integer arrays + expect(magnitude(result)).toBe(0); + }); + + test("should normalize Float16Array and convert to Float32Array", () => { + const vector = new Float16Array([3, 4]); + const result = normalize(vector, true, true); + // For Float16Array, the function should return Float32Array + expect(result).toBeInstanceOf(Float32Array); + expect(result[0]).toBeCloseTo(0.6, 1); + expect(result[1]).toBeCloseTo(0.8, 1); + }); + + test("should throw error for zero vector by default", () => { + const vector = new Float32Array([0, 0, 0]); + expect(() => normalize(vector)).toThrow("Cannot normalize a zero vector."); + }); + + test("should return original zero vector when throwOnZero is false", () => { + const vector = new Float32Array([0, 0, 0]); + const result = normalize(vector, false); + expect(result).toBe(vector); + expect(result[0]).toBe(0); + expect(result[1]).toBe(0); + expect(result[2]).toBe(0); + }); + + test("should handle negative values", () => { + const vector = new Float32Array([-3, -4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(result[0]).toBeCloseTo(-0.6, 5); + expect(result[1]).toBeCloseTo(-0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle mixed positive and negative values", () => { + const vector = new Float32Array([3, -4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(-0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle single element vector", () => { + const vector = new Float32Array([5]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(result[0]).toBe(1); + }); + + test("should handle already normalized vector", () => { + const vector = new Float32Array([0.6, 0.8]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle large vectors", () => { + const vector = new Float32Array(1000).fill(1); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should preserve type for other integer arrays", () => { + // Test Int32Array which is not explicitly handled + const vector = new Int32Array([3, 4]); + // @ts-ignore - Int32Array is not explicitly handled by normalize + const result = normalize(vector); + // Should fall through to Float32Array for unhandled types + expect(result).toBeInstanceOf(Float32Array); + }); + }); + + describe("normalizeNumberArray", () => { + test("should normalize number array to unit length", () => { + const values = [3, 4]; + const result = normalizeNumberArray(values); + expect(Array.isArray(result)).toBe(true); + expect(result.length).toBe(2); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should return original array for zero vector by default", () => { + const values = [0, 0, 0]; + const result = normalizeNumberArray(values); + expect(result).toBe(values); + expect(result[0]).toBe(0); + expect(result[1]).toBe(0); + expect(result[2]).toBe(0); + }); + + test("should throw error for zero vector when throwOnZero is true", () => { + const values = [0, 0, 0]; + expect(() => normalizeNumberArray(values, true)).toThrow("Cannot normalize a zero vector."); + }); + + test("should handle negative values", () => { + const values = [-3, -4]; + const result = normalizeNumberArray(values); + expect(result[0]).toBeCloseTo(-0.6, 5); + expect(result[1]).toBeCloseTo(-0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle mixed positive and negative values", () => { + const values = [3, -4]; + const result = normalizeNumberArray(values); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(-0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle single element array", () => { + const values = [5]; + const result = normalizeNumberArray(values); + expect(result[0]).toBe(1); + }); + + test("should handle already normalized array", () => { + const values = [0.6, 0.8]; + const result = normalizeNumberArray(values); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle large arrays", () => { + const values = new Array(1000).fill(1); + const result = normalizeNumberArray(values); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle decimal values", () => { + const values = [0.1, 0.2, 0.3]; + const result = normalizeNumberArray(values); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should not mutate original array", () => { + const values = [3, 4]; + const original = [...values]; + normalizeNumberArray(values); + expect(values).toEqual(original); + }); + }); +}); From cb18634102f2e82bb6f17dfa1e472fc717988160 Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 11 Jan 2026 08:26:16 +0000 Subject: [PATCH 06/14] [refactor] Remove TaskRegistry registration from individual AI tasks, add it in index to prevent tree shaking - Eliminated TaskRegistry registration from multiple AI task files to streamline task management. - Centralized task registration in the index file to ensure all tasks are registered in one place, improving maintainability and reducing redundancy. - Updated documentation to reflect the changes in task registration structure. --- packages/ai/src/task/BackgroundRemovalTask.ts | 3 +- packages/ai/src/task/ChunkToVectorTask.ts | 1 - packages/ai/src/task/ContextBuilderTask.ts | 1 - packages/ai/src/task/DocumentEnricherTask.ts | 1 - .../ai/src/task/DocumentNodeRetrievalTask.ts | 1 - .../DocumentNodeVectorHybridSearchTask.ts | 1 - .../src/task/DocumentNodeVectorSearchTask.ts | 1 - .../src/task/DocumentNodeVectorUpsertTask.ts | 1 - packages/ai/src/task/DownloadModelTask.ts | 1 - packages/ai/src/task/FaceDetectorTask.ts | 3 +- packages/ai/src/task/FaceLandmarkerTask.ts | 3 +- packages/ai/src/task/GestureRecognizerTask.ts | 3 +- packages/ai/src/task/HandLandmarkerTask.ts | 3 +- .../ai/src/task/HierarchicalChunkerTask.ts | 1 - packages/ai/src/task/HierarchyJoinTask.ts | 1 - .../ai/src/task/ImageClassificationTask.ts | 3 +- packages/ai/src/task/ImageEmbeddingTask.ts | 3 +- packages/ai/src/task/ImageSegmentationTask.ts | 3 +- packages/ai/src/task/ImageToTextTask.ts | 3 +- packages/ai/src/task/ObjectDetectionTask.ts | 3 +- packages/ai/src/task/PoseLandmarkerTask.ts | 3 +- packages/ai/src/task/QueryExpanderTask.ts | 1 - packages/ai/src/task/RerankerTask.ts | 1 - packages/ai/src/task/StructuralParserTask.ts | 1 - packages/ai/src/task/TextChunkerTask.ts | 1 - .../ai/src/task/TextClassificationTask.ts | 3 +- packages/ai/src/task/TextEmbeddingTask.ts | 3 +- packages/ai/src/task/TextFillMaskTask.ts | 3 +- packages/ai/src/task/TextGenerationTask.ts | 3 +- .../ai/src/task/TextLanguageDetectionTask.ts | 3 +- .../task/TextNamedEntityRecognitionTask.ts | 3 +- .../ai/src/task/TextQuestionAnswerTask.ts | 3 +- packages/ai/src/task/TextRewriterTask.ts | 1 - packages/ai/src/task/TextSummaryTask.ts | 1 - packages/ai/src/task/TextTranslationTask.ts | 1 - packages/ai/src/task/TopicSegmenterTask.ts | 1 - packages/ai/src/task/UnloadModelTask.ts | 1 - packages/ai/src/task/VectorQuantizeTask.ts | 1 - packages/ai/src/task/VectorSimilarityTask.ts | 1 - packages/ai/src/task/index.ts | 86 +++++++++++++++++++ packages/task-graph/src/common.ts | 5 +- packages/tasks/src/browser.ts | 8 ++ packages/tasks/src/bun.ts | 8 ++ packages/tasks/src/common.ts | 30 +++++++ packages/tasks/src/node.ts | 8 ++ packages/tasks/src/task/DebugLogTask.ts | 4 +- packages/tasks/src/task/DelayTask.ts | 3 - packages/tasks/src/task/FetchUrlTask.ts | 2 - .../tasks/src/task/FileLoaderTask.server.ts | 8 +- packages/tasks/src/task/FileLoaderTask.ts | 2 - .../src/task/InputTask.ts | 23 +++-- packages/tasks/src/task/JavaScriptTask.ts | 4 +- packages/tasks/src/task/JsonTask.ts | 3 - packages/tasks/src/task/LambdaTask.ts | 3 - packages/tasks/src/task/MergeTask.ts | 2 - .../src/task/OutputTask.ts | 22 ++--- packages/tasks/src/task/SplitTask.ts | 4 +- 57 files changed, 183 insertions(+), 117 deletions(-) rename packages/{task-graph => tasks}/src/task/InputTask.ts (75%) rename packages/{task-graph => tasks}/src/task/OutputTask.ts (71%) diff --git a/packages/ai/src/task/BackgroundRemovalTask.ts b/packages/ai/src/task/BackgroundRemovalTask.ts index b3e1a81f..a60515d1 100644 --- a/packages/ai/src/task/BackgroundRemovalTask.ts +++ b/packages/ai/src/task/BackgroundRemovalTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -61,7 +61,6 @@ export class BackgroundRemovalTask extends AiVisionTask< } } -TaskRegistry.registerTask(BackgroundRemovalTask); /** * Convenience function to run background removal tasks. diff --git a/packages/ai/src/task/ChunkToVectorTask.ts b/packages/ai/src/task/ChunkToVectorTask.ts index e5394f4d..6cda3d70 100644 --- a/packages/ai/src/task/ChunkToVectorTask.ts +++ b/packages/ai/src/task/ChunkToVectorTask.ts @@ -160,7 +160,6 @@ export class ChunkToVectorTask extends Task< } } -TaskRegistry.registerTask(ChunkToVectorTask); export const chunkToVector = (input: ChunkToVectorTaskInput, config?: JobQueueTaskConfig) => { return new ChunkToVectorTask({} as ChunkToVectorTaskInput, config).run(input); diff --git a/packages/ai/src/task/ContextBuilderTask.ts b/packages/ai/src/task/ContextBuilderTask.ts index 19dee6dc..2e6e6da1 100644 --- a/packages/ai/src/task/ContextBuilderTask.ts +++ b/packages/ai/src/task/ContextBuilderTask.ts @@ -320,7 +320,6 @@ export class ContextBuilderTask extends Task< } } -TaskRegistry.registerTask(ContextBuilderTask); export const contextBuilder = (input: ContextBuilderTaskInput, config?: JobQueueTaskConfig) => { return new ContextBuilderTask({} as ContextBuilderTaskInput, config).run(input); diff --git a/packages/ai/src/task/DocumentEnricherTask.ts b/packages/ai/src/task/DocumentEnricherTask.ts index 4fce5aec..b805f55d 100644 --- a/packages/ai/src/task/DocumentEnricherTask.ts +++ b/packages/ai/src/task/DocumentEnricherTask.ts @@ -398,7 +398,6 @@ export class DocumentEnricherTask extends Task< } } -TaskRegistry.registerTask(DocumentEnricherTask); export const documentEnricher = (input: DocumentEnricherTaskInput, config?: JobQueueTaskConfig) => { return new DocumentEnricherTask({} as DocumentEnricherTaskInput, config).run(input); diff --git a/packages/ai/src/task/DocumentNodeRetrievalTask.ts b/packages/ai/src/task/DocumentNodeRetrievalTask.ts index b54565a5..8b50bb2a 100644 --- a/packages/ai/src/task/DocumentNodeRetrievalTask.ts +++ b/packages/ai/src/task/DocumentNodeRetrievalTask.ts @@ -231,7 +231,6 @@ export class DocumentNodeRetrievalTask extends Task< } } -TaskRegistry.registerTask(DocumentNodeRetrievalTask); export const retrieval = (input: RetrievalTaskInput, config?: JobQueueTaskConfig) => { return new DocumentNodeRetrievalTask({} as RetrievalTaskInput, config).run(input); diff --git a/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts b/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts index d4f4c608..96e65da7 100644 --- a/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts +++ b/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts @@ -217,7 +217,6 @@ export class DocumentNodeVectorHybridSearchTask extends Task< } } -TaskRegistry.registerTask(DocumentNodeVectorHybridSearchTask); export const hybridSearch = async ( input: HybridSearchTaskInput, diff --git a/packages/ai/src/task/DocumentNodeVectorSearchTask.ts b/packages/ai/src/task/DocumentNodeVectorSearchTask.ts index 63d736f8..04455e9f 100644 --- a/packages/ai/src/task/DocumentNodeVectorSearchTask.ts +++ b/packages/ai/src/task/DocumentNodeVectorSearchTask.ts @@ -153,7 +153,6 @@ export class DocumentNodeVectorSearchTask extends Task< } } -TaskRegistry.registerTask(DocumentNodeVectorSearchTask); export const vectorStoreSearch = ( input: VectorStoreSearchTaskInput, diff --git a/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts b/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts index e5800c30..832f7761 100644 --- a/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts +++ b/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts @@ -153,7 +153,6 @@ export class DocumentNodeVectorUpsertTask extends Task< } } -TaskRegistry.registerTask(DocumentNodeVectorUpsertTask); export const vectorStoreUpsert = ( input: VectorStoreUpsertTaskInput, diff --git a/packages/ai/src/task/DownloadModelTask.ts b/packages/ai/src/task/DownloadModelTask.ts index 1fff2a26..aaf8c82e 100644 --- a/packages/ai/src/task/DownloadModelTask.ts +++ b/packages/ai/src/task/DownloadModelTask.ts @@ -97,7 +97,6 @@ export class DownloadModelTask extends AiTask< } } -TaskRegistry.registerTask(DownloadModelTask); /** * Download a model from a remote source and cache it locally. diff --git a/packages/ai/src/task/FaceDetectorTask.ts b/packages/ai/src/task/FaceDetectorTask.ts index 465d8717..592bb1d5 100644 --- a/packages/ai/src/task/FaceDetectorTask.ts +++ b/packages/ai/src/task/FaceDetectorTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -160,7 +160,6 @@ export class FaceDetectorTask extends AiVisionTask< } } -TaskRegistry.registerTask(FaceDetectorTask); /** * Convenience function to run face detection tasks. diff --git a/packages/ai/src/task/FaceLandmarkerTask.ts b/packages/ai/src/task/FaceLandmarkerTask.ts index 961bc436..e001ef6a 100644 --- a/packages/ai/src/task/FaceLandmarkerTask.ts +++ b/packages/ai/src/task/FaceLandmarkerTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -196,7 +196,6 @@ export class FaceLandmarkerTask extends AiVisionTask< } } -TaskRegistry.registerTask(FaceLandmarkerTask); /** * Convenience function to run face landmark detection tasks. diff --git a/packages/ai/src/task/GestureRecognizerTask.ts b/packages/ai/src/task/GestureRecognizerTask.ts index 706b6c44..70ef9e9c 100644 --- a/packages/ai/src/task/GestureRecognizerTask.ts +++ b/packages/ai/src/task/GestureRecognizerTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -202,7 +202,6 @@ export class GestureRecognizerTask extends AiVisionTask< } } -TaskRegistry.registerTask(GestureRecognizerTask); /** * Convenience function to run gesture recognition tasks. diff --git a/packages/ai/src/task/HandLandmarkerTask.ts b/packages/ai/src/task/HandLandmarkerTask.ts index 739e92a1..c0b18103 100644 --- a/packages/ai/src/task/HandLandmarkerTask.ts +++ b/packages/ai/src/task/HandLandmarkerTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -174,7 +174,6 @@ export class HandLandmarkerTask extends AiVisionTask< } } -TaskRegistry.registerTask(HandLandmarkerTask); /** * Convenience function to run hand landmark detection tasks. diff --git a/packages/ai/src/task/HierarchicalChunkerTask.ts b/packages/ai/src/task/HierarchicalChunkerTask.ts index e35a4f67..789e4af5 100644 --- a/packages/ai/src/task/HierarchicalChunkerTask.ts +++ b/packages/ai/src/task/HierarchicalChunkerTask.ts @@ -283,7 +283,6 @@ export class HierarchicalChunkerTask extends Task< } } -TaskRegistry.registerTask(HierarchicalChunkerTask); export const hierarchicalChunker = ( input: HierarchicalChunkerTaskInput, diff --git a/packages/ai/src/task/HierarchyJoinTask.ts b/packages/ai/src/task/HierarchyJoinTask.ts index 3f55eca7..b14c511a 100644 --- a/packages/ai/src/task/HierarchyJoinTask.ts +++ b/packages/ai/src/task/HierarchyJoinTask.ts @@ -228,7 +228,6 @@ export class HierarchyJoinTask extends Task< } } -TaskRegistry.registerTask(HierarchyJoinTask); export const hierarchyJoin = (input: HierarchyJoinTaskInput, config?: JobQueueTaskConfig) => { return new HierarchyJoinTask({} as HierarchyJoinTaskInput, config).run(input); diff --git a/packages/ai/src/task/ImageClassificationTask.ts b/packages/ai/src/task/ImageClassificationTask.ts index 858c40f3..452772d7 100644 --- a/packages/ai/src/task/ImageClassificationTask.ts +++ b/packages/ai/src/task/ImageClassificationTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeCategory, TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -81,7 +81,6 @@ export class ImageClassificationTask extends AiVisionTask< } } -TaskRegistry.registerTask(ImageClassificationTask); /** * Convenience function to run image classification tasks. diff --git a/packages/ai/src/task/ImageEmbeddingTask.ts b/packages/ai/src/task/ImageEmbeddingTask.ts index 94e0219c..689197c3 100644 --- a/packages/ai/src/task/ImageEmbeddingTask.ts +++ b/packages/ai/src/task/ImageEmbeddingTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema, @@ -69,7 +69,6 @@ export class ImageEmbeddingTask extends AiVisionTask< } } -TaskRegistry.registerTask(ImageEmbeddingTask); /** * Convenience function to run image embedding tasks. diff --git a/packages/ai/src/task/ImageSegmentationTask.ts b/packages/ai/src/task/ImageSegmentationTask.ts index 4c4e96ed..1a5541ab 100644 --- a/packages/ai/src/task/ImageSegmentationTask.ts +++ b/packages/ai/src/task/ImageSegmentationTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -105,7 +105,6 @@ export class ImageSegmentationTask extends AiVisionTask< } } -TaskRegistry.registerTask(ImageSegmentationTask); /** * Convenience function to run image segmentation tasks. diff --git a/packages/ai/src/task/ImageToTextTask.ts b/packages/ai/src/task/ImageToTextTask.ts index 2cf6b919..082bf2fb 100644 --- a/packages/ai/src/task/ImageToTextTask.ts +++ b/packages/ai/src/task/ImageToTextTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -72,7 +72,6 @@ export class ImageToTextTask extends AiVisionTask< } } -TaskRegistry.registerTask(ImageToTextTask); /** * Convenience function to run image to text tasks. diff --git a/packages/ai/src/task/ObjectDetectionTask.ts b/packages/ai/src/task/ObjectDetectionTask.ts index 6132d1cd..38d11680 100644 --- a/packages/ai/src/task/ObjectDetectionTask.ts +++ b/packages/ai/src/task/ObjectDetectionTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeBoundingBox, TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -102,7 +102,6 @@ export class ObjectDetectionTask extends AiVisionTask< } } -TaskRegistry.registerTask(ObjectDetectionTask); /** * Convenience function to run object detection tasks. diff --git a/packages/ai/src/task/PoseLandmarkerTask.ts b/packages/ai/src/task/PoseLandmarkerTask.ts index 8f0a45f3..3d4b4404 100644 --- a/packages/ai/src/task/PoseLandmarkerTask.ts +++ b/packages/ai/src/task/PoseLandmarkerTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; @@ -191,7 +191,6 @@ export class PoseLandmarkerTask extends AiVisionTask< } } -TaskRegistry.registerTask(PoseLandmarkerTask); /** * Convenience function to run pose landmark detection tasks. diff --git a/packages/ai/src/task/QueryExpanderTask.ts b/packages/ai/src/task/QueryExpanderTask.ts index b3804b19..c72bc3fd 100644 --- a/packages/ai/src/task/QueryExpanderTask.ts +++ b/packages/ai/src/task/QueryExpanderTask.ts @@ -299,7 +299,6 @@ export class QueryExpanderTask extends Task< } } -TaskRegistry.registerTask(QueryExpanderTask); export const queryExpander = (input: QueryExpanderTaskInput, config?: JobQueueTaskConfig) => { return new QueryExpanderTask({} as QueryExpanderTaskInput, config).run(input); diff --git a/packages/ai/src/task/RerankerTask.ts b/packages/ai/src/task/RerankerTask.ts index bcd5d8c0..5baead4d 100644 --- a/packages/ai/src/task/RerankerTask.ts +++ b/packages/ai/src/task/RerankerTask.ts @@ -326,7 +326,6 @@ export class RerankerTask extends Task { return new RerankerTask({} as RerankerTaskInput, config).run(input); diff --git a/packages/ai/src/task/StructuralParserTask.ts b/packages/ai/src/task/StructuralParserTask.ts index bcfc9617..bdda6839 100644 --- a/packages/ai/src/task/StructuralParserTask.ts +++ b/packages/ai/src/task/StructuralParserTask.ts @@ -140,7 +140,6 @@ export class StructuralParserTask extends Task< } } -TaskRegistry.registerTask(StructuralParserTask); export const structuralParser = (input: StructuralParserTaskInput, config?: JobQueueTaskConfig) => { return new StructuralParserTask({} as StructuralParserTaskInput, config).run(input); diff --git a/packages/ai/src/task/TextChunkerTask.ts b/packages/ai/src/task/TextChunkerTask.ts index 99cce8f9..c3c2a274 100644 --- a/packages/ai/src/task/TextChunkerTask.ts +++ b/packages/ai/src/task/TextChunkerTask.ts @@ -343,7 +343,6 @@ export class TextChunkerTask extends Task< } } -TaskRegistry.registerTask(TextChunkerTask); export const textChunker = (input: TextChunkerTaskInput, config?: JobQueueTaskConfig) => { return new TextChunkerTask({} as TextChunkerTaskInput, config).run(input); diff --git a/packages/ai/src/task/TextClassificationTask.ts b/packages/ai/src/task/TextClassificationTask.ts index 0b7c6215..1f1cb598 100644 --- a/packages/ai/src/task/TextClassificationTask.ts +++ b/packages/ai/src/task/TextClassificationTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; import { TypeModel } from "./base/AiTaskSchemas"; @@ -97,7 +97,6 @@ export class TextClassificationTask extends AiTask< } } -TaskRegistry.registerTask(TextClassificationTask); /** * Convenience function to run text classifier tasks. diff --git a/packages/ai/src/task/TextEmbeddingTask.ts b/packages/ai/src/task/TextEmbeddingTask.ts index e85f3ded..c296a4f7 100644 --- a/packages/ai/src/task/TextEmbeddingTask.ts +++ b/packages/ai/src/task/TextEmbeddingTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema, @@ -73,7 +73,6 @@ export class TextEmbeddingTask extends AiTask { return new TopicSegmenterTask({} as TopicSegmenterTaskInput, config).run(input); diff --git a/packages/ai/src/task/UnloadModelTask.ts b/packages/ai/src/task/UnloadModelTask.ts index 21b8dbc6..1e5c3c7c 100644 --- a/packages/ai/src/task/UnloadModelTask.ts +++ b/packages/ai/src/task/UnloadModelTask.ts @@ -63,7 +63,6 @@ export class UnloadModelTask extends AiTask< public static cacheable = false; } -TaskRegistry.registerTask(UnloadModelTask); /** * Unload a model from memory and clear its cache. diff --git a/packages/ai/src/task/VectorQuantizeTask.ts b/packages/ai/src/task/VectorQuantizeTask.ts index 9ed102dc..0fe2bc5d 100644 --- a/packages/ai/src/task/VectorQuantizeTask.ts +++ b/packages/ai/src/task/VectorQuantizeTask.ts @@ -238,7 +238,6 @@ export class VectorQuantizeTask extends Task< } } -TaskRegistry.registerTask(VectorQuantizeTask); export const vectorQuantize = (input: VectorQuantizeTaskInput, config?: JobQueueTaskConfig) => { return new VectorQuantizeTask({} as VectorQuantizeTaskInput, config).run(input); diff --git a/packages/ai/src/task/VectorSimilarityTask.ts b/packages/ai/src/task/VectorSimilarityTask.ts index d5714716..f70b690e 100644 --- a/packages/ai/src/task/VectorSimilarityTask.ts +++ b/packages/ai/src/task/VectorSimilarityTask.ts @@ -141,7 +141,6 @@ export class VectorSimilarityTask extends GraphAsTask< } } -TaskRegistry.registerTask(VectorSimilarityTask); export const similarity = (input: VectorSimilarityTaskInput, config?: JobQueueTaskConfig) => { return new VectorSimilarityTask({} as VectorSimilarityTaskInput, config).run(input); diff --git a/packages/ai/src/task/index.ts b/packages/ai/src/task/index.ts index a4cb0f38..419b9318 100644 --- a/packages/ai/src/task/index.ts +++ b/packages/ai/src/task/index.ts @@ -4,6 +4,92 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { TaskRegistry } from "@workglow/task-graph"; +import { BackgroundRemovalTask } from "./BackgroundRemovalTask"; +import { ChunkToVectorTask } from "./ChunkToVectorTask"; +import { ContextBuilderTask } from "./ContextBuilderTask"; +import { DocumentEnricherTask } from "./DocumentEnricherTask"; +import { DocumentNodeRetrievalTask } from "./DocumentNodeRetrievalTask"; +import { DocumentNodeVectorHybridSearchTask } from "./DocumentNodeVectorHybridSearchTask"; +import { DocumentNodeVectorSearchTask } from "./DocumentNodeVectorSearchTask"; +import { DocumentNodeVectorUpsertTask } from "./DocumentNodeVectorUpsertTask"; +import { DownloadModelTask } from "./DownloadModelTask"; +import { FaceDetectorTask } from "./FaceDetectorTask"; +import { FaceLandmarkerTask } from "./FaceLandmarkerTask"; +import { GestureRecognizerTask } from "./GestureRecognizerTask"; +import { HandLandmarkerTask } from "./HandLandmarkerTask"; +import { HierarchicalChunkerTask } from "./HierarchicalChunkerTask"; +import { HierarchyJoinTask } from "./HierarchyJoinTask"; +import { ImageClassificationTask } from "./ImageClassificationTask"; +import { ImageEmbeddingTask } from "./ImageEmbeddingTask"; +import { ImageSegmentationTask } from "./ImageSegmentationTask"; +import { ImageToTextTask } from "./ImageToTextTask"; +import { ObjectDetectionTask } from "./ObjectDetectionTask"; +import { PoseLandmarkerTask } from "./PoseLandmarkerTask"; +import { QueryExpanderTask } from "./QueryExpanderTask"; +import { RerankerTask } from "./RerankerTask"; +import { StructuralParserTask } from "./StructuralParserTask"; +import { TextChunkerTask } from "./TextChunkerTask"; +import { TextClassificationTask } from "./TextClassificationTask"; +import { TextEmbeddingTask } from "./TextEmbeddingTask"; +import { TextFillMaskTask } from "./TextFillMaskTask"; +import { TextGenerationTask } from "./TextGenerationTask"; +import { TextLanguageDetectionTask } from "./TextLanguageDetectionTask"; +import { TextNamedEntityRecognitionTask } from "./TextNamedEntityRecognitionTask"; +import { TextQuestionAnswerTask } from "./TextQuestionAnswerTask"; +import { TextRewriterTask } from "./TextRewriterTask"; +import { TextSummaryTask } from "./TextSummaryTask"; +import { TextTranslationTask } from "./TextTranslationTask"; +import { TopicSegmenterTask } from "./TopicSegmenterTask"; +import { UnloadModelTask } from "./UnloadModelTask"; +import { VectorQuantizeTask } from "./VectorQuantizeTask"; +import { VectorSimilarityTask } from "./VectorSimilarityTask"; + +// Register all AI tasks with the TaskRegistry. +// Centralized registration ensures tasks are available for JSON deserialization +// and prevents tree-shaking issues. +[ + BackgroundRemovalTask, + ChunkToVectorTask, + ContextBuilderTask, + DocumentEnricherTask, + DocumentNodeRetrievalTask, + DocumentNodeVectorHybridSearchTask, + DocumentNodeVectorSearchTask, + DocumentNodeVectorUpsertTask, + DownloadModelTask, + FaceDetectorTask, + FaceLandmarkerTask, + GestureRecognizerTask, + HandLandmarkerTask, + HierarchicalChunkerTask, + HierarchyJoinTask, + ImageClassificationTask, + ImageEmbeddingTask, + ImageSegmentationTask, + ImageToTextTask, + ObjectDetectionTask, + PoseLandmarkerTask, + QueryExpanderTask, + RerankerTask, + StructuralParserTask, + TextChunkerTask, + TextClassificationTask, + TextEmbeddingTask, + TextFillMaskTask, + TextGenerationTask, + TextLanguageDetectionTask, + TextNamedEntityRecognitionTask, + TextQuestionAnswerTask, + TextRewriterTask, + TextSummaryTask, + TextTranslationTask, + TopicSegmenterTask, + UnloadModelTask, + VectorQuantizeTask, + VectorSimilarityTask, +].map(TaskRegistry.registerTask); + export * from "./BackgroundRemovalTask"; export * from "./base/AiTask"; export * from "./base/AiTaskSchemas"; diff --git a/packages/task-graph/src/common.ts b/packages/task-graph/src/common.ts index 776cc7a5..9ac5e134 100644 --- a/packages/task-graph/src/common.ts +++ b/packages/task-graph/src/common.ts @@ -4,17 +4,16 @@ * SPDX-License-Identifier: Apache-2.0 */ +export * from "./task/Task"; + export * from "./task/ArrayTask"; export * from "./task/ConditionalTask"; export * from "./task/GraphAsTask"; export * from "./task/GraphAsTaskRunner"; export * from "./task/InputResolver"; -export * from "./task/InputTask"; export * from "./task/ITask"; export * from "./task/JobQueueFactory"; export * from "./task/JobQueueTask"; -export * from "./task/OutputTask"; -export * from "./task/Task"; export * from "./task/TaskError"; export * from "./task/TaskEvents"; export * from "./task/TaskJSON"; diff --git a/packages/tasks/src/browser.ts b/packages/tasks/src/browser.ts index 2e087fbc..38816e73 100644 --- a/packages/tasks/src/browser.ts +++ b/packages/tasks/src/browser.ts @@ -4,5 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { TaskRegistry } from "@workglow/task-graph"; +import { FileLoaderTask } from "./task/FileLoaderTask"; + +// Register browser-specific tasks with the TaskRegistry. +// Centralized registration ensures tasks are available for JSON deserialization +// and prevents tree-shaking issues. +[FileLoaderTask].map(TaskRegistry.registerTask); + export * from "./common"; export * from "./task/FileLoaderTask"; diff --git a/packages/tasks/src/bun.ts b/packages/tasks/src/bun.ts index f0293640..a820eb56 100644 --- a/packages/tasks/src/bun.ts +++ b/packages/tasks/src/bun.ts @@ -4,5 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { TaskRegistry } from "@workglow/task-graph"; +import { FileLoaderTask } from "./task/FileLoaderTask.server"; + +// Register bun-specific tasks with the TaskRegistry. +// Centralized registration ensures tasks are available for JSON deserialization +// and prevents tree-shaking issues. +[FileLoaderTask].map(TaskRegistry.registerTask); + export * from "./common"; export * from "./task/FileLoaderTask.server"; diff --git a/packages/tasks/src/common.ts b/packages/tasks/src/common.ts index 64e1bb6e..84e96928 100644 --- a/packages/tasks/src/common.ts +++ b/packages/tasks/src/common.ts @@ -4,11 +4,41 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { TaskRegistry } from "@workglow/task-graph"; +import { DebugLogTask } from "./task/DebugLogTask"; +import { DelayTask } from "./task/DelayTask"; +import { FetchUrlTask } from "./task/FetchUrlTask"; +import { InputTask } from "./task/InputTask"; +import { JavaScriptTask } from "./task/JavaScriptTask"; +import { JsonTask } from "./task/JsonTask"; +import { LambdaTask } from "./task/LambdaTask"; +import { MergeTask } from "./task/MergeTask"; +import { OutputTask } from "./task/OutputTask"; +import { SplitTask } from "./task/SplitTask"; + +// Register all common tasks with the TaskRegistry. +// Centralized registration ensures tasks are available for JSON deserialization +// and prevents tree-shaking issues. +[ + DebugLogTask, + DelayTask, + FetchUrlTask, + InputTask, + JavaScriptTask, + JsonTask, + LambdaTask, + MergeTask, + OutputTask, + SplitTask, +].map(TaskRegistry.registerTask); + export * from "./task/DebugLogTask"; export * from "./task/DelayTask"; export * from "./task/FetchUrlTask"; +export * from "./task/InputTask"; export * from "./task/JavaScriptTask"; export * from "./task/JsonTask"; export * from "./task/LambdaTask"; export * from "./task/MergeTask"; +export * from "./task/OutputTask"; export * from "./task/SplitTask"; diff --git a/packages/tasks/src/node.ts b/packages/tasks/src/node.ts index f0293640..d506e158 100644 --- a/packages/tasks/src/node.ts +++ b/packages/tasks/src/node.ts @@ -4,5 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { TaskRegistry } from "@workglow/task-graph"; +import { FileLoaderTask } from "./task/FileLoaderTask.server"; + +// Register server-specific tasks with the TaskRegistry. +// Centralized registration ensures tasks are available for JSON deserialization +// and prevents tree-shaking issues. +[FileLoaderTask].map(TaskRegistry.registerTask); + export * from "./common"; export * from "./task/FileLoaderTask.server"; diff --git a/packages/tasks/src/task/DebugLogTask.ts b/packages/tasks/src/task/DebugLogTask.ts index e2785ccf..4e5cff46 100644 --- a/packages/tasks/src/task/DebugLogTask.ts +++ b/packages/tasks/src/task/DebugLogTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, Task, TaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, Task, TaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; const log_levels = ["dir", "log", "debug", "info", "warn", "error"] as const; @@ -86,8 +86,6 @@ export class DebugLogTask< } } -TaskRegistry.registerTask(DebugLogTask); - export const debugLog = (input: DebugLogTaskInput, config: TaskConfig = {}) => { const task = new DebugLogTask({}, config); return task.run(input); diff --git a/packages/tasks/src/task/DelayTask.ts b/packages/tasks/src/task/DelayTask.ts index 6e75b0e0..8b94b6dc 100644 --- a/packages/tasks/src/task/DelayTask.ts +++ b/packages/tasks/src/task/DelayTask.ts @@ -77,9 +77,6 @@ export class DelayTask< } } -// Register DelayTask with the task registry -TaskRegistry.registerTask(DelayTask); - /** * DelayTask * diff --git a/packages/tasks/src/task/FetchUrlTask.ts b/packages/tasks/src/task/FetchUrlTask.ts index c2879110..82115e8f 100644 --- a/packages/tasks/src/task/FetchUrlTask.ts +++ b/packages/tasks/src/task/FetchUrlTask.ts @@ -415,8 +415,6 @@ export class FetchUrlTask< } } -TaskRegistry.registerTask(FetchUrlTask); - export const fetchUrl = async ( input: FetchUrlTaskInput, config: FetchUrlTaskConfig = {} diff --git a/packages/tasks/src/task/FileLoaderTask.server.ts b/packages/tasks/src/task/FileLoaderTask.server.ts index 0cad8e5f..1f6d54d2 100644 --- a/packages/tasks/src/task/FileLoaderTask.server.ts +++ b/packages/tasks/src/task/FileLoaderTask.server.ts @@ -212,17 +212,15 @@ export class FileLoaderTask extends BaseFileLoaderTask { } } -// override the base registration -TaskRegistry.registerTask(FileLoaderTask); - export const fileLoader = (input: FileLoaderTaskInput, config?: JobQueueTaskConfig) => { return new FileLoaderTask({}, config).run(input); }; declare module "@workglow/task-graph" { interface Workflow { - fileLoaderServer: CreateWorkflow; + fileLoader: CreateWorkflow; } } -Workflow.prototype.fileLoaderServer = CreateWorkflow(FileLoaderTask); +// Override fileLoader to use the server version that handles file:// URLs +Workflow.prototype.fileLoader = CreateWorkflow(FileLoaderTask); diff --git a/packages/tasks/src/task/FileLoaderTask.ts b/packages/tasks/src/task/FileLoaderTask.ts index 92559a4a..aea490a5 100644 --- a/packages/tasks/src/task/FileLoaderTask.ts +++ b/packages/tasks/src/task/FileLoaderTask.ts @@ -405,8 +405,6 @@ export class FileLoaderTask extends Task< } } -TaskRegistry.registerTask(FileLoaderTask); - export const fileLoader = (input: FileLoaderTaskInput, config?: JobQueueTaskConfig) => { return new FileLoaderTask({}, config).run(input); }; diff --git a/packages/task-graph/src/task/InputTask.ts b/packages/tasks/src/task/InputTask.ts similarity index 75% rename from packages/task-graph/src/task/InputTask.ts rename to packages/tasks/src/task/InputTask.ts index 45c830e9..2fedd42b 100644 --- a/packages/task-graph/src/task/InputTask.ts +++ b/packages/tasks/src/task/InputTask.ts @@ -4,16 +4,13 @@ * All Rights Reserved */ +import { CreateWorkflow, Task, TaskConfig, Workflow } from "@workglow/task-graph"; import type { DataPortSchema } from "@workglow/util"; -import { IExecuteContext, IExecuteReactiveContext } from "./ITask"; -import { Task } from "./Task"; -import { TaskRegistry } from "./TaskRegistry"; -import { TaskConfig } from "./TaskTypes"; export type InputTaskInput = Record; export type InputTaskOutput = Record; export type InputTaskConfig = TaskConfig & { - schema: DataPortSchema; + readonly schema: DataPortSchema; }; export class InputTask extends Task { @@ -54,17 +51,19 @@ export class InputTask extends Task; + } +} + +Workflow.prototype.input = CreateWorkflow(InputTask); diff --git a/packages/tasks/src/task/JavaScriptTask.ts b/packages/tasks/src/task/JavaScriptTask.ts index fdf278f4..6efb28af 100644 --- a/packages/tasks/src/task/JavaScriptTask.ts +++ b/packages/tasks/src/task/JavaScriptTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CreateWorkflow, Task, TaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; +import { CreateWorkflow, Task, TaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { Interpreter } from "../util/interpreter"; @@ -71,8 +71,6 @@ export class JavaScriptTask extends Task { return new JavaScriptTask({}, config).run(input); }; diff --git a/packages/tasks/src/task/JsonTask.ts b/packages/tasks/src/task/JsonTask.ts index 8994326b..2a7f2aec 100644 --- a/packages/tasks/src/task/JsonTask.ts +++ b/packages/tasks/src/task/JsonTask.ts @@ -96,9 +96,6 @@ export class JsonTask< } } -// Register JsonTask with the task registry -TaskRegistry.registerTask(JsonTask); - /** * Convenience function to create and run a JsonTask */ diff --git a/packages/tasks/src/task/LambdaTask.ts b/packages/tasks/src/task/LambdaTask.ts index 64c65732..9990941a 100644 --- a/packages/tasks/src/task/LambdaTask.ts +++ b/packages/tasks/src/task/LambdaTask.ts @@ -103,9 +103,6 @@ export class LambdaTask< } } -// Register LambdaTask with the task registry -TaskRegistry.registerTask(LambdaTask); - export function process(value: string): string; export function process(value: number): number; export function process(value: boolean): string; diff --git a/packages/tasks/src/task/MergeTask.ts b/packages/tasks/src/task/MergeTask.ts index c6029459..3726633d 100644 --- a/packages/tasks/src/task/MergeTask.ts +++ b/packages/tasks/src/task/MergeTask.ts @@ -82,8 +82,6 @@ export class MergeTask< } } -TaskRegistry.registerTask(MergeTask); - export const merge = (input: MergeTaskInput, config: TaskConfig = {}) => { const task = new MergeTask({}, config); return task.run(input); diff --git a/packages/task-graph/src/task/OutputTask.ts b/packages/tasks/src/task/OutputTask.ts similarity index 71% rename from packages/task-graph/src/task/OutputTask.ts rename to packages/tasks/src/task/OutputTask.ts index b4600127..5e0bcbb7 100644 --- a/packages/task-graph/src/task/OutputTask.ts +++ b/packages/tasks/src/task/OutputTask.ts @@ -4,18 +4,14 @@ * All Rights Reserved */ -import { DataPortSchema } from "@workglow/util"; -import { CreateWorkflow, Workflow } from "../task-graph/Workflow"; -import { IExecuteContext, IExecuteReactiveContext } from "./ITask"; -import { Task } from "./Task"; -import { TaskRegistry } from "./TaskRegistry"; -import { TaskConfig } from "./TaskTypes"; +import { CreateWorkflow, Task, TaskConfig, Workflow } from "@workglow/task-graph"; +import type { DataPortSchema } from "@workglow/util"; export type OutputTaskInput = Record; export type OutputTaskOutput = Record; export type OutputTaskConfig = TaskConfig & { - schema: DataPortSchema; + readonly schema: DataPortSchema; }; export class OutputTask extends Task { @@ -56,23 +52,17 @@ export class OutputTask extends Task { const task = new SplitTask({}, config); return task.run(input); From caed5b0252ffd0161d776abc2cbcda3a751eefa2 Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 11 Jan 2026 18:57:48 +0000 Subject: [PATCH 07/14] [refactor] Update task imports - Refactored imports in TaskUI and TaskNode components to streamline dependencies. - Expanded TODO list with new items related to chunk and node handling, model improvements, and documentation updates. - Removed unnecessary whitespace in several task files for cleaner code. - Centralized task registration in the index file to improve maintainability and reduce redundancy. --- TODO.md | 9 +- examples/cli/src/components/TaskUI.tsx | 3 +- examples/web/src/graph/TaskNode.tsx | 3 +- packages/ai/src/task/BackgroundRemovalTask.ts | 1 - packages/ai/src/task/ChunkToVectorTask.ts | 2 - packages/ai/src/task/ContextBuilderTask.ts | 9 +- packages/ai/src/task/DocumentEnricherTask.ts | 2 - .../ai/src/task/DocumentNodeRetrievalTask.ts | 2 - .../DocumentNodeVectorHybridSearchTask.ts | 2 - .../src/task/DocumentNodeVectorSearchTask.ts | 2 - .../src/task/DocumentNodeVectorUpsertTask.ts | 2 - packages/ai/src/task/DownloadModelTask.ts | 10 +-- packages/ai/src/task/FaceDetectorTask.ts | 1 - packages/ai/src/task/FaceLandmarkerTask.ts | 1 - packages/ai/src/task/GestureRecognizerTask.ts | 1 - packages/ai/src/task/HandLandmarkerTask.ts | 1 - .../ai/src/task/HierarchicalChunkerTask.ts | 3 - packages/ai/src/task/HierarchyJoinTask.ts | 2 - .../ai/src/task/ImageClassificationTask.ts | 1 - packages/ai/src/task/ImageEmbeddingTask.ts | 1 - packages/ai/src/task/ImageSegmentationTask.ts | 1 - packages/ai/src/task/ImageToTextTask.ts | 1 - packages/ai/src/task/ObjectDetectionTask.ts | 1 - packages/ai/src/task/PoseLandmarkerTask.ts | 1 - packages/ai/src/task/QueryExpanderTask.ts | 2 - packages/ai/src/task/RerankerTask.ts | 2 - packages/ai/src/task/StructuralParserTask.ts | 2 - packages/ai/src/task/TextChunkerTask.ts | 2 - .../ai/src/task/TextClassificationTask.ts | 1 - packages/ai/src/task/TextEmbeddingTask.ts | 1 - packages/ai/src/task/TextFillMaskTask.ts | 1 - .../ai/src/task/TextLanguageDetectionTask.ts | 1 - .../task/TextNamedEntityRecognitionTask.ts | 1 - .../ai/src/task/TextQuestionAnswerTask.ts | 1 - packages/ai/src/task/TextRewriterTask.ts | 10 +-- packages/ai/src/task/TextSummaryTask.ts | 10 +-- packages/ai/src/task/TextTranslationTask.ts | 10 +-- packages/ai/src/task/TopicSegmenterTask.ts | 2 - packages/ai/src/task/UnloadModelTask.ts | 10 +-- packages/ai/src/task/VectorQuantizeTask.ts | 15 +--- packages/ai/src/task/VectorSimilarityTask.ts | 9 +- packages/ai/src/task/index.ts | 84 ++++++++++--------- packages/task-graph/README.md | 52 ------------ packages/task-graph/src/common.ts | 19 +---- packages/task-graph/src/task/README.md | 9 -- packages/task-graph/src/task/index.ts | 28 +++++++ packages/tasks/README.md | 73 ++++++++++++++++ packages/tasks/src/common.ts | 51 ++++++----- .../src/task/ArrayTask.ts | 18 ++-- packages/tasks/src/task/DelayTask.ts | 1 - packages/tasks/src/task/FetchUrlTask.ts | 1 - .../tasks/src/task/FileLoaderTask.server.ts | 1 - packages/tasks/src/task/FileLoaderTask.ts | 1 - packages/tasks/src/task/JsonTask.ts | 1 - packages/tasks/src/task/LambdaTask.ts | 1 - packages/tasks/src/task/MergeTask.ts | 9 +- packages/test/src/binding/RegisterTasks.ts | 15 ++++ packages/test/src/samples/ONNXModelSamples.ts | 15 +++- .../test/src/test/rag/ChunkToVector.test.ts | 1 + .../test/src/test/rag/RagWorkflow.test.ts | 26 ++---- packages/test/src/test/task/ArrayTask.test.ts | 2 +- 61 files changed, 250 insertions(+), 300 deletions(-) create mode 100644 packages/task-graph/src/task/index.ts rename packages/{task-graph => tasks}/src/task/ArrayTask.ts (95%) create mode 100644 packages/test/src/binding/RegisterTasks.ts diff --git a/TODO.md b/TODO.md index 6d2fb670..406d7c1c 100644 --- a/TODO.md +++ b/TODO.md @@ -1,8 +1,15 @@ TODO.md +- [ ] Chunks and nodes are not the same. + - [ ] We need to rename the files related to embedding. + - [ ] And we may need to save the chunk's node path. Or paths? +- [ ] Get a better model for question answering. +- [ ] Get a better model for named entity recognition, the current one recognized everything as a token, not helpful. +- [ ] Titles are not making it into the chunks. +- [ ] Tests for CLI commands. + - [ ] Add ability for queues to specify if inputs should be converted to text, binary blob, a transferable object, structured clone, or just passed as is. - [ ] Add specialized versions of the task queues for hugging face transformers and tensorflow mediapipe. - [ ] Audio conversion like the image conversion -- [ ] Make all the workflow name start with a lower case letter and change the users of them, particularly in the docs and examples - [ ] rename the registration stuff to not look ugly: registerHuggingfaceTransformers() and registerHuggingfaceTransformersUsingWorkers() and registerHuggingfaceTransformersInsideWorker() - [ ] fix image transferables diff --git a/examples/cli/src/components/TaskUI.tsx b/examples/cli/src/components/TaskUI.tsx index 84c13fd6..3045c8aa 100644 --- a/examples/cli/src/components/TaskUI.tsx +++ b/examples/cli/src/components/TaskUI.tsx @@ -5,7 +5,8 @@ */ import { DownloadModelTask } from "@workglow/ai"; -import { ArrayTask, ITask, ITaskGraph, TaskStatus } from "@workglow/task-graph"; +import { ITask, ITaskGraph, TaskStatus } from "@workglow/task-graph"; +import { ArrayTask } from "@workglow/tasks"; import type { FC } from "react"; import { memo, useEffect, useState } from "react"; import { Box, Text } from "retuink"; diff --git a/examples/web/src/graph/TaskNode.tsx b/examples/web/src/graph/TaskNode.tsx index db26fd93..49c2277a 100644 --- a/examples/web/src/graph/TaskNode.tsx +++ b/examples/web/src/graph/TaskNode.tsx @@ -4,7 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ArrayTask, ITask, TaskStatus } from "@workglow/task-graph"; +import { ITask, TaskStatus } from "@workglow/task-graph"; +import { ArrayTask } from "@workglow/tasks"; import { Node, NodeProps } from "@xyflow/react"; import { useEffect, useState } from "react"; import { FiCloud, FiCloudLightning } from "react-icons/fi"; diff --git a/packages/ai/src/task/BackgroundRemovalTask.ts b/packages/ai/src/task/BackgroundRemovalTask.ts index a60515d1..19ddb597 100644 --- a/packages/ai/src/task/BackgroundRemovalTask.ts +++ b/packages/ai/src/task/BackgroundRemovalTask.ts @@ -61,7 +61,6 @@ export class BackgroundRemovalTask extends AiVisionTask< } } - /** * Convenience function to run background removal tasks. * Creates and executes a BackgroundRemovalTask with the provided input. diff --git a/packages/ai/src/task/ChunkToVectorTask.ts b/packages/ai/src/task/ChunkToVectorTask.ts index 6cda3d70..c6a149c1 100644 --- a/packages/ai/src/task/ChunkToVectorTask.ts +++ b/packages/ai/src/task/ChunkToVectorTask.ts @@ -10,7 +10,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { @@ -160,7 +159,6 @@ export class ChunkToVectorTask extends Task< } } - export const chunkToVector = (input: ChunkToVectorTaskInput, config?: JobQueueTaskConfig) => { return new ChunkToVectorTask({} as ChunkToVectorTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/ContextBuilderTask.ts b/packages/ai/src/task/ContextBuilderTask.ts index 2e6e6da1..5f6935ef 100644 --- a/packages/ai/src/task/ContextBuilderTask.ts +++ b/packages/ai/src/task/ContextBuilderTask.ts @@ -4,13 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - CreateWorkflow, - JobQueueTaskConfig, - Task, - TaskRegistry, - Workflow, -} from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Task, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; export const ContextFormat = { @@ -320,7 +314,6 @@ export class ContextBuilderTask extends Task< } } - export const contextBuilder = (input: ContextBuilderTaskInput, config?: JobQueueTaskConfig) => { return new ContextBuilderTask({} as ContextBuilderTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/DocumentEnricherTask.ts b/packages/ai/src/task/DocumentEnricherTask.ts index b805f55d..12937f4d 100644 --- a/packages/ai/src/task/DocumentEnricherTask.ts +++ b/packages/ai/src/task/DocumentEnricherTask.ts @@ -16,7 +16,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; @@ -398,7 +397,6 @@ export class DocumentEnricherTask extends Task< } } - export const documentEnricher = (input: DocumentEnricherTaskInput, config?: JobQueueTaskConfig) => { return new DocumentEnricherTask({} as DocumentEnricherTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/DocumentNodeRetrievalTask.ts b/packages/ai/src/task/DocumentNodeRetrievalTask.ts index 8b50bb2a..b9b2bd92 100644 --- a/packages/ai/src/task/DocumentNodeRetrievalTask.ts +++ b/packages/ai/src/task/DocumentNodeRetrievalTask.ts @@ -13,7 +13,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { @@ -231,7 +230,6 @@ export class DocumentNodeRetrievalTask extends Task< } } - export const retrieval = (input: RetrievalTaskInput, config?: JobQueueTaskConfig) => { return new DocumentNodeRetrievalTask({} as RetrievalTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts b/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts index 96e65da7..90941c0a 100644 --- a/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts +++ b/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts @@ -13,7 +13,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { @@ -217,7 +216,6 @@ export class DocumentNodeVectorHybridSearchTask extends Task< } } - export const hybridSearch = async ( input: HybridSearchTaskInput, config?: JobQueueTaskConfig diff --git a/packages/ai/src/task/DocumentNodeVectorSearchTask.ts b/packages/ai/src/task/DocumentNodeVectorSearchTask.ts index 04455e9f..37d481a6 100644 --- a/packages/ai/src/task/DocumentNodeVectorSearchTask.ts +++ b/packages/ai/src/task/DocumentNodeVectorSearchTask.ts @@ -13,7 +13,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { @@ -153,7 +152,6 @@ export class DocumentNodeVectorSearchTask extends Task< } } - export const vectorStoreSearch = ( input: VectorStoreSearchTaskInput, config?: JobQueueTaskConfig diff --git a/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts b/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts index 832f7761..efc62a9b 100644 --- a/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts +++ b/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts @@ -13,7 +13,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { @@ -153,7 +152,6 @@ export class DocumentNodeVectorUpsertTask extends Task< } } - export const vectorStoreUpsert = ( input: VectorStoreUpsertTaskInput, config?: JobQueueTaskConfig diff --git a/packages/ai/src/task/DownloadModelTask.ts b/packages/ai/src/task/DownloadModelTask.ts index aaf8c82e..a806b977 100644 --- a/packages/ai/src/task/DownloadModelTask.ts +++ b/packages/ai/src/task/DownloadModelTask.ts @@ -4,14 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - CreateWorkflow, - DeReplicateFromSchema, - JobQueueTaskConfig, - TaskRegistry, - TypeReplicateArray, - Workflow, -} from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; import { TypeModel } from "./base/AiTaskSchemas"; @@ -97,7 +90,6 @@ export class DownloadModelTask extends AiTask< } } - /** * Download a model from a remote source and cache it locally. * diff --git a/packages/ai/src/task/FaceDetectorTask.ts b/packages/ai/src/task/FaceDetectorTask.ts index 592bb1d5..3cc9bc2b 100644 --- a/packages/ai/src/task/FaceDetectorTask.ts +++ b/packages/ai/src/task/FaceDetectorTask.ts @@ -160,7 +160,6 @@ export class FaceDetectorTask extends AiVisionTask< } } - /** * Convenience function to run face detection tasks. * Creates and executes a FaceDetectorTask with the provided input. diff --git a/packages/ai/src/task/FaceLandmarkerTask.ts b/packages/ai/src/task/FaceLandmarkerTask.ts index e001ef6a..c5e9107b 100644 --- a/packages/ai/src/task/FaceLandmarkerTask.ts +++ b/packages/ai/src/task/FaceLandmarkerTask.ts @@ -196,7 +196,6 @@ export class FaceLandmarkerTask extends AiVisionTask< } } - /** * Convenience function to run face landmark detection tasks. * Creates and executes a FaceLandmarkerTask with the provided input. diff --git a/packages/ai/src/task/GestureRecognizerTask.ts b/packages/ai/src/task/GestureRecognizerTask.ts index 70ef9e9c..632139f6 100644 --- a/packages/ai/src/task/GestureRecognizerTask.ts +++ b/packages/ai/src/task/GestureRecognizerTask.ts @@ -202,7 +202,6 @@ export class GestureRecognizerTask extends AiVisionTask< } } - /** * Convenience function to run gesture recognition tasks. * Creates and executes a GestureRecognizerTask with the provided input. diff --git a/packages/ai/src/task/HandLandmarkerTask.ts b/packages/ai/src/task/HandLandmarkerTask.ts index c0b18103..f5a4b1f6 100644 --- a/packages/ai/src/task/HandLandmarkerTask.ts +++ b/packages/ai/src/task/HandLandmarkerTask.ts @@ -174,7 +174,6 @@ export class HandLandmarkerTask extends AiVisionTask< } } - /** * Convenience function to run hand landmark detection tasks. * Creates and executes a HandLandmarkerTask with the provided input. diff --git a/packages/ai/src/task/HierarchicalChunkerTask.ts b/packages/ai/src/task/HierarchicalChunkerTask.ts index 789e4af5..1a3d70fa 100644 --- a/packages/ai/src/task/HierarchicalChunkerTask.ts +++ b/packages/ai/src/task/HierarchicalChunkerTask.ts @@ -19,7 +19,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; @@ -159,7 +158,6 @@ export class HierarchicalChunkerTask extends Task< // Flat chunking: treat entire document as flat text await this.chunkFlat(root, doc_id, tokenBudget, chunks); } - return { doc_id, chunks, @@ -283,7 +281,6 @@ export class HierarchicalChunkerTask extends Task< } } - export const hierarchicalChunker = ( input: HierarchicalChunkerTaskInput, config?: JobQueueTaskConfig diff --git a/packages/ai/src/task/HierarchyJoinTask.ts b/packages/ai/src/task/HierarchyJoinTask.ts index b14c511a..09f0ccb5 100644 --- a/packages/ai/src/task/HierarchyJoinTask.ts +++ b/packages/ai/src/task/HierarchyJoinTask.ts @@ -15,7 +15,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; @@ -228,7 +227,6 @@ export class HierarchyJoinTask extends Task< } } - export const hierarchyJoin = (input: HierarchyJoinTaskInput, config?: JobQueueTaskConfig) => { return new HierarchyJoinTask({} as HierarchyJoinTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/ImageClassificationTask.ts b/packages/ai/src/task/ImageClassificationTask.ts index 452772d7..aa5e8510 100644 --- a/packages/ai/src/task/ImageClassificationTask.ts +++ b/packages/ai/src/task/ImageClassificationTask.ts @@ -81,7 +81,6 @@ export class ImageClassificationTask extends AiVisionTask< } } - /** * Convenience function to run image classification tasks. * Creates and executes an ImageClassificationTask with the provided input. diff --git a/packages/ai/src/task/ImageEmbeddingTask.ts b/packages/ai/src/task/ImageEmbeddingTask.ts index 689197c3..8213443e 100644 --- a/packages/ai/src/task/ImageEmbeddingTask.ts +++ b/packages/ai/src/task/ImageEmbeddingTask.ts @@ -69,7 +69,6 @@ export class ImageEmbeddingTask extends AiVisionTask< } } - /** * Convenience function to run image embedding tasks. * Creates and executes an ImageEmbeddingTask with the provided input. diff --git a/packages/ai/src/task/ImageSegmentationTask.ts b/packages/ai/src/task/ImageSegmentationTask.ts index 1a5541ab..775558e2 100644 --- a/packages/ai/src/task/ImageSegmentationTask.ts +++ b/packages/ai/src/task/ImageSegmentationTask.ts @@ -105,7 +105,6 @@ export class ImageSegmentationTask extends AiVisionTask< } } - /** * Convenience function to run image segmentation tasks. * Creates and executes an ImageSegmentationTask with the provided input. diff --git a/packages/ai/src/task/ImageToTextTask.ts b/packages/ai/src/task/ImageToTextTask.ts index 082bf2fb..76a8cc2f 100644 --- a/packages/ai/src/task/ImageToTextTask.ts +++ b/packages/ai/src/task/ImageToTextTask.ts @@ -72,7 +72,6 @@ export class ImageToTextTask extends AiVisionTask< } } - /** * Convenience function to run image to text tasks. * Creates and executes an ImageToTextTask with the provided input. diff --git a/packages/ai/src/task/ObjectDetectionTask.ts b/packages/ai/src/task/ObjectDetectionTask.ts index 38d11680..e891029c 100644 --- a/packages/ai/src/task/ObjectDetectionTask.ts +++ b/packages/ai/src/task/ObjectDetectionTask.ts @@ -102,7 +102,6 @@ export class ObjectDetectionTask extends AiVisionTask< } } - /** * Convenience function to run object detection tasks. * Creates and executes an ObjectDetectionTask with the provided input. diff --git a/packages/ai/src/task/PoseLandmarkerTask.ts b/packages/ai/src/task/PoseLandmarkerTask.ts index 3d4b4404..955ddeb2 100644 --- a/packages/ai/src/task/PoseLandmarkerTask.ts +++ b/packages/ai/src/task/PoseLandmarkerTask.ts @@ -191,7 +191,6 @@ export class PoseLandmarkerTask extends AiVisionTask< } } - /** * Convenience function to run pose landmark detection tasks. * Creates and executes a PoseLandmarkerTask with the provided input. diff --git a/packages/ai/src/task/QueryExpanderTask.ts b/packages/ai/src/task/QueryExpanderTask.ts index c72bc3fd..135969c0 100644 --- a/packages/ai/src/task/QueryExpanderTask.ts +++ b/packages/ai/src/task/QueryExpanderTask.ts @@ -9,7 +9,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; @@ -299,7 +298,6 @@ export class QueryExpanderTask extends Task< } } - export const queryExpander = (input: QueryExpanderTaskInput, config?: JobQueueTaskConfig) => { return new QueryExpanderTask({} as QueryExpanderTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/RerankerTask.ts b/packages/ai/src/task/RerankerTask.ts index 5baead4d..e92545fc 100644 --- a/packages/ai/src/task/RerankerTask.ts +++ b/packages/ai/src/task/RerankerTask.ts @@ -9,7 +9,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; @@ -326,7 +325,6 @@ export class RerankerTask extends Task { return new RerankerTask({} as RerankerTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/StructuralParserTask.ts b/packages/ai/src/task/StructuralParserTask.ts index bdda6839..afb77abc 100644 --- a/packages/ai/src/task/StructuralParserTask.ts +++ b/packages/ai/src/task/StructuralParserTask.ts @@ -10,7 +10,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; @@ -140,7 +139,6 @@ export class StructuralParserTask extends Task< } } - export const structuralParser = (input: StructuralParserTaskInput, config?: JobQueueTaskConfig) => { return new StructuralParserTask({} as StructuralParserTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/TextChunkerTask.ts b/packages/ai/src/task/TextChunkerTask.ts index c3c2a274..50375484 100644 --- a/packages/ai/src/task/TextChunkerTask.ts +++ b/packages/ai/src/task/TextChunkerTask.ts @@ -9,7 +9,6 @@ import { IExecuteContext, JobQueueTaskConfig, Task, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; @@ -343,7 +342,6 @@ export class TextChunkerTask extends Task< } } - export const textChunker = (input: TextChunkerTaskInput, config?: JobQueueTaskConfig) => { return new TextChunkerTask({} as TextChunkerTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/TextClassificationTask.ts b/packages/ai/src/task/TextClassificationTask.ts index 1f1cb598..009e336d 100644 --- a/packages/ai/src/task/TextClassificationTask.ts +++ b/packages/ai/src/task/TextClassificationTask.ts @@ -97,7 +97,6 @@ export class TextClassificationTask extends AiTask< } } - /** * Convenience function to run text classifier tasks. * Creates and executes a TextClassificationTask with the provided input. diff --git a/packages/ai/src/task/TextEmbeddingTask.ts b/packages/ai/src/task/TextEmbeddingTask.ts index c296a4f7..27d60d60 100644 --- a/packages/ai/src/task/TextEmbeddingTask.ts +++ b/packages/ai/src/task/TextEmbeddingTask.ts @@ -73,7 +73,6 @@ export class TextEmbeddingTask extends AiTask { return new TopicSegmenterTask({} as TopicSegmenterTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/UnloadModelTask.ts b/packages/ai/src/task/UnloadModelTask.ts index 1e5c3c7c..5f3b929c 100644 --- a/packages/ai/src/task/UnloadModelTask.ts +++ b/packages/ai/src/task/UnloadModelTask.ts @@ -4,14 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - CreateWorkflow, - DeReplicateFromSchema, - JobQueueTaskConfig, - TaskRegistry, - TypeReplicateArray, - Workflow, -} from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; import { TypeModel } from "./base/AiTaskSchemas"; @@ -63,7 +56,6 @@ export class UnloadModelTask extends AiTask< public static cacheable = false; } - /** * Unload a model from memory and clear its cache. * diff --git a/packages/ai/src/task/VectorQuantizeTask.ts b/packages/ai/src/task/VectorQuantizeTask.ts index 0fe2bc5d..c84ef750 100644 --- a/packages/ai/src/task/VectorQuantizeTask.ts +++ b/packages/ai/src/task/VectorQuantizeTask.ts @@ -4,13 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - CreateWorkflow, - JobQueueTaskConfig, - Task, - TaskRegistry, - Workflow, -} from "@workglow/task-graph"; +import { CreateWorkflow, JobQueueTaskConfig, Task, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema, @@ -194,16 +188,16 @@ export class VectorQuantizeTask extends Task< if (values.length === 0) { return { min: 0, max: 1 }; } - + let min = values[0]; let max = values[0]; - + for (let i = 1; i < values.length; i++) { const val = values[i]; if (val < min) min = val; if (val > max) max = val; } - + return { min, max }; } @@ -238,7 +232,6 @@ export class VectorQuantizeTask extends Task< } } - export const vectorQuantize = (input: VectorQuantizeTaskInput, config?: JobQueueTaskConfig) => { return new VectorQuantizeTask({} as VectorQuantizeTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/VectorSimilarityTask.ts b/packages/ai/src/task/VectorSimilarityTask.ts index f70b690e..81cc5919 100644 --- a/packages/ai/src/task/VectorSimilarityTask.ts +++ b/packages/ai/src/task/VectorSimilarityTask.ts @@ -4,13 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - CreateWorkflow, - GraphAsTask, - JobQueueTaskConfig, - TaskRegistry, - Workflow, -} from "@workglow/task-graph"; +import { CreateWorkflow, GraphAsTask, JobQueueTaskConfig, Workflow } from "@workglow/task-graph"; import { cosineSimilarity, DataPortSchema, @@ -141,7 +135,6 @@ export class VectorSimilarityTask extends GraphAsTask< } } - export const similarity = (input: VectorSimilarityTaskInput, config?: JobQueueTaskConfig) => { return new VectorSimilarityTask({} as VectorSimilarityTaskInput, config).run(input); }; diff --git a/packages/ai/src/task/index.ts b/packages/ai/src/task/index.ts index 419b9318..00436895 100644 --- a/packages/ai/src/task/index.ts +++ b/packages/ai/src/task/index.ts @@ -48,47 +48,49 @@ import { VectorSimilarityTask } from "./VectorSimilarityTask"; // Register all AI tasks with the TaskRegistry. // Centralized registration ensures tasks are available for JSON deserialization // and prevents tree-shaking issues. -[ - BackgroundRemovalTask, - ChunkToVectorTask, - ContextBuilderTask, - DocumentEnricherTask, - DocumentNodeRetrievalTask, - DocumentNodeVectorHybridSearchTask, - DocumentNodeVectorSearchTask, - DocumentNodeVectorUpsertTask, - DownloadModelTask, - FaceDetectorTask, - FaceLandmarkerTask, - GestureRecognizerTask, - HandLandmarkerTask, - HierarchicalChunkerTask, - HierarchyJoinTask, - ImageClassificationTask, - ImageEmbeddingTask, - ImageSegmentationTask, - ImageToTextTask, - ObjectDetectionTask, - PoseLandmarkerTask, - QueryExpanderTask, - RerankerTask, - StructuralParserTask, - TextChunkerTask, - TextClassificationTask, - TextEmbeddingTask, - TextFillMaskTask, - TextGenerationTask, - TextLanguageDetectionTask, - TextNamedEntityRecognitionTask, - TextQuestionAnswerTask, - TextRewriterTask, - TextSummaryTask, - TextTranslationTask, - TopicSegmenterTask, - UnloadModelTask, - VectorQuantizeTask, - VectorSimilarityTask, -].map(TaskRegistry.registerTask); +export const registerAiTasks = () => { + [ + BackgroundRemovalTask, + ChunkToVectorTask, + ContextBuilderTask, + DocumentEnricherTask, + DocumentNodeRetrievalTask, + DocumentNodeVectorHybridSearchTask, + DocumentNodeVectorSearchTask, + DocumentNodeVectorUpsertTask, + DownloadModelTask, + FaceDetectorTask, + FaceLandmarkerTask, + GestureRecognizerTask, + HandLandmarkerTask, + HierarchicalChunkerTask, + HierarchyJoinTask, + ImageClassificationTask, + ImageEmbeddingTask, + ImageSegmentationTask, + ImageToTextTask, + ObjectDetectionTask, + PoseLandmarkerTask, + QueryExpanderTask, + RerankerTask, + StructuralParserTask, + TextChunkerTask, + TextClassificationTask, + TextEmbeddingTask, + TextFillMaskTask, + TextGenerationTask, + TextLanguageDetectionTask, + TextNamedEntityRecognitionTask, + TextQuestionAnswerTask, + TextRewriterTask, + TextSummaryTask, + TextTranslationTask, + TopicSegmenterTask, + UnloadModelTask, + VectorQuantizeTask, + VectorSimilarityTask, + ].map(TaskRegistry.registerTask); +}; export * from "./BackgroundRemovalTask"; export * from "./base/AiTask"; diff --git a/packages/task-graph/README.md b/packages/task-graph/README.md index 73824243..66e21e0b 100644 --- a/packages/task-graph/README.md +++ b/packages/task-graph/README.md @@ -922,58 +922,6 @@ try { ## Advanced Patterns -### Array Tasks (Parallel Processing) - -```typescript -class ArrayProcessorTask extends ArrayTask<{ items: string[] }, { results: string[] }> { - static readonly type = "ArrayProcessorTask"; - - static inputSchema() { - return { - type: "object", - properties: { - items: { - type: "array", - items: { - type: "string", - }, - }, - }, - required: ["items"], - additionalProperties: false, - } as const satisfies DataPortSchema; - } - - static outputSchema() { - return { - type: "object", - properties: { - results: { - type: "array", - items: { - type: "string", - }, - }, - }, - required: ["results"], - additionalProperties: false, - } as const satisfies DataPortSchema; - } - - async execute(input: { items: string[] }) { - return { results: input.items.map((item) => item.toUpperCase()) }; - } -} - -// Process array items in parallel -const task = new ArrayProcessorTask({ - items: ["hello", "world", "foo", "bar"], -}); - -const result = await task.run(); -// { results: ["HELLO", "WORLD", "FOO", "BAR"] } -``` - ### Job Queue Tasks ```typescript diff --git a/packages/task-graph/src/common.ts b/packages/task-graph/src/common.ts index 9ac5e134..a39a5b03 100644 --- a/packages/task-graph/src/common.ts +++ b/packages/task-graph/src/common.ts @@ -4,23 +4,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -export * from "./task/Task"; - -export * from "./task/ArrayTask"; -export * from "./task/ConditionalTask"; -export * from "./task/GraphAsTask"; -export * from "./task/GraphAsTaskRunner"; -export * from "./task/InputResolver"; -export * from "./task/ITask"; -export * from "./task/JobQueueFactory"; -export * from "./task/JobQueueTask"; -export * from "./task/TaskError"; -export * from "./task/TaskEvents"; -export * from "./task/TaskJSON"; -export * from "./task/TaskQueueRegistry"; -export * from "./task/TaskRegistry"; -export * from "./task/TaskTypes"; - export * from "./task-graph/Dataflow"; export * from "./task-graph/DataflowEvents"; @@ -33,6 +16,8 @@ export * from "./task-graph/Conversions"; export * from "./task-graph/IWorkflow"; export * from "./task-graph/Workflow"; +export * from "./task"; + export * from "./storage/TaskGraphRepository"; export * from "./storage/TaskGraphTabularRepository"; export * from "./storage/TaskOutputRepository"; diff --git a/packages/task-graph/src/task/README.md b/packages/task-graph/src/task/README.md index a767d9dd..8c5489da 100644 --- a/packages/task-graph/src/task/README.md +++ b/packages/task-graph/src/task/README.md @@ -7,7 +7,6 @@ This module provides a flexible task processing system with support for various - [Task Types](#task-types) - [A Simple Task](#a-simple-task) - [GraphAsTask](#graphastask) - - [ArrayTask](#arraytask) - [Job Queue Tasks](#job-queue-tasks) - [Task Lifecycle](#task-lifecycle) - [Event Handling](#event-handling) @@ -23,7 +22,6 @@ This module provides a flexible task processing system with support for various ### Core Classes - `Task`: Base class implementing core task functionality -- `ArrayTask`: Executes a task or a task with multiple inputs in parallel with a subGraph - `JobQueueTask`: Integrates with job queue system for distributed processing ## Task Types @@ -74,13 +72,6 @@ class MyTask extends Task { ### GraphAsTask - GraphAsTask tasks are tasks that contain other tasks. They are represented as an internal TaskGraph. -- A ArrayTask is a compound task that can run a task as normal, or if the inputs are an array and the input definition has x-replicate=true defined for that input, then the task will run parallel copies with a subGraph. - -### ArrayTask - -- ArrayTask is a task that can run a task as normal, or if the inputs are an arryay and the input definition has x-replicate=true, then the task will run parallel copies with a subGraph. -- The subGraph is a TaskGraph that is created from the inputs of the task. -- The results of the subGraph are combined such that the outputs are turned into arrays. ### Job Queue Tasks diff --git a/packages/task-graph/src/task/index.ts b/packages/task-graph/src/task/index.ts new file mode 100644 index 00000000..cce47aeb --- /dev/null +++ b/packages/task-graph/src/task/index.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from "./ConditionalTask"; +export * from "./GraphAsTask"; +export * from "./GraphAsTaskRunner"; +export * from "./InputResolver"; +export * from "./ITask"; +export * from "./JobQueueFactory"; +export * from "./JobQueueTask"; +export * from "./Task"; +export * from "./TaskError"; +export * from "./TaskEvents"; +export * from "./TaskJSON"; +export * from "./TaskQueueRegistry"; +export * from "./TaskRegistry"; +export * from "./TaskTypes"; + +import { ConditionalTask } from "./ConditionalTask"; +import { GraphAsTask } from "./GraphAsTask"; +import { TaskRegistry } from "./TaskRegistry"; + +export const registerBaseTasks = () => { + [ConditionalTask, GraphAsTask].map(TaskRegistry.registerTask); +}; diff --git a/packages/tasks/README.md b/packages/tasks/README.md index b79ad65f..485a1ea0 100644 --- a/packages/tasks/README.md +++ b/packages/tasks/README.md @@ -14,6 +14,7 @@ A package of task types for common operations, workflow management, and data pro - [JavaScriptTask](#javascripttask) - [LambdaTask](#lambdatask) - [JsonTask](#jsontask) + - [ArrayTask](#arraytask) - [Workflow Integration](#workflow-integration) - [Error Handling](#error-handling) - [Configuration](#configuration) @@ -494,6 +495,78 @@ const dynamicWorkflow = await new JsonTask({ - Automatic data flow between tasks - Enables configuration-driven workflows +### ArrayTask + +A compound task that processes arrays by either executing directly for non-array inputs or creating parallel task instances for array inputs. Supports parallel processing of array elements and combination generation when multiple inputs are arrays. + +**Key Features:** + +- Automatically handles single values or arrays +- Parallel execution for array inputs +- Generates all combinations when multiple inputs are arrays +- Uses `x-replicate` annotation to mark array-capable inputs + +**Examples:** + +```typescript +import { ArrayTask, DataPortSchema } from "@workglow/tasks"; + +class ArrayProcessorTask extends ArrayTask<{ items: string[] }, { results: string[] }> { + static readonly type = "ArrayProcessorTask"; + + static inputSchema() { + return { + type: "object", + properties: { + items: { + type: "array", + items: { + type: "string", + }, + }, + }, + required: ["items"], + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema() { + return { + type: "object", + properties: { + results: { + type: "array", + items: { + type: "string", + }, + }, + }, + required: ["results"], + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { items: string[] }) { + return { results: input.items.map((item) => item.toUpperCase()) }; + } +} + +// Process array items in parallel +const task = new ArrayProcessorTask({ + items: ["hello", "world", "foo", "bar"], +}); + +const result = await task.run(); +// { results: ["HELLO", "WORLD", "FOO", "BAR"] } +``` + +**Features:** + +- Parallel processing of array elements +- Automatic task instance creation per array element +- Combination generation for multiple array inputs +- Seamless single-value and array handling + ## Workflow Integration All tasks can be used standalone or integrated into workflows: diff --git a/packages/tasks/src/common.ts b/packages/tasks/src/common.ts index 84e96928..19ecc94b 100644 --- a/packages/tasks/src/common.ts +++ b/packages/tasks/src/common.ts @@ -4,7 +4,20 @@ * SPDX-License-Identifier: Apache-2.0 */ +export * from "./task/ArrayTask"; +export * from "./task/DebugLogTask"; +export * from "./task/DelayTask"; +export * from "./task/FetchUrlTask"; +export * from "./task/InputTask"; +export * from "./task/JavaScriptTask"; +export * from "./task/JsonTask"; +export * from "./task/LambdaTask"; +export * from "./task/MergeTask"; +export * from "./task/OutputTask"; +export * from "./task/SplitTask"; + import { TaskRegistry } from "@workglow/task-graph"; +import { ArrayTask } from "./task/ArrayTask"; import { DebugLogTask } from "./task/DebugLogTask"; import { DelayTask } from "./task/DelayTask"; import { FetchUrlTask } from "./task/FetchUrlTask"; @@ -19,26 +32,18 @@ import { SplitTask } from "./task/SplitTask"; // Register all common tasks with the TaskRegistry. // Centralized registration ensures tasks are available for JSON deserialization // and prevents tree-shaking issues. -[ - DebugLogTask, - DelayTask, - FetchUrlTask, - InputTask, - JavaScriptTask, - JsonTask, - LambdaTask, - MergeTask, - OutputTask, - SplitTask, -].map(TaskRegistry.registerTask); - -export * from "./task/DebugLogTask"; -export * from "./task/DelayTask"; -export * from "./task/FetchUrlTask"; -export * from "./task/InputTask"; -export * from "./task/JavaScriptTask"; -export * from "./task/JsonTask"; -export * from "./task/LambdaTask"; -export * from "./task/MergeTask"; -export * from "./task/OutputTask"; -export * from "./task/SplitTask"; +export const registerCommonTasks = () => { + [ + ArrayTask, + DebugLogTask, + DelayTask, + FetchUrlTask, + InputTask, + JavaScriptTask, + JsonTask, + LambdaTask, + MergeTask, + OutputTask, + SplitTask, + ].map(TaskRegistry.registerTask); +}; diff --git a/packages/task-graph/src/task/ArrayTask.ts b/packages/tasks/src/task/ArrayTask.ts similarity index 95% rename from packages/task-graph/src/task/ArrayTask.ts rename to packages/tasks/src/task/ArrayTask.ts index d2875d5f..1b10c788 100644 --- a/packages/task-graph/src/task/ArrayTask.ts +++ b/packages/tasks/src/task/ArrayTask.ts @@ -12,12 +12,18 @@ import { type DataPortSchema, } from "@workglow/util"; -import { TaskGraph } from "../task-graph/TaskGraph"; -import { GraphResultArray, PROPERTY_ARRAY } from "../task-graph/TaskGraphRunner"; -import { GraphAsTask } from "./GraphAsTask"; -import { GraphAsTaskRunner } from "./GraphAsTaskRunner"; -import { JsonTaskItem, TaskGraphItemJson } from "./TaskJSON"; -import { TaskConfig, TaskInput, TaskOutput } from "./TaskTypes"; +import { + GraphAsTask, + GraphAsTaskRunner, + GraphResultArray, + JsonTaskItem, + PROPERTY_ARRAY, + TaskConfig, + TaskGraph, + TaskGraphItemJson, + TaskInput, + TaskOutput, +} from "@workglow/task-graph"; export const TypeReplicateArray = ( type: T, diff --git a/packages/tasks/src/task/DelayTask.ts b/packages/tasks/src/task/DelayTask.ts index 8b94b6dc..546474ca 100644 --- a/packages/tasks/src/task/DelayTask.ts +++ b/packages/tasks/src/task/DelayTask.ts @@ -10,7 +10,6 @@ import { Task, TaskAbortedError, TaskConfig, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema, sleep } from "@workglow/util"; diff --git a/packages/tasks/src/task/FetchUrlTask.ts b/packages/tasks/src/task/FetchUrlTask.ts index 82115e8f..dcc363f6 100644 --- a/packages/tasks/src/task/FetchUrlTask.ts +++ b/packages/tasks/src/task/FetchUrlTask.ts @@ -17,7 +17,6 @@ import { JobQueueTaskConfig, TaskConfigurationError, TaskInvalidInputError, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; diff --git a/packages/tasks/src/task/FileLoaderTask.server.ts b/packages/tasks/src/task/FileLoaderTask.server.ts index 1f6d54d2..32978d61 100644 --- a/packages/tasks/src/task/FileLoaderTask.server.ts +++ b/packages/tasks/src/task/FileLoaderTask.server.ts @@ -9,7 +9,6 @@ import { IExecuteContext, JobQueueTaskConfig, TaskAbortedError, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { readFile } from "node:fs/promises"; diff --git a/packages/tasks/src/task/FileLoaderTask.ts b/packages/tasks/src/task/FileLoaderTask.ts index aea490a5..da0659f3 100644 --- a/packages/tasks/src/task/FileLoaderTask.ts +++ b/packages/tasks/src/task/FileLoaderTask.ts @@ -10,7 +10,6 @@ import { JobQueueTaskConfig, Task, TaskAbortedError, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; diff --git a/packages/tasks/src/task/JsonTask.ts b/packages/tasks/src/task/JsonTask.ts index 2a7f2aec..32b7d565 100644 --- a/packages/tasks/src/task/JsonTask.ts +++ b/packages/tasks/src/task/JsonTask.ts @@ -12,7 +12,6 @@ import { JsonTaskItem, TaskConfig, TaskConfigurationError, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; diff --git a/packages/tasks/src/task/LambdaTask.ts b/packages/tasks/src/task/LambdaTask.ts index 9990941a..a2d2d9b2 100644 --- a/packages/tasks/src/task/LambdaTask.ts +++ b/packages/tasks/src/task/LambdaTask.ts @@ -14,7 +14,6 @@ import { TaskConfigurationError, TaskInput, TaskOutput, - TaskRegistry, Workflow, } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; diff --git a/packages/tasks/src/task/MergeTask.ts b/packages/tasks/src/task/MergeTask.ts index 3726633d..396e7f8f 100644 --- a/packages/tasks/src/task/MergeTask.ts +++ b/packages/tasks/src/task/MergeTask.ts @@ -4,14 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - CreateWorkflow, - IExecuteContext, - Task, - TaskConfig, - TaskRegistry, - Workflow, -} from "@workglow/task-graph"; +import { CreateWorkflow, IExecuteContext, Task, TaskConfig, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; const inputSchema = { diff --git a/packages/test/src/binding/RegisterTasks.ts b/packages/test/src/binding/RegisterTasks.ts new file mode 100644 index 00000000..4d2405ce --- /dev/null +++ b/packages/test/src/binding/RegisterTasks.ts @@ -0,0 +1,15 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { registerAiTasks } from "@workglow/ai"; +import { registerBaseTasks } from "@workglow/task-graph"; +import { registerCommonTasks } from "@workglow/tasks"; + +export const registerTasks = () => { + registerBaseTasks(); + registerCommonTasks(); + registerAiTasks(); +}; diff --git a/packages/test/src/samples/ONNXModelSamples.ts b/packages/test/src/samples/ONNXModelSamples.ts index 373f5831..8178996d 100644 --- a/packages/test/src/samples/ONNXModelSamples.ts +++ b/packages/test/src/samples/ONNXModelSamples.ts @@ -84,7 +84,7 @@ export async function registerHuggingfaceLocalModels(): Promise { model_id: "onnx:Xenova/distilbert-base-uncased-distilled-squad:q8", title: "distilbert-base-uncased-distilled-squad", description: "Xenova/distilbert-base-uncased-distilled-squad quantized to 8bit", - tasks: ["TextQuestionAnsweringTask"], + tasks: ["TextQuestionAnswerTask"], provider: HF_TRANSFORMERS_ONNX, provider_config: { pipeline: "question-answering", @@ -92,6 +92,19 @@ export async function registerHuggingfaceLocalModels(): Promise { }, metadata: {}, }, + { + model_id: "onnx:onnx-community/ModernBERT-finetuned-squad-ONNX", + title: "ModernBERT-finetuned-squad-ONNX", + description: "onnx-community/ModernBERT-finetuned-squad-ONNX quantized to int8", + tasks: ["TextQuestionAnswerTask"], + provider: HF_TRANSFORMERS_ONNX, + provider_config: { + pipeline: "question-answering", + model_path: "onnx-community/ModernBERT-finetuned-squad-ONNX", + dtype: "int8", + }, + metadata: {}, + }, { model_id: "onnx:Xenova/gpt2:q8", title: "gpt2", diff --git a/packages/test/src/test/rag/ChunkToVector.test.ts b/packages/test/src/test/rag/ChunkToVector.test.ts index 7a236125..e1d00ea1 100644 --- a/packages/test/src/test/rag/ChunkToVector.test.ts +++ b/packages/test/src/test/rag/ChunkToVector.test.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import "@workglow/ai"; // Trigger Workflow prototype extensions import type { ChunkToVectorTaskOutput, HierarchicalChunkerTaskOutput } from "@workglow/ai"; import { type ChunkNode, NodeIdGenerator, StructuralParser } from "@workglow/storage"; import { Workflow } from "@workglow/task-graph"; diff --git a/packages/test/src/test/rag/RagWorkflow.test.ts b/packages/test/src/test/rag/RagWorkflow.test.ts index 327ac19a..62508381 100644 --- a/packages/test/src/test/rag/RagWorkflow.test.ts +++ b/packages/test/src/test/rag/RagWorkflow.test.ts @@ -38,8 +38,10 @@ import { InMemoryModelRepository, + retrieval, RetrievalTaskOutput, setGlobalModelRepository, + textQuestionAnswer, TextQuestionAnswerTaskOutput, VectorStoreUpsertTaskOutput, } from "@workglow/ai"; @@ -66,7 +68,7 @@ describe("RAG Workflow End-to-End", () => { const embeddingModel = "onnx:Xenova/all-MiniLM-L6-v2:q8"; const summaryModel = "onnx:Falconsai/text_summarization:fp32"; const nerModel = "onnx:onnx-community/NeuroBERT-NER-ONNX:q8"; - const qaModel = "onnx:Xenova/distilbert-base-uncased-distilled-squad:q8"; + const qaModel = "onnx:onnx-community/ModernBERT-finetuned-squad-ONNX"; beforeAll(async () => { // Setup task queue and model repository @@ -94,7 +96,7 @@ describe("RAG Workflow End-to-End", () => { setTaskQueueRegistry(null); }); - it.only("should ingest markdown documents with NER enrichment", async () => { + it("should ingest markdown documents with NER enrichment", async () => { // Find markdown files in docs folder const docsPath = join(process.cwd(), "docs", "background"); const files = readdirSync(docsPath).filter((f) => f.endsWith(".md")); @@ -117,8 +119,8 @@ describe("RAG Workflow End-to-End", () => { sourceUri: filePath, }) .documentEnricher({ - generateSummaries: true, - extractEntities: true, + generateSummaries: false, + extractEntities: false, summaryModel, nerModel, }) @@ -189,19 +191,14 @@ describe("RAG Workflow End-to-End", () => { console.log(`\nAnswering question: "${question}"`); - // Step 1: Retrieve relevant context - const retrievalWorkflow = new Workflow(); - - retrievalWorkflow.retrieval({ + const retrievalResult = await retrieval({ repository: vectorRepoName, query: question, model: embeddingModel, topK: 3, - scoreThreshold: 0.2, // Lower threshold to find results + scoreThreshold: 0.2, }); - const retrievalResult = (await retrievalWorkflow.run()) as RetrievalTaskOutput; - expect(retrievalResult.chunks).toBeDefined(); if (retrievalResult.chunks.length === 0) { @@ -216,17 +213,12 @@ describe("RAG Workflow End-to-End", () => { console.log(`Context length: ${context.length} characters`); - // Step 3: Answer question using context - const qaWorkflow = new Workflow(); - - qaWorkflow.textQuestionAnswer({ + const answer = await textQuestionAnswer({ context, question, model: qaModel, }); - const answer = (await qaWorkflow.run()) as TextQuestionAnswerTaskOutput; - // Verify answer expect(answer.text).toBeDefined(); expect(typeof answer.text).toBe("string"); diff --git a/packages/test/src/test/task/ArrayTask.test.ts b/packages/test/src/test/task/ArrayTask.test.ts index fd4f8318..c8c4614d 100644 --- a/packages/test/src/test/task/ArrayTask.test.ts +++ b/packages/test/src/test/task/ArrayTask.test.ts @@ -5,7 +5,6 @@ */ import { - ArrayTask, Dataflow, IExecuteContext, ITask, @@ -17,6 +16,7 @@ import { TaskOutput, TaskStatus, } from "@workglow/task-graph"; +import { ArrayTask } from "@workglow/tasks"; import { ConvertAllToOptionalArray, DataPortSchema } from "@workglow/util"; import { describe, expect, test, vi } from "vitest"; From 2db58e82c41207ded9fed17941be03fd73ece3e7 Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Mon, 12 Jan 2026 00:20:53 +0000 Subject: [PATCH 08/14] [feat] Rename Chunk Vector Management Tasks and Update Documentation - Added new tasks for chunk vector management: ChunkRetrievalTask, ChunkVectorHybridSearchTask, ChunkVectorSearchTask, and ChunkVectorUpsertTask. - Updated existing documentation to reflect the new chunk vector tasks and their functionalities. - Refactored related components to utilize the new chunk vector repository structure, enhancing the overall architecture for vector storage and retrieval. - Improved task registration in the index file for better maintainability and accessibility. --- docs/developers/03_extending.md | 48 +++++++++---------- packages/ai/README.md | 28 +++++------ ...RetrievalTask.ts => ChunkRetrievalTask.ts} | 9 ++-- ...Task.ts => ChunkVectorHybridSearchTask.ts} | 17 +++---- ...SearchTask.ts => ChunkVectorSearchTask.ts} | 17 +++---- ...UpsertTask.ts => ChunkVectorUpsertTask.ts} | 17 +++---- packages/ai/src/task/index.ts | 22 ++++----- .../ChunkVectorRepositoryRegistry.ts} | 36 ++++++-------- .../ChunkVectorSchema.ts} | 10 ++-- .../IChunkVectorRepository.ts} | 4 +- .../InMemoryChunkVectorRepository.ts} | 32 ++++++------- .../PostgresChunkVectorRepository.ts} | 38 +++++++-------- .../README.md | 0 .../SqliteChunkVectorRepository.ts} | 32 ++++++------- packages/storage/src/common-server.ts | 4 +- packages/storage/src/common.ts | 8 ++-- .../src/document/DocumentRepository.ts | 12 ++--- packages/storage/src/util/RepositorySchema.ts | 2 +- .../rag/DocumentNodeRetrievalTask.test.ts | 11 ++--- .../rag/DocumentNodeVectorSearchTask.test.ts | 39 +++++++-------- .../DocumentNodeVectorStoreUpsertTask.test.ts | 35 +++++++------- .../src/test/rag/DocumentRepository.test.ts | 6 +-- packages/test/src/test/rag/EndToEnd.test.ts | 4 +- packages/test/src/test/rag/FullChain.test.ts | 8 +--- .../src/test/rag/HybridSearchTask.test.ts | 13 ++--- .../test/src/test/rag/RagWorkflow.test.ts | 10 ++-- 26 files changed, 207 insertions(+), 255 deletions(-) rename packages/ai/src/task/{DocumentNodeRetrievalTask.ts => ChunkRetrievalTask.ts} (96%) rename packages/ai/src/task/{DocumentNodeVectorHybridSearchTask.ts => ChunkVectorHybridSearchTask.ts} (91%) rename packages/ai/src/task/{DocumentNodeVectorSearchTask.ts => ChunkVectorSearchTask.ts} (88%) rename packages/ai/src/task/{DocumentNodeVectorUpsertTask.ts => ChunkVectorUpsertTask.ts} (88%) rename packages/storage/src/{document-node-vector/DocumentNodeVectorRepositoryRegistry.ts => chunk-vector/ChunkVectorRepositoryRegistry.ts} (64%) rename packages/storage/src/{document-node-vector/DocumentNodeVectorSchema.ts => chunk-vector/ChunkVectorSchema.ts} (73%) rename packages/storage/src/{document-node-vector/IDocumentNodeVectorRepository.ts => chunk-vector/IChunkVectorRepository.ts} (96%) rename packages/storage/src/{document-node-vector/InMemoryDocumentNodeVectorRepository.ts => chunk-vector/InMemoryChunkVectorRepository.ts} (86%) rename packages/storage/src/{document-node-vector/PostgresDocumentNodeVectorRepository.ts => chunk-vector/PostgresChunkVectorRepository.ts} (89%) rename packages/storage/src/{document-node-vector => chunk-vector}/README.md (100%) rename packages/storage/src/{document-node-vector/SqliteDocumentNodeVectorRepository.ts => chunk-vector/SqliteChunkVectorRepository.ts} (87%) diff --git a/docs/developers/03_extending.md b/docs/developers/03_extending.md index add9693a..b9cbddf9 100644 --- a/docs/developers/03_extending.md +++ b/docs/developers/03_extending.md @@ -138,13 +138,13 @@ When defining task input schemas, you can use `format` annotations to enable aut The system supports several format annotations out of the box: -| Format | Description | Helper Function | -| --------------------------------- | ----------------------------------- | ------------------------------------ | -| `model` | Any AI model configuration | `TypeModel()` | -| `model:TaskName` | Model compatible with specific task | — | -| `repository:tabular` | Tabular data repository | `TypeTabularRepository()` | -| `repository:document-node-vector` | Vector storage repository | `TypeDocumentNodeVectorRepository()` | -| `repository:document` | Document repository | `TypeDocumentRepository()` | +| Format | Description | Helper Function | +| --------------------------------- | ----------------------------------- | ----------------------------- | +| `model` | Any AI model configuration | `TypeModel()` | +| `model:TaskName` | Model compatible with specific task | — | +| `repository:tabular` | Tabular data repository | `TypeTabularRepository()` | +| `repository:document-node-vector` | Vector storage repository | `TypeChunkVectorRepository()` | +| `repository:document` | Document repository | `TypeDocumentRepository()` | ### Example: Using Format Annotations @@ -263,26 +263,26 @@ The `@workglow/ai` package provides a comprehensive set of tasks for building RA ### Vector and Embedding Tasks -| Task | Description | -| ------------------------------ | ---------------------------------------------- | -| `TextEmbeddingTask` | Generates embeddings using configurable models | -| `ChunkToVectorTask` | Transforms chunks to vector store format | -| `DocumentNodeVectorUpsertTask` | Stores vectors in a repository | -| `DocumentNodeVectorSearchTask` | Searches vectors by similarity | -| `VectorQuantizeTask` | Quantizes vectors for storage efficiency | +| Task | Description | +| ----------------------- | ---------------------------------------------- | +| `TextEmbeddingTask` | Generates embeddings using configurable models | +| `ChunkToVectorTask` | Transforms chunks to vector store format | +| `ChunkVectorUpsertTask` | Stores vectors in a repository | +| `ChunkVectorSearchTask` | Searches vectors by similarity | +| `VectorQuantizeTask` | Quantizes vectors for storage efficiency | ### Retrieval and Generation Tasks -| Task | Description | -| ------------------------------------ | --------------------------------------------- | -| `QueryExpanderTask` | Expands queries for better retrieval coverage | -| `DocumentNodeVectorHybridSearchTask` | Combines vector and full-text search | -| `RerankerTask` | Reranks search results for relevance | -| `HierarchyJoinTask` | Enriches results with parent context | -| `ContextBuilderTask` | Builds context for LLM prompts | -| `DocumentNodeRetrievalTask` | Orchestrates end-to-end retrieval | -| `TextQuestionAnswerTask` | Generates answers from context | -| `TextGenerationTask` | General text generation | +| Task | Description | +| ----------------------------- | --------------------------------------------- | +| `QueryExpanderTask` | Expands queries for better retrieval coverage | +| `ChunkVectorHybridSearchTask` | Combines vector and full-text search | +| `RerankerTask` | Reranks search results for relevance | +| `HierarchyJoinTask` | Enriches results with parent context | +| `ContextBuilderTask` | Builds context for LLM prompts | +| `DocumentNodeRetrievalTask` | Orchestrates end-to-end retrieval | +| `TextQuestionAnswerTask` | Generates answers from context | +| `TextGenerationTask` | General text generation | ### Chainable RAG Pipeline Example diff --git a/packages/ai/README.md b/packages/ai/README.md index a1bf267b..fd3fde8b 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -412,23 +412,23 @@ The AI package provides a comprehensive set of tasks for building RAG pipelines. ### Vector and Storage Tasks -| Task | Description | -| ------------------------------ | ---------------------------------------- | -| `ChunkToVectorTask` | Transforms chunks to vector store format | -| `DocumentNodeVectorUpsertTask` | Stores vectors in a repository | -| `DocumentNodeVectorSearchTask` | Searches vectors by similarity | -| `VectorQuantizeTask` | Quantizes vectors for storage efficiency | +| Task | Description | +| ----------------------- | ---------------------------------------- | +| `ChunkToVectorTask` | Transforms chunks to vector store format | +| `ChunkVectorUpsertTask` | Stores vectors in a repository | +| `ChunkVectorSearchTask` | Searches vectors by similarity | +| `VectorQuantizeTask` | Quantizes vectors for storage efficiency | ### Retrieval and Generation Tasks -| Task | Description | -| ------------------------------------ | --------------------------------------------- | -| `QueryExpanderTask` | Expands queries for better retrieval coverage | -| `DocumentNodeVectorHybridSearchTask` | Combines vector and full-text search | -| `RerankerTask` | Reranks search results for relevance | -| `HierarchyJoinTask` | Enriches results with parent context | -| `ContextBuilderTask` | Builds context for LLM prompts | -| `DocumentNodeRetrievalTask` | Orchestrates end-to-end retrieval | +| Task | Description | +| ----------------------------- | --------------------------------------------- | +| `QueryExpanderTask` | Expands queries for better retrieval coverage | +| `ChunkVectorHybridSearchTask` | Combines vector and full-text search | +| `RerankerTask` | Reranks search results for relevance | +| `HierarchyJoinTask` | Enriches results with parent context | +| `ContextBuilderTask` | Builds context for LLM prompts | +| `DocumentNodeRetrievalTask` | Orchestrates end-to-end retrieval | ### Complete RAG Workflow Example diff --git a/packages/ai/src/task/DocumentNodeRetrievalTask.ts b/packages/ai/src/task/ChunkRetrievalTask.ts similarity index 96% rename from packages/ai/src/task/DocumentNodeRetrievalTask.ts rename to packages/ai/src/task/ChunkRetrievalTask.ts index b9b2bd92..9f91b457 100644 --- a/packages/ai/src/task/DocumentNodeRetrievalTask.ts +++ b/packages/ai/src/task/ChunkRetrievalTask.ts @@ -4,10 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - AnyDocumentNodeVectorRepository, - TypeDocumentNodeVectorRepository, -} from "@workglow/storage"; +import { AnyChunkVectorRepository, TypeChunkVectorRepository } from "@workglow/storage"; import { CreateWorkflow, IExecuteContext, @@ -28,7 +25,7 @@ import { TextEmbeddingTask } from "./TextEmbeddingTask"; const inputSchema = { type: "object", properties: { - repository: TypeDocumentNodeVectorRepository({ + repository: TypeChunkVectorRepository({ title: "Document Chunk Vector Repository", description: "The document chunk vector repository instance to search in", }), @@ -175,7 +172,7 @@ export class DocumentNodeRetrievalTask extends Task< } = input; // Repository is resolved by input resolver system before execution - const repo = repository as AnyDocumentNodeVectorRepository; + const repo = repository as AnyChunkVectorRepository; // Determine query vector let queryVector: TypedArray; diff --git a/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts b/packages/ai/src/task/ChunkVectorHybridSearchTask.ts similarity index 91% rename from packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts rename to packages/ai/src/task/ChunkVectorHybridSearchTask.ts index 90941c0a..df9b12cf 100644 --- a/packages/ai/src/task/DocumentNodeVectorHybridSearchTask.ts +++ b/packages/ai/src/task/ChunkVectorHybridSearchTask.ts @@ -4,10 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - AnyDocumentNodeVectorRepository, - TypeDocumentNodeVectorRepository, -} from "@workglow/storage"; +import { AnyChunkVectorRepository, TypeChunkVectorRepository } from "@workglow/storage"; import { CreateWorkflow, IExecuteContext, @@ -25,7 +22,7 @@ import { const inputSchema = { type: "object", properties: { - repository: TypeDocumentNodeVectorRepository({ + repository: TypeChunkVectorRepository({ title: "Document Chunk Vector Repository", description: "The document chunk vector repository instance to search in (must support hybridSearch)", @@ -139,12 +136,12 @@ export type HybridSearchTaskOutput = FromSchema { - public static type = "DocumentNodeVectorHybridSearchTask"; + public static type = "ChunkVectorHybridSearchTask"; public static category = "RAG"; public static title = "Hybrid Search"; public static description = "Combined vector + full-text search for improved retrieval"; @@ -174,7 +171,7 @@ export class DocumentNodeVectorHybridSearchTask extends Task< } = input; // Repository is resolved by input resolver system before execution - const repo = repository as AnyDocumentNodeVectorRepository; + const repo = repository as AnyChunkVectorRepository; // Check if repository supports hybrid search if (!repo.hybridSearch) { @@ -220,7 +217,7 @@ export const hybridSearch = async ( input: HybridSearchTaskInput, config?: JobQueueTaskConfig ): Promise => { - return new DocumentNodeVectorHybridSearchTask({} as HybridSearchTaskInput, config).run(input); + return new ChunkVectorHybridSearchTask({} as HybridSearchTaskInput, config).run(input); }; declare module "@workglow/task-graph" { @@ -229,4 +226,4 @@ declare module "@workglow/task-graph" { } } -Workflow.prototype.hybridSearch = CreateWorkflow(DocumentNodeVectorHybridSearchTask); +Workflow.prototype.hybridSearch = CreateWorkflow(ChunkVectorHybridSearchTask); diff --git a/packages/ai/src/task/DocumentNodeVectorSearchTask.ts b/packages/ai/src/task/ChunkVectorSearchTask.ts similarity index 88% rename from packages/ai/src/task/DocumentNodeVectorSearchTask.ts rename to packages/ai/src/task/ChunkVectorSearchTask.ts index 37d481a6..9e30dd5f 100644 --- a/packages/ai/src/task/DocumentNodeVectorSearchTask.ts +++ b/packages/ai/src/task/ChunkVectorSearchTask.ts @@ -4,10 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - AnyDocumentNodeVectorRepository, - TypeDocumentNodeVectorRepository, -} from "@workglow/storage"; +import { AnyChunkVectorRepository, TypeChunkVectorRepository } from "@workglow/storage"; import { CreateWorkflow, IExecuteContext, @@ -25,7 +22,7 @@ import { const inputSchema = { type: "object", properties: { - repository: TypeDocumentNodeVectorRepository({ + repository: TypeChunkVectorRepository({ title: "Vector Repository", description: "The vector repository instance to search in", }), @@ -109,12 +106,12 @@ export type VectorStoreSearchTaskOutput = FromSchema { - public static type = "DocumentNodeVectorSearchTask"; + public static type = "ChunkVectorSearchTask"; public static category = "Vector Store"; public static title = "Vector Store Search"; public static description = "Search for similar vectors in a vector repository"; @@ -134,7 +131,7 @@ export class DocumentNodeVectorSearchTask extends Task< ): Promise { const { repository, query, topK = 10, filter, scoreThreshold = 0 } = input; - const repo = repository as AnyDocumentNodeVectorRepository; + const repo = repository as AnyChunkVectorRepository; const results = await repo.similaritySearch(query, { topK, @@ -156,7 +153,7 @@ export const vectorStoreSearch = ( input: VectorStoreSearchTaskInput, config?: JobQueueTaskConfig ) => { - return new DocumentNodeVectorSearchTask({} as VectorStoreSearchTaskInput, config).run(input); + return new ChunkVectorSearchTask({} as VectorStoreSearchTaskInput, config).run(input); }; declare module "@workglow/task-graph" { @@ -169,4 +166,4 @@ declare module "@workglow/task-graph" { } } -Workflow.prototype.vectorStoreSearch = CreateWorkflow(DocumentNodeVectorSearchTask); +Workflow.prototype.vectorStoreSearch = CreateWorkflow(ChunkVectorSearchTask); diff --git a/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts b/packages/ai/src/task/ChunkVectorUpsertTask.ts similarity index 88% rename from packages/ai/src/task/DocumentNodeVectorUpsertTask.ts rename to packages/ai/src/task/ChunkVectorUpsertTask.ts index efc62a9b..6d84b3b1 100644 --- a/packages/ai/src/task/DocumentNodeVectorUpsertTask.ts +++ b/packages/ai/src/task/ChunkVectorUpsertTask.ts @@ -4,10 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - AnyDocumentNodeVectorRepository, - TypeDocumentNodeVectorRepository, -} from "@workglow/storage"; +import { AnyChunkVectorRepository, TypeChunkVectorRepository } from "@workglow/storage"; import { CreateWorkflow, IExecuteContext, @@ -31,7 +28,7 @@ const inputSchema = { title: "Document ID", description: "The document ID", }, - repository: TypeDocumentNodeVectorRepository({ + repository: TypeChunkVectorRepository({ title: "Document Chunk Vector Repository", description: "The document chunk vector repository instance to store vectors in", }), @@ -85,12 +82,12 @@ export type VectorStoreUpsertTaskOutput = FromSchema; * Task for upserting (insert or update) vectors into a vector repository. * Supports both single and bulk operations. */ -export class DocumentNodeVectorUpsertTask extends Task< +export class ChunkVectorUpsertTask extends Task< VectorStoreUpsertTaskInput, VectorStoreUpsertTaskOutput, JobQueueTaskConfig > { - public static type = "DocumentNodeVectorUpsertTask"; + public static type = "ChunkVectorUpsertTask"; public static category = "Vector Store"; public static title = "Vector Store Upsert"; public static description = "Store vector embeddings with metadata in a vector repository"; @@ -113,7 +110,7 @@ export class DocumentNodeVectorUpsertTask extends Task< // Normalize inputs to arrays const vectorArray = Array.isArray(vectors) ? vectors : [vectors]; - const repo = repository as AnyDocumentNodeVectorRepository; + const repo = repository as AnyChunkVectorRepository; await context.updateProgress(1, "Upserting vectors"); @@ -156,7 +153,7 @@ export const vectorStoreUpsert = ( input: VectorStoreUpsertTaskInput, config?: JobQueueTaskConfig ) => { - return new DocumentNodeVectorUpsertTask({} as VectorStoreUpsertTaskInput, config).run(input); + return new ChunkVectorUpsertTask({} as VectorStoreUpsertTaskInput, config).run(input); }; declare module "@workglow/task-graph" { @@ -169,4 +166,4 @@ declare module "@workglow/task-graph" { } } -Workflow.prototype.vectorStoreUpsert = CreateWorkflow(DocumentNodeVectorUpsertTask); +Workflow.prototype.vectorStoreUpsert = CreateWorkflow(ChunkVectorUpsertTask); diff --git a/packages/ai/src/task/index.ts b/packages/ai/src/task/index.ts index 00436895..96c0db1b 100644 --- a/packages/ai/src/task/index.ts +++ b/packages/ai/src/task/index.ts @@ -6,13 +6,13 @@ import { TaskRegistry } from "@workglow/task-graph"; import { BackgroundRemovalTask } from "./BackgroundRemovalTask"; +import { DocumentNodeRetrievalTask } from "./ChunkRetrievalTask"; import { ChunkToVectorTask } from "./ChunkToVectorTask"; +import { ChunkVectorHybridSearchTask } from "./ChunkVectorHybridSearchTask"; +import { ChunkVectorSearchTask } from "./ChunkVectorSearchTask"; +import { ChunkVectorUpsertTask } from "./ChunkVectorUpsertTask"; import { ContextBuilderTask } from "./ContextBuilderTask"; import { DocumentEnricherTask } from "./DocumentEnricherTask"; -import { DocumentNodeRetrievalTask } from "./DocumentNodeRetrievalTask"; -import { DocumentNodeVectorHybridSearchTask } from "./DocumentNodeVectorHybridSearchTask"; -import { DocumentNodeVectorSearchTask } from "./DocumentNodeVectorSearchTask"; -import { DocumentNodeVectorUpsertTask } from "./DocumentNodeVectorUpsertTask"; import { DownloadModelTask } from "./DownloadModelTask"; import { FaceDetectorTask } from "./FaceDetectorTask"; import { FaceLandmarkerTask } from "./FaceLandmarkerTask"; @@ -55,9 +55,9 @@ export const registerAiTasks = () => { ContextBuilderTask, DocumentEnricherTask, DocumentNodeRetrievalTask, - DocumentNodeVectorHybridSearchTask, - DocumentNodeVectorSearchTask, - DocumentNodeVectorUpsertTask, + ChunkVectorHybridSearchTask, + ChunkVectorSearchTask, + ChunkVectorUpsertTask, DownloadModelTask, FaceDetectorTask, FaceLandmarkerTask, @@ -95,13 +95,13 @@ export const registerAiTasks = () => { export * from "./BackgroundRemovalTask"; export * from "./base/AiTask"; export * from "./base/AiTaskSchemas"; +export * from "./ChunkRetrievalTask"; export * from "./ChunkToVectorTask"; +export * from "./ChunkVectorHybridSearchTask"; +export * from "./ChunkVectorSearchTask"; +export * from "./ChunkVectorUpsertTask"; export * from "./ContextBuilderTask"; export * from "./DocumentEnricherTask"; -export * from "./DocumentNodeRetrievalTask"; -export * from "./DocumentNodeVectorHybridSearchTask"; -export * from "./DocumentNodeVectorSearchTask"; -export * from "./DocumentNodeVectorUpsertTask"; export * from "./DownloadModelTask"; export * from "./FaceDetectorTask"; export * from "./FaceLandmarkerTask"; diff --git a/packages/storage/src/document-node-vector/DocumentNodeVectorRepositoryRegistry.ts b/packages/storage/src/chunk-vector/ChunkVectorRepositoryRegistry.ts similarity index 64% rename from packages/storage/src/document-node-vector/DocumentNodeVectorRepositoryRegistry.ts rename to packages/storage/src/chunk-vector/ChunkVectorRepositoryRegistry.ts index e2f69781..c50d90f2 100644 --- a/packages/storage/src/document-node-vector/DocumentNodeVectorRepositoryRegistry.ts +++ b/packages/storage/src/chunk-vector/ChunkVectorRepositoryRegistry.ts @@ -10,21 +10,21 @@ import { registerInputResolver, ServiceRegistry, } from "@workglow/util"; -import { AnyDocumentNodeVectorRepository } from "./IDocumentNodeVectorRepository"; +import { AnyChunkVectorRepository } from "./IChunkVectorRepository"; /** * Service token for the documenbt chunk vector repository registry * Maps repository IDs to IVectorChunkRepository instances */ export const DOCUMENT_CHUNK_VECTOR_REPOSITORIES = createServiceToken< - Map + Map >("storage.document-node-vector.repositories"); // Register default factory if not already registered if (!globalServiceRegistry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES)) { globalServiceRegistry.register( DOCUMENT_CHUNK_VECTOR_REPOSITORIES, - (): Map => new Map(), + (): Map => new Map(), true ); } @@ -33,10 +33,7 @@ if (!globalServiceRegistry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES)) { * Gets the global document chunk vector repository registry * @returns Map of document chunk vector repository ID to instance */ -export function getGlobalDocumentNodeVectorRepositories(): Map< - string, - AnyDocumentNodeVectorRepository -> { +export function getGlobalChunkVectorRepositories(): Map { return globalServiceRegistry.get(DOCUMENT_CHUNK_VECTOR_REPOSITORIES); } @@ -45,11 +42,11 @@ export function getGlobalDocumentNodeVectorRepositories(): Map< * @param id The unique identifier for this repository * @param repository The repository instance to register */ -export function registerDocumentNodeVectorRepository( +export function registerChunkVectorRepository( id: string, - repository: AnyDocumentNodeVectorRepository + repository: AnyChunkVectorRepository ): void { - const repos = getGlobalDocumentNodeVectorRepositories(); + const repos = getGlobalChunkVectorRepositories(); repos.set(id, repository); } @@ -58,24 +55,22 @@ export function registerDocumentNodeVectorRepository( * @param id The repository identifier * @returns The repository instance or undefined if not found */ -export function getDocumentNodeVectorRepository( - id: string -): AnyDocumentNodeVectorRepository | undefined { - return getGlobalDocumentNodeVectorRepositories().get(id); +export function getChunkVectorRepository(id: string): AnyChunkVectorRepository | undefined { + return getGlobalChunkVectorRepositories().get(id); } /** * Resolves a repository ID to an IVectorChunkRepository from the registry. * Used by the input resolver system. */ -async function resolveDocumentNodeVectorRepositoryFromRegistry( +async function resolveChunkVectorRepositoryFromRegistry( id: string, format: string, registry: ServiceRegistry -): Promise { +): Promise { const repos = registry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) - ? registry.get>(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) - : getGlobalDocumentNodeVectorRepositories(); + ? registry.get>(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) + : getGlobalChunkVectorRepositories(); const repo = repos.get(id); if (!repo) { @@ -85,7 +80,4 @@ async function resolveDocumentNodeVectorRepositoryFromRegistry( } // Register the repository resolver for format: "repository:document-node-vector" -registerInputResolver( - "repository:document-node-vector", - resolveDocumentNodeVectorRepositoryFromRegistry -); +registerInputResolver("repository:document-node-vector", resolveChunkVectorRepositoryFromRegistry); diff --git a/packages/storage/src/document-node-vector/DocumentNodeVectorSchema.ts b/packages/storage/src/chunk-vector/ChunkVectorSchema.ts similarity index 73% rename from packages/storage/src/document-node-vector/DocumentNodeVectorSchema.ts rename to packages/storage/src/chunk-vector/ChunkVectorSchema.ts index c3bd30b1..8ce95438 100644 --- a/packages/storage/src/document-node-vector/DocumentNodeVectorSchema.ts +++ b/packages/storage/src/chunk-vector/ChunkVectorSchema.ts @@ -9,7 +9,7 @@ import { TypedArraySchema, type DataPortSchemaObject, type TypedArray } from "@w /** * Default schema for document chunk storage with vector embeddings */ -export const DocumentNodeVectorSchema = { +export const ChunkVectorSchema = { type: "object", properties: { chunk_id: { type: "string" }, @@ -19,12 +19,12 @@ export const DocumentNodeVectorSchema = { }, additionalProperties: false, } as const satisfies DataPortSchemaObject; -export type DocumentNodeVectorSchema = typeof DocumentNodeVectorSchema; +export type ChunkVectorSchema = typeof ChunkVectorSchema; -export const DocumentNodeVectorKey = ["chunk_id"] as const; -export type DocumentNodeVectorKey = typeof DocumentNodeVectorKey; +export const ChunkVectorKey = ["chunk_id"] as const; +export type ChunkVectorKey = typeof ChunkVectorKey; -export interface DocumentNodeVector< +export interface ChunkVector< Metadata extends Record = Record, Vector extends TypedArray = Float32Array, > { diff --git a/packages/storage/src/document-node-vector/IDocumentNodeVectorRepository.ts b/packages/storage/src/chunk-vector/IChunkVectorRepository.ts similarity index 96% rename from packages/storage/src/document-node-vector/IDocumentNodeVectorRepository.ts rename to packages/storage/src/chunk-vector/IChunkVectorRepository.ts index 548c2695..060d7108 100644 --- a/packages/storage/src/document-node-vector/IDocumentNodeVectorRepository.ts +++ b/packages/storage/src/chunk-vector/IChunkVectorRepository.ts @@ -13,7 +13,7 @@ import type { } from "@workglow/util"; import type { ITabularRepository, TabularEventListeners } from "../tabular/ITabularRepository"; -export type AnyDocumentNodeVectorRepository = IDocumentNodeVectorRepository; +export type AnyChunkVectorRepository = IChunkVectorRepository; /** * Options for vector search operations @@ -69,7 +69,7 @@ export type VectorChunkEventParameters< * @typeParam PrimaryKeyNames - Array of property names that form the primary key * @typeParam Entity - The entity type */ -export interface IDocumentNodeVectorRepository< +export interface IChunkVectorRepository< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, Entity = FromSchema, diff --git a/packages/storage/src/document-node-vector/InMemoryDocumentNodeVectorRepository.ts b/packages/storage/src/chunk-vector/InMemoryChunkVectorRepository.ts similarity index 86% rename from packages/storage/src/document-node-vector/InMemoryDocumentNodeVectorRepository.ts rename to packages/storage/src/chunk-vector/InMemoryChunkVectorRepository.ts index 61fb1d8b..dcadfb3f 100644 --- a/packages/storage/src/document-node-vector/InMemoryDocumentNodeVectorRepository.ts +++ b/packages/storage/src/chunk-vector/InMemoryChunkVectorRepository.ts @@ -7,16 +7,12 @@ import type { TypedArray } from "@workglow/util"; import { cosineSimilarity } from "@workglow/util"; import { InMemoryTabularRepository } from "../tabular/InMemoryTabularRepository"; -import { - DocumentNodeVector, - DocumentNodeVectorKey, - DocumentNodeVectorSchema, -} from "./DocumentNodeVectorSchema"; +import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; import type { HybridSearchOptions, - IDocumentNodeVectorRepository, + IChunkVectorRepository, VectorSearchOptions, -} from "./IDocumentNodeVectorRepository"; +} from "./IChunkVectorRepository"; /** * Check if metadata matches filter @@ -58,20 +54,20 @@ function textRelevance(text: string, query: string): number { * @template Metadata - The metadata type for the document chunk * @template Vector - The vector type for the document chunk */ -export class InMemoryDocumentNodeVectorRepository< +export class InMemoryChunkVectorRepository< Metadata extends Record = Record, Vector extends TypedArray = Float32Array, > extends InMemoryTabularRepository< - typeof DocumentNodeVectorSchema, - typeof DocumentNodeVectorKey, - DocumentNodeVector + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector > implements - IDocumentNodeVectorRepository< - typeof DocumentNodeVectorSchema, - typeof DocumentNodeVectorKey, - DocumentNodeVector + IChunkVectorRepository< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector > { private vectorDimensions: number; @@ -83,7 +79,7 @@ export class InMemoryDocumentNodeVectorRepository< * @param VectorType - The type of vector to use (defaults to Float32Array) */ constructor(dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array) { - super(DocumentNodeVectorSchema, DocumentNodeVectorKey); + super(ChunkVectorSchema, ChunkVectorKey); this.vectorDimensions = dimensions; this.VectorType = VectorType; @@ -102,7 +98,7 @@ export class InMemoryDocumentNodeVectorRepository< options: VectorSearchOptions> = {} ) { const { topK = 10, filter, scoreThreshold = 0 } = options; - const results: Array & { score: number }> = []; + const results: Array & { score: number }> = []; const allEntities = (await this.getAll()) || []; @@ -145,7 +141,7 @@ export class InMemoryDocumentNodeVectorRepository< return this.similaritySearch(query, { topK, filter, scoreThreshold }); } - const results: Array & { score: number }> = []; + const results: Array & { score: number }> = []; const allEntities = (await this.getAll()) || []; for (const entity of allEntities) { diff --git a/packages/storage/src/document-node-vector/PostgresDocumentNodeVectorRepository.ts b/packages/storage/src/chunk-vector/PostgresChunkVectorRepository.ts similarity index 89% rename from packages/storage/src/document-node-vector/PostgresDocumentNodeVectorRepository.ts rename to packages/storage/src/chunk-vector/PostgresChunkVectorRepository.ts index 5d067c7c..6ca5c27d 100644 --- a/packages/storage/src/document-node-vector/PostgresDocumentNodeVectorRepository.ts +++ b/packages/storage/src/chunk-vector/PostgresChunkVectorRepository.ts @@ -7,16 +7,12 @@ import { cosineSimilarity, type TypedArray } from "@workglow/util"; import type { Pool } from "pg"; import { PostgresTabularRepository } from "../tabular/PostgresTabularRepository"; -import { - DocumentNodeVector, - DocumentNodeVectorKey, - DocumentNodeVectorSchema, -} from "./DocumentNodeVectorSchema"; +import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; import type { HybridSearchOptions, - IDocumentNodeVectorRepository, + IChunkVectorRepository, VectorSearchOptions, -} from "./IDocumentNodeVectorRepository"; +} from "./IChunkVectorRepository"; /** * PostgreSQL document chunk vector repository implementation using pgvector extension. @@ -30,20 +26,20 @@ import type { * @template Metadata - The metadata type for the document chunk * @template Vector - The vector type for the document chunk */ -export class PostgresDocumentNodeVectorRepository< +export class PostgresChunkVectorRepository< Metadata extends Record = Record, Vector extends TypedArray = Float32Array, > extends PostgresTabularRepository< - typeof DocumentNodeVectorSchema, - typeof DocumentNodeVectorKey, - DocumentNodeVector + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector > implements - IDocumentNodeVectorRepository< - typeof DocumentNodeVectorSchema, - typeof DocumentNodeVectorKey, - DocumentNodeVector + IChunkVectorRepository< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector > { private vectorDimensions: number; @@ -61,7 +57,7 @@ export class PostgresDocumentNodeVectorRepository< dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array ) { - super(db, table, DocumentNodeVectorSchema, DocumentNodeVectorKey); + super(db, table, ChunkVectorSchema, ChunkVectorKey); this.vectorDimensions = dimensions; this.VectorType = VectorType; @@ -74,7 +70,7 @@ export class PostgresDocumentNodeVectorRepository< async similaritySearch( query: TypedArray, options: VectorSearchOptions = {} - ): Promise & { score: number }>> { + ): Promise & { score: number }>> { const { topK = 10, filter, scoreThreshold = 0 } = options; try { @@ -113,7 +109,7 @@ export class PostgresDocumentNodeVectorRepository< const result = await this.db.query(sql, params); // Fetch vectors separately for each result - const results: Array & { score: number }> = []; + const results: Array & { score: number }> = []; for (const row of result.rows) { const vectorResult = await this.db.query( `SELECT vector::text FROM "${this.table}" WHERE id = $1`, @@ -188,7 +184,7 @@ export class PostgresDocumentNodeVectorRepository< const result = await this.db.query(sql, params); // Fetch vectors separately for each result - const results: Array & { score: number }> = []; + const results: Array & { score: number }> = []; for (const row of result.rows) { const vectorResult = await this.db.query( `SELECT vector::text FROM "${this.table}" WHERE id = $1`, @@ -218,7 +214,7 @@ export class PostgresDocumentNodeVectorRepository< private async searchFallback(query: TypedArray, options: VectorSearchOptions) { const { topK = 10, filter, scoreThreshold = 0 } = options; const allRows = (await this.getAll()) || []; - const results: Array & { score: number }> = []; + const results: Array & { score: number }> = []; for (const row of allRows) { const vector = row.vector; @@ -248,7 +244,7 @@ export class PostgresDocumentNodeVectorRepository< const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; const allRows = (await this.getAll()) || []; - const results: Array & { score: number }> = []; + const results: Array & { score: number }> = []; const queryLower = textQuery.toLowerCase(); const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); diff --git a/packages/storage/src/document-node-vector/README.md b/packages/storage/src/chunk-vector/README.md similarity index 100% rename from packages/storage/src/document-node-vector/README.md rename to packages/storage/src/chunk-vector/README.md diff --git a/packages/storage/src/document-node-vector/SqliteDocumentNodeVectorRepository.ts b/packages/storage/src/chunk-vector/SqliteChunkVectorRepository.ts similarity index 87% rename from packages/storage/src/document-node-vector/SqliteDocumentNodeVectorRepository.ts rename to packages/storage/src/chunk-vector/SqliteChunkVectorRepository.ts index 0dd02653..d9609f1f 100644 --- a/packages/storage/src/document-node-vector/SqliteDocumentNodeVectorRepository.ts +++ b/packages/storage/src/chunk-vector/SqliteChunkVectorRepository.ts @@ -8,16 +8,12 @@ import { Sqlite } from "@workglow/sqlite"; import type { TypedArray } from "@workglow/util"; import { cosineSimilarity } from "@workglow/util"; import { SqliteTabularRepository } from "../tabular/SqliteTabularRepository"; -import { - DocumentNodeVector, - DocumentNodeVectorKey, - DocumentNodeVectorSchema, -} from "./DocumentNodeVectorSchema"; +import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; import type { HybridSearchOptions, - IDocumentNodeVectorRepository, + IChunkVectorRepository, VectorSearchOptions, -} from "./IDocumentNodeVectorRepository"; +} from "./IChunkVectorRepository"; /** * Check if metadata matches filter @@ -38,20 +34,20 @@ function matchesFilter(metadata: Metadata, filter: Partial): * @template Metadata - The metadata type for the document chunk * @template Vector - The vector type for the document chunk */ -export class SqliteDocumentNodeVectorRepository< +export class SqliteChunkVectorRepository< Metadata extends Record = Record, Vector extends TypedArray = Float32Array, > extends SqliteTabularRepository< - typeof DocumentNodeVectorSchema, - typeof DocumentNodeVectorKey, - DocumentNodeVector + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector > implements - IDocumentNodeVectorRepository< - typeof DocumentNodeVectorSchema, - typeof DocumentNodeVectorKey, - DocumentNodeVector + IChunkVectorRepository< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector > { private vectorDimensions: number; @@ -70,7 +66,7 @@ export class SqliteDocumentNodeVectorRepository< dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array ) { - super(dbOrPath, table, DocumentNodeVectorSchema, DocumentNodeVectorKey); + super(dbOrPath, table, ChunkVectorSchema, ChunkVectorKey); this.vectorDimensions = dimensions; this.VectorType = VectorType; @@ -92,7 +88,7 @@ export class SqliteDocumentNodeVectorRepository< async similaritySearch(query: TypedArray, options: VectorSearchOptions = {}) { const { topK = 10, filter, scoreThreshold = 0 } = options; - const results: Array & { score: number }> = []; + const results: Array & { score: number }> = []; const allEntities = (await this.getAll()) || []; @@ -137,7 +133,7 @@ export class SqliteDocumentNodeVectorRepository< return this.similaritySearch(query, { topK, filter, scoreThreshold }); } - const results: Array & { score: number }> = []; + const results: Array & { score: number }> = []; const allEntities = (await this.getAll()) || []; const queryLower = textQuery.toLowerCase(); const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); diff --git a/packages/storage/src/common-server.ts b/packages/storage/src/common-server.ts index 61b76cf2..096b9bfe 100644 --- a/packages/storage/src/common-server.ts +++ b/packages/storage/src/common-server.ts @@ -25,8 +25,8 @@ export * from "./queue-limiter/PostgresRateLimiterStorage"; export * from "./queue-limiter/SqliteRateLimiterStorage"; export * from "./queue-limiter/SupabaseRateLimiterStorage"; -export * from "./document-node-vector/PostgresDocumentNodeVectorRepository"; -export * from "./document-node-vector/SqliteDocumentNodeVectorRepository"; +export * from "./chunk-vector/PostgresChunkVectorRepository"; +export * from "./chunk-vector/SqliteChunkVectorRepository"; // testing export * from "./kv/IndexedDbKvRepository"; diff --git a/packages/storage/src/common.ts b/packages/storage/src/common.ts index 3049d312..9c0d27e3 100644 --- a/packages/storage/src/common.ts +++ b/packages/storage/src/common.ts @@ -34,7 +34,7 @@ export * from "./document/DocumentSchema"; export * from "./document/DocumentStorageSchema"; export * from "./document/StructuralParser"; -export * from "./document-node-vector/DocumentNodeVectorRepositoryRegistry"; -export * from "./document-node-vector/DocumentNodeVectorSchema"; -export * from "./document-node-vector/IDocumentNodeVectorRepository"; -export * from "./document-node-vector/InMemoryDocumentNodeVectorRepository"; +export * from "./chunk-vector/ChunkVectorRepositoryRegistry"; +export * from "./chunk-vector/ChunkVectorSchema"; +export * from "./chunk-vector/IChunkVectorRepository"; +export * from "./chunk-vector/InMemoryChunkVectorRepository"; diff --git a/packages/storage/src/document/DocumentRepository.ts b/packages/storage/src/document/DocumentRepository.ts index 8f0cafc7..1e61c627 100644 --- a/packages/storage/src/document/DocumentRepository.ts +++ b/packages/storage/src/document/DocumentRepository.ts @@ -5,11 +5,11 @@ */ import type { TypedArray } from "@workglow/util"; -import { DocumentNodeVector } from "../document-node-vector/DocumentNodeVectorSchema"; +import { ChunkVector } from "../chunk-vector/ChunkVectorSchema"; import type { - AnyDocumentNodeVectorRepository, + AnyChunkVectorRepository, VectorSearchOptions, -} from "../document-node-vector/IDocumentNodeVectorRepository"; +} from "../chunk-vector/IChunkVectorRepository"; import type { ITabularRepository } from "../tabular/ITabularRepository"; import { Document } from "./Document"; import { ChunkNode, DocumentNode } from "./DocumentSchema"; @@ -29,7 +29,7 @@ export class DocumentRepository { DocumentStorageKey, DocumentStorageEntity >; - private vectorStorage?: AnyDocumentNodeVectorRepository; + private vectorStorage?: AnyChunkVectorRepository; /** * Creates a new DocumentRepository instance. @@ -54,7 +54,7 @@ export class DocumentRepository { ["doc_id"], DocumentStorageEntity >, - vectorStorage?: AnyDocumentNodeVectorRepository + vectorStorage?: AnyChunkVectorRepository ) { this.tabularStorage = tabularStorage; this.vectorStorage = vectorStorage; @@ -216,7 +216,7 @@ export class DocumentRepository { async search( query: TypedArray, options?: VectorSearchOptions> - ): Promise, TypedArray>>> { + ): Promise, TypedArray>>> { return this.vectorStorage?.similaritySearch(query, options) || []; } } diff --git a/packages/storage/src/util/RepositorySchema.ts b/packages/storage/src/util/RepositorySchema.ts index 9d4805b2..b7ccc3f1 100644 --- a/packages/storage/src/util/RepositorySchema.ts +++ b/packages/storage/src/util/RepositorySchema.ts @@ -58,7 +58,7 @@ export function TypeTabularRepository = {}>( * @param options Additional schema options to merge * @returns JSON schema for vector repository input */ -export function TypeDocumentNodeVectorRepository = {}>( +export function TypeChunkVectorRepository = {}>( options: O = {} as O ) { return { diff --git a/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts b/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts index 754f119a..3ace34d3 100644 --- a/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts +++ b/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts @@ -5,17 +5,14 @@ */ import { retrieval } from "@workglow/ai"; -import { - InMemoryDocumentNodeVectorRepository, - registerDocumentNodeVectorRepository, -} from "@workglow/storage"; +import { InMemoryChunkVectorRepository, registerChunkVectorRepository } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("DocumentNodeRetrievalTask", () => { - let repo: InMemoryDocumentNodeVectorRepository; + let repo: InMemoryChunkVectorRepository; beforeEach(async () => { - repo = new InMemoryDocumentNodeVectorRepository(3); + repo = new InMemoryChunkVectorRepository(3); await repo.setupDatabase(); // Populate repository with test data @@ -271,7 +268,7 @@ describe("DocumentNodeRetrievalTask", () => { test("should resolve repository from string ID", async () => { // Register repository by ID - registerDocumentNodeVectorRepository("test-retrieval-repo", repo); + registerChunkVectorRepository("test-retrieval-repo", repo); const queryVector = new Float32Array([1.0, 0.0, 0.0]); diff --git a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts b/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts index ccf1b37a..095827e0 100644 --- a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts +++ b/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts @@ -4,18 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DocumentNodeVectorSearchTask } from "@workglow/ai"; -import { - InMemoryDocumentNodeVectorRepository, - registerDocumentNodeVectorRepository, -} from "@workglow/storage"; +import { ChunkVectorSearchTask } from "@workglow/ai"; +import { InMemoryChunkVectorRepository, registerChunkVectorRepository } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; -describe("DocumentNodeVectorSearchTask", () => { - let repo: InMemoryDocumentNodeVectorRepository; +describe("ChunkVectorSearchTask", () => { + let repo: InMemoryChunkVectorRepository; beforeEach(async () => { - repo = new InMemoryDocumentNodeVectorRepository(3); + repo = new InMemoryChunkVectorRepository(3); await repo.setupDatabase(); // Populate repository with test data @@ -53,7 +50,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should search and return top K results", async () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -78,7 +75,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should respect topK limit", async () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -92,7 +89,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should filter by metadata", async () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -110,7 +107,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should apply score threshold", async () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -127,7 +124,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should return empty results when no matches", async () => { const queryVector = new Float32Array([0.0, 0.0, 1.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -145,7 +142,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should handle default topK value", async () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -159,7 +156,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should work with quantized query vectors (Int8Array)", async () => { const queryVector = new Int8Array([127, 0, 0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -174,7 +171,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should return results sorted by similarity score", async () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -188,12 +185,12 @@ describe("DocumentNodeVectorSearchTask", () => { }); test("should handle empty repository", async () => { - const emptyRepo = new InMemoryDocumentNodeVectorRepository(3); + const emptyRepo = new InMemoryChunkVectorRepository(3); await emptyRepo.setupDatabase(); const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: emptyRepo, query: queryVector, @@ -210,7 +207,7 @@ describe("DocumentNodeVectorSearchTask", () => { test("should combine filter and score threshold", async () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); const result = await task.run({ repository: repo, query: queryVector, @@ -230,11 +227,11 @@ describe("DocumentNodeVectorSearchTask", () => { test("should resolve repository from string ID", async () => { // Register repository by ID - registerDocumentNodeVectorRepository("test-vector-repo", repo); + registerChunkVectorRepository("test-vector-repo", repo); const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const task = new DocumentNodeVectorSearchTask(); + const task = new ChunkVectorSearchTask(); // Pass repository as string ID instead of instance const result = await task.run({ repository: "test-vector-repo" as any, diff --git a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts b/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts index 97a46304..c42ec1d3 100644 --- a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts +++ b/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts @@ -4,18 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DocumentNodeVectorUpsertTask } from "@workglow/ai"; -import { - InMemoryDocumentNodeVectorRepository, - registerDocumentNodeVectorRepository, -} from "@workglow/storage"; +import { ChunkVectorUpsertTask } from "@workglow/ai"; +import { InMemoryChunkVectorRepository, registerChunkVectorRepository } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; -describe("DocumentNodeVectorUpsertTask", () => { - let repo: InMemoryDocumentNodeVectorRepository; +describe("ChunkVectorUpsertTask", () => { + let repo: InMemoryChunkVectorRepository; beforeEach(async () => { - repo = new InMemoryDocumentNodeVectorRepository(3); + repo = new InMemoryChunkVectorRepository(3); await repo.setupDatabase(); }); @@ -27,7 +24,7 @@ describe("DocumentNodeVectorUpsertTask", () => { const vector = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]); const metadata = { text: "Test document", source: "test.txt" }; - const task = new DocumentNodeVectorUpsertTask(); + const task = new ChunkVectorUpsertTask(); const result = await task.run({ repository: repo, doc_id: "doc1", @@ -54,7 +51,7 @@ describe("DocumentNodeVectorUpsertTask", () => { ]; const metadata = { text: "Document with multiple vectors", source: "doc.txt" }; - const task = new DocumentNodeVectorUpsertTask(); + const task = new ChunkVectorUpsertTask(); const result = await task.run({ repository: repo, doc_id: "doc1", @@ -79,7 +76,7 @@ describe("DocumentNodeVectorUpsertTask", () => { const vector = [new Float32Array([0.1, 0.2, 0.3])]; const metadata = { text: "Single item as array" }; - const task = new DocumentNodeVectorUpsertTask(); + const task = new ChunkVectorUpsertTask(); const result = await task.run({ repository: repo, doc_id: "doc1", @@ -102,7 +99,7 @@ describe("DocumentNodeVectorUpsertTask", () => { const metadata2 = { text: "Updated document", source: "updated.txt" }; // First upsert - const task1 = new DocumentNodeVectorUpsertTask(); + const task1 = new ChunkVectorUpsertTask(); const result1 = await task1.run({ repository: repo, doc_id: "doc1", @@ -111,7 +108,7 @@ describe("DocumentNodeVectorUpsertTask", () => { }); // Update with same ID - const task2 = new DocumentNodeVectorUpsertTask(); + const task2 = new ChunkVectorUpsertTask(); const result2 = await task2.run({ repository: repo, doc_id: "doc1", @@ -128,7 +125,7 @@ describe("DocumentNodeVectorUpsertTask", () => { const vectors = [new Float32Array([0.1, 0.2]), new Float32Array([0.3, 0.4])]; const metadata = { text: "Shared metadata" }; - const task = new DocumentNodeVectorUpsertTask(); + const task = new ChunkVectorUpsertTask(); const result = await task.run({ repository: repo, doc_id: "doc1", @@ -144,7 +141,7 @@ describe("DocumentNodeVectorUpsertTask", () => { const vector = new Int8Array([127, -128, 64, -64, 0]); const metadata = { text: "Quantized vector" }; - const task = new DocumentNodeVectorUpsertTask(); + const task = new ChunkVectorUpsertTask(); const result = await task.run({ repository: repo, doc_id: "doc1", @@ -163,7 +160,7 @@ describe("DocumentNodeVectorUpsertTask", () => { const vector = new Float32Array([0.1, 0.2, 0.3]); const metadata = { text: "Simple metadata" }; - const task = new DocumentNodeVectorUpsertTask(); + const task = new ChunkVectorUpsertTask(); const result = await task.run({ repository: repo, doc_id: "doc1", @@ -185,7 +182,7 @@ describe("DocumentNodeVectorUpsertTask", () => { ); const metadata = { text: "Batch document" }; - const task = new DocumentNodeVectorUpsertTask(); + const task = new ChunkVectorUpsertTask(); const result = await task.run({ repository: repo, doc_id: "batch-doc", @@ -202,12 +199,12 @@ describe("DocumentNodeVectorUpsertTask", () => { test("should resolve repository from string ID", async () => { // Register repository by ID - registerDocumentNodeVectorRepository("test-upsert-repo", repo); + registerChunkVectorRepository("test-upsert-repo", repo); const vector = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]); const metadata = { text: "Test document", source: "test.txt" }; - const task = new DocumentNodeVectorUpsertTask(); + const task = new ChunkVectorUpsertTask(); // Pass repository as string ID instead of instance const result = await task.run({ repository: "test-upsert-repo" as any, diff --git a/packages/test/src/test/rag/DocumentRepository.test.ts b/packages/test/src/test/rag/DocumentRepository.test.ts index 5f6bf291..66a3c695 100644 --- a/packages/test/src/test/rag/DocumentRepository.test.ts +++ b/packages/test/src/test/rag/DocumentRepository.test.ts @@ -9,7 +9,7 @@ import { DocumentRepository, DocumentStorageKey, DocumentStorageSchema, - InMemoryDocumentNodeVectorRepository, + InMemoryChunkVectorRepository, InMemoryTabularRepository, NodeIdGenerator, NodeKind, @@ -19,7 +19,7 @@ import { beforeEach, describe, expect, it } from "vitest"; describe("DocumentRepository", () => { let repo: DocumentRepository; - let vectorStorage: InMemoryDocumentNodeVectorRepository; + let vectorStorage: InMemoryChunkVectorRepository; beforeEach(async () => { const tabularStorage = new InMemoryTabularRepository( @@ -28,7 +28,7 @@ describe("DocumentRepository", () => { ); await tabularStorage.setupDatabase(); - vectorStorage = new InMemoryDocumentNodeVectorRepository(3); + vectorStorage = new InMemoryChunkVectorRepository(3); await vectorStorage.setupDatabase(); repo = new DocumentRepository(tabularStorage, vectorStorage); diff --git a/packages/test/src/test/rag/EndToEnd.test.ts b/packages/test/src/test/rag/EndToEnd.test.ts index 1f8f277d..ecbaf1e7 100644 --- a/packages/test/src/test/rag/EndToEnd.test.ts +++ b/packages/test/src/test/rag/EndToEnd.test.ts @@ -10,7 +10,7 @@ import { DocumentRepository, DocumentStorageKey, DocumentStorageSchema, - InMemoryDocumentNodeVectorRepository, + InMemoryChunkVectorRepository, InMemoryTabularRepository, NodeIdGenerator, StructuralParser, @@ -88,7 +88,7 @@ Finds patterns in data.`; ); await tabularStorage.setupDatabase(); - const vectorStorage = new InMemoryDocumentNodeVectorRepository(3); + const vectorStorage = new InMemoryChunkVectorRepository(3); await vectorStorage.setupDatabase(); const docRepo = new DocumentRepository(tabularStorage, vectorStorage); diff --git a/packages/test/src/test/rag/FullChain.test.ts b/packages/test/src/test/rag/FullChain.test.ts index 5460e71b..cb4d61e2 100644 --- a/packages/test/src/test/rag/FullChain.test.ts +++ b/packages/test/src/test/rag/FullChain.test.ts @@ -5,17 +5,13 @@ */ import { HierarchicalChunkerTaskOutput } from "@workglow/ai"; -import { - ChunkNode, - InMemoryDocumentNodeVectorRepository, - NodeIdGenerator, -} from "@workglow/storage"; +import { ChunkNode, InMemoryChunkVectorRepository, NodeIdGenerator } from "@workglow/storage"; import { Workflow } from "@workglow/task-graph"; import { describe, expect, it } from "vitest"; describe("Complete chainable workflow", () => { it("should chain from parsing to storage without loops", async () => { - const vectorRepo = new InMemoryDocumentNodeVectorRepository(3); + const vectorRepo = new InMemoryChunkVectorRepository(3); await vectorRepo.setupDatabase(); const markdown = `# Test Document diff --git a/packages/test/src/test/rag/HybridSearchTask.test.ts b/packages/test/src/test/rag/HybridSearchTask.test.ts index 03ad874b..89c1ef4b 100644 --- a/packages/test/src/test/rag/HybridSearchTask.test.ts +++ b/packages/test/src/test/rag/HybridSearchTask.test.ts @@ -5,17 +5,14 @@ */ import { hybridSearch } from "@workglow/ai"; -import { - InMemoryDocumentNodeVectorRepository, - registerDocumentNodeVectorRepository, -} from "@workglow/storage"; +import { InMemoryChunkVectorRepository, registerChunkVectorRepository } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; -describe("DocumentNodeVectorHybridSearchTask", () => { - let repo: InMemoryDocumentNodeVectorRepository; +describe("ChunkVectorHybridSearchTask", () => { + let repo: InMemoryChunkVectorRepository; beforeEach(async () => { - repo = new InMemoryDocumentNodeVectorRepository(3); + repo = new InMemoryChunkVectorRepository(3); await repo.setupDatabase(); // Populate repository with test data @@ -251,7 +248,7 @@ describe("DocumentNodeVectorHybridSearchTask", () => { test("should resolve repository from string ID", async () => { // Register repository by ID - registerDocumentNodeVectorRepository("test-hybrid-repo", repo); + registerChunkVectorRepository("test-hybrid-repo", repo); const queryVector = new Float32Array([1.0, 0.0, 0.0]); const queryText = "machine learning"; diff --git a/packages/test/src/test/rag/RagWorkflow.test.ts b/packages/test/src/test/rag/RagWorkflow.test.ts index 62508381..ac595af7 100644 --- a/packages/test/src/test/rag/RagWorkflow.test.ts +++ b/packages/test/src/test/rag/RagWorkflow.test.ts @@ -50,9 +50,9 @@ import { DocumentRepository, DocumentStorageKey, DocumentStorageSchema, - InMemoryDocumentNodeVectorRepository, + InMemoryChunkVectorRepository, InMemoryTabularRepository, - registerDocumentNodeVectorRepository, + registerChunkVectorRepository, } from "@workglow/storage"; import { getTaskQueueRegistry, setTaskQueueRegistry, Workflow } from "@workglow/task-graph"; import { readdirSync } from "fs"; @@ -62,7 +62,7 @@ import { registerHuggingfaceLocalModels } from "../../samples"; export { FileLoaderTask } from "@workglow/tasks"; describe("RAG Workflow End-to-End", () => { - let vectorRepo: InMemoryDocumentNodeVectorRepository; + let vectorRepo: InMemoryChunkVectorRepository; let docRepo: DocumentRepository; const vectorRepoName = "rag-test-vector-repo"; const embeddingModel = "onnx:Xenova/all-MiniLM-L6-v2:q8"; @@ -79,11 +79,11 @@ describe("RAG Workflow End-to-End", () => { await registerHuggingfaceLocalModels(); // Setup repositories - vectorRepo = new InMemoryDocumentNodeVectorRepository(3); + vectorRepo = new InMemoryChunkVectorRepository(3); await vectorRepo.setupDatabase(); // Register vector repository for use in workflows - registerDocumentNodeVectorRepository(vectorRepoName, vectorRepo); + registerChunkVectorRepository(vectorRepoName, vectorRepo); const tabularRepo = new InMemoryTabularRepository(DocumentStorageSchema, DocumentStorageKey); await tabularRepo.setupDatabase(); From 59f8a8532c7760eab1fa1f762b54a5212337e8fa Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Tue, 13 Jan 2026 23:35:28 +0000 Subject: [PATCH 09/14] [refactor] Update readme, todo, and change chunk upsert to handle multiple (or singular) metadata --- TODO.md | 23 +- packages/ai/src/task/ChunkVectorUpsertTask.ts | 55 ++- packages/storage/src/chunk-vector/README.md | 463 +++++++----------- packages/storage/src/document/DocumentNode.ts | 29 ++ 4 files changed, 260 insertions(+), 310 deletions(-) diff --git a/TODO.md b/TODO.md index 406d7c1c..16a7a810 100644 --- a/TODO.md +++ b/TODO.md @@ -1,8 +1,23 @@ TODO.md -- [ ] Chunks and nodes are not the same. - - [ ] We need to rename the files related to embedding. - - [ ] And we may need to save the chunk's node path. Or paths? +- [ ] Rename repositories in the packages/storage to use the word Storage instead of Repository. +- [ ] Vector Storage (not chunk storage) + - [ ] Rename the files from packages/storage/src/vector-storage to packages/storage/src/vector + - [ ] No fixed column names, use the schema to define the columns. + - [ ] Option for which column to use if there are multiple, default to the first one. + - [ ] Use @mceachen/sqlite-vec for sqlite storage. +- [ ] Datasets Package + - [ ] Documents repository (mabye rename to DocumentDataset) + - [ ] Chunks repository (maybe rename to ChunkDataset) or DocumentChunksDataset? Or just part of DocumentDataset? Or is it a new thing? + - [ ] Move Model repository to datasets package. +- [ ] Chunk Repository + - [ ] Add to packages/tasks or packages/ai + - [ ] Model like Model repository (although that just has one) + - [ ] Model even closer to Document repositories +- [ ] Chunks and nodes are not always the same. + - [ ] And we may need to save the chunk's node path. Or paths? or document range? Standard metadata? +- [ ] Use Repository to always envelope the storage operations (for transactions, dealing with IDs, etc). + - [ ] Get a better model for question answering. - [ ] Get a better model for named entity recognition, the current one recognized everything as a token, not helpful. - [ ] Titles are not making it into the chunks. @@ -13,3 +28,5 @@ TODO.md - [ ] Audio conversion like the image conversion - [ ] rename the registration stuff to not look ugly: registerHuggingfaceTransformers() and registerHuggingfaceTransformersUsingWorkers() and registerHuggingfaceTransformersInsideWorker() - [ ] fix image transferables + +onnx-community/ModernBERT-finetuned-squad-ONNX - summarization diff --git a/packages/ai/src/task/ChunkVectorUpsertTask.ts b/packages/ai/src/task/ChunkVectorUpsertTask.ts index 6d84b3b1..0d79afe8 100644 --- a/packages/ai/src/task/ChunkVectorUpsertTask.ts +++ b/packages/ai/src/task/ChunkVectorUpsertTask.ts @@ -23,26 +23,27 @@ import { TypeSingleOrArray } from "./base/AiTaskSchemas"; const inputSchema = { type: "object", properties: { + repository: TypeChunkVectorRepository({ + title: "Document Chunk Vector Repository", + description: "The document chunk vector repository instance to store vectors in", + }), doc_id: { type: "string", title: "Document ID", description: "The document ID", }, - repository: TypeChunkVectorRepository({ - title: "Document Chunk Vector Repository", - description: "The document chunk vector repository instance to store vectors in", - }), vectors: TypeSingleOrArray( TypedArraySchema({ - title: "Vector", - description: "The vector embedding", + title: "Vectors", + description: "The vector embeddings", }) ), - metadata: { + metadata: TypeSingleOrArray({ type: "object", title: "Metadata", description: "Metadata associated with the vector", - }, + additionalProperties: true, + }), }, required: ["repository", "doc_id", "vectors", "metadata"], additionalProperties: false, @@ -61,14 +62,14 @@ const outputSchema = { title: "Document ID", description: "The document ID", }, - ids: { + chunk_ids: { type: "array", items: { type: "string" }, - title: "IDs", - description: "IDs of upserted vectors", + title: "Chunk IDs", + description: "Chunk IDs of upserted vectors", }, }, - required: ["count", "ids"], + required: ["count", "doc_id", "chunk_ids"], additionalProperties: false, } as const satisfies DataPortSchema; @@ -109,42 +110,50 @@ export class ChunkVectorUpsertTask extends Task< // Normalize inputs to arrays const vectorArray = Array.isArray(vectors) ? vectors : [vectors]; + const metadataArray = Array.isArray(metadata) + ? metadata + : Array(vectorArray.length).fill(metadata); const repo = repository as AnyChunkVectorRepository; await context.updateProgress(1, "Upserting vectors"); - const idArray: string[] = []; + const chunk_ids: string[] = []; // Bulk upsert if multiple items if (vectorArray.length > 1) { + if (vectorArray.length !== metadataArray.length) { + throw new Error("Mismatch: vectors and metadata arrays must have the same length"); + } const entities = vectorArray.map((vector, i) => { const chunk_id = `${doc_id}_${i}`; - idArray.push(chunk_id); + const metadataItem = metadataArray[i]; + chunk_ids.push(chunk_id); return { chunk_id, doc_id, - vector: vector as any, // Store TypedArray directly (memory) or as string (SQL) - metadata, + vector, + metadata: metadataItem, }; }); - await repo.putBulk(entities as any); + await repo.putBulk(entities); } else if (vectorArray.length === 1) { // Single upsert const chunk_id = `${doc_id}_0`; - idArray.push(chunk_id); + const metadataItem = metadataArray[0]; + chunk_ids.push(chunk_id); await repo.put({ chunk_id, doc_id, - vector: vectorArray[0] as any, // Store TypedArray directly (memory) or as string (SQL) - metadata, - } as any); + vector: vectorArray[0], + metadata: metadataItem, + }); } return { doc_id, - ids: idArray, - count: vectorArray.length, + chunk_ids, + count: chunk_ids.length, }; } } diff --git a/packages/storage/src/chunk-vector/README.md b/packages/storage/src/chunk-vector/README.md index 3e848f59..2af14dda 100644 --- a/packages/storage/src/chunk-vector/README.md +++ b/packages/storage/src/chunk-vector/README.md @@ -1,15 +1,13 @@ -# Vector Storage Module +# Chunk Vector Storage Module -A flexible vector storage solution with multiple backend implementations for RAG (Retrieval-Augmented Generation) pipelines. Provides a consistent interface for vector CRUD operations with similarity search and hybrid search capabilities. +Storage for document chunk embeddings with vector similarity search capabilities. Extends the tabular repository pattern to add vector search functionality for RAG (Retrieval-Augmented Generation) pipelines. ## Features - **Multiple Storage Backends:** - - 🧠 `InMemoryVectorRepository` - Fast in-memory storage for testing and small datasets - - 📁 `SqliteVectorRepository` - Persistent SQLite storage for local applications - - 🐘 `PostgresVectorRepository` - PostgreSQL with pgvector extension for production - - 🔍 `SeekDbVectorRepository` - SeekDB/OceanBase with native hybrid search - - 📱 `EdgeVecRepository` - Edge/browser deployment with IndexedDB and WebGPU support + - 🧠 `InMemoryChunkVectorRepository` - Fast in-memory storage for testing and small datasets + - 📁 `SqliteChunkVectorRepository` - Persistent SQLite storage for local applications + - 🐘 `PostgresChunkVectorRepository` - PostgreSQL with pgvector extension for production - **Quantized Vector Support:** - Float32Array (standard 32-bit floating point) @@ -20,17 +18,16 @@ A flexible vector storage solution with multiple backend implementations for RAG - Int16Array (16-bit signed - quantization) - Uint16Array (16-bit unsigned - quantization) -- **Advanced Search Capabilities:** +- **Search Capabilities:** - Vector similarity search (cosine similarity) - - Hybrid search (vector + full-text) + - Hybrid search (vector + full-text keyword matching) - Metadata filtering - Top-K retrieval with score thresholds -- **Production Ready:** - - Type-safe interfaces - - Event emitters for monitoring - - Bulk operations support - - Efficient indexing strategies +- **Built on Tabular Repositories:** + - Extends `ITabularRepository` for standard CRUD operations + - Inherits event emitter pattern for monitoring + - Type-safe schema-based storage ## Installation @@ -40,70 +37,67 @@ bun install @workglow/storage ## Usage -### In-Memory Repository (Testing/Browser) +### In-Memory Repository (Testing/Development) ```typescript -import { InMemoryVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorRepository } from "@workglow/storage"; -// Standard Float32 vectors -const repo = new InMemoryVectorRepository<{ text: string; source: string }>(); +// Create repository with 384 dimensions +const repo = new InMemoryChunkVectorRepository(384); await repo.setupDatabase(); -// Upsert vectors -await repo.upsert( - "doc1", - new Float32Array([0.1, 0.2, 0.3, ...]), - { text: "Hello world", source: "example.txt" } -); +// Store a chunk with its embedding +await repo.put({ + chunk_id: "chunk-001", + doc_id: "doc-001", + vector: new Float32Array([0.1, 0.2, 0.3 /* ... 384 dims */]), + metadata: { text: "Hello world", source: "example.txt" }, +}); -// Search for similar vectors -const results = await repo.similaritySearch( - new Float32Array([0.15, 0.25, 0.35, ...]), - { topK: 5, scoreThreshold: 0.7 } -); +// Search for similar chunks +const results = await repo.similaritySearch(new Float32Array([0.15, 0.25, 0.35 /* ... */]), { + topK: 5, + scoreThreshold: 0.7, +}); ``` ### Quantized Vectors (Reduced Storage) ```typescript -import { InMemoryVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorRepository } from "@workglow/storage"; // Use Int8Array for 4x smaller storage (binary quantization) -const repo = new InMemoryVectorRepository< - { text: string }, - Int8Array ->(); +const repo = new InMemoryChunkVectorRepository<{ text: string }, Int8Array>(384, Int8Array); await repo.setupDatabase(); // Store quantized vectors -await repo.upsert( - "doc1", - new Int8Array([127, -128, 64, ...]), - { text: "Quantized embedding" } -); +await repo.put({ + chunk_id: "chunk-001", + doc_id: "doc-001", + vector: new Int8Array([127, -128, 64 /* ... */]), + metadata: { category: "ai" }, +}); // Search with quantized query -const results = await repo.similaritySearch( - new Int8Array([100, -50, 75, ...]), - { topK: 5 } -); +const results = await repo.similaritySearch(new Int8Array([100, -50, 75 /* ... */]), { topK: 5 }); ``` ### SQLite Repository (Local Persistence) ```typescript -import { SqliteVectorRepository } from "@workglow/storage"; +import { SqliteChunkVectorRepository } from "@workglow/storage"; -const repo = new SqliteVectorRepository<{ text: string }>( +const repo = new SqliteChunkVectorRepository<{ text: string }>( "./vectors.db", // database path - "embeddings" // table name + "chunks", // table name + 768 // vector dimension ); await repo.setupDatabase(); -// Bulk upsert -await repo.upsertBulk([ - { id: "1", vector: new Float32Array([...]), metadata: { text: "..." } }, - { id: "2", vector: new Float32Array([...]), metadata: { text: "..." } }, +// Bulk insert using inherited tabular methods +await repo.putMany([ + { chunk_id: "1", doc_id: "doc1", vector: new Float32Array([...]), metadata: { text: "..." } }, + { chunk_id: "2", doc_id: "doc1", vector: new Float32Array([...]), metadata: { text: "..." } }, ]); ``` @@ -111,18 +105,25 @@ await repo.upsertBulk([ ```typescript import { Pool } from "pg"; -import { PostgresVectorRepository } from "@workglow/storage"; +import { PostgresChunkVectorRepository } from "@workglow/storage"; const pool = new Pool({ connectionString: "postgresql://..." }); -const repo = new PostgresVectorRepository<{ text: string; category: string }>( +const repo = new PostgresChunkVectorRepository<{ text: string; category: string }>( pool, - "vectors", + "chunks", 384 // vector dimension ); await repo.setupDatabase(); +// Native pgvector similarity search with filter +const results = await repo.similaritySearch(queryVector, { + topK: 10, + filter: { category: "ai" }, + scoreThreshold: 0.5, +}); + // Hybrid search (vector + full-text) -const results = await repo.hybridSearch(queryVector, { +const hybridResults = await repo.hybridSearch(queryVector, { textQuery: "machine learning", topK: 10, vectorWeight: 0.7, @@ -130,106 +131,137 @@ const results = await repo.hybridSearch(queryVector, { }); ``` -### SeekDB (Hybrid Search Database) +## Data Model -```typescript -import mysql from "mysql2/promise"; -import { SeekDbVectorRepository } from "@workglow/storage"; +### ChunkVector Schema -const pool = mysql.createPool({ host: "...", database: "..." }); -const repo = new SeekDbVectorRepository<{ text: string }>( - pool, - "vectors", - 768 // vector dimension -); -await repo.setupDatabase(); +Each chunk vector entry contains: -// Native hybrid search -const results = await repo.hybridSearch(queryVector, { - textQuery: "neural networks", - topK: 5, - vectorWeight: 0.6, -}); +```typescript +interface ChunkVector< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> { + chunk_id: string; // Unique identifier for the chunk + doc_id: string; // Parent document identifier + vector: Vector; // Embedding vector + metadata: Metadata; // Custom metadata (text content, entities, etc.) +} ``` -### EdgeVec (Browser/Edge Deployment) +### Default Schema ```typescript -import { EdgeVecRepository } from "@workglow/storage"; - -const repo = new EdgeVecRepository<{ text: string }>({ - dbName: "my-vectors", // IndexedDB name - enableWebGPU: true, // Enable GPU acceleration -}); -await repo.setupDatabase(); +const ChunkVectorSchema = { + type: "object", + properties: { + chunk_id: { type: "string" }, + doc_id: { type: "string" }, + vector: TypedArraySchema(), + metadata: { type: "object", additionalProperties: true }, + }, + additionalProperties: false, +} as const; -// Works entirely in the browser -await repo.upsert("1", vector, { text: "..." }); -const results = await repo.similaritySearch(queryVector, { topK: 3 }); +const ChunkVectorKey = ["chunk_id"] as const; ``` -## API Documentation +## API Reference -### Core Methods +### IChunkVectorRepository Interface -All repositories implement the `IVectorRepository` interface: +Extends `ITabularRepository` with vector-specific methods: ```typescript -interface IVectorRepository { - // Setup - setupDatabase(): Promise; - - // CRUD Operations - upsert(id: string, vector: Float32Array, metadata: Metadata): Promise; - upsertBulk(items: VectorEntry[]): Promise; - get(id: string): Promise | undefined>; - delete(id: string): Promise; - deleteBulk(ids: string[]): Promise; - deleteByFilter(filter: Partial): Promise; - - // Search - search( - query: Float32Array, - options?: VectorSearchOptions - ): Promise[]>; +interface IChunkVectorRepository extends ITabularRepository< + Schema, + PrimaryKeyNames, + Entity +> { + // Get the vector dimension + getVectorDimensions(): number; + + // Vector similarity search + similaritySearch( + query: TypedArray, + options?: VectorSearchOptions + ): Promise<(Entity & { score: number })[]>; + + // Hybrid search (optional - not all implementations support it) hybridSearch?( - query: Float32Array, - options: HybridSearchOptions - ): Promise[]>; + query: TypedArray, + options: HybridSearchOptions + ): Promise<(Entity & { score: number })[]>; +} +``` - // Utility - size(): Promise; - clear(): Promise; - destroy(): void; +### Inherited Tabular Methods - // Events - on(event: "upsert" | "delete" | "search", callback: Function): void; -} +From `ITabularRepository`: + +```typescript +// Setup +setupDatabase(): Promise; + +// CRUD Operations +put(entity: Entity): Promise; +putMany(entities: Entity[]): Promise; +get(key: PrimaryKey): Promise; +getAll(): Promise; +delete(key: PrimaryKey): Promise; +deleteMany(keys: PrimaryKey[]): Promise; + +// Utility +size(): Promise; +clear(): Promise; +destroy(): void; ``` ### Search Options ```typescript -interface VectorSearchOptions { - topK?: number; // Number of results (default: 10) - filter?: Partial; // Filter by metadata - scoreThreshold?: number; // Minimum score (0-1) +interface VectorSearchOptions> { + readonly topK?: number; // Number of results (default: 10) + readonly filter?: Partial; // Filter by metadata fields + readonly scoreThreshold?: number; // Minimum score 0-1 (default: 0) } interface HybridSearchOptions extends VectorSearchOptions { - textQuery: string; // Full-text query - vectorWeight?: number; // Vector weight 0-1 (default: 0.7) + readonly textQuery: string; // Full-text query keywords + readonly vectorWeight?: number; // Vector weight 0-1 (default: 0.7) } ``` +## Global Registry + +Register and retrieve chunk vector repositories globally: + +```typescript +import { + registerChunkVectorRepository, + getChunkVectorRepository, + getGlobalChunkVectorRepositories, +} from "@workglow/storage"; + +// Register a repository +registerChunkVectorRepository("my-chunks", repo); + +// Retrieve by ID +const repo = getChunkVectorRepository("my-chunks"); + +// Get all registered repositories +const allRepos = getGlobalChunkVectorRepositories(); +``` + ## Quantization Benefits -Quantized vectors can significantly reduce storage and improve performance: +Quantized vectors reduce storage and can improve performance: | Vector Type | Bytes/Dim | Storage vs Float32 | Use Case | | ------------ | --------- | ------------------ | ------------------------------------ | | Float32Array | 4 | 100% (baseline) | Standard embeddings | | Float64Array | 8 | 200% | High precision needed | +| Float16Array | 2 | 50% | Great precision/size tradeoff | | Int16Array | 2 | 50% | Good precision/size tradeoff | | Int8Array | 1 | 25% | Binary quantization, max compression | | Uint8Array | 1 | 25% | Quantized embeddings [0-255] | @@ -243,7 +275,7 @@ Quantized vectors can significantly reduce storage and improve performance: ### InMemory -- **Best for:** Testing, small datasets (<10K vectors), browser apps +- **Best for:** Testing, small datasets (<10K vectors), development - **Pros:** Fastest, no dependencies, supports all vector types - **Cons:** No persistence, memory limited @@ -251,196 +283,59 @@ Quantized vectors can significantly reduce storage and improve performance: - **Best for:** Local apps, medium datasets (<100K vectors) - **Pros:** Persistent, single file, no server -- **Cons:** No native vector indexing, slower for large datasets +- **Cons:** No native vector indexing (linear scan), slower for large datasets ### PostgreSQL + pgvector - **Best for:** Production, large datasets (>100K vectors) -- **Pros:** HNSW indexing, efficient, scalable +- **Pros:** Native HNSW/IVFFlat indexing, efficient similarity search, scalable - **Cons:** Requires PostgreSQL server and pgvector extension +- **Setup:** `CREATE EXTENSION vector;` -### SeekDB - -- **Best for:** Hybrid search workloads, production -- **Pros:** Native hybrid search, MySQL-compatible -- **Cons:** Requires SeekDB/OceanBase instance - -### EdgeVec - -- **Best for:** Privacy-sensitive apps, offline-first, edge computing -- **Pros:** No server, IndexedDB persistence, WebGPU acceleration -- **Cons:** Limited by browser storage, smaller datasets - -## Integration with RAG Tasks - -The vector repositories integrate seamlessly with RAG tasks: - -```typescript -import { InMemoryVectorRepository } from "@workglow/storage"; -import { Workflow } from "@workglow/task-graph"; - -const repo = new InMemoryVectorRepository(); -await repo.setupDatabase(); - -const workflow = new Workflow() - // Load and chunk document - .fileLoader({ path: "./doc.md" }) - .textChunker({ chunkSize: 512, chunkOverlap: 50 }) - - // Generate embeddings - .textEmbedding({ model: "Xenova/all-MiniLM-L6-v2" }) - - // Store in vector repository - .vectorStoreUpsert({ repository: repo }); - -await workflow.run(); - -// Later: Search -const searchWorkflow = new Workflow() - .textEmbedding({ text: "What is RAG?", model: "..." }) - .vectorStoreSearch({ repository: repo, topK: 5 }) - .contextBuilder({ format: "markdown" }) - .textQuestionAnswer({ question: "What is RAG?" }); - -const result = await searchWorkflow.run(); -``` - -## Hierarchical Document Integration - -For document-level storage and hierarchical context enrichment, use vector repositories alongside document repositories: - -```typescript -import { InMemoryVectorRepository, InMemoryDocumentRepository } from "@workglow/storage"; -import { Workflow } from "@workglow/task-graph"; - -const vectorRepo = new InMemoryVectorRepository(); -const docRepo = new InMemoryDocumentRepository(); -await vectorRepo.setupDatabase(); - -// Ingestion with hierarchical structure -await new Workflow() - .structuralParser({ - text: markdownContent, - title: "Documentation", - format: "markdown", - }) - .hierarchicalChunker({ - maxTokens: 512, - overlap: 50, - strategy: "hierarchical", - }) - .textEmbedding({ model: "Xenova/all-MiniLM-L6-v2" }) - .chunkToVector() - .vectorStoreUpsert({ repository: vectorRepo }) - .run(); - -// Retrieval with parent context -const result = await new Workflow() - .textEmbedding({ text: query, model: "Xenova/all-MiniLM-L6-v2" }) - .vectorStoreSearch({ repository: vectorRepo, topK: 10 }) - .hierarchyJoin({ - documentRepository: docRepo, - includeParentSummaries: true, - includeEntities: true, - }) - .reranker({ query, topK: 5 }) - .contextBuilder({ format: "markdown" }) - .run(); -``` - -### Vector Metadata for Hierarchical Documents - -When using hierarchical chunking, base vector metadata (stored in vector database) includes: - -```typescript -metadata: { - doc_id: string, // Document identifier - chunkId: string, // Chunk identifier - leafNodeId: string, // Reference to document tree node - depth: number, // Hierarchy depth - text: string, // Chunk text content - nodePath: string[], // Node IDs from root to leaf - // From enrichment (optional): - summary?: string, // Summary of the chunk content - entities?: Entity[], // Named entities extracted from the chunk -} -``` - -After `HierarchyJoinTask`, enriched metadata includes additional fields: - -```typescript -enrichedMetadata: { - // ... all base metadata fields above ... - parentSummaries?: string[], // Summaries from ancestor nodes (looked up on-demand) - sectionTitles?: string[], // Titles of ancestor section nodes -} -``` - -Note: `parentSummaries` is not stored in the vector database. It is computed on-demand by `HierarchyJoinTask` using `doc_id` and `leafNodeId` to look up ancestors from the document repository. +## Integration with DocumentRepository -## Document Repository - -The `IDocumentRepository` interface provides storage for hierarchical document structures: - -```typescript -class DocumentRepository { - constructor(tabularStorage: ITabularRepository, vectorStorage: IVectorRepository); - - upsert(document: Document): Promise; - get(doc_id: string): Promise; - getNode(doc_id: string, nodeId: string): Promise; - getAncestors(doc_id: string, nodeId: string): Promise; - getChunks(doc_id: string): Promise; - findChunksByNodeId(doc_id: string, nodeId: string): Promise; - delete(doc_id: string): Promise; - list(): Promise; - search(query: TypedArray, options?: VectorSearchOptions): Promise; -} -``` - -### Document Repository - -The `DocumentRepository` class provides a unified interface for storing hierarchical documents and searching chunks. It uses composition of storage backends: - -| Component | Purpose | -| -------------------- | -------------------------------------------- | -| `ITabularRepository` | Stores document structure and metadata | -| `IVectorRepository` | Enables similarity search on document chunks | - -**Example Usage:** +The chunk vector repository works alongside `DocumentRepository` for hierarchical document storage: ```typescript import { DocumentRepository, + InMemoryChunkVectorRepository, InMemoryTabularRepository, - InMemoryVectorRepository, } from "@workglow/storage"; - -// Define schema for document storage -const DocumentStorageSchema = { - type: "object", - properties: { - doc_id: { type: "string" }, - data: { type: "string" }, - }, - required: ["doc_id", "data"], -} as const; +import { DocumentStorageSchema } from "@workglow/storage"; // Initialize storage backends const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["doc_id"]); await tabularStorage.setupDatabase(); -const vectorStorage = new InMemoryVectorRepository(); +const vectorStorage = new InMemoryChunkVectorRepository(384); await vectorStorage.setupDatabase(); -// Create document repository +// Create document repository with both storages const docRepo = new DocumentRepository(tabularStorage, vectorStorage); -// Use the repository +// Store document structure in tabular, chunks in vector await docRepo.upsert(document); + +// Search chunks by vector similarity const results = await docRepo.search(queryVector, { topK: 5 }); ``` +### Chunk Metadata for Hierarchical Documents + +When using hierarchical chunking, chunk metadata typically includes: + +```typescript +metadata: { + text: string; // Chunk text content + leafNodeId?: string; // Reference to document tree node + depth?: number; // Hierarchy depth + nodePath?: string[]; // Node IDs from root to leaf + summary?: string; // Summary of the chunk content + entities?: Entity[]; // Named entities extracted from the chunk +} +``` + ## License Apache 2.0 diff --git a/packages/storage/src/document/DocumentNode.ts b/packages/storage/src/document/DocumentNode.ts index a769fdbe..6fccdd2b 100644 --- a/packages/storage/src/document/DocumentNode.ts +++ b/packages/storage/src/document/DocumentNode.ts @@ -132,3 +132,32 @@ export function getNodePath(root: DocumentNode, targetNodeId: string): string[] return search(root) ? path : undefined; } + +/** + * Get document range for a node path + */ +export function getDocumentRange(root: DocumentNode, nodePath: string[]): NodeRange { + let currentNode = root as DocumentRootNode | SectionNode | TopicNode; + + // Start from index 1 since nodePath[0] is the root + for (let i = 1; i < nodePath.length; i++) { + const targetId = nodePath[i]; + const children = currentNode.children; + let found: DocumentNode | undefined; + + for (let j = 0; j < children.length; j++) { + if (children[j].nodeId === targetId) { + found = children[j]; + break; + } + } + + if (!found) { + throw new Error(`Node with id ${targetId} not found in path`); + } + + currentNode = found as DocumentRootNode | SectionNode | TopicNode; + } + + return currentNode.range; +} From b44458ddae6e58462bff688fe4edc5986158ee7d Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Wed, 14 Jan 2026 00:14:16 +0000 Subject: [PATCH 10/14] [refactor] Update storage structure and rename repositories to storage - Refactored storage components to rename repository classes from `Repository` to `Storage`, enhancing clarity in naming conventions. - Updated various files to reflect the new `Storage` naming, including `InMemory`, `IndexedDb`, `Postgres`, and `Sqlite` implementations. - Adjusted related documentation and tests to ensure consistency with the new structure. - Improved overall organization of storage-related code for better maintainability and readability. --- bun.lock | 68 +++++++++--------- docs/developers/03_extending.md | 2 +- package.json | 12 ++-- packages/ai/README.md | 2 +- .../ai/src/model/InMemoryModelRepository.ts | 4 +- packages/ai/src/model/ModelRepository.ts | 6 +- packages/ai/src/task/ChunkRetrievalTask.ts | 4 +- .../src/task/ChunkVectorHybridSearchTask.ts | 4 +- packages/ai/src/task/ChunkVectorSearchTask.ts | 4 +- packages/ai/src/task/ChunkVectorUpsertTask.ts | 4 +- packages/storage/README.md | 70 +++++++++---------- packages/storage/src/browser.ts | 10 +-- ...istry.ts => ChunkVectorStorageRegistry.ts} | 16 ++--- ...orRepository.ts => IChunkVectorStorage.ts} | 8 +-- ...itory.ts => InMemoryChunkVectorStorage.ts} | 12 ++-- ...itory.ts => PostgresChunkVectorStorage.ts} | 12 ++-- packages/storage/src/chunk-vector/README.md | 40 +++++------ ...ository.ts => SqliteChunkVectorStorage.ts} | 12 ++-- packages/storage/src/common-server.ts | 26 +++---- packages/storage/src/common.ts | 24 +++---- .../src/document/DocumentRepository.ts | 18 ++--- ...Repository.ts => FsFolderJsonKvStorage.ts} | 16 ++--- ...erKvRepository.ts => FsFolderKvStorage.ts} | 12 ++-- .../kv/{IKvRepository.ts => IKvStorage.ts} | 4 +- ...ryKvRepository.ts => InMemoryKvStorage.ts} | 16 ++--- ...bKvRepository.ts => IndexedDbKvStorage.ts} | 16 ++--- .../src/kv/{KvRepository.ts => KvStorage.ts} | 14 ++-- ...arRepository.ts => KvViaTabularStorage.ts} | 12 ++-- ...esKvRepository.ts => PostgresKvStorage.ts} | 16 ++--- packages/storage/src/kv/README.md | 6 +- ...liteKvRepository.ts => SqliteKvStorage.ts} | 16 ++--- ...seKvRepository.ts => SupabaseKvStorage.ts} | 18 ++--- ...Repository.ts => BaseSqlTabularStorage.ts} | 12 ++-- ...larRepository.ts => BaseTabularStorage.ts} | 14 ++-- ...rRepository.ts => CachedTabularStorage.ts} | 36 +++++----- ...epository.ts => FsFolderTabularStorage.ts} | 18 ++--- ...abularRepository.ts => ITabularStorage.ts} | 4 +- ...epository.ts => InMemoryTabularStorage.ts} | 14 ++-- ...pository.ts => IndexedDbTabularStorage.ts} | 12 ++-- ...epository.ts => PostgresTabularStorage.ts} | 18 ++--- packages/storage/src/tabular/README.md | 32 ++++----- ...ory.ts => SharedInMemoryTabularStorage.ts} | 26 +++---- ...rRepository.ts => SqliteTabularStorage.ts} | 16 ++--- ...epository.ts => SupabaseTabularStorage.ts} | 16 ++--- ...yRegistry.ts => TabularStorageRegistry.ts} | 16 ++--- .../src/storage/TaskGraphTabularRepository.ts | 4 +- .../storage/TaskOutputTabularRepository.ts | 4 +- packages/task-graph/src/task/README.md | 4 +- .../binding/FsFolderTaskGraphRepository.ts | 4 +- .../binding/FsFolderTaskOutputRepository.ts | 4 +- .../binding/InMemoryTaskGraphRepository.ts | 4 +- .../binding/InMemoryTaskOutputRepository.ts | 4 +- .../src/binding/IndexedDbModelRepository.ts | 4 +- .../binding/IndexedDbTaskGraphRepository.ts | 4 +- .../binding/IndexedDbTaskOutputRepository.ts | 4 +- .../src/binding/PostgresModelRepository.ts | 4 +- .../binding/PostgresTaskGraphRepository.ts | 4 +- .../binding/PostgresTaskOutputRepository.ts | 4 +- .../test/src/binding/SqliteModelRepository.ts | 4 +- .../src/binding/SqliteTaskGraphRepository.ts | 4 +- .../src/binding/SqliteTaskOutputRepository.ts | 4 +- .../rag/DocumentNodeRetrievalTask.test.ts | 6 +- .../rag/DocumentNodeVectorSearchTask.test.ts | 8 +-- .../DocumentNodeVectorStoreUpsertTask.test.ts | 26 +++---- .../src/test/rag/DocumentRepository.test.ts | 14 ++-- packages/test/src/test/rag/EndToEnd.test.ts | 8 +-- packages/test/src/test/rag/FullChain.test.ts | 4 +- .../src/test/rag/HybridSearchTask.test.ts | 6 +- .../test/src/test/rag/RagWorkflow.test.ts | 10 +-- .../FsFolderJsonKvRepository.test.ts | 6 +- .../storage-kv/FsFolderKvRepository.test.ts | 6 +- .../storage-kv/InMemoryKvRepository.test.ts | 6 +- .../storage-kv/IndexedDbKvRepository.test.ts | 6 +- .../storage-kv/PostgresKvRepository.test.ts | 6 +- .../storage-kv/SqliteKvRepository.test.ts | 6 +- .../storage-kv/SupabaseKvRepository.test.ts | 10 +-- .../storage-kv/genericKvRepositoryTests.ts | 10 +-- .../CachedTabularRepository.test.ts | 62 ++++++++-------- .../FsFolderTabularRepository.test.ts | 12 ++-- .../InMemoryTabularRepository.test.ts | 12 ++-- .../IndexedDbTabularRepository.test.ts | 20 +++--- .../PostgresTabularRepository.test.ts | 10 +-- .../SqliteTabularRepository.test.ts | 10 +-- .../SupabaseTabularRepository.test.ts | 10 +-- ...nericTabularRepositorySubscriptionTests.ts | 6 +- .../genericTabularRepositoryTests.ts | 18 ++--- .../IndexedDbHybridSubscription.test.ts | 18 ++--- .../src/test/task-graph/InputResolver.test.ts | 10 +-- 88 files changed, 564 insertions(+), 564 deletions(-) rename packages/storage/src/chunk-vector/{ChunkVectorRepositoryRegistry.ts => ChunkVectorStorageRegistry.ts} (85%) rename packages/storage/src/chunk-vector/{IChunkVectorRepository.ts => IChunkVectorStorage.ts} (92%) rename packages/storage/src/chunk-vector/{InMemoryChunkVectorRepository.ts => InMemoryChunkVectorStorage.ts} (95%) rename packages/storage/src/chunk-vector/{PostgresChunkVectorRepository.ts => PostgresChunkVectorStorage.ts} (97%) rename packages/storage/src/chunk-vector/{SqliteChunkVectorRepository.ts => SqliteChunkVectorStorage.ts} (95%) rename packages/storage/src/kv/{FsFolderJsonKvRepository.ts => FsFolderJsonKvStorage.ts} (65%) rename packages/storage/src/kv/{FsFolderKvRepository.ts => FsFolderKvStorage.ts} (94%) rename packages/storage/src/kv/{IKvRepository.ts => IKvStorage.ts} (96%) rename packages/storage/src/kv/{InMemoryKvRepository.ts => InMemoryKvStorage.ts} (59%) rename packages/storage/src/kv/{IndexedDbKvRepository.ts => IndexedDbKvStorage.ts} (61%) rename packages/storage/src/kv/{KvRepository.ts => KvStorage.ts} (93%) rename packages/storage/src/kv/{KvViaTabularRepository.ts => KvViaTabularStorage.ts} (94%) rename packages/storage/src/kv/{PostgresKvRepository.ts => PostgresKvStorage.ts} (61%) rename packages/storage/src/kv/{SqliteKvRepository.ts => SqliteKvStorage.ts} (62%) rename packages/storage/src/kv/{SupabaseKvRepository.ts => SupabaseKvStorage.ts} (66%) rename packages/storage/src/tabular/{BaseSqlTabularRepository.ts => BaseSqlTabularStorage.ts} (96%) rename packages/storage/src/tabular/{BaseTabularRepository.ts => BaseTabularStorage.ts} (97%) rename packages/storage/src/tabular/{CachedTabularRepository.ts => CachedTabularStorage.ts} (89%) rename packages/storage/src/tabular/{FsFolderTabularRepository.ts => FsFolderTabularStorage.ts} (96%) rename packages/storage/src/tabular/{ITabularRepository.ts => ITabularStorage.ts} (98%) rename packages/storage/src/tabular/{InMemoryTabularRepository.ts => InMemoryTabularStorage.ts} (96%) rename packages/storage/src/tabular/{IndexedDbTabularRepository.ts => IndexedDbTabularStorage.ts} (98%) rename packages/storage/src/tabular/{PostgresTabularRepository.ts => PostgresTabularStorage.ts} (98%) rename packages/storage/src/tabular/{SharedInMemoryTabularRepository.ts => SharedInMemoryTabularStorage.ts} (92%) rename packages/storage/src/tabular/{SqliteTabularRepository.ts => SqliteTabularStorage.ts} (98%) rename packages/storage/src/tabular/{SupabaseTabularRepository.ts => SupabaseTabularStorage.ts} (98%) rename packages/storage/src/tabular/{TabularRepositoryRegistry.ts => TabularStorageRegistry.ts} (84%) diff --git a/bun.lock b/bun.lock index 8db3757a..62fb6c0f 100644 --- a/bun.lock +++ b/bun.lock @@ -5,13 +5,13 @@ "": { "name": "workglow", "dependencies": { - "caniuse-lite": "^1.0.30001763", + "caniuse-lite": "^1.0.30001764", }, "devDependencies": { "@sroussey/changesets-cli": "^2.29.7", - "@types/bun": "^1.3.5", - "@typescript-eslint/eslint-plugin": "^8.52.0", - "@typescript-eslint/parser": "^8.52.0", + "@types/bun": "^1.3.6", + "@typescript-eslint/eslint-plugin": "^8.53.0", + "@typescript-eslint/parser": "^8.53.0", "concurrently": "^9.2.1", "eslint": "^9.39.2", "eslint-plugin-jsx-a11y": "^6.10.2", @@ -20,9 +20,9 @@ "eslint-plugin-regexp": "^2.10.0", "globals": "^17.0.0", "prettier": "^3.7.4", - "turbo": "^2.7.3", + "turbo": "^2.7.4", "typescript": "5.9.3", - "vitest": "^4.0.16", + "vitest": "^4.0.17", }, }, "examples/cli": { @@ -651,7 +651,7 @@ "@types/better-sqlite3": ["@types/better-sqlite3@7.6.13", "", { "dependencies": { "@types/node": "*" } }, "sha512-NMv9ASNARoKksWtsq/SHakpYAYnhBrQgGD8zkLYk/jaK8jUGn08CfEdTRgYhMypUQAfzSP8W6gNLe0q19/t4VA=="], - "@types/bun": ["@types/bun@1.3.5", "", { "dependencies": { "bun-types": "1.3.5" } }, "sha512-RnygCqNrd3srIPEWBd5LFeUYG7plCoH2Yw9WaZGyNmdTEei+gWaHqydbaIRkIkcbXwhBT94q78QljxN0Sk838w=="], + "@types/bun": ["@types/bun@1.3.6", "", { "dependencies": { "bun-types": "1.3.6" } }, "sha512-uWCv6FO/8LcpREhenN1d1b6fcspAB+cefwD7uti8C8VffIv0Um08TKMn98FynpTiU38+y2dUO55T11NgDt8VAA=="], "@types/chai": ["@types/chai@5.2.3", "", { "dependencies": { "@types/deep-eql": "*", "assertion-error": "^2.0.1" } }, "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA=="], @@ -687,25 +687,25 @@ "@types/ws": ["@types/ws@8.18.1", "", { "dependencies": { "@types/node": "*" } }, "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg=="], - "@typescript-eslint/eslint-plugin": ["@typescript-eslint/eslint-plugin@8.52.0", "", { "dependencies": { "@eslint-community/regexpp": "^4.12.2", "@typescript-eslint/scope-manager": "8.52.0", "@typescript-eslint/type-utils": "8.52.0", "@typescript-eslint/utils": "8.52.0", "@typescript-eslint/visitor-keys": "8.52.0", "ignore": "^7.0.5", "natural-compare": "^1.4.0", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "@typescript-eslint/parser": "^8.52.0", "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-okqtOgqu2qmZJ5iN4TWlgfF171dZmx2FzdOv2K/ixL2LZWDStL8+JgQerI2sa8eAEfoydG9+0V96m7V+P8yE1Q=="], + "@typescript-eslint/eslint-plugin": ["@typescript-eslint/eslint-plugin@8.53.0", "", { "dependencies": { "@eslint-community/regexpp": "^4.12.2", "@typescript-eslint/scope-manager": "8.53.0", "@typescript-eslint/type-utils": "8.53.0", "@typescript-eslint/utils": "8.53.0", "@typescript-eslint/visitor-keys": "8.53.0", "ignore": "^7.0.5", "natural-compare": "^1.4.0", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "@typescript-eslint/parser": "^8.53.0", "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-eEXsVvLPu8Z4PkFibtuFJLJOTAV/nPdgtSjkGoPpddpFk3/ym2oy97jynY6ic2m6+nc5M8SE1e9v/mHKsulcJg=="], - "@typescript-eslint/parser": ["@typescript-eslint/parser@8.52.0", "", { "dependencies": { "@typescript-eslint/scope-manager": "8.52.0", "@typescript-eslint/types": "8.52.0", "@typescript-eslint/typescript-estree": "8.52.0", "@typescript-eslint/visitor-keys": "8.52.0", "debug": "^4.4.3" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-iIACsx8pxRnguSYhHiMn2PvhvfpopO9FXHyn1mG5txZIsAaB6F0KwbFnUQN3KCiG3Jcuad/Cao2FAs1Wp7vAyg=="], + "@typescript-eslint/parser": ["@typescript-eslint/parser@8.53.0", "", { "dependencies": { "@typescript-eslint/scope-manager": "8.53.0", "@typescript-eslint/types": "8.53.0", "@typescript-eslint/typescript-estree": "8.53.0", "@typescript-eslint/visitor-keys": "8.53.0", "debug": "^4.4.3" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-npiaib8XzbjtzS2N4HlqPvlpxpmZ14FjSJrteZpPxGUaYPlvhzlzUZ4mZyABo0EFrOWnvyd0Xxroq//hKhtAWg=="], - "@typescript-eslint/project-service": ["@typescript-eslint/project-service@8.52.0", "", { "dependencies": { "@typescript-eslint/tsconfig-utils": "^8.52.0", "@typescript-eslint/types": "^8.52.0", "debug": "^4.4.3" }, "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-xD0MfdSdEmeFa3OmVqonHi+Cciab96ls1UhIF/qX/O/gPu5KXD0bY9lu33jj04fjzrXHcuvjBcBC+D3SNSadaw=="], + "@typescript-eslint/project-service": ["@typescript-eslint/project-service@8.53.0", "", { "dependencies": { "@typescript-eslint/tsconfig-utils": "^8.53.0", "@typescript-eslint/types": "^8.53.0", "debug": "^4.4.3" }, "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-Bl6Gdr7NqkqIP5yP9z1JU///Nmes4Eose6L1HwpuVHwScgDPPuEWbUVhvlZmb8hy0vX9syLk5EGNL700WcBlbg=="], - "@typescript-eslint/scope-manager": ["@typescript-eslint/scope-manager@8.52.0", "", { "dependencies": { "@typescript-eslint/types": "8.52.0", "@typescript-eslint/visitor-keys": "8.52.0" } }, "sha512-ixxqmmCcc1Nf8S0mS0TkJ/3LKcC8mruYJPOU6Ia2F/zUUR4pApW7LzrpU3JmtePbRUTes9bEqRc1Gg4iyRnDzA=="], + "@typescript-eslint/scope-manager": ["@typescript-eslint/scope-manager@8.53.0", "", { "dependencies": { "@typescript-eslint/types": "8.53.0", "@typescript-eslint/visitor-keys": "8.53.0" } }, "sha512-kWNj3l01eOGSdVBnfAF2K1BTh06WS0Yet6JUgb9Cmkqaz3Jlu0fdVUjj9UI8gPidBWSMqDIglmEXifSgDT/D0g=="], - "@typescript-eslint/tsconfig-utils": ["@typescript-eslint/tsconfig-utils@8.52.0", "", { "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-jl+8fzr/SdzdxWJznq5nvoI7qn2tNYV/ZBAEcaFMVXf+K6jmXvAFrgo/+5rxgnL152f//pDEAYAhhBAZGrVfwg=="], + "@typescript-eslint/tsconfig-utils": ["@typescript-eslint/tsconfig-utils@8.53.0", "", { "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-K6Sc0R5GIG6dNoPdOooQ+KtvT5KCKAvTcY8h2rIuul19vxH5OTQk7ArKkd4yTzkw66WnNY0kPPzzcmWA+XRmiA=="], - "@typescript-eslint/type-utils": ["@typescript-eslint/type-utils@8.52.0", "", { "dependencies": { "@typescript-eslint/types": "8.52.0", "@typescript-eslint/typescript-estree": "8.52.0", "@typescript-eslint/utils": "8.52.0", "debug": "^4.4.3", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-JD3wKBRWglYRQkAtsyGz1AewDu3mTc7NtRjR/ceTyGoPqmdS5oCdx/oZMWD5Zuqmo6/MpsYs0wp6axNt88/2EQ=="], + "@typescript-eslint/type-utils": ["@typescript-eslint/type-utils@8.53.0", "", { "dependencies": { "@typescript-eslint/types": "8.53.0", "@typescript-eslint/typescript-estree": "8.53.0", "@typescript-eslint/utils": "8.53.0", "debug": "^4.4.3", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-BBAUhlx7g4SmcLhn8cnbxoxtmS7hcq39xKCgiutL3oNx1TaIp+cny51s8ewnKMpVUKQUGb41RAUWZ9kxYdovuw=="], - "@typescript-eslint/types": ["@typescript-eslint/types@8.52.0", "", {}, "sha512-LWQV1V4q9V4cT4H5JCIx3481iIFxH1UkVk+ZkGGAV1ZGcjGI9IoFOfg3O6ywz8QqCDEp7Inlg6kovMofsNRaGg=="], + "@typescript-eslint/types": ["@typescript-eslint/types@8.53.0", "", {}, "sha512-Bmh9KX31Vlxa13+PqPvt4RzKRN1XORYSLlAE+sO1i28NkisGbTtSLFVB3l7PWdHtR3E0mVMuC7JilWJ99m2HxQ=="], - "@typescript-eslint/typescript-estree": ["@typescript-eslint/typescript-estree@8.52.0", "", { "dependencies": { "@typescript-eslint/project-service": "8.52.0", "@typescript-eslint/tsconfig-utils": "8.52.0", "@typescript-eslint/types": "8.52.0", "@typescript-eslint/visitor-keys": "8.52.0", "debug": "^4.4.3", "minimatch": "^9.0.5", "semver": "^7.7.3", "tinyglobby": "^0.2.15", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-XP3LClsCc0FsTK5/frGjolyADTh3QmsLp6nKd476xNI9CsSsLnmn4f0jrzNoAulmxlmNIpeXuHYeEQv61Q6qeQ=="], + "@typescript-eslint/typescript-estree": ["@typescript-eslint/typescript-estree@8.53.0", "", { "dependencies": { "@typescript-eslint/project-service": "8.53.0", "@typescript-eslint/tsconfig-utils": "8.53.0", "@typescript-eslint/types": "8.53.0", "@typescript-eslint/visitor-keys": "8.53.0", "debug": "^4.4.3", "minimatch": "^9.0.5", "semver": "^7.7.3", "tinyglobby": "^0.2.15", "ts-api-utils": "^2.4.0" }, "peerDependencies": { "typescript": ">=4.8.4 <6.0.0" } }, "sha512-pw0c0Gdo7Z4xOG987u3nJ8akL9093yEEKv8QTJ+Bhkghj1xyj8cgPaavlr9rq8h7+s6plUJ4QJYw2gCZodqmGw=="], - "@typescript-eslint/utils": ["@typescript-eslint/utils@8.52.0", "", { "dependencies": { "@eslint-community/eslint-utils": "^4.9.1", "@typescript-eslint/scope-manager": "8.52.0", "@typescript-eslint/types": "8.52.0", "@typescript-eslint/typescript-estree": "8.52.0" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-wYndVMWkweqHpEpwPhwqE2lnD2DxC6WVLupU/DOt/0/v+/+iQbbzO3jOHjmBMnhu0DgLULvOaU4h4pwHYi2oRQ=="], + "@typescript-eslint/utils": ["@typescript-eslint/utils@8.53.0", "", { "dependencies": { "@eslint-community/eslint-utils": "^4.9.1", "@typescript-eslint/scope-manager": "8.53.0", "@typescript-eslint/types": "8.53.0", "@typescript-eslint/typescript-estree": "8.53.0" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0", "typescript": ">=4.8.4 <6.0.0" } }, "sha512-XDY4mXTez3Z1iRDI5mbRhH4DFSt46oaIFsLg+Zn97+sYrXACziXSQcSelMybnVZ5pa1P6xYkPr5cMJyunM1ZDA=="], - "@typescript-eslint/visitor-keys": ["@typescript-eslint/visitor-keys@8.52.0", "", { "dependencies": { "@typescript-eslint/types": "8.52.0", "eslint-visitor-keys": "^4.2.1" } }, "sha512-ink3/Zofus34nmBsPjow63FP5M7IGff0RKAgqR6+CFpdk22M7aLwC9gOcLGYqr7MczLPzZVERW9hRog3O4n1sQ=="], + "@typescript-eslint/visitor-keys": ["@typescript-eslint/visitor-keys@8.53.0", "", { "dependencies": { "@typescript-eslint/types": "8.53.0", "eslint-visitor-keys": "^4.2.1" } }, "sha512-LZ2NqIHFhvFwxG0qZeLL9DvdNAHPGCY5dIRwBhyYeU+LfLhcStE1ImjsuTG/WaVh3XysGaeLW8Rqq7cGkPCFvw=="], "@uiw/codemirror-extensions-basic-setup": ["@uiw/codemirror-extensions-basic-setup@4.25.3", "", { "dependencies": { "@codemirror/autocomplete": "^6.0.0", "@codemirror/commands": "^6.0.0", "@codemirror/language": "^6.0.0", "@codemirror/lint": "^6.0.0", "@codemirror/search": "^6.0.0", "@codemirror/state": "^6.0.0", "@codemirror/view": "^6.0.0" } }, "sha512-F1doRyD50CWScwGHG2bBUtUpwnOv/zqSnzkZqJcX5YAHQx6Z1CuX8jdnFMH6qktRrPU1tfpNYftTWu3QIoHiMA=="], @@ -717,19 +717,19 @@ "@vitejs/plugin-react": ["@vitejs/plugin-react@5.1.1", "", { "dependencies": { "@babel/core": "^7.28.5", "@babel/plugin-transform-react-jsx-self": "^7.27.1", "@babel/plugin-transform-react-jsx-source": "^7.27.1", "@rolldown/pluginutils": "1.0.0-beta.47", "@types/babel__core": "^7.20.5", "react-refresh": "^0.18.0" }, "peerDependencies": { "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" } }, "sha512-WQfkSw0QbQ5aJ2CHYw23ZGkqnRwqKHD/KYsMeTkZzPT4Jcf0DcBxBtwMJxnu6E7oxw5+JC6ZAiePgh28uJ1HBA=="], - "@vitest/expect": ["@vitest/expect@4.0.16", "", { "dependencies": { "@standard-schema/spec": "^1.0.0", "@types/chai": "^5.2.2", "@vitest/spy": "4.0.16", "@vitest/utils": "4.0.16", "chai": "^6.2.1", "tinyrainbow": "^3.0.3" } }, "sha512-eshqULT2It7McaJkQGLkPjPjNph+uevROGuIMJdG3V+0BSR2w9u6J9Lwu+E8cK5TETlfou8GRijhafIMhXsimA=="], + "@vitest/expect": ["@vitest/expect@4.0.17", "", { "dependencies": { "@standard-schema/spec": "^1.0.0", "@types/chai": "^5.2.2", "@vitest/spy": "4.0.17", "@vitest/utils": "4.0.17", "chai": "^6.2.1", "tinyrainbow": "^3.0.3" } }, "sha512-mEoqP3RqhKlbmUmntNDDCJeTDavDR+fVYkSOw8qRwJFaW/0/5zA9zFeTrHqNtcmwh6j26yMmwx2PqUDPzt5ZAQ=="], - "@vitest/mocker": ["@vitest/mocker@4.0.16", "", { "dependencies": { "@vitest/spy": "4.0.16", "estree-walker": "^3.0.3", "magic-string": "^0.30.21" }, "peerDependencies": { "msw": "^2.4.9", "vite": "^6.0.0 || ^7.0.0-0" }, "optionalPeers": ["msw", "vite"] }, "sha512-yb6k4AZxJTB+q9ycAvsoxGn+j/po0UaPgajllBgt1PzoMAAmJGYFdDk0uCcRcxb3BrME34I6u8gHZTQlkqSZpg=="], + "@vitest/mocker": ["@vitest/mocker@4.0.17", "", { "dependencies": { "@vitest/spy": "4.0.17", "estree-walker": "^3.0.3", "magic-string": "^0.30.21" }, "peerDependencies": { "msw": "^2.4.9", "vite": "^6.0.0 || ^7.0.0-0" }, "optionalPeers": ["msw", "vite"] }, "sha512-+ZtQhLA3lDh1tI2wxe3yMsGzbp7uuJSWBM1iTIKCbppWTSBN09PUC+L+fyNlQApQoR+Ps8twt2pbSSXg2fQVEQ=="], - "@vitest/pretty-format": ["@vitest/pretty-format@4.0.16", "", { "dependencies": { "tinyrainbow": "^3.0.3" } }, "sha512-eNCYNsSty9xJKi/UdVD8Ou16alu7AYiS2fCPRs0b1OdhJiV89buAXQLpTbe+X8V9L6qrs9CqyvU7OaAopJYPsA=="], + "@vitest/pretty-format": ["@vitest/pretty-format@4.0.17", "", { "dependencies": { "tinyrainbow": "^3.0.3" } }, "sha512-Ah3VAYmjcEdHg6+MwFE17qyLqBHZ+ni2ScKCiW2XrlSBV4H3Z7vYfPfz7CWQ33gyu76oc0Ai36+kgLU3rfF4nw=="], - "@vitest/runner": ["@vitest/runner@4.0.16", "", { "dependencies": { "@vitest/utils": "4.0.16", "pathe": "^2.0.3" } }, "sha512-VWEDm5Wv9xEo80ctjORcTQRJ539EGPB3Pb9ApvVRAY1U/WkHXmmYISqU5E79uCwcW7xYUV38gwZD+RV755fu3Q=="], + "@vitest/runner": ["@vitest/runner@4.0.17", "", { "dependencies": { "@vitest/utils": "4.0.17", "pathe": "^2.0.3" } }, "sha512-JmuQyf8aMWoo/LmNFppdpkfRVHJcsgzkbCA+/Bk7VfNH7RE6Ut2qxegeyx2j3ojtJtKIbIGy3h+KxGfYfk28YQ=="], - "@vitest/snapshot": ["@vitest/snapshot@4.0.16", "", { "dependencies": { "@vitest/pretty-format": "4.0.16", "magic-string": "^0.30.21", "pathe": "^2.0.3" } }, "sha512-sf6NcrYhYBsSYefxnry+DR8n3UV4xWZwWxYbCJUt2YdvtqzSPR7VfGrY0zsv090DAbjFZsi7ZaMi1KnSRyK1XA=="], + "@vitest/snapshot": ["@vitest/snapshot@4.0.17", "", { "dependencies": { "@vitest/pretty-format": "4.0.17", "magic-string": "^0.30.21", "pathe": "^2.0.3" } }, "sha512-npPelD7oyL+YQM2gbIYvlavlMVWUfNNGZPcu0aEUQXt7FXTuqhmgiYupPnAanhKvyP6Srs2pIbWo30K0RbDtRQ=="], - "@vitest/spy": ["@vitest/spy@4.0.16", "", {}, "sha512-4jIOWjKP0ZUaEmJm00E0cOBLU+5WE0BpeNr3XN6TEF05ltro6NJqHWxXD0kA8/Zc8Nh23AT8WQxwNG+WeROupw=="], + "@vitest/spy": ["@vitest/spy@4.0.17", "", {}, "sha512-I1bQo8QaP6tZlTomQNWKJE6ym4SHf3oLS7ceNjozxxgzavRAgZDc06T7kD8gb9bXKEgcLNt00Z+kZO6KaJ62Ew=="], - "@vitest/utils": ["@vitest/utils@4.0.16", "", { "dependencies": { "@vitest/pretty-format": "4.0.16", "tinyrainbow": "^3.0.3" } }, "sha512-h8z9yYhV3e1LEfaQ3zdypIrnAg/9hguReGZoS7Gl0aBG5xgA410zBqECqmaF/+RkTggRsfnzc1XaAHA6bmUufA=="], + "@vitest/utils": ["@vitest/utils@4.0.17", "", { "dependencies": { "@vitest/pretty-format": "4.0.17", "tinyrainbow": "^3.0.3" } }, "sha512-RG6iy+IzQpa9SB8HAFHJ9Y+pTzI+h8553MrciN9eC6TFBErqrQaTas4vG+MVj8S4uKk8uTT2p0vgZPnTdxd96w=="], "@workglow/ai": ["@workglow/ai@workspace:packages/ai"], @@ -843,7 +843,7 @@ "buffer": ["buffer@5.7.1", "", { "dependencies": { "base64-js": "^1.3.1", "ieee754": "^1.1.13" } }, "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ=="], - "bun-types": ["bun-types@1.3.5", "", { "dependencies": { "@types/node": "*" } }, "sha512-inmAYe2PFLs0SUbFOWSVD24sg1jFlMPxOjOSSCYqUgn4Hsc3rDc7dFvfVYjFPNHtov6kgUeulV4SxbuIV/stPw=="], + "bun-types": ["bun-types@1.3.6", "", { "dependencies": { "@types/node": "*" } }, "sha512-OlFwHcnNV99r//9v5IIOgQ9Uk37gZqrNMCcqEaExdkVq3Avwqok1bJFmvGMCkCE0FqzdY8VMOZpfpR3lwI+CsQ=="], "call-bind": ["call-bind@1.0.8", "", { "dependencies": { "call-bind-apply-helpers": "^1.0.0", "es-define-property": "^1.0.0", "get-intrinsic": "^1.2.4", "set-function-length": "^1.2.2" } }, "sha512-oKlSFMcMwpUg2ednkhQ454wfWiU/ul3CkJe/PEHcTKuiX6RpbehUiFMXu13HalGZxfUwCQzZG747YXBn1im9ww=="], @@ -855,7 +855,7 @@ "camelcase-css": ["camelcase-css@2.0.1", "", {}, "sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA=="], - "caniuse-lite": ["caniuse-lite@1.0.30001763", "", {}, "sha512-mh/dGtq56uN98LlNX9qdbKnzINhX0QzhiWBFEkFfsFO4QyCvL8YegrJAazCwXIeqkIob8BlZPGM3xdnY+sgmvQ=="], + "caniuse-lite": ["caniuse-lite@1.0.30001764", "", {}, "sha512-9JGuzl2M+vPL+pz70gtMF9sHdMFbY9FJaQBi186cHKH3pSzDvzoUJUPV6fqiKIMyXbud9ZLg4F3Yza1vJ1+93g=="], "chai": ["chai@6.2.1", "", {}, "sha512-p4Z49OGG5W/WBCPSS/dH3jQ73kD6tiMmUM+bckNK6Jr5JHMG3k9bg/BvKR8lKmtVBKmOiuVaV2ws8s9oSbwysg=="], @@ -1651,19 +1651,19 @@ "tunnel-agent": ["tunnel-agent@0.6.0", "", { "dependencies": { "safe-buffer": "^5.0.1" } }, "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w=="], - "turbo": ["turbo@2.7.3", "", { "optionalDependencies": { "turbo-darwin-64": "2.7.3", "turbo-darwin-arm64": "2.7.3", "turbo-linux-64": "2.7.3", "turbo-linux-arm64": "2.7.3", "turbo-windows-64": "2.7.3", "turbo-windows-arm64": "2.7.3" }, "bin": { "turbo": "bin/turbo" } }, "sha512-+HjKlP4OfYk+qzvWNETA3cUO5UuK6b5MSc2UJOKyvBceKucQoQGb2g7HlC2H1GHdkfKrk4YF1VPvROkhVZDDLQ=="], + "turbo": ["turbo@2.7.4", "", { "optionalDependencies": { "turbo-darwin-64": "2.7.4", "turbo-darwin-arm64": "2.7.4", "turbo-linux-64": "2.7.4", "turbo-linux-arm64": "2.7.4", "turbo-windows-64": "2.7.4", "turbo-windows-arm64": "2.7.4" }, "bin": { "turbo": "bin/turbo" } }, "sha512-bkO4AddmDishzJB2ze7aYYPaejMoJVfS0XnaR6RCdXFOY8JGJfQE+l9fKiV7uDPa5Ut44gmOWJL3894CIMeH9g=="], - "turbo-darwin-64": ["turbo-darwin-64@2.7.3", "", { "os": "darwin", "cpu": "x64" }, "sha512-aZHhvRiRHXbJw1EcEAq4aws1hsVVUZ9DPuSFaq9VVFAKCup7niIEwc22glxb7240yYEr1vLafdQ2U294Vcwz+w=="], + "turbo-darwin-64": ["turbo-darwin-64@2.7.4", "", { "os": "darwin", "cpu": "x64" }, "sha512-xDR30ltfkSsRfGzABBckvl1nz1cZ3ssTujvdj+TPwOweeDRvZ0e06t5DS0rmRBvyKpgGs42K/EK6Mn2qLlFY9A=="], - "turbo-darwin-arm64": ["turbo-darwin-arm64@2.7.3", "", { "os": "darwin", "cpu": "arm64" }, "sha512-CkVrHSq+Bnhl9sX2LQgqQYVfLTWC2gvI74C4758OmU0djfrssDKU9d4YQF0AYXXhIIRZipSXfxClQziIMD+EAg=="], + "turbo-darwin-arm64": ["turbo-darwin-arm64@2.7.4", "", { "os": "darwin", "cpu": "arm64" }, "sha512-P7sjqXtOL/+nYWPvcDGWhi8wf8M8mZHHB8XEzw2VX7VJrS8IGHyJHGD1AYfDvhAEcr7pnk3gGifz3/xyhI655w=="], - "turbo-linux-64": ["turbo-linux-64@2.7.3", "", { "os": "linux", "cpu": "x64" }, "sha512-GqDsCNnzzr89kMaLGpRALyigUklzgxIrSy2pHZVXyifgczvYPnLglex78Aj3T2gu+T3trPPH2iJ+pWucVOCC2Q=="], + "turbo-linux-64": ["turbo-linux-64@2.7.4", "", { "os": "linux", "cpu": "x64" }, "sha512-GofFOxRO/IhG8BcPyMSSB3Y2+oKQotsaYbHxL9yD6JPb20/o35eo+zUSyazOtilAwDHnak5dorAJFoFU8MIg2A=="], - "turbo-linux-arm64": ["turbo-linux-arm64@2.7.3", "", { "os": "linux", "cpu": "arm64" }, "sha512-NdCDTfIcIo3dWjsiaAHlxu5gW61Ed/8maah1IAF/9E3EtX0aAHNiBMbuYLZaR4vRJ7BeVkYB6xKWRtdFLZ0y3g=="], + "turbo-linux-arm64": ["turbo-linux-arm64@2.7.4", "", { "os": "linux", "cpu": "arm64" }, "sha512-+RQKgNjksVPxYAyAgmDV7w/1qj++qca+nSNTAOKGOfJiDtSvRKoci89oftJ6anGs00uamLKVEQ712TI/tfNAIw=="], - "turbo-windows-64": ["turbo-windows-64@2.7.3", "", { "os": "win32", "cpu": "x64" }, "sha512-7bVvO987daXGSJVYBoG8R4Q+csT1pKIgLJYZevXRQ0Hqw0Vv4mKme/TOjYXs9Qb1xMKh51Tb3bXKDbd8/4G08g=="], + "turbo-windows-64": ["turbo-windows-64@2.7.4", "", { "os": "win32", "cpu": "x64" }, "sha512-rfak1+g+ON3czs1mDYsCS4X74ZmK6gOgRQTXjDICtzvR4o61paqtgAYtNPofcVsMWeF4wvCajSeoAkkeAnQ1kg=="], - "turbo-windows-arm64": ["turbo-windows-arm64@2.7.3", "", { "os": "win32", "cpu": "arm64" }, "sha512-nTodweTbPmkvwMu/a55XvjMsPtuyUSC+sV7f/SR57K36rB2I0YG21qNETN+00LOTUW9B3omd8XkiXJkt4kx/cw=="], + "turbo-windows-arm64": ["turbo-windows-arm64@2.7.4", "", { "os": "win32", "cpu": "arm64" }, "sha512-1ZgBNjNRbDu/fPeqXuX9i26x3CJ/Y1gcwUpQ+Vp7kN9Un6RZ9kzs164f/knrjcu5E+szCRexVjRSJay1k5jApA=="], "type-check": ["type-check@0.4.0", "", { "dependencies": { "prelude-ls": "^1.2.1" } }, "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew=="], @@ -1697,7 +1697,7 @@ "vite": ["vite@7.2.6", "", { "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", "picomatch": "^4.0.3", "postcss": "^8.5.6", "rollup": "^4.43.0", "tinyglobby": "^0.2.15" }, "optionalDependencies": { "fsevents": "~2.3.3" }, "peerDependencies": { "@types/node": "^20.19.0 || >=22.12.0", "jiti": ">=1.21.0", "less": "^4.0.0", "lightningcss": "^1.21.0", "sass": "^1.70.0", "sass-embedded": "^1.70.0", "stylus": ">=0.54.8", "sugarss": "^5.0.0", "terser": "^5.16.0", "tsx": "^4.8.1", "yaml": "^2.4.2" }, "optionalPeers": ["@types/node", "jiti", "less", "lightningcss", "sass", "sass-embedded", "stylus", "sugarss", "terser", "tsx", "yaml"], "bin": { "vite": "bin/vite.js" } }, "sha512-tI2l/nFHC5rLh7+5+o7QjKjSR04ivXDF4jcgV0f/bTQ+OJiITy5S6gaynVsEM+7RqzufMnVbIon6Sr5x1SDYaQ=="], - "vitest": ["vitest@4.0.16", "", { "dependencies": { "@vitest/expect": "4.0.16", "@vitest/mocker": "4.0.16", "@vitest/pretty-format": "4.0.16", "@vitest/runner": "4.0.16", "@vitest/snapshot": "4.0.16", "@vitest/spy": "4.0.16", "@vitest/utils": "4.0.16", "es-module-lexer": "^1.7.0", "expect-type": "^1.2.2", "magic-string": "^0.30.21", "obug": "^2.1.1", "pathe": "^2.0.3", "picomatch": "^4.0.3", "std-env": "^3.10.0", "tinybench": "^2.9.0", "tinyexec": "^1.0.2", "tinyglobby": "^0.2.15", "tinyrainbow": "^3.0.3", "vite": "^6.0.0 || ^7.0.0", "why-is-node-running": "^2.3.0" }, "peerDependencies": { "@edge-runtime/vm": "*", "@opentelemetry/api": "^1.9.0", "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", "@vitest/browser-playwright": "4.0.16", "@vitest/browser-preview": "4.0.16", "@vitest/browser-webdriverio": "4.0.16", "@vitest/ui": "4.0.16", "happy-dom": "*", "jsdom": "*" }, "optionalPeers": ["@edge-runtime/vm", "@opentelemetry/api", "@types/node", "@vitest/browser-playwright", "@vitest/browser-preview", "@vitest/browser-webdriverio", "@vitest/ui", "happy-dom", "jsdom"], "bin": { "vitest": "vitest.mjs" } }, "sha512-E4t7DJ9pESL6E3I8nFjPa4xGUd3PmiWDLsDztS2qXSJWfHtbQnwAWylaBvSNY48I3vr8PTqIZlyK8TE3V3CA4Q=="], + "vitest": ["vitest@4.0.17", "", { "dependencies": { "@vitest/expect": "4.0.17", "@vitest/mocker": "4.0.17", "@vitest/pretty-format": "4.0.17", "@vitest/runner": "4.0.17", "@vitest/snapshot": "4.0.17", "@vitest/spy": "4.0.17", "@vitest/utils": "4.0.17", "es-module-lexer": "^1.7.0", "expect-type": "^1.2.2", "magic-string": "^0.30.21", "obug": "^2.1.1", "pathe": "^2.0.3", "picomatch": "^4.0.3", "std-env": "^3.10.0", "tinybench": "^2.9.0", "tinyexec": "^1.0.2", "tinyglobby": "^0.2.15", "tinyrainbow": "^3.0.3", "vite": "^6.0.0 || ^7.0.0", "why-is-node-running": "^2.3.0" }, "peerDependencies": { "@edge-runtime/vm": "*", "@opentelemetry/api": "^1.9.0", "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", "@vitest/browser-playwright": "4.0.17", "@vitest/browser-preview": "4.0.17", "@vitest/browser-webdriverio": "4.0.17", "@vitest/ui": "4.0.17", "happy-dom": "*", "jsdom": "*" }, "optionalPeers": ["@edge-runtime/vm", "@opentelemetry/api", "@types/node", "@vitest/browser-playwright", "@vitest/browser-preview", "@vitest/browser-webdriverio", "@vitest/ui", "happy-dom", "jsdom"], "bin": { "vitest": "vitest.mjs" } }, "sha512-FQMeF0DJdWY0iOnbv466n/0BudNdKj1l5jYgl5JVTwjSsZSlqyXFt/9+1sEyhR6CLowbZpV7O1sCHrzBhucKKg=="], "w3c-keyname": ["w3c-keyname@2.2.8", "", {}, "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ=="], diff --git a/docs/developers/03_extending.md b/docs/developers/03_extending.md index b9cbddf9..875ccb76 100644 --- a/docs/developers/03_extending.md +++ b/docs/developers/03_extending.md @@ -185,7 +185,7 @@ export class MyTask extends Task { async executeReactive(input: MyTaskInput) { // By the time execute runs, model is a ModelConfig object - // and dataSource is an ITabularRepository instance + // and dataSource is an ITabularStorage instance const { model, dataSource, prompt } = input; // ... } diff --git a/package.json b/package.json index e3b41125..e06b6829 100644 --- a/package.json +++ b/package.json @@ -32,7 +32,7 @@ "publish": "bun ./scripts/publish-workspaces.ts" }, "dependencies": { - "caniuse-lite": "^1.0.30001763" + "caniuse-lite": "^1.0.30001764" }, "catalog": { "@sroussey/transformers": "3.8.2", @@ -43,9 +43,9 @@ }, "devDependencies": { "@sroussey/changesets-cli": "^2.29.7", - "@types/bun": "^1.3.5", - "@typescript-eslint/eslint-plugin": "^8.52.0", - "@typescript-eslint/parser": "^8.52.0", + "@types/bun": "^1.3.6", + "@typescript-eslint/eslint-plugin": "^8.53.0", + "@typescript-eslint/parser": "^8.53.0", "concurrently": "^9.2.1", "eslint": "^9.39.2", "eslint-plugin-jsx-a11y": "^6.10.2", @@ -54,9 +54,9 @@ "eslint-plugin-regexp": "^2.10.0", "globals": "^17.0.0", "prettier": "^3.7.4", - "turbo": "^2.7.3", + "turbo": "^2.7.4", "typescript": "5.9.3", - "vitest": "^4.0.16" + "vitest": "^4.0.17" }, "engines": { "bun": "^1.3.0" diff --git a/packages/ai/README.md b/packages/ai/README.md index fd3fde8b..f448d039 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -593,7 +593,7 @@ This resolution is handled by the input resolver system, which inspects schema ` | --------------------------------- | ---------------------------------------- | -------------------------- | | `model` | Any AI model configuration | ModelRepository | | `model:TaskName` | Model compatible with specific task type | ModelRepository | -| `repository:tabular` | Tabular data repository | TabularRepositoryRegistry | +| `repository:tabular` | Tabular data repository | TabularStorageRegistry | | `repository:document-node-vector` | Vector storage repository | VectorRepositoryRegistry | | `repository:document` | Document repository | DocumentRepositoryRegistry | diff --git a/packages/ai/src/model/InMemoryModelRepository.ts b/packages/ai/src/model/InMemoryModelRepository.ts index 72fc78a9..45f6a0aa 100644 --- a/packages/ai/src/model/InMemoryModelRepository.ts +++ b/packages/ai/src/model/InMemoryModelRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { InMemoryTabularRepository } from "@workglow/storage"; +import { InMemoryTabularStorage } from "@workglow/storage"; import { ModelRepository } from "./ModelRepository"; import { ModelPrimaryKeyNames, ModelRecordSchema } from "./ModelSchema"; @@ -14,6 +14,6 @@ import { ModelPrimaryKeyNames, ModelRecordSchema } from "./ModelSchema"; */ export class InMemoryModelRepository extends ModelRepository { constructor() { - super(new InMemoryTabularRepository(ModelRecordSchema, ModelPrimaryKeyNames)); + super(new InMemoryTabularStorage(ModelRecordSchema, ModelPrimaryKeyNames)); } } diff --git a/packages/ai/src/model/ModelRepository.ts b/packages/ai/src/model/ModelRepository.ts index 234fb168..5940e2dc 100644 --- a/packages/ai/src/model/ModelRepository.ts +++ b/packages/ai/src/model/ModelRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type BaseTabularRepository } from "@workglow/storage"; +import { type BaseTabularStorage } from "@workglow/storage"; import { EventEmitter, EventParameters } from "@workglow/util"; import { ModelPrimaryKeyNames, ModelRecord, ModelRecordSchema } from "./ModelSchema"; @@ -37,12 +37,12 @@ export class ModelRepository { /** * Repository for storing and managing Model instances */ - protected readonly modelTabularRepository: BaseTabularRepository< + protected readonly modelTabularRepository: BaseTabularStorage< typeof ModelRecordSchema, typeof ModelPrimaryKeyNames >; constructor( - modelTabularRepository: BaseTabularRepository< + modelTabularRepository: BaseTabularStorage< typeof ModelRecordSchema, typeof ModelPrimaryKeyNames > diff --git a/packages/ai/src/task/ChunkRetrievalTask.ts b/packages/ai/src/task/ChunkRetrievalTask.ts index 9f91b457..ee2ff5f7 100644 --- a/packages/ai/src/task/ChunkRetrievalTask.ts +++ b/packages/ai/src/task/ChunkRetrievalTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorRepository, TypeChunkVectorRepository } from "@workglow/storage"; +import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/storage"; import { CreateWorkflow, IExecuteContext, @@ -172,7 +172,7 @@ export class DocumentNodeRetrievalTask extends Task< } = input; // Repository is resolved by input resolver system before execution - const repo = repository as AnyChunkVectorRepository; + const repo = repository as AnyChunkVectorStorage; // Determine query vector let queryVector: TypedArray; diff --git a/packages/ai/src/task/ChunkVectorHybridSearchTask.ts b/packages/ai/src/task/ChunkVectorHybridSearchTask.ts index df9b12cf..61a8948f 100644 --- a/packages/ai/src/task/ChunkVectorHybridSearchTask.ts +++ b/packages/ai/src/task/ChunkVectorHybridSearchTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorRepository, TypeChunkVectorRepository } from "@workglow/storage"; +import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/storage"; import { CreateWorkflow, IExecuteContext, @@ -171,7 +171,7 @@ export class ChunkVectorHybridSearchTask extends Task< } = input; // Repository is resolved by input resolver system before execution - const repo = repository as AnyChunkVectorRepository; + const repo = repository as AnyChunkVectorStorage; // Check if repository supports hybrid search if (!repo.hybridSearch) { diff --git a/packages/ai/src/task/ChunkVectorSearchTask.ts b/packages/ai/src/task/ChunkVectorSearchTask.ts index 9e30dd5f..45c6d3f6 100644 --- a/packages/ai/src/task/ChunkVectorSearchTask.ts +++ b/packages/ai/src/task/ChunkVectorSearchTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorRepository, TypeChunkVectorRepository } from "@workglow/storage"; +import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/storage"; import { CreateWorkflow, IExecuteContext, @@ -131,7 +131,7 @@ export class ChunkVectorSearchTask extends Task< ): Promise { const { repository, query, topK = 10, filter, scoreThreshold = 0 } = input; - const repo = repository as AnyChunkVectorRepository; + const repo = repository as AnyChunkVectorStorage; const results = await repo.similaritySearch(query, { topK, diff --git a/packages/ai/src/task/ChunkVectorUpsertTask.ts b/packages/ai/src/task/ChunkVectorUpsertTask.ts index 0d79afe8..f069efc8 100644 --- a/packages/ai/src/task/ChunkVectorUpsertTask.ts +++ b/packages/ai/src/task/ChunkVectorUpsertTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorRepository, TypeChunkVectorRepository } from "@workglow/storage"; +import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/storage"; import { CreateWorkflow, IExecuteContext, @@ -114,7 +114,7 @@ export class ChunkVectorUpsertTask extends Task< ? metadata : Array(vectorArray.length).fill(metadata); - const repo = repository as AnyChunkVectorRepository; + const repo = repository as AnyChunkVectorStorage; await context.updateProgress(1, "Upserting vectors"); diff --git a/packages/storage/README.md b/packages/storage/README.md index 254f73fd..f86bfa63 100644 --- a/packages/storage/README.md +++ b/packages/storage/README.md @@ -33,8 +33,8 @@ Modular storage solutions for Workglow.AI platform with multiple backend impleme - [Compound Primary Keys](#compound-primary-keys) - [Custom File Layout (KV on filesystem)](#custom-file-layout-kv-on-filesystem) - [API Reference](#api-reference) - - [IKvRepository\](#ikvrepositorykey-value) - - [ITabularRepository\](#itabularrepositoryschema-primarykeynames) + - [IKvStorage\](#ikvrepositorykey-value) + - [ITabularStorage\](#itabularrepositoryschema-primarykeynames) - [IQueueStorage\](#iqueuestorageinput-output) - [Examples](#examples) - [User Management System](#user-management-system) @@ -47,16 +47,16 @@ Modular storage solutions for Workglow.AI platform with multiple backend impleme ```typescript // Key-Value Storage (simple data) -import { InMemoryKvRepository } from "@workglow/storage"; +import { InMemoryKvStorage } from "@workglow/storage"; -const kvStore = new InMemoryKvRepository(); +const kvStore = new InMemoryKvStorage(); await kvStore.put("user:123", { name: "Alice", age: 30 }); const kvUser = await kvStore.get("user:123"); // { name: "Alice", age: 30 } ``` ```typescript // Tabular Storage (structured data with schemas) -import { InMemoryTabularRepository } from "@workglow/storage"; +import { InMemoryTabularStorage } from "@workglow/storage"; import { JsonSchema } from "@workglow/util"; const userSchema = { @@ -71,7 +71,7 @@ const userSchema = { additionalProperties: false, } as const satisfies JsonSchema; -const userRepo = new InMemoryTabularRepository( +const userRepo = new InMemoryTabularStorage( userSchema, ["id"], // primary key ["email"] // additional indexes @@ -139,7 +139,7 @@ The package uses conditional exports, so importing from `@workglow/storage` auto ```typescript // Import from the top-level package; it resolves to the correct target per environment -import { InMemoryKvRepository, SqliteTabularRepository } from "@workglow/storage"; +import { InMemoryKvStorage, SqliteTabularStorage } from "@workglow/storage"; ``` ## Storage Types @@ -151,10 +151,10 @@ Simple key-value storage for unstructured or semi-structured data. #### Basic Usage ```typescript -import { InMemoryKvRepository, FsFolderJsonKvRepository } from "@workglow/storage"; +import { InMemoryKvStorage, FsFolderJsonKvRepository } from "@workglow/storage"; // In-memory (for testing/caching) -const cache = new InMemoryKvRepository(); +const cache = new InMemoryKvStorage(); await cache.put("config", { theme: "dark", language: "en" }); // File-based JSON (persistent) @@ -190,7 +190,7 @@ const supabaseStore = new SupabaseKvRepository(supabase, "settings"); #### Bulk Operations ```typescript -const store = new InMemoryKvRepository(); +const store = new InMemoryKvStorage(); // Bulk insert await store.putBulk([ @@ -209,7 +209,7 @@ const count = await store.size(); // 2 #### Event Handling ```typescript -const store = new InMemoryKvRepository(); +const store = new InMemoryKvStorage(); // Listen to storage events store.on("put", (key, value) => { @@ -232,7 +232,7 @@ Structured storage with schemas, primary keys, and indexing for complex data rel ```typescript import { JsonSchema } from "@workglow/util"; -import { InMemoryTabularRepository } from "@workglow/storage"; +import { InMemoryTabularStorage } from "@workglow/storage"; // Define your entity schema const UserSchema = { @@ -250,7 +250,7 @@ const UserSchema = { } as const satisfies JsonSchema; // Create repository with primary key and indexes -const userRepo = new InMemoryTabularRepository( +const userRepo = new InMemoryTabularStorage( UserSchema, ["id"], // Primary key (can be compound: ["dept", "id"]) ["email", "department", ["department", "age"]] // Indexes for fast lookups @@ -343,9 +343,9 @@ await userRepo.deleteSearch({ ```typescript // SQLite (Node.js/Bun) -import { SqliteTabularRepository } from "@workglow/storage"; +import { SqliteTabularStorage } from "@workglow/storage"; -const sqliteUsers = new SqliteTabularRepository( +const sqliteUsers = new SqliteTabularStorage( "./users.db", "users", UserSchema, @@ -354,11 +354,11 @@ const sqliteUsers = new SqliteTabularRepository( ); // PostgreSQL (Node.js/Bun) -import { PostgresTabularRepository } from "@workglow/storage"; +import { PostgresTabularStorage } from "@workglow/storage"; import { Pool } from "pg"; const pool = new Pool({ connectionString: "postgresql://..." }); -const pgUsers = new PostgresTabularRepository( +const pgUsers = new PostgresTabularStorage( pool, "users", UserSchema, @@ -489,13 +489,13 @@ const cloudJobQueue = new SupabaseQueueStorage(supabase, "background-jobs"); ```typescript import { SqliteKvRepository, - PostgresTabularRepository, + PostgresTabularStorage, FsFolderJsonKvRepository, } from "@workglow/storage"; // Mix and match storage backends const cache = new FsFolderJsonKvRepository("./cache"); -const users = new PostgresTabularRepository(pool, "users", UserSchema, ["id"]); +const users = new PostgresTabularStorage(pool, "users", UserSchema, ["id"]); ``` ### Bun Environment @@ -503,7 +503,7 @@ const users = new PostgresTabularRepository(pool, "users", UserSchema, ["id"]); ```typescript // Bun has access to all implementations import { - SqliteTabularRepository, + SqliteTabularStorage, FsFolderJsonKvRepository, PostgresQueueStorage, SupabaseTabularRepository, @@ -513,7 +513,7 @@ import { Database } from "bun:sqlite"; import { createClient } from "@supabase/supabase-js"; const db = new Database("./app.db"); -const data = new SqliteTabularRepository(db, "items", ItemSchema, ["id"]); +const data = new SqliteTabularStorage(db, "items", ItemSchema, ["id"]); // Or use Supabase for cloud storage const supabase = createClient("https://your-project.supabase.co", "your-anon-key"); @@ -532,7 +532,7 @@ Repositories can be registered globally by ID, allowing tasks to reference them import { registerTabularRepository, getTabularRepository, - InMemoryTabularRepository, + InMemoryTabularStorage, } from "@workglow/storage"; // Define your schema @@ -548,7 +548,7 @@ const userSchema = { } as const; // Create and register a repository -const userRepo = new InMemoryTabularRepository(userSchema, ["id"] as const); +const userRepo = new InMemoryTabularStorage(userSchema, ["id"] as const); registerTabularRepository("users", userRepo); // Later, retrieve the repository by ID @@ -616,7 +616,7 @@ const docSchema = TypeDocumentRepository({ All storage implementations support event emission for monitoring and reactive programming: ```typescript -const store = new InMemoryTabularRepository(UserSchema, ["id"]); +const store = new InMemoryTabularStorage(UserSchema, ["id"]); // Monitor all operations store.on("put", (entity) => console.log("User created/updated:", entity)); @@ -645,7 +645,7 @@ const OrderLineSchema = { additionalProperties: false, } as const satisfies JsonSchema; -const orderLines = new InMemoryTabularRepository( +const orderLines = new InMemoryTabularStorage( OrderLineSchema, ["orderId", "lineNumber"], // Compound primary key ["productId"] // Additional index @@ -684,12 +684,12 @@ await files.put("note-1", "Hello world"); ## API Reference -### IKvRepository +### IKvStorage Core interface for key-value storage: ```typescript -interface IKvRepository { +interface IKvStorage { // Core operations put(key: Key, value: Value): Promise; putBulk(items: Array<{ key: Key; value: Value }>): Promise; @@ -708,12 +708,12 @@ interface IKvRepository { } ``` -### ITabularRepository +### ITabularStorage Core interface for tabular storage: ```typescript -interface ITabularRepository { +interface ITabularStorage { // Core operations put(entity: Entity): Promise; putBulk(entities: Entity[]): Promise; @@ -795,7 +795,7 @@ interface IQueueStorage { ```typescript import { JsonSchema, FromSchema } from "@workglow/util"; -import { InMemoryTabularRepository, InMemoryKvRepository } from "@workglow/storage"; +import { InMemoryTabularStorage, InMemoryKvStorage } from "@workglow/storage"; // User profile with tabular storage const UserSchema = { @@ -817,14 +817,14 @@ const UserSchema = { additionalProperties: false, } as const satisfies JsonSchema; -const userRepo = new InMemoryTabularRepository( +const userRepo = new InMemoryTabularStorage( UserSchema, ["id"], ["email", "username"] ); // User sessions with KV storage -const sessionStore = new InMemoryKvRepository(); +const sessionStore = new InMemoryKvStorage(); // User management class class UserManager { @@ -1075,13 +1075,13 @@ bun test --grep "Sqlite" # Native tests ```typescript import { describe, test, expect, beforeEach } from "vitest"; -import { InMemoryTabularRepository } from "@workglow/storage"; +import { InMemoryTabularStorage } from "@workglow/storage"; describe("UserRepository", () => { - let userRepo: InMemoryTabularRepository; + let userRepo: InMemoryTabularStorage; beforeEach(() => { - userRepo = new InMemoryTabularRepository( + userRepo = new InMemoryTabularStorage( UserSchema, ["id"], ["email"] diff --git a/packages/storage/src/browser.ts b/packages/storage/src/browser.ts index f960c5ed..2f961441 100644 --- a/packages/storage/src/browser.ts +++ b/packages/storage/src/browser.ts @@ -6,12 +6,12 @@ export * from "./common"; -export * from "./tabular/IndexedDbTabularRepository"; -export * from "./tabular/SharedInMemoryTabularRepository"; -export * from "./tabular/SupabaseTabularRepository"; +export * from "./tabular/IndexedDbTabularStorage"; +export * from "./tabular/SharedInMemoryTabularStorage"; +export * from "./tabular/SupabaseTabularStorage"; -export * from "./kv/IndexedDbKvRepository"; -export * from "./kv/SupabaseKvRepository"; +export * from "./kv/IndexedDbKvStorage"; +export * from "./kv/SupabaseKvStorage"; export * from "./queue/IndexedDbQueueStorage"; export * from "./queue/SupabaseQueueStorage"; diff --git a/packages/storage/src/chunk-vector/ChunkVectorRepositoryRegistry.ts b/packages/storage/src/chunk-vector/ChunkVectorStorageRegistry.ts similarity index 85% rename from packages/storage/src/chunk-vector/ChunkVectorRepositoryRegistry.ts rename to packages/storage/src/chunk-vector/ChunkVectorStorageRegistry.ts index c50d90f2..7c51b929 100644 --- a/packages/storage/src/chunk-vector/ChunkVectorRepositoryRegistry.ts +++ b/packages/storage/src/chunk-vector/ChunkVectorStorageRegistry.ts @@ -10,21 +10,21 @@ import { registerInputResolver, ServiceRegistry, } from "@workglow/util"; -import { AnyChunkVectorRepository } from "./IChunkVectorRepository"; +import { AnyChunkVectorStorage } from "./IChunkVectorStorage"; /** * Service token for the documenbt chunk vector repository registry * Maps repository IDs to IVectorChunkRepository instances */ export const DOCUMENT_CHUNK_VECTOR_REPOSITORIES = createServiceToken< - Map + Map >("storage.document-node-vector.repositories"); // Register default factory if not already registered if (!globalServiceRegistry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES)) { globalServiceRegistry.register( DOCUMENT_CHUNK_VECTOR_REPOSITORIES, - (): Map => new Map(), + (): Map => new Map(), true ); } @@ -33,7 +33,7 @@ if (!globalServiceRegistry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES)) { * Gets the global document chunk vector repository registry * @returns Map of document chunk vector repository ID to instance */ -export function getGlobalChunkVectorRepositories(): Map { +export function getGlobalChunkVectorRepositories(): Map { return globalServiceRegistry.get(DOCUMENT_CHUNK_VECTOR_REPOSITORIES); } @@ -44,7 +44,7 @@ export function getGlobalChunkVectorRepositories(): Map { +): Promise { const repos = registry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) - ? registry.get>(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) + ? registry.get>(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) : getGlobalChunkVectorRepositories(); const repo = repos.get(id); diff --git a/packages/storage/src/chunk-vector/IChunkVectorRepository.ts b/packages/storage/src/chunk-vector/IChunkVectorStorage.ts similarity index 92% rename from packages/storage/src/chunk-vector/IChunkVectorRepository.ts rename to packages/storage/src/chunk-vector/IChunkVectorStorage.ts index 060d7108..0b68d414 100644 --- a/packages/storage/src/chunk-vector/IChunkVectorRepository.ts +++ b/packages/storage/src/chunk-vector/IChunkVectorStorage.ts @@ -11,9 +11,9 @@ import type { TypedArray, TypedArraySchemaOptions, } from "@workglow/util"; -import type { ITabularRepository, TabularEventListeners } from "../tabular/ITabularRepository"; +import type { ITabularStorage, TabularEventListeners } from "../tabular/ITabularStorage"; -export type AnyChunkVectorRepository = IChunkVectorRepository; +export type AnyChunkVectorStorage = IChunkVectorStorage; /** * Options for vector search operations @@ -69,11 +69,11 @@ export type VectorChunkEventParameters< * @typeParam PrimaryKeyNames - Array of property names that form the primary key * @typeParam Entity - The entity type */ -export interface IChunkVectorRepository< +export interface IChunkVectorStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, Entity = FromSchema, -> extends ITabularRepository { +> extends ITabularStorage { /** * Get the vector dimension * @returns The vector dimension diff --git a/packages/storage/src/chunk-vector/InMemoryChunkVectorRepository.ts b/packages/storage/src/chunk-vector/InMemoryChunkVectorStorage.ts similarity index 95% rename from packages/storage/src/chunk-vector/InMemoryChunkVectorRepository.ts rename to packages/storage/src/chunk-vector/InMemoryChunkVectorStorage.ts index dcadfb3f..009c5b44 100644 --- a/packages/storage/src/chunk-vector/InMemoryChunkVectorRepository.ts +++ b/packages/storage/src/chunk-vector/InMemoryChunkVectorStorage.ts @@ -6,13 +6,13 @@ import type { TypedArray } from "@workglow/util"; import { cosineSimilarity } from "@workglow/util"; -import { InMemoryTabularRepository } from "../tabular/InMemoryTabularRepository"; +import { InMemoryTabularStorage } from "../tabular/InMemoryTabularStorage"; import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; import type { HybridSearchOptions, - IChunkVectorRepository, + IChunkVectorStorage, VectorSearchOptions, -} from "./IChunkVectorRepository"; +} from "./IChunkVectorStorage"; /** * Check if metadata matches filter @@ -54,17 +54,17 @@ function textRelevance(text: string, query: string): number { * @template Metadata - The metadata type for the document chunk * @template Vector - The vector type for the document chunk */ -export class InMemoryChunkVectorRepository< +export class InMemoryChunkVectorStorage< Metadata extends Record = Record, Vector extends TypedArray = Float32Array, > - extends InMemoryTabularRepository< + extends InMemoryTabularStorage< typeof ChunkVectorSchema, typeof ChunkVectorKey, ChunkVector > implements - IChunkVectorRepository< + IChunkVectorStorage< typeof ChunkVectorSchema, typeof ChunkVectorKey, ChunkVector diff --git a/packages/storage/src/chunk-vector/PostgresChunkVectorRepository.ts b/packages/storage/src/chunk-vector/PostgresChunkVectorStorage.ts similarity index 97% rename from packages/storage/src/chunk-vector/PostgresChunkVectorRepository.ts rename to packages/storage/src/chunk-vector/PostgresChunkVectorStorage.ts index 6ca5c27d..afd68081 100644 --- a/packages/storage/src/chunk-vector/PostgresChunkVectorRepository.ts +++ b/packages/storage/src/chunk-vector/PostgresChunkVectorStorage.ts @@ -6,13 +6,13 @@ import { cosineSimilarity, type TypedArray } from "@workglow/util"; import type { Pool } from "pg"; -import { PostgresTabularRepository } from "../tabular/PostgresTabularRepository"; +import { PostgresTabularStorage } from "../tabular/PostgresTabularStorage"; import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; import type { HybridSearchOptions, - IChunkVectorRepository, + IChunkVectorStorage, VectorSearchOptions, -} from "./IChunkVectorRepository"; +} from "./IChunkVectorStorage"; /** * PostgreSQL document chunk vector repository implementation using pgvector extension. @@ -26,17 +26,17 @@ import type { * @template Metadata - The metadata type for the document chunk * @template Vector - The vector type for the document chunk */ -export class PostgresChunkVectorRepository< +export class PostgresChunkVectorStorage< Metadata extends Record = Record, Vector extends TypedArray = Float32Array, > - extends PostgresTabularRepository< + extends PostgresTabularStorage< typeof ChunkVectorSchema, typeof ChunkVectorKey, ChunkVector > implements - IChunkVectorRepository< + IChunkVectorStorage< typeof ChunkVectorSchema, typeof ChunkVectorKey, ChunkVector diff --git a/packages/storage/src/chunk-vector/README.md b/packages/storage/src/chunk-vector/README.md index 2af14dda..f64c8ca0 100644 --- a/packages/storage/src/chunk-vector/README.md +++ b/packages/storage/src/chunk-vector/README.md @@ -5,9 +5,9 @@ Storage for document chunk embeddings with vector similarity search capabilities ## Features - **Multiple Storage Backends:** - - 🧠 `InMemoryChunkVectorRepository` - Fast in-memory storage for testing and small datasets - - 📁 `SqliteChunkVectorRepository` - Persistent SQLite storage for local applications - - 🐘 `PostgresChunkVectorRepository` - PostgreSQL with pgvector extension for production + - 🧠 `InMemoryChunkVectorStorage` - Fast in-memory storage for testing and small datasets + - 📁 `SqliteChunkVectorStorage` - Persistent SQLite storage for local applications + - 🐘 `PostgresChunkVectorStorage` - PostgreSQL with pgvector extension for production - **Quantized Vector Support:** - Float32Array (standard 32-bit floating point) @@ -25,7 +25,7 @@ Storage for document chunk embeddings with vector similarity search capabilities - Top-K retrieval with score thresholds - **Built on Tabular Repositories:** - - Extends `ITabularRepository` for standard CRUD operations + - Extends `ITabularStorage` for standard CRUD operations - Inherits event emitter pattern for monitoring - Type-safe schema-based storage @@ -40,10 +40,10 @@ bun install @workglow/storage ### In-Memory Repository (Testing/Development) ```typescript -import { InMemoryChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage } from "@workglow/storage"; // Create repository with 384 dimensions -const repo = new InMemoryChunkVectorRepository(384); +const repo = new InMemoryChunkVectorStorage(384); await repo.setupDatabase(); // Store a chunk with its embedding @@ -64,10 +64,10 @@ const results = await repo.similaritySearch(new Float32Array([0.15, 0.25, 0.35 / ### Quantized Vectors (Reduced Storage) ```typescript -import { InMemoryChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage } from "@workglow/storage"; // Use Int8Array for 4x smaller storage (binary quantization) -const repo = new InMemoryChunkVectorRepository<{ text: string }, Int8Array>(384, Int8Array); +const repo = new InMemoryChunkVectorStorage<{ text: string }, Int8Array>(384, Int8Array); await repo.setupDatabase(); // Store quantized vectors @@ -85,9 +85,9 @@ const results = await repo.similaritySearch(new Int8Array([100, -50, 75 /* ... * ### SQLite Repository (Local Persistence) ```typescript -import { SqliteChunkVectorRepository } from "@workglow/storage"; +import { SqliteChunkVectorStorage } from "@workglow/storage"; -const repo = new SqliteChunkVectorRepository<{ text: string }>( +const repo = new SqliteChunkVectorStorage<{ text: string }>( "./vectors.db", // database path "chunks", // table name 768 // vector dimension @@ -105,10 +105,10 @@ await repo.putMany([ ```typescript import { Pool } from "pg"; -import { PostgresChunkVectorRepository } from "@workglow/storage"; +import { PostgresChunkVectorStorage } from "@workglow/storage"; const pool = new Pool({ connectionString: "postgresql://..." }); -const repo = new PostgresChunkVectorRepository<{ text: string; category: string }>( +const repo = new PostgresChunkVectorStorage<{ text: string; category: string }>( pool, "chunks", 384 // vector dimension @@ -168,12 +168,12 @@ const ChunkVectorKey = ["chunk_id"] as const; ## API Reference -### IChunkVectorRepository Interface +### IChunkVectorStorage Interface -Extends `ITabularRepository` with vector-specific methods: +Extends `ITabularStorage` with vector-specific methods: ```typescript -interface IChunkVectorRepository extends ITabularRepository< +interface IChunkVectorStorage extends ITabularStorage< Schema, PrimaryKeyNames, Entity @@ -197,7 +197,7 @@ interface IChunkVectorRepository extends ITabul ### Inherited Tabular Methods -From `ITabularRepository`: +From `ITabularStorage`: ```typescript // Setup @@ -299,16 +299,16 @@ The chunk vector repository works alongside `DocumentRepository` for hierarchica ```typescript import { DocumentRepository, - InMemoryChunkVectorRepository, - InMemoryTabularRepository, + InMemoryChunkVectorStorage, + InMemoryTabularStorage, } from "@workglow/storage"; import { DocumentStorageSchema } from "@workglow/storage"; // Initialize storage backends -const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["doc_id"]); +const tabularStorage = new InMemoryTabularStorage(DocumentStorageSchema, ["doc_id"]); await tabularStorage.setupDatabase(); -const vectorStorage = new InMemoryChunkVectorRepository(384); +const vectorStorage = new InMemoryChunkVectorStorage(384); await vectorStorage.setupDatabase(); // Create document repository with both storages diff --git a/packages/storage/src/chunk-vector/SqliteChunkVectorRepository.ts b/packages/storage/src/chunk-vector/SqliteChunkVectorStorage.ts similarity index 95% rename from packages/storage/src/chunk-vector/SqliteChunkVectorRepository.ts rename to packages/storage/src/chunk-vector/SqliteChunkVectorStorage.ts index d9609f1f..a23a4fb2 100644 --- a/packages/storage/src/chunk-vector/SqliteChunkVectorRepository.ts +++ b/packages/storage/src/chunk-vector/SqliteChunkVectorStorage.ts @@ -7,13 +7,13 @@ import { Sqlite } from "@workglow/sqlite"; import type { TypedArray } from "@workglow/util"; import { cosineSimilarity } from "@workglow/util"; -import { SqliteTabularRepository } from "../tabular/SqliteTabularRepository"; +import { SqliteTabularStorage } from "../tabular/SqliteTabularStorage"; import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; import type { HybridSearchOptions, - IChunkVectorRepository, + IChunkVectorStorage, VectorSearchOptions, -} from "./IChunkVectorRepository"; +} from "./IChunkVectorStorage"; /** * Check if metadata matches filter @@ -34,17 +34,17 @@ function matchesFilter(metadata: Metadata, filter: Partial): * @template Metadata - The metadata type for the document chunk * @template Vector - The vector type for the document chunk */ -export class SqliteChunkVectorRepository< +export class SqliteChunkVectorStorage< Metadata extends Record = Record, Vector extends TypedArray = Float32Array, > - extends SqliteTabularRepository< + extends SqliteTabularStorage< typeof ChunkVectorSchema, typeof ChunkVectorKey, ChunkVector > implements - IChunkVectorRepository< + IChunkVectorStorage< typeof ChunkVectorSchema, typeof ChunkVectorKey, ChunkVector diff --git a/packages/storage/src/common-server.ts b/packages/storage/src/common-server.ts index 096b9bfe..118159b9 100644 --- a/packages/storage/src/common-server.ts +++ b/packages/storage/src/common-server.ts @@ -6,16 +6,16 @@ export * from "./common"; -export * from "./tabular/FsFolderTabularRepository"; -export * from "./tabular/PostgresTabularRepository"; -export * from "./tabular/SqliteTabularRepository"; -export * from "./tabular/SupabaseTabularRepository"; +export * from "./tabular/FsFolderTabularStorage"; +export * from "./tabular/PostgresTabularStorage"; +export * from "./tabular/SqliteTabularStorage"; +export * from "./tabular/SupabaseTabularStorage"; -export * from "./kv/FsFolderJsonKvRepository"; -export * from "./kv/FsFolderKvRepository"; -export * from "./kv/PostgresKvRepository"; -export * from "./kv/SqliteKvRepository"; -export * from "./kv/SupabaseKvRepository"; +export * from "./kv/FsFolderJsonKvStorage"; +export * from "./kv/FsFolderKvStorage"; +export * from "./kv/PostgresKvStorage"; +export * from "./kv/SqliteKvStorage"; +export * from "./kv/SupabaseKvStorage"; export * from "./queue/PostgresQueueStorage"; export * from "./queue/SqliteQueueStorage"; @@ -25,12 +25,12 @@ export * from "./queue-limiter/PostgresRateLimiterStorage"; export * from "./queue-limiter/SqliteRateLimiterStorage"; export * from "./queue-limiter/SupabaseRateLimiterStorage"; -export * from "./chunk-vector/PostgresChunkVectorRepository"; -export * from "./chunk-vector/SqliteChunkVectorRepository"; +export * from "./chunk-vector/PostgresChunkVectorStorage"; +export * from "./chunk-vector/SqliteChunkVectorStorage"; // testing -export * from "./kv/IndexedDbKvRepository"; +export * from "./kv/IndexedDbKvStorage"; export * from "./queue-limiter/IndexedDbRateLimiterStorage"; export * from "./queue/IndexedDbQueueStorage"; -export * from "./tabular/IndexedDbTabularRepository"; +export * from "./tabular/IndexedDbTabularStorage"; export * from "./util/IndexedDbTable"; diff --git a/packages/storage/src/common.ts b/packages/storage/src/common.ts index 9c0d27e3..4971fca5 100644 --- a/packages/storage/src/common.ts +++ b/packages/storage/src/common.ts @@ -4,18 +4,18 @@ * SPDX-License-Identifier: Apache-2.0 */ -export * from "./tabular/BaseTabularRepository"; -export * from "./tabular/CachedTabularRepository"; -export * from "./tabular/InMemoryTabularRepository"; -export * from "./tabular/ITabularRepository"; -export * from "./tabular/TabularRepositoryRegistry"; +export * from "./tabular/BaseTabularStorage"; +export * from "./tabular/CachedTabularStorage"; +export * from "./tabular/InMemoryTabularStorage"; +export * from "./tabular/ITabularStorage"; +export * from "./tabular/TabularStorageRegistry"; export * from "./util/RepositorySchema"; -export * from "./kv/IKvRepository"; -export * from "./kv/InMemoryKvRepository"; -export * from "./kv/KvRepository"; -export * from "./kv/KvViaTabularRepository"; +export * from "./kv/IKvStorage"; +export * from "./kv/InMemoryKvStorage"; +export * from "./kv/KvStorage"; +export * from "./kv/KvViaTabularStorage"; export * from "./queue/InMemoryQueueStorage"; export * from "./queue/IQueueStorage"; @@ -34,7 +34,7 @@ export * from "./document/DocumentSchema"; export * from "./document/DocumentStorageSchema"; export * from "./document/StructuralParser"; -export * from "./chunk-vector/ChunkVectorRepositoryRegistry"; +export * from "./chunk-vector/ChunkVectorStorageRegistry"; export * from "./chunk-vector/ChunkVectorSchema"; -export * from "./chunk-vector/IChunkVectorRepository"; -export * from "./chunk-vector/InMemoryChunkVectorRepository"; +export * from "./chunk-vector/IChunkVectorStorage"; +export * from "./chunk-vector/InMemoryChunkVectorStorage"; diff --git a/packages/storage/src/document/DocumentRepository.ts b/packages/storage/src/document/DocumentRepository.ts index 1e61c627..746a16f6 100644 --- a/packages/storage/src/document/DocumentRepository.ts +++ b/packages/storage/src/document/DocumentRepository.ts @@ -7,10 +7,10 @@ import type { TypedArray } from "@workglow/util"; import { ChunkVector } from "../chunk-vector/ChunkVectorSchema"; import type { - AnyChunkVectorRepository, + AnyChunkVectorStorage, VectorSearchOptions, -} from "../chunk-vector/IChunkVectorRepository"; -import type { ITabularRepository } from "../tabular/ITabularRepository"; +} from "../chunk-vector/IChunkVectorStorage"; +import type { ITabularStorage } from "../tabular/ITabularStorage"; import { Document } from "./Document"; import { ChunkNode, DocumentNode } from "./DocumentSchema"; import { @@ -24,12 +24,12 @@ import { * inheritance/interface patterns. */ export class DocumentRepository { - private tabularStorage: ITabularRepository< + private tabularStorage: ITabularStorage< DocumentStorageSchema, DocumentStorageKey, DocumentStorageEntity >; - private vectorStorage?: AnyChunkVectorRepository; + private vectorStorage?: AnyChunkVectorStorage; /** * Creates a new DocumentRepository instance. @@ -39,22 +39,22 @@ export class DocumentRepository { * * @example * ```typescript - * const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["doc_id"]); + * const tabularStorage = new InMemoryTabularStorage(DocumentStorageSchema, ["doc_id"]); * await tabularStorage.setupDatabase(); * - * const vectorStorage = new InMemoryVectorRepository(); + * const vectorStorage = new InMemoryVectorStorage(); * await vectorStorage.setupDatabase(); * * const docRepo = new DocumentRepository(tabularStorage, vectorStorage); * ``` */ constructor( - tabularStorage: ITabularRepository< + tabularStorage: ITabularStorage< typeof DocumentStorageSchema, ["doc_id"], DocumentStorageEntity >, - vectorStorage?: AnyChunkVectorRepository + vectorStorage?: AnyChunkVectorStorage ) { this.tabularStorage = tabularStorage; this.vectorStorage = vectorStorage; diff --git a/packages/storage/src/kv/FsFolderJsonKvRepository.ts b/packages/storage/src/kv/FsFolderJsonKvStorage.ts similarity index 65% rename from packages/storage/src/kv/FsFolderJsonKvRepository.ts rename to packages/storage/src/kv/FsFolderJsonKvStorage.ts index cdb7f5fd..ca085e5a 100644 --- a/packages/storage/src/kv/FsFolderJsonKvRepository.ts +++ b/packages/storage/src/kv/FsFolderJsonKvStorage.ts @@ -5,11 +5,11 @@ */ import { createServiceToken, JsonSchema } from "@workglow/util"; -import { FsFolderTabularRepository } from "../tabular/FsFolderTabularRepository"; -import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvRepository } from "./IKvRepository"; -import { KvViaTabularRepository } from "./KvViaTabularRepository"; +import { FsFolderTabularStorage } from "../tabular/FsFolderTabularStorage"; +import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvStorage } from "./IKvStorage"; +import { KvViaTabularStorage } from "./KvViaTabularStorage"; -export const FS_FOLDER_JSON_KV_REPOSITORY = createServiceToken>( +export const FS_FOLDER_JSON_KV_REPOSITORY = createServiceToken>( "storage.kvRepository.fsFolderJson" ); @@ -21,14 +21,14 @@ export const FS_FOLDER_JSON_KV_REPOSITORY = createServiceToken; /** - * Creates a new KvRepository instance + * Creates a new KvStorage instance */ constructor( public folderPath: string, @@ -36,7 +36,7 @@ export class FsFolderJsonKvRepository extends KvViaTabularRepository { valueSchema: JsonSchema = {} ) { super(keySchema, valueSchema); - this.tabularRepository = new FsFolderTabularRepository( + this.tabularRepository = new FsFolderTabularStorage( folderPath, DefaultKeyValueSchema, DefaultKeyValueKey diff --git a/packages/storage/src/kv/FsFolderKvRepository.ts b/packages/storage/src/kv/FsFolderKvStorage.ts similarity index 94% rename from packages/storage/src/kv/FsFolderKvRepository.ts rename to packages/storage/src/kv/FsFolderKvStorage.ts index f74b7c0a..6c76a705 100644 --- a/packages/storage/src/kv/FsFolderKvRepository.ts +++ b/packages/storage/src/kv/FsFolderKvStorage.ts @@ -7,10 +7,10 @@ import { createServiceToken, JsonSchema } from "@workglow/util"; import { mkdir, readFile, rm, unlink, writeFile } from "fs/promises"; import path from "path"; -import { IKvRepository } from "./IKvRepository"; -import { KvRepository } from "./KvRepository"; +import { IKvStorage } from "./IKvStorage"; +import { KvStorage } from "./KvStorage"; -export const FS_FOLDER_KV_REPOSITORY = createServiceToken>( +export const FS_FOLDER_KV_REPOSITORY = createServiceToken>( "storage.kvRepository.fsFolder" ); @@ -22,13 +22,13 @@ export const FS_FOLDER_KV_REPOSITORY = createServiceToken extends KvRepository { +> extends KvStorage { /** - * Creates a new KvRepository instance + * Creates a new KvStorage instance */ constructor( public folderPath: string, diff --git a/packages/storage/src/kv/IKvRepository.ts b/packages/storage/src/kv/IKvStorage.ts similarity index 96% rename from packages/storage/src/kv/IKvRepository.ts rename to packages/storage/src/kv/IKvStorage.ts index 650c3d4b..486179ae 100644 --- a/packages/storage/src/kv/IKvRepository.ts +++ b/packages/storage/src/kv/IKvStorage.ts @@ -5,7 +5,7 @@ */ import { DataPortSchemaObject, EventParameters } from "@workglow/util"; -import { JSONValue } from "../tabular/ITabularRepository"; +import { JSONValue } from "../tabular/ITabularStorage"; /** * Default schema types for simple string row data @@ -52,7 +52,7 @@ export type KvEventParameters = * @typeParam Value - Type for the value struct re * @typeParam Combined - Combined type of Key & Value */ -export interface IKvRepository< +export interface IKvStorage< Key extends string | number = string, Value extends any = any, Combined = { key: Key; value: Value }, diff --git a/packages/storage/src/kv/InMemoryKvRepository.ts b/packages/storage/src/kv/InMemoryKvStorage.ts similarity index 59% rename from packages/storage/src/kv/InMemoryKvRepository.ts rename to packages/storage/src/kv/InMemoryKvStorage.ts index 248be097..b72c8aba 100644 --- a/packages/storage/src/kv/InMemoryKvRepository.ts +++ b/packages/storage/src/kv/InMemoryKvStorage.ts @@ -5,11 +5,11 @@ */ import { createServiceToken, JsonSchema } from "@workglow/util"; -import { InMemoryTabularRepository } from "../tabular/InMemoryTabularRepository"; -import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvRepository } from "./IKvRepository"; -import { KvViaTabularRepository } from "./KvViaTabularRepository"; +import { InMemoryTabularStorage } from "../tabular/InMemoryTabularStorage"; +import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvStorage } from "./IKvStorage"; +import { KvViaTabularStorage } from "./KvViaTabularStorage"; -export const MEMORY_KV_REPOSITORY = createServiceToken>( +export const MEMORY_KV_REPOSITORY = createServiceToken>( "storage.kvRepository.inMemory" ); @@ -21,18 +21,18 @@ export const MEMORY_KV_REPOSITORY = createServiceToken; /** - * Creates a new KvRepository instance + * Creates a new KvStorage instance */ constructor(keySchema: JsonSchema = { type: "string" }, valueSchema: JsonSchema = {}) { super(keySchema, valueSchema); - this.tabularRepository = new InMemoryTabularRepository( + this.tabularRepository = new InMemoryTabularStorage( DefaultKeyValueSchema, DefaultKeyValueKey ); diff --git a/packages/storage/src/kv/IndexedDbKvRepository.ts b/packages/storage/src/kv/IndexedDbKvStorage.ts similarity index 61% rename from packages/storage/src/kv/IndexedDbKvRepository.ts rename to packages/storage/src/kv/IndexedDbKvStorage.ts index 81d17fa6..4d01c06f 100644 --- a/packages/storage/src/kv/IndexedDbKvRepository.ts +++ b/packages/storage/src/kv/IndexedDbKvStorage.ts @@ -5,11 +5,11 @@ */ import { createServiceToken, JsonSchema } from "@workglow/util"; -import { IndexedDbTabularRepository } from "../tabular/IndexedDbTabularRepository"; -import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvRepository } from "./IKvRepository"; -import { KvViaTabularRepository } from "./KvViaTabularRepository"; +import { IndexedDbTabularStorage } from "../tabular/IndexedDbTabularStorage"; +import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvStorage } from "./IKvStorage"; +import { KvViaTabularStorage } from "./KvViaTabularStorage"; -export const IDB_KV_REPOSITORY = createServiceToken>( +export const IDB_KV_REPOSITORY = createServiceToken>( "storage.kvRepository.indexedDb" ); @@ -21,14 +21,14 @@ export const IDB_KV_REPOSITORY = createServiceToken; /** - * Creates a new KvRepository instance + * Creates a new KvStorage instance */ constructor( public dbName: string, @@ -36,7 +36,7 @@ export class IndexedDbKvRepository extends KvViaTabularRepository { valueSchema: JsonSchema = {} ) { super(keySchema, valueSchema); - this.tabularRepository = new IndexedDbTabularRepository( + this.tabularRepository = new IndexedDbTabularStorage( dbName, DefaultKeyValueSchema, DefaultKeyValueKey diff --git a/packages/storage/src/kv/KvRepository.ts b/packages/storage/src/kv/KvStorage.ts similarity index 93% rename from packages/storage/src/kv/KvRepository.ts rename to packages/storage/src/kv/KvStorage.ts index a58ba457..c25571e7 100644 --- a/packages/storage/src/kv/KvRepository.ts +++ b/packages/storage/src/kv/KvStorage.ts @@ -5,17 +5,17 @@ */ import { createServiceToken, EventEmitter, JsonSchema, makeFingerprint } from "@workglow/util"; -import { JSONValue } from "../tabular/ITabularRepository"; +import { JSONValue } from "../tabular/ITabularStorage"; import { - IKvRepository, + IKvStorage, KvEventListener, KvEventListeners, KvEventName, KvEventParameters, -} from "./IKvRepository"; +} from "./IKvStorage"; export const KV_REPOSITORY = - createServiceToken>("storage.kvRepository"); + createServiceToken>("storage.kvRepository"); /** * Abstract base class for key-value storage repositories. @@ -25,17 +25,17 @@ export const KV_REPOSITORY = * @template Value - The type of the value being stored * @template Combined - Combined type of Key & Value */ -export abstract class KvRepository< +export abstract class KvStorage< Key extends string = string, Value extends any = any, Combined = { key: Key; value: Value }, -> implements IKvRepository +> implements IKvStorage { /** Event emitter for repository events */ protected events = new EventEmitter>(); /** - * Creates a new KvRepository instance + * Creates a new KvStorage instance */ constructor( public keySchema: JsonSchema = { type: "string" }, diff --git a/packages/storage/src/kv/KvViaTabularRepository.ts b/packages/storage/src/kv/KvViaTabularStorage.ts similarity index 94% rename from packages/storage/src/kv/KvViaTabularRepository.ts rename to packages/storage/src/kv/KvViaTabularStorage.ts index bc854394..6b84496d 100644 --- a/packages/storage/src/kv/KvViaTabularRepository.ts +++ b/packages/storage/src/kv/KvViaTabularStorage.ts @@ -4,9 +4,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { BaseTabularRepository } from "../tabular/BaseTabularRepository"; -import { DefaultKeyValueKey, DefaultKeyValueSchema } from "./IKvRepository"; -import { KvRepository } from "./KvRepository"; +import type { BaseTabularStorage } from "../tabular/BaseTabularStorage"; +import { DefaultKeyValueKey, DefaultKeyValueSchema } from "./IKvStorage"; +import { KvStorage } from "./KvStorage"; /** * Abstract base class for key-value storage repositories that uses a tabular repository for storage. @@ -16,12 +16,12 @@ import { KvRepository } from "./KvRepository"; * @template Value - The type of the value being stored * @template Combined - Combined type of Key & Value */ -export abstract class KvViaTabularRepository< +export abstract class KvViaTabularStorage< Key extends string = string, Value extends any = any, Combined = { key: Key; value: Value }, -> extends KvRepository { - public abstract tabularRepository: BaseTabularRepository< +> extends KvStorage { + public abstract tabularRepository: BaseTabularStorage< typeof DefaultKeyValueSchema, typeof DefaultKeyValueKey >; diff --git a/packages/storage/src/kv/PostgresKvRepository.ts b/packages/storage/src/kv/PostgresKvStorage.ts similarity index 61% rename from packages/storage/src/kv/PostgresKvRepository.ts rename to packages/storage/src/kv/PostgresKvStorage.ts index c37ae29d..5c6c6083 100644 --- a/packages/storage/src/kv/PostgresKvRepository.ts +++ b/packages/storage/src/kv/PostgresKvStorage.ts @@ -5,11 +5,11 @@ */ import { createServiceToken, JsonSchema } from "@workglow/util"; -import { PostgresTabularRepository } from "../tabular/PostgresTabularRepository"; -import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvRepository } from "./IKvRepository"; -import { KvViaTabularRepository } from "./KvViaTabularRepository"; +import { PostgresTabularStorage } from "../tabular/PostgresTabularStorage"; +import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvStorage } from "./IKvStorage"; +import { KvViaTabularStorage } from "./KvViaTabularStorage"; -export const POSTGRES_KV_REPOSITORY = createServiceToken>( +export const POSTGRES_KV_REPOSITORY = createServiceToken>( "storage.kvRepository.postgres" ); @@ -21,14 +21,14 @@ export const POSTGRES_KV_REPOSITORY = createServiceToken; /** - * Creates a new KvRepository instance + * Creates a new KvStorage instance */ constructor( public db: any, @@ -37,7 +37,7 @@ export class PostgresKvRepository extends KvViaTabularRepository { valueSchema: JsonSchema = {} ) { super(keySchema, valueSchema); - this.tabularRepository = new PostgresTabularRepository( + this.tabularRepository = new PostgresTabularStorage( db, dbName, DefaultKeyValueSchema, diff --git a/packages/storage/src/kv/README.md b/packages/storage/src/kv/README.md index ce7a299f..1345c547 100644 --- a/packages/storage/src/kv/README.md +++ b/packages/storage/src/kv/README.md @@ -23,7 +23,7 @@ A flexible key-value storage solution with multiple backend implementations. Pro - 💾 `IndexedDbKvRepository` - Browser IndexedDB storage - 🐘 `PostgresKvRepository` - PostgreSQL database storage - 📁 `SqliteKvRepository` - SQLite database storage - - 🧠 `InMemoryKvRepository` - Volatile memory storage + - 🧠 `InMemoryKvStorage` - Volatile memory storage - Type-safe key/value definitions - JSON value serialization support - Event emitter for storage operations (put/get/delete) @@ -104,9 +104,9 @@ await sqliteRepo.put("temp:789", "cached_value"); ### In-Memory Storage ```typescript -import { InMemoryKvRepository } from "@workglow/storage/kv"; +import { InMemoryKvStorage } from "@workglow/storage/kv"; -const memRepo = new InMemoryKvRepository( +const memRepo = new InMemoryKvStorage( "string", // Key type "json" // Value type ); diff --git a/packages/storage/src/kv/SqliteKvRepository.ts b/packages/storage/src/kv/SqliteKvStorage.ts similarity index 62% rename from packages/storage/src/kv/SqliteKvRepository.ts rename to packages/storage/src/kv/SqliteKvStorage.ts index ee46710b..79fedd69 100644 --- a/packages/storage/src/kv/SqliteKvRepository.ts +++ b/packages/storage/src/kv/SqliteKvStorage.ts @@ -5,11 +5,11 @@ */ import { createServiceToken, JsonSchema } from "@workglow/util"; -import { SqliteTabularRepository } from "../tabular/SqliteTabularRepository"; -import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvRepository } from "./IKvRepository"; -import { KvViaTabularRepository } from "./KvViaTabularRepository"; +import { SqliteTabularStorage } from "../tabular/SqliteTabularStorage"; +import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvStorage } from "./IKvStorage"; +import { KvViaTabularStorage } from "./KvViaTabularStorage"; -export const SQLITE_KV_REPOSITORY = createServiceToken>( +export const SQLITE_KV_REPOSITORY = createServiceToken>( "storage.kvRepository.sqlite" ); @@ -21,14 +21,14 @@ export const SQLITE_KV_REPOSITORY = createServiceToken; /** - * Creates a new KvRepository instance + * Creates a new KvStorage instance */ constructor( public db: any, @@ -37,7 +37,7 @@ export class SqliteKvRepository extends KvViaTabularRepository { valueSchema: JsonSchema = {} ) { super(keySchema, valueSchema); - this.tabularRepository = new SqliteTabularRepository( + this.tabularRepository = new SqliteTabularStorage( db, dbName, DefaultKeyValueSchema, diff --git a/packages/storage/src/kv/SupabaseKvRepository.ts b/packages/storage/src/kv/SupabaseKvStorage.ts similarity index 66% rename from packages/storage/src/kv/SupabaseKvRepository.ts rename to packages/storage/src/kv/SupabaseKvStorage.ts index 2edf0701..ed131376 100644 --- a/packages/storage/src/kv/SupabaseKvRepository.ts +++ b/packages/storage/src/kv/SupabaseKvStorage.ts @@ -6,11 +6,11 @@ import type { SupabaseClient } from "@supabase/supabase-js"; import { createServiceToken, JsonSchema } from "@workglow/util"; -import { SupabaseTabularRepository } from "../tabular/SupabaseTabularRepository"; -import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvRepository } from "./IKvRepository"; -import { KvViaTabularRepository } from "./KvViaTabularRepository"; +import { SupabaseTabularStorage } from "../tabular/SupabaseTabularStorage"; +import { DefaultKeyValueKey, DefaultKeyValueSchema, IKvStorage } from "./IKvStorage"; +import { KvViaTabularStorage } from "./KvViaTabularStorage"; -export const SUPABASE_KV_REPOSITORY = createServiceToken>( +export const SUPABASE_KV_REPOSITORY = createServiceToken>( "storage.kvRepository.supabase" ); @@ -22,14 +22,14 @@ export const SUPABASE_KV_REPOSITORY = createServiceToken; /** - * Creates a new SupabaseKvRepository instance + * Creates a new SupabaseKvStorage instance * * @param client - Supabase client instance * @param tableName - Name of the table to store data @@ -41,7 +41,7 @@ export class SupabaseKvRepository extends KvViaTabularRepository { public tableName: string, keySchema: JsonSchema = { type: "string" }, valueSchema: JsonSchema = {}, - tabularRepository?: SupabaseTabularRepository< + tabularRepository?: SupabaseTabularStorage< typeof DefaultKeyValueSchema, typeof DefaultKeyValueKey > @@ -49,6 +49,6 @@ export class SupabaseKvRepository extends KvViaTabularRepository { super(keySchema, valueSchema); this.tabularRepository = tabularRepository ?? - new SupabaseTabularRepository(client, tableName, DefaultKeyValueSchema, DefaultKeyValueKey); + new SupabaseTabularStorage(client, tableName, DefaultKeyValueSchema, DefaultKeyValueKey); } } diff --git a/packages/storage/src/tabular/BaseSqlTabularRepository.ts b/packages/storage/src/tabular/BaseSqlTabularStorage.ts similarity index 96% rename from packages/storage/src/tabular/BaseSqlTabularRepository.ts rename to packages/storage/src/tabular/BaseSqlTabularStorage.ts index 4d2ccd8c..1f21bfdd 100644 --- a/packages/storage/src/tabular/BaseSqlTabularRepository.ts +++ b/packages/storage/src/tabular/BaseSqlTabularStorage.ts @@ -10,10 +10,10 @@ import { JsonSchema, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseTabularRepository } from "./BaseTabularRepository"; -import { SimplifyPrimaryKey, ValueOptionType } from "./ITabularRepository"; +import { BaseTabularStorage } from "./BaseTabularStorage"; +import { SimplifyPrimaryKey, ValueOptionType } from "./ITabularStorage"; -// BaseTabularRepository is a tabular store that uses SQLite and Postgres use as common code +// BaseTabularStorage is a tabular store that uses SQLite and Postgres use as common code /** * Base class for SQL-based tabular repositories that implements common functionality @@ -22,16 +22,16 @@ import { SimplifyPrimaryKey, ValueOptionType } from "./ITabularRepository"; * @template Schema - The schema definition for the entity using JSON Schema * @template PrimaryKeyNames - Array of property names that form the primary key */ -export abstract class BaseSqlTabularRepository< +export abstract class BaseSqlTabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, Value = Omit, -> extends BaseTabularRepository { +> extends BaseTabularStorage { /** - * Creates a new instance of BaseSqlTabularRepository + * Creates a new instance of BaseSqlTabularStorage * @param table - The name of the database table to use for storage * @param schema - Schema defining the structure of the entity * @param primaryKeyNames - Array of property names that form the primary key diff --git a/packages/storage/src/tabular/BaseTabularRepository.ts b/packages/storage/src/tabular/BaseTabularStorage.ts similarity index 97% rename from packages/storage/src/tabular/BaseTabularRepository.ts rename to packages/storage/src/tabular/BaseTabularStorage.ts index aea5dfb9..16beec5b 100644 --- a/packages/storage/src/tabular/BaseTabularRepository.ts +++ b/packages/storage/src/tabular/BaseTabularStorage.ts @@ -13,9 +13,9 @@ import { TypedArraySchemaOptions, } from "@workglow/util"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, - ITabularRepository, + ITabularStorage, SimplifyPrimaryKey, TabularChangePayload, TabularEventListener, @@ -24,9 +24,9 @@ import { TabularEventParameters, TabularSubscribeOptions, ValueOptionType, -} from "./ITabularRepository"; +} from "./ITabularStorage"; -export const TABULAR_REPOSITORY = createServiceToken( +export const TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository" ); @@ -39,14 +39,14 @@ export const TABULAR_REPOSITORY = createServiceToken( * @template Schema - The schema definition for the entity using JSON Schema * @template PrimaryKeyNames - Array of property names that form the primary key */ -export abstract class BaseTabularRepository< +export abstract class BaseTabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, Value = Omit, -> implements ITabularRepository { +> implements ITabularStorage { /** Event emitter for repository events */ protected events = new EventEmitter>(); @@ -55,7 +55,7 @@ export abstract class BaseTabularRepository< protected valueSchema: DataPortSchemaObject; /** - * Creates a new BaseTabularRepository instance + * Creates a new BaseTabularStorage instance * @param schema - Schema defining the structure of the entity * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, diff --git a/packages/storage/src/tabular/CachedTabularRepository.ts b/packages/storage/src/tabular/CachedTabularStorage.ts similarity index 89% rename from packages/storage/src/tabular/CachedTabularRepository.ts rename to packages/storage/src/tabular/CachedTabularStorage.ts index db0d8e23..199a516e 100644 --- a/packages/storage/src/tabular/CachedTabularRepository.ts +++ b/packages/storage/src/tabular/CachedTabularStorage.ts @@ -10,52 +10,52 @@ import { FromSchema, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseTabularRepository } from "./BaseTabularRepository"; -import { InMemoryTabularRepository } from "./InMemoryTabularRepository"; +import { BaseTabularStorage } from "./BaseTabularStorage"; +import { InMemoryTabularStorage } from "./InMemoryTabularStorage"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, - ITabularRepository, + ITabularStorage, SimplifyPrimaryKey, TabularSubscribeOptions, -} from "./ITabularRepository"; +} from "./ITabularStorage"; -export const CACHED_TABULAR_REPOSITORY = createServiceToken( +export const CACHED_TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository.cached" ); /** * A tabular repository wrapper that adds caching layer to a durable repository. - * Uses InMemoryTabularRepository or SharedInMemoryTabularRepository as a cache + * Uses InMemoryTabularStorage or SharedInMemoryTabularStorage as a cache * for faster access to frequently used data. * * @template Schema - The schema definition for the entity using JSON Schema * @template PrimaryKeyNames - Array of property names that form the primary key */ -export class CachedTabularRepository< +export class CachedTabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularRepository { - public readonly cache: ITabularRepository; - private durable: ITabularRepository; +> extends BaseTabularStorage { + public readonly cache: ITabularStorage; + private durable: ITabularStorage; private cacheInitialized = false; /** - * Creates a new CachedTabularRepository instance + * Creates a new CachedTabularStorage instance * @param durable - The durable repository to use as the source of truth - * @param cache - Optional cache repository (InMemoryTabularRepository or SharedInMemoryTabularRepository). - * If not provided, a new InMemoryTabularRepository will be created. + * @param cache - Optional cache repository (InMemoryTabularStorage or SharedInMemoryTabularStorage). + * If not provided, a new InMemoryTabularStorage will be created. * @param schema - Schema defining the structure of the entity * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. */ constructor( - durable: ITabularRepository, - cache?: ITabularRepository, + durable: ITabularStorage, + cache?: ITabularStorage, schema?: Schema, primaryKeyNames?: PrimaryKeyNames, indexes?: readonly (keyof Entity | readonly (keyof Entity)[])[] @@ -65,7 +65,7 @@ export class CachedTabularRepository< // So we require them to be provided or assume they match if (!schema || !primaryKeyNames) { throw new Error( - "Schema and primaryKeyNames must be provided when creating CachedTabularRepository" + "Schema and primaryKeyNames must be provided when creating CachedTabularStorage" ); } @@ -76,7 +76,7 @@ export class CachedTabularRepository< if (cache) { this.cache = cache; } else { - this.cache = new InMemoryTabularRepository( + this.cache = new InMemoryTabularStorage( schema, primaryKeyNames, indexes || [] diff --git a/packages/storage/src/tabular/FsFolderTabularRepository.ts b/packages/storage/src/tabular/FsFolderTabularStorage.ts similarity index 96% rename from packages/storage/src/tabular/FsFolderTabularRepository.ts rename to packages/storage/src/tabular/FsFolderTabularStorage.ts index 9a4fe9b3..f3b531f9 100644 --- a/packages/storage/src/tabular/FsFolderTabularRepository.ts +++ b/packages/storage/src/tabular/FsFolderTabularStorage.ts @@ -15,16 +15,16 @@ import { import { mkdir, readdir, readFile, rm, writeFile } from "node:fs/promises"; import path from "node:path"; import { PollingSubscriptionManager } from "../util/PollingSubscriptionManager"; -import { BaseTabularRepository } from "./BaseTabularRepository"; +import { BaseTabularStorage } from "./BaseTabularStorage"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, -} from "./ITabularRepository"; +} from "./ITabularStorage"; -export const FS_FOLDER_TABULAR_REPOSITORY = createServiceToken( +export const FS_FOLDER_TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository.fsFolder" ); @@ -35,13 +35,13 @@ export const FS_FOLDER_TABULAR_REPOSITORY = createServiceToken, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularRepository { +> extends BaseTabularStorage { private folderPath: string; /** Shared polling subscription manager */ private pollingManager: PollingSubscriptionManager< @@ -51,7 +51,7 @@ export class FsFolderTabularRepository< > | null = null; /** - * Creates a new FsFolderTabularRepository instance. + * Creates a new FsFolderTabularStorage instance. * * @param folderPath - The directory path where the JSON files will be stored * @param schema - Schema defining the structure of the entity @@ -218,7 +218,7 @@ export class FsFolderTabularRepository< * @throws {Error} Always throws an error indicating search is not supported */ async search(key: Partial): Promise { - throw new Error("Search not supported for FsFolderTabularRepository"); + throw new Error("Search not supported for FsFolderTabularStorage"); } /** @@ -240,7 +240,7 @@ export class FsFolderTabularRepository< * @throws Error always - deleteSearch is not supported for filesystem storage */ async deleteSearch(_criteria: DeleteSearchCriteria): Promise { - throw new Error("deleteSearch is not supported for FsFolderTabularRepository"); + throw new Error("deleteSearch is not supported for FsFolderTabularStorage"); } /** diff --git a/packages/storage/src/tabular/ITabularRepository.ts b/packages/storage/src/tabular/ITabularStorage.ts similarity index 98% rename from packages/storage/src/tabular/ITabularRepository.ts rename to packages/storage/src/tabular/ITabularStorage.ts index 1064ed7c..126df293 100644 --- a/packages/storage/src/tabular/ITabularRepository.ts +++ b/packages/storage/src/tabular/ITabularStorage.ts @@ -132,7 +132,7 @@ export type SimplifyPrimaryKey< * @typeParam Schema - The schema definition for the entity using JSON Schema * @typeParam PrimaryKeyNames - Array of property names that form the primary key */ -export interface ITabularRepository< +export interface ITabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types @@ -212,4 +212,4 @@ export interface ITabularRepository< [Symbol.asyncDispose](): Promise; } -export type AnyTabularRepository = ITabularRepository; +export type AnyTabularStorage = ITabularStorage; diff --git a/packages/storage/src/tabular/InMemoryTabularRepository.ts b/packages/storage/src/tabular/InMemoryTabularStorage.ts similarity index 96% rename from packages/storage/src/tabular/InMemoryTabularRepository.ts rename to packages/storage/src/tabular/InMemoryTabularStorage.ts index dab87deb..b4057017 100644 --- a/packages/storage/src/tabular/InMemoryTabularRepository.ts +++ b/packages/storage/src/tabular/InMemoryTabularStorage.ts @@ -11,17 +11,17 @@ import { makeFingerprint, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseTabularRepository } from "./BaseTabularRepository"; +import { BaseTabularStorage } from "./BaseTabularStorage"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, isSearchCondition, SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, -} from "./ITabularRepository"; +} from "./ITabularStorage"; -export const MEMORY_TABULAR_REPOSITORY = createServiceToken( +export const MEMORY_TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository.inMemory" ); @@ -32,18 +32,18 @@ export const MEMORY_TABULAR_REPOSITORY = createServiceToken, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularRepository { +> extends BaseTabularStorage { /** Internal storage using a Map with fingerprint strings as keys */ values = new Map(); /** - * Creates a new InMemoryTabularRepository instance + * Creates a new InMemoryTabularStorage instance * @param schema - Schema defining the structure of the entity * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, diff --git a/packages/storage/src/tabular/IndexedDbTabularRepository.ts b/packages/storage/src/tabular/IndexedDbTabularStorage.ts similarity index 98% rename from packages/storage/src/tabular/IndexedDbTabularRepository.ts rename to packages/storage/src/tabular/IndexedDbTabularStorage.ts index 50bc483a..428791ba 100644 --- a/packages/storage/src/tabular/IndexedDbTabularRepository.ts +++ b/packages/storage/src/tabular/IndexedDbTabularStorage.ts @@ -17,18 +17,18 @@ import { ExpectedIndexDefinition, MigrationOptions, } from "../util/IndexedDbTable"; -import { BaseTabularRepository } from "./BaseTabularRepository"; +import { BaseTabularStorage } from "./BaseTabularStorage"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, isSearchCondition, SearchOperator, SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, -} from "./ITabularRepository"; +} from "./ITabularStorage"; -export const IDB_TABULAR_REPOSITORY = createServiceToken( +export const IDB_TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository.indexedDb" ); @@ -38,13 +38,13 @@ export const IDB_TABULAR_REPOSITORY = createServiceToken( * @template Schema - The schema definition for the entity * @template PrimaryKeyNames - Array of property names that form the primary key */ -export class IndexedDbTabularRepository< +export class IndexedDbTabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularRepository { +> extends BaseTabularStorage { /** Promise that resolves to the IndexedDB database instance */ private db: IDBDatabase | undefined; /** Promise to track ongoing database setup to prevent concurrent setup calls */ diff --git a/packages/storage/src/tabular/PostgresTabularRepository.ts b/packages/storage/src/tabular/PostgresTabularStorage.ts similarity index 98% rename from packages/storage/src/tabular/PostgresTabularRepository.ts rename to packages/storage/src/tabular/PostgresTabularStorage.ts index 1054229e..7ecfa0a2 100644 --- a/packages/storage/src/tabular/PostgresTabularRepository.ts +++ b/packages/storage/src/tabular/PostgresTabularStorage.ts @@ -13,9 +13,9 @@ import { TypedArraySchemaOptions, } from "@workglow/util"; import type { Pool } from "pg"; -import { BaseSqlTabularRepository } from "./BaseSqlTabularRepository"; +import { BaseSqlTabularStorage } from "./BaseSqlTabularStorage"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, isSearchCondition, SearchOperator, @@ -23,31 +23,31 @@ import { TabularChangePayload, TabularSubscribeOptions, ValueOptionType, -} from "./ITabularRepository"; +} from "./ITabularStorage"; -export const POSTGRES_TABULAR_REPOSITORY = createServiceToken( +export const POSTGRES_TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository.postgres" ); /** - * A PostgreSQL-based tabular repository implementation that extends BaseSqlTabularRepository. + * A PostgreSQL-based tabular repository implementation that extends BaseSqlTabularStorage. * This class provides persistent storage for data in a PostgreSQL database, * making it suitable for multi-user scenarios. * * @template Schema - The schema definition for the entity * @template PrimaryKeyNames - Array of property names that form the primary key */ -export class PostgresTabularRepository< +export class PostgresTabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseSqlTabularRepository { +> extends BaseSqlTabularStorage { protected db: Pool; /** - * Creates a new PostgresTabularRepository instance. + * Creates a new PostgresTabularStorage instance. * * @param db - PostgreSQL db * @param table - Name of the table to store data (defaults to "tabular_store") @@ -817,7 +817,7 @@ export class PostgresTabularRepository< callback: (change: TabularChangePayload) => void, options?: TabularSubscribeOptions ): () => void { - throw new Error("subscribeToChanges is not supported for PostgresTabularRepository"); + throw new Error("subscribeToChanges is not supported for PostgresTabularStorage"); } /** diff --git a/packages/storage/src/tabular/README.md b/packages/storage/src/tabular/README.md index df169cda..75fdb385 100644 --- a/packages/storage/src/tabular/README.md +++ b/packages/storage/src/tabular/README.md @@ -9,9 +9,9 @@ A collection of storage implementations for tabular data with multiple backend s - [Using TypeBox](#using-typebox) - [Using Zod 4](#using-zod-4) - [Implementations](#implementations) - - [InMemoryTabularRepository](#inmemorytabularrepository) - - [SqliteTabularRepository](#sqlitetabularrepository) - - [PostgresTabularRepository](#postgrestabularrepository) + - [InMemoryTabularStorage](#inmemorytabularrepository) + - [SqliteTabularStorage](#sqlitetabularrepository) + - [PostgresTabularStorage](#postgrestabularrepository) - [IndexedDbTabularRepository](#indexeddbtabularrepository) - [FsFolderTabularRepository](#fsfoldertabularrepository) - [Events](#events) @@ -43,7 +43,7 @@ npm install @workglow/storage ## Basic Usage ```typescript -import { InMemoryTabularRepository } from "@workglow/storage/tabular"; +import { InMemoryTabularStorage } from "@workglow/storage/tabular"; // Define schema and primary keys const schema = { @@ -55,8 +55,8 @@ const schema = { const primaryKeys = ["id"] as const; // Create repository instance (when using const schemas, the next three generics -// on InMemoryTabularRepository are automatically created for you) -const repo = new InMemoryTabularRepository(schema, primaryKeys); +// on InMemoryTabularStorage are automatically created for you) +const repo = new InMemoryTabularStorage(schema, primaryKeys); // Basic operations await repo.put({ id: "1", name: "Alice", age: 30, active: true }); @@ -75,7 +75,7 @@ You can define schemas using plain JSON Schema objects, or use schema libraries TypeBox schemas are JSON Schema compatible and can be used directly: ```typescript -import { InMemoryTabularRepository } from "@workglow/storage/tabular"; +import { InMemoryTabularStorage } from "@workglow/storage/tabular"; import { Type, Static } from "@sinclair/typebox"; import { DataPortSchemaObject, FromSchema } from "@workglow/util"; @@ -98,7 +98,7 @@ type UserEntity = FromSchema; // IMPORTANT: You must explicitly provide generic type parameters for t // TypeScript cannot infer them from TypeBox schemas -const repo = new InMemoryTabularRepository( +const repo = new InMemoryTabularStorage( userSchema, primaryKeys, ["email", "active"] as const // Indexes @@ -119,7 +119,7 @@ await repo.put({ Zod 4 has built-in JSON Schema support using the `.toJSONSchema()` method: ```typescript -import { InMemoryTabularRepository } from "@workglow/storage/tabular"; +import { InMemoryTabularStorage } from "@workglow/storage/tabular"; import { z } from "zod"; import { DataPortSchemaObject } from "@workglow/util"; @@ -141,7 +141,7 @@ type UserEntity = z.infer; // IMPORTANT: You must explicitly provide generic type parameters // TypeScript cannot infer them from Zod schemas (even after conversion) -const repo = new InMemoryTabularRepository( +const repo = new InMemoryTabularStorage( userSchema, primaryKeys, ["email", "active"] as const // Indexes @@ -159,14 +159,14 @@ await repo.put({ ## Implementations -### InMemoryTabularRepository +### InMemoryTabularStorage - Ideal for testing/development - No persistence - Fast search capabilities ```typescript -const repo = new InMemoryTabularRepository< +const repo = new InMemoryTabularStorage< typeof schema, typeof primaryKeys, Entity, // required if using TypeBox, Zod, etc, otherwise automatically created @@ -175,13 +175,13 @@ const repo = new InMemoryTabularRepository< >(schema, primaryKeys, ["name", "active"]); ``` -### SqliteTabularRepository +### SqliteTabularStorage - Embedded SQLite database - File-based or in-memory ```typescript -const repo = new SqliteTabularRepository< +const repo = new SqliteTabularStorage< typeof schema, typeof primaryKeys, Entity, // required if using TypeBox, Zod, etc, otherwise automatically created @@ -196,7 +196,7 @@ const repo = new SqliteTabularRepository< ); ``` -### PostgresTabularRepository +### PostgresTabularStorage - PostgreSQL backend - Connection pooling support @@ -207,7 +207,7 @@ import { Pool } from "pg"; const pool = new Pool({ /* config */ }); -const repo = new PostgresTabularRepository< +const repo = new PostgresTabularStorage< typeof schema, typeof primaryKeys, Entity, // required if using TypeBox, Zod, etc, otherwise automatically created diff --git a/packages/storage/src/tabular/SharedInMemoryTabularRepository.ts b/packages/storage/src/tabular/SharedInMemoryTabularStorage.ts similarity index 92% rename from packages/storage/src/tabular/SharedInMemoryTabularRepository.ts rename to packages/storage/src/tabular/SharedInMemoryTabularStorage.ts index 41058df6..84217d7d 100644 --- a/packages/storage/src/tabular/SharedInMemoryTabularRepository.ts +++ b/packages/storage/src/tabular/SharedInMemoryTabularStorage.ts @@ -10,16 +10,16 @@ import { FromSchema, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseTabularRepository } from "./BaseTabularRepository"; +import { BaseTabularStorage } from "./BaseTabularStorage"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, SimplifyPrimaryKey, TabularSubscribeOptions, -} from "./ITabularRepository"; -import { InMemoryTabularRepository } from "./InMemoryTabularRepository"; +} from "./ITabularStorage"; +import { InMemoryTabularStorage } from "./InMemoryTabularStorage"; -export const SHARED_IN_MEMORY_TABULAR_REPOSITORY = createServiceToken( +export const SHARED_IN_MEMORY_TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository.sharedInMemory" ); @@ -37,27 +37,27 @@ type BroadcastMessage = /** * A tabular repository implementation that shares data across browser tabs/windows - * using BroadcastChannel API. Uses InMemoryTabularRepository internally and + * using BroadcastChannel API. Uses InMemoryTabularStorage internally and * synchronizes changes across all instances. * * @template Schema - The schema definition for the entity using JSON Schema * @template PrimaryKeyNames - Array of property names that form the primary key */ -export class SharedInMemoryTabularRepository< +export class SharedInMemoryTabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularRepository { +> extends BaseTabularStorage { private channel: BroadcastChannel | null = null; private channelName: string; - private inMemoryRepo: InMemoryTabularRepository; + private inMemoryRepo: InMemoryTabularStorage; private isInitialized = false; private syncInProgress = false; /** - * Creates a new SharedInMemoryTabularRepository instance + * Creates a new SharedInMemoryTabularStorage instance * @param channelName - Unique name for the BroadcastChannel (defaults to "tabular_store") * @param schema - Schema defining the structure of the entity * @param primaryKeyNames - Array of property names that form the primary key @@ -72,7 +72,7 @@ export class SharedInMemoryTabularRepository< ) { super(schema, primaryKeyNames, indexes); this.channelName = channelName; - this.inMemoryRepo = new InMemoryTabularRepository( + this.inMemoryRepo = new InMemoryTabularStorage( schema, primaryKeyNames, indexes @@ -115,7 +115,7 @@ export class SharedInMemoryTabularRepository< } /** - * Sets up event forwarding from the internal InMemoryTabularRepository + * Sets up event forwarding from the internal InMemoryTabularStorage */ private setupEventForwarding(): void { this.inMemoryRepo.on("put", (entity) => { @@ -333,7 +333,7 @@ export class SharedInMemoryTabularRepository< /** * Subscribes to changes in the repository. - * Delegates to the internal InMemoryTabularRepository which monitors local changes. + * Delegates to the internal InMemoryTabularStorage which monitors local changes. * Changes from other tabs/windows are already propagated via BroadcastChannel. * * @param callback - Function called when a change occurs diff --git a/packages/storage/src/tabular/SqliteTabularRepository.ts b/packages/storage/src/tabular/SqliteTabularStorage.ts similarity index 98% rename from packages/storage/src/tabular/SqliteTabularRepository.ts rename to packages/storage/src/tabular/SqliteTabularStorage.ts index 331c71e8..4b53d0b6 100644 --- a/packages/storage/src/tabular/SqliteTabularRepository.ts +++ b/packages/storage/src/tabular/SqliteTabularStorage.ts @@ -12,9 +12,9 @@ import { JsonSchema, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseSqlTabularRepository } from "./BaseSqlTabularRepository"; +import { BaseSqlTabularStorage } from "./BaseSqlTabularStorage"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, isSearchCondition, SearchOperator, @@ -22,18 +22,18 @@ import { TabularChangePayload, TabularSubscribeOptions, ValueOptionType, -} from "./ITabularRepository"; +} from "./ITabularStorage"; // Define local type for SQL operations type ExcludeDateKeyOptionType = Exclude; -export const SQLITE_TABULAR_REPOSITORY = createServiceToken( +export const SQLITE_TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository.sqlite" ); const Database = Sqlite.Database; -// SqliteTabularRepository is a key-value store that uses SQLite as the backend for +// SqliteTabularStorage is a key-value store that uses SQLite as the backend for // in app data. /** @@ -41,13 +41,13 @@ const Database = Sqlite.Database; * @template Schema - The schema definition for the entity * @template PrimaryKeyNames - Array of property names that form the primary key */ -export class SqliteTabularRepository< +export class SqliteTabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseSqlTabularRepository { +> extends BaseSqlTabularStorage { /** The SQLite database instance */ private db: Sqlite.Database; @@ -738,7 +738,7 @@ export class SqliteTabularRepository< callback: (change: TabularChangePayload) => void, options?: TabularSubscribeOptions ): () => void { - throw new Error("subscribeToChanges is not supported for SqliteTabularRepository"); + throw new Error("subscribeToChanges is not supported for SqliteTabularStorage"); } /** diff --git a/packages/storage/src/tabular/SupabaseTabularRepository.ts b/packages/storage/src/tabular/SupabaseTabularStorage.ts similarity index 98% rename from packages/storage/src/tabular/SupabaseTabularRepository.ts rename to packages/storage/src/tabular/SupabaseTabularStorage.ts index f4d433dd..95bb33d2 100644 --- a/packages/storage/src/tabular/SupabaseTabularRepository.ts +++ b/packages/storage/src/tabular/SupabaseTabularStorage.ts @@ -12,9 +12,9 @@ import { JsonSchema, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseSqlTabularRepository } from "./BaseSqlTabularRepository"; +import { BaseSqlTabularStorage } from "./BaseSqlTabularStorage"; import { - AnyTabularRepository, + AnyTabularStorage, DeleteSearchCriteria, isSearchCondition, SearchOperator, @@ -23,32 +23,32 @@ import { TabularChangeType, TabularSubscribeOptions, ValueOptionType, -} from "./ITabularRepository"; +} from "./ITabularStorage"; -export const SUPABASE_TABULAR_REPOSITORY = createServiceToken( +export const SUPABASE_TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository.supabase" ); /** - * A Supabase-based tabular repository implementation that extends BaseSqlTabularRepository. + * A Supabase-based tabular repository implementation that extends BaseSqlTabularStorage. * This class provides persistent storage for data in a Supabase database, * making it suitable for multi-user scenarios. * * @template Schema - The schema definition for the entity * @template PrimaryKeyNames - Array of property names that form the primary key */ -export class SupabaseTabularRepository< +export class SupabaseTabularStorage< Schema extends DataPortSchemaObject, PrimaryKeyNames extends ReadonlyArray, // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseSqlTabularRepository { +> extends BaseSqlTabularStorage { private client: SupabaseClient; private realtimeChannel: RealtimeChannel | null = null; /** - * Creates a new SupabaseTabularRepository instance. + * Creates a new SupabaseTabularStorage instance. * * @param client - Supabase client instance * @param table - Name of the table to store data (defaults to "tabular_store") diff --git a/packages/storage/src/tabular/TabularRepositoryRegistry.ts b/packages/storage/src/tabular/TabularStorageRegistry.ts similarity index 84% rename from packages/storage/src/tabular/TabularRepositoryRegistry.ts rename to packages/storage/src/tabular/TabularStorageRegistry.ts index 96db1972..c00ea82d 100644 --- a/packages/storage/src/tabular/TabularRepositoryRegistry.ts +++ b/packages/storage/src/tabular/TabularStorageRegistry.ts @@ -10,13 +10,13 @@ import { registerInputResolver, ServiceRegistry, } from "@workglow/util"; -import { AnyTabularRepository } from "./ITabularRepository"; +import { AnyTabularStorage } from "./ITabularStorage"; /** * Service token for the tabular repository registry - * Maps repository IDs to ITabularRepository instances + * Maps repository IDs to ITabularStorage instances */ -export const TABULAR_REPOSITORIES = createServiceToken>( +export const TABULAR_REPOSITORIES = createServiceToken>( "storage.tabular.repositories" ); @@ -24,7 +24,7 @@ export const TABULAR_REPOSITORIES = createServiceToken => new Map(), + (): Map => new Map(), true ); } @@ -33,7 +33,7 @@ if (!globalServiceRegistry.has(TABULAR_REPOSITORIES)) { * Gets the global tabular repository registry * @returns Map of tabular repository ID to instance */ -export function getGlobalTabularRepositories(): Map { +export function getGlobalTabularRepositories(): Map { return globalServiceRegistry.get(TABULAR_REPOSITORIES); } @@ -42,7 +42,7 @@ export function getGlobalTabularRepositories(): Map; diff --git a/packages/task-graph/src/storage/TaskOutputTabularRepository.ts b/packages/task-graph/src/storage/TaskOutputTabularRepository.ts index cf7a5789..7d109061 100644 --- a/packages/task-graph/src/storage/TaskOutputTabularRepository.ts +++ b/packages/task-graph/src/storage/TaskOutputTabularRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type BaseTabularRepository } from "@workglow/storage"; +import { type BaseTabularStorage } from "@workglow/storage"; import { compress, DataPortSchemaObject, decompress, makeFingerprint } from "@workglow/util"; import { TaskInput, TaskOutput } from "../task/TaskTypes"; import { TaskOutputRepository } from "./TaskOutputRepository"; @@ -27,7 +27,7 @@ export const TaskOutputSchema = { export const TaskOutputPrimaryKeyNames = ["key", "taskType"] as const; -export type TaskOutputRepositoryStorage = BaseTabularRepository< +export type TaskOutputRepositoryStorage = BaseTabularStorage< typeof TaskOutputSchema, typeof TaskOutputPrimaryKeyNames >; diff --git a/packages/task-graph/src/task/README.md b/packages/task-graph/src/task/README.md index 8c5489da..13a88491 100644 --- a/packages/task-graph/src/task/README.md +++ b/packages/task-graph/src/task/README.md @@ -252,7 +252,7 @@ This resolution happens automatically before `validateInput()` is called, so by import { Task } from "@workglow/task-graph"; import { TypeTabularRepository } from "@workglow/storage"; -class DataProcessingTask extends Task<{ repository: ITabularRepository; query: string }> { +class DataProcessingTask extends Task<{ repository: ITabularStorage; query: string }> { static readonly type = "DataProcessingTask"; static inputSchema() { @@ -270,7 +270,7 @@ class DataProcessingTask extends Task<{ repository: ITabularRepository; query: s } async execute(input: DataProcessingTaskInput, context: IExecuteContext) { - // repository is guaranteed to be an ITabularRepository instance + // repository is guaranteed to be an ITabularStorage instance const data = await input.repository.getAll(); return { results: data }; } diff --git a/packages/test/src/binding/FsFolderTaskGraphRepository.ts b/packages/test/src/binding/FsFolderTaskGraphRepository.ts index d34df738..14982744 100644 --- a/packages/test/src/binding/FsFolderTaskGraphRepository.ts +++ b/packages/test/src/binding/FsFolderTaskGraphRepository.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { FsFolderTabularRepository } from "@workglow/storage"; +import { FsFolderTabularStorage } from "@workglow/storage"; import { TaskGraphPrimaryKeyNames, TaskGraphSchema, @@ -23,7 +23,7 @@ export const FS_FOLDER_TASK_GRAPH_REPOSITORY = createServiceToken { - let repo: InMemoryChunkVectorRepository; + let repo: InMemoryChunkVectorStorage; beforeEach(async () => { - repo = new InMemoryChunkVectorRepository(3); + repo = new InMemoryChunkVectorStorage(3); await repo.setupDatabase(); // Populate repository with test data diff --git a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts b/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts index 095827e0..e1fd3526 100644 --- a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts +++ b/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts @@ -5,14 +5,14 @@ */ import { ChunkVectorSearchTask } from "@workglow/ai"; -import { InMemoryChunkVectorRepository, registerChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorSearchTask", () => { - let repo: InMemoryChunkVectorRepository; + let repo: InMemoryChunkVectorStorage; beforeEach(async () => { - repo = new InMemoryChunkVectorRepository(3); + repo = new InMemoryChunkVectorStorage(3); await repo.setupDatabase(); // Populate repository with test data @@ -185,7 +185,7 @@ describe("ChunkVectorSearchTask", () => { }); test("should handle empty repository", async () => { - const emptyRepo = new InMemoryChunkVectorRepository(3); + const emptyRepo = new InMemoryChunkVectorStorage(3); await emptyRepo.setupDatabase(); const queryVector = new Float32Array([1.0, 0.0, 0.0]); diff --git a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts b/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts index c42ec1d3..72c5c6b8 100644 --- a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts +++ b/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts @@ -5,14 +5,14 @@ */ import { ChunkVectorUpsertTask } from "@workglow/ai"; -import { InMemoryChunkVectorRepository, registerChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorUpsertTask", () => { - let repo: InMemoryChunkVectorRepository; + let repo: InMemoryChunkVectorStorage; beforeEach(async () => { - repo = new InMemoryChunkVectorRepository(3); + repo = new InMemoryChunkVectorStorage(3); await repo.setupDatabase(); }); @@ -34,10 +34,10 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(1); expect(result.doc_id).toBe("doc1"); - expect(result.ids).toHaveLength(1); + expect(result.chunk_ids).toHaveLength(1); // Verify vector was stored - const retrieved = await repo.get({ chunk_id: result.ids[0] }); + const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); expect(retrieved).toBeDefined(); expect(retrieved?.doc_id).toBe("doc1"); expect(retrieved!.metadata).toEqual(metadata); @@ -61,11 +61,11 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(3); expect(result.doc_id).toBe("doc1"); - expect(result.ids).toHaveLength(3); + expect(result.chunk_ids).toHaveLength(3); // Verify all vectors were stored for (let i = 0; i < 3; i++) { - const retrieved = await repo.get({ chunk_id: result.ids[i] }); + const retrieved = await repo.get({ chunk_id: result.chunk_ids[i] }); expect(retrieved).toBeDefined(); expect(retrieved?.doc_id).toBe("doc1"); expect(retrieved!.metadata).toEqual(metadata); @@ -87,7 +87,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(1); expect(result.doc_id).toBe("doc1"); - const retrieved = await repo.get({ chunk_id: result.ids[0] }); + const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); expect(retrieved).toBeDefined(); expect(retrieved!.metadata).toEqual(metadata); }); @@ -116,7 +116,7 @@ describe("ChunkVectorUpsertTask", () => { metadata: metadata2, }); - const retrieved = await repo.get({ chunk_id: result2.ids[0] }); + const retrieved = await repo.get({ chunk_id: result2.chunk_ids[0] }); expect(retrieved).toBeDefined(); expect(retrieved!.metadata).toEqual(metadata2); }); @@ -151,7 +151,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(1); - const retrieved = await repo.get({ chunk_id: result.ids[0] }); + const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); expect(retrieved).toBeDefined(); expect(retrieved?.vector).toBeInstanceOf(Int8Array); }); @@ -170,7 +170,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(1); - const retrieved = await repo.get({ chunk_id: result.ids[0] }); + const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); expect(retrieved!.metadata).toEqual(metadata); }); @@ -191,7 +191,7 @@ describe("ChunkVectorUpsertTask", () => { }); expect(result.count).toBe(count); - expect(result.ids).toHaveLength(count); + expect(result.chunk_ids).toHaveLength(count); const size = await repo.size(); expect(size).toBe(count); @@ -217,7 +217,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.doc_id).toBe("doc1"); // Verify vector was stored - const retrieved = await repo.get({ chunk_id: result.ids[0] }); + const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); expect(retrieved).toBeDefined(); expect(retrieved?.doc_id).toBe("doc1"); expect(retrieved!.metadata).toEqual(metadata); diff --git a/packages/test/src/test/rag/DocumentRepository.test.ts b/packages/test/src/test/rag/DocumentRepository.test.ts index 66a3c695..64f4ec92 100644 --- a/packages/test/src/test/rag/DocumentRepository.test.ts +++ b/packages/test/src/test/rag/DocumentRepository.test.ts @@ -9,8 +9,8 @@ import { DocumentRepository, DocumentStorageKey, DocumentStorageSchema, - InMemoryChunkVectorRepository, - InMemoryTabularRepository, + InMemoryChunkVectorStorage, + InMemoryTabularStorage, NodeIdGenerator, NodeKind, StructuralParser, @@ -19,16 +19,16 @@ import { beforeEach, describe, expect, it } from "vitest"; describe("DocumentRepository", () => { let repo: DocumentRepository; - let vectorStorage: InMemoryChunkVectorRepository; + let vectorStorage: InMemoryChunkVectorStorage; beforeEach(async () => { - const tabularStorage = new InMemoryTabularRepository( + const tabularStorage = new InMemoryTabularStorage( DocumentStorageSchema, DocumentStorageKey ); await tabularStorage.setupDatabase(); - vectorStorage = new InMemoryChunkVectorRepository(3); + vectorStorage = new InMemoryChunkVectorStorage(3); await vectorStorage.setupDatabase(); repo = new DocumentRepository(tabularStorage, vectorStorage); @@ -205,7 +205,7 @@ Paragraph.`; it("should return empty list for empty repository", async () => { // Create fresh empty repo - const tabularStorage = new InMemoryTabularRepository( + const tabularStorage = new InMemoryTabularStorage( DocumentStorageSchema, DocumentStorageKey ); @@ -341,7 +341,7 @@ Paragraph.`; }); it("should return empty array for search when no vector storage configured", async () => { - const tabularStorage = new InMemoryTabularRepository( + const tabularStorage = new InMemoryTabularStorage( DocumentStorageSchema, DocumentStorageKey ); diff --git a/packages/test/src/test/rag/EndToEnd.test.ts b/packages/test/src/test/rag/EndToEnd.test.ts index ecbaf1e7..3fac3586 100644 --- a/packages/test/src/test/rag/EndToEnd.test.ts +++ b/packages/test/src/test/rag/EndToEnd.test.ts @@ -10,8 +10,8 @@ import { DocumentRepository, DocumentStorageKey, DocumentStorageSchema, - InMemoryChunkVectorRepository, - InMemoryTabularRepository, + InMemoryChunkVectorStorage, + InMemoryTabularStorage, NodeIdGenerator, StructuralParser, } from "@workglow/storage"; @@ -82,13 +82,13 @@ Finds patterns in data.`; }); it("should demonstrate document repository integration", async () => { - const tabularStorage = new InMemoryTabularRepository( + const tabularStorage = new InMemoryTabularStorage( DocumentStorageSchema, DocumentStorageKey ); await tabularStorage.setupDatabase(); - const vectorStorage = new InMemoryChunkVectorRepository(3); + const vectorStorage = new InMemoryChunkVectorStorage(3); await vectorStorage.setupDatabase(); const docRepo = new DocumentRepository(tabularStorage, vectorStorage); diff --git a/packages/test/src/test/rag/FullChain.test.ts b/packages/test/src/test/rag/FullChain.test.ts index cb4d61e2..0ee22153 100644 --- a/packages/test/src/test/rag/FullChain.test.ts +++ b/packages/test/src/test/rag/FullChain.test.ts @@ -5,13 +5,13 @@ */ import { HierarchicalChunkerTaskOutput } from "@workglow/ai"; -import { ChunkNode, InMemoryChunkVectorRepository, NodeIdGenerator } from "@workglow/storage"; +import { ChunkNode, InMemoryChunkVectorStorage, NodeIdGenerator } from "@workglow/storage"; import { Workflow } from "@workglow/task-graph"; import { describe, expect, it } from "vitest"; describe("Complete chainable workflow", () => { it("should chain from parsing to storage without loops", async () => { - const vectorRepo = new InMemoryChunkVectorRepository(3); + const vectorRepo = new InMemoryChunkVectorStorage(3); await vectorRepo.setupDatabase(); const markdown = `# Test Document diff --git a/packages/test/src/test/rag/HybridSearchTask.test.ts b/packages/test/src/test/rag/HybridSearchTask.test.ts index 89c1ef4b..79803922 100644 --- a/packages/test/src/test/rag/HybridSearchTask.test.ts +++ b/packages/test/src/test/rag/HybridSearchTask.test.ts @@ -5,14 +5,14 @@ */ import { hybridSearch } from "@workglow/ai"; -import { InMemoryChunkVectorRepository, registerChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorHybridSearchTask", () => { - let repo: InMemoryChunkVectorRepository; + let repo: InMemoryChunkVectorStorage; beforeEach(async () => { - repo = new InMemoryChunkVectorRepository(3); + repo = new InMemoryChunkVectorStorage(3); await repo.setupDatabase(); // Populate repository with test data diff --git a/packages/test/src/test/rag/RagWorkflow.test.ts b/packages/test/src/test/rag/RagWorkflow.test.ts index ac595af7..412033db 100644 --- a/packages/test/src/test/rag/RagWorkflow.test.ts +++ b/packages/test/src/test/rag/RagWorkflow.test.ts @@ -50,8 +50,8 @@ import { DocumentRepository, DocumentStorageKey, DocumentStorageSchema, - InMemoryChunkVectorRepository, - InMemoryTabularRepository, + InMemoryChunkVectorStorage, + InMemoryTabularStorage, registerChunkVectorRepository, } from "@workglow/storage"; import { getTaskQueueRegistry, setTaskQueueRegistry, Workflow } from "@workglow/task-graph"; @@ -62,7 +62,7 @@ import { registerHuggingfaceLocalModels } from "../../samples"; export { FileLoaderTask } from "@workglow/tasks"; describe("RAG Workflow End-to-End", () => { - let vectorRepo: InMemoryChunkVectorRepository; + let vectorRepo: InMemoryChunkVectorStorage; let docRepo: DocumentRepository; const vectorRepoName = "rag-test-vector-repo"; const embeddingModel = "onnx:Xenova/all-MiniLM-L6-v2:q8"; @@ -79,13 +79,13 @@ describe("RAG Workflow End-to-End", () => { await registerHuggingfaceLocalModels(); // Setup repositories - vectorRepo = new InMemoryChunkVectorRepository(3); + vectorRepo = new InMemoryChunkVectorStorage(3); await vectorRepo.setupDatabase(); // Register vector repository for use in workflows registerChunkVectorRepository(vectorRepoName, vectorRepo); - const tabularRepo = new InMemoryTabularRepository(DocumentStorageSchema, DocumentStorageKey); + const tabularRepo = new InMemoryTabularStorage(DocumentStorageSchema, DocumentStorageKey); await tabularRepo.setupDatabase(); docRepo = new DocumentRepository(tabularRepo, vectorRepo); diff --git a/packages/test/src/test/storage-kv/FsFolderJsonKvRepository.test.ts b/packages/test/src/test/storage-kv/FsFolderJsonKvRepository.test.ts index f71d6c60..0e187a6f 100644 --- a/packages/test/src/test/storage-kv/FsFolderJsonKvRepository.test.ts +++ b/packages/test/src/test/storage-kv/FsFolderJsonKvRepository.test.ts @@ -4,14 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { FsFolderJsonKvRepository } from "@workglow/storage"; +import { FsFolderJsonKvStorage } from "@workglow/storage"; import { mkdirSync, rmSync } from "fs"; import { afterEach, beforeEach, describe } from "vitest"; import { runGenericKvRepositoryTests } from "./genericKvRepositoryTests"; const testDir = ".cache/test/kv-fs-folder-json"; -describe("FsFolderJsonKvRepository", () => { +describe("FsFolderJsonKvStorage", () => { beforeEach(() => { try { mkdirSync(testDir, { recursive: true }); @@ -25,6 +25,6 @@ describe("FsFolderJsonKvRepository", () => { }); runGenericKvRepositoryTests(async (keyType, valueType) => { - return new FsFolderJsonKvRepository(testDir, keyType, valueType); + return new FsFolderJsonKvStorage(testDir, keyType, valueType); }); }); diff --git a/packages/test/src/test/storage-kv/FsFolderKvRepository.test.ts b/packages/test/src/test/storage-kv/FsFolderKvRepository.test.ts index d2986dad..c446488a 100644 --- a/packages/test/src/test/storage-kv/FsFolderKvRepository.test.ts +++ b/packages/test/src/test/storage-kv/FsFolderKvRepository.test.ts @@ -4,14 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { FsFolderKvRepository } from "@workglow/storage"; +import { FsFolderKvStorage } from "@workglow/storage"; import { mkdirSync, rmSync } from "fs"; import { afterEach, beforeEach, describe } from "vitest"; import { runGenericKvRepositoryTests } from "./genericKvRepositoryTests"; const testDir = ".cache/test/kv-fs-folder"; -describe("FsFolderKvRepository", () => { +describe("FsFolderKvStorage", () => { beforeEach(() => { try { mkdirSync(testDir, { recursive: true }); @@ -30,7 +30,7 @@ describe("FsFolderKvRepository", () => { typeof valueType === "object" && valueType !== null && "type" in valueType ? String(valueType.type) : "data"; - return new FsFolderKvRepository( + return new FsFolderKvStorage( testDir, (key) => `${String(key)}.${schemaType}`, keyType, diff --git a/packages/test/src/test/storage-kv/InMemoryKvRepository.test.ts b/packages/test/src/test/storage-kv/InMemoryKvRepository.test.ts index fc064063..e1209e46 100644 --- a/packages/test/src/test/storage-kv/InMemoryKvRepository.test.ts +++ b/packages/test/src/test/storage-kv/InMemoryKvRepository.test.ts @@ -4,12 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { InMemoryKvRepository } from "@workglow/storage"; +import { InMemoryKvStorage } from "@workglow/storage"; import { describe } from "vitest"; import { runGenericKvRepositoryTests } from "./genericKvRepositoryTests"; -describe("InMemoryKvRepository", () => { +describe("InMemoryKvStorage", () => { runGenericKvRepositoryTests( - async (keyType, valueType) => new InMemoryKvRepository(keyType, valueType) + async (keyType, valueType) => new InMemoryKvStorage(keyType, valueType) ); }); diff --git a/packages/test/src/test/storage-kv/IndexedDbKvRepository.test.ts b/packages/test/src/test/storage-kv/IndexedDbKvRepository.test.ts index 01da5907..fb267bd6 100644 --- a/packages/test/src/test/storage-kv/IndexedDbKvRepository.test.ts +++ b/packages/test/src/test/storage-kv/IndexedDbKvRepository.test.ts @@ -4,16 +4,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { IndexedDbKvRepository } from "@workglow/storage"; +import { IndexedDbKvStorage } from "@workglow/storage"; import { uuid4 } from "@workglow/util"; import "fake-indexeddb/auto"; import { describe } from "vitest"; import { runGenericKvRepositoryTests } from "./genericKvRepositoryTests"; -describe("IndexedDbKvRepository", () => { +describe("IndexedDbKvStorage", () => { const dbName = `idx_test_${uuid4().replace(/-/g, "_")}`; runGenericKvRepositoryTests( - async (keyType, valueType) => new IndexedDbKvRepository(`${dbName}`, keyType, valueType) + async (keyType, valueType) => new IndexedDbKvStorage(`${dbName}`, keyType, valueType) ); }); diff --git a/packages/test/src/test/storage-kv/PostgresKvRepository.test.ts b/packages/test/src/test/storage-kv/PostgresKvRepository.test.ts index 32b2d9dd..b891e3be 100644 --- a/packages/test/src/test/storage-kv/PostgresKvRepository.test.ts +++ b/packages/test/src/test/storage-kv/PostgresKvRepository.test.ts @@ -5,7 +5,7 @@ */ import { PGlite } from "@electric-sql/pglite"; -import { PostgresKvRepository } from "@workglow/storage"; +import { PostgresKvStorage } from "@workglow/storage"; import { uuid4 } from "@workglow/util"; import type { Pool } from "pg"; import { describe } from "vitest"; @@ -13,9 +13,9 @@ import { runGenericKvRepositoryTests } from "./genericKvRepositoryTests"; const db = new PGlite() as unknown as Pool; -describe("PostgresKvRepository", () => { +describe("PostgresKvStorage", () => { runGenericKvRepositoryTests(async (keyType, valueType) => { const dbName = `pg_test_${uuid4().replace(/-/g, "_")}`; - return new PostgresKvRepository(db, dbName, keyType, valueType); + return new PostgresKvStorage(db, dbName, keyType, valueType); }); }); diff --git a/packages/test/src/test/storage-kv/SqliteKvRepository.test.ts b/packages/test/src/test/storage-kv/SqliteKvRepository.test.ts index 1b5792cc..50aa8423 100644 --- a/packages/test/src/test/storage-kv/SqliteKvRepository.test.ts +++ b/packages/test/src/test/storage-kv/SqliteKvRepository.test.ts @@ -4,15 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { SqliteKvRepository } from "@workglow/storage"; +import { SqliteKvStorage } from "@workglow/storage"; import { uuid4 } from "@workglow/util"; import { describe } from "vitest"; import { runGenericKvRepositoryTests } from "./genericKvRepositoryTests"; -describe("SqliteKvRepository", () => { +describe("SqliteKvStorage", () => { runGenericKvRepositoryTests( async (keyType, valueType) => - new SqliteKvRepository( + new SqliteKvStorage( ":memory:", `sql_test_${uuid4().replace(/-/g, "_")}`, keyType, diff --git a/packages/test/src/test/storage-kv/SupabaseKvRepository.test.ts b/packages/test/src/test/storage-kv/SupabaseKvRepository.test.ts index 1f750a58..c42c2259 100644 --- a/packages/test/src/test/storage-kv/SupabaseKvRepository.test.ts +++ b/packages/test/src/test/storage-kv/SupabaseKvRepository.test.ts @@ -7,24 +7,24 @@ import { DefaultKeyValueKey, DefaultKeyValueSchema, - SupabaseKvRepository, - SupabaseTabularRepository, + SupabaseKvStorage, + SupabaseTabularStorage, } from "@workglow/storage"; import { uuid4 } from "@workglow/util"; import { describe } from "vitest"; import { createSupabaseMockClient } from "../helpers/SupabaseMockClient"; import { runGenericKvRepositoryTests } from "./genericKvRepositoryTests"; -describe("SupabaseKvRepository", () => { +describe("SupabaseKvStorage", () => { const client = createSupabaseMockClient(); runGenericKvRepositoryTests(async (keyType, valueType) => { const tableName = `supabase_test_${uuid4().replace(/-/g, "_")}`; - return new SupabaseKvRepository( + return new SupabaseKvStorage( client, tableName, keyType, valueType, - new SupabaseTabularRepository(client, tableName, DefaultKeyValueSchema, DefaultKeyValueKey) + new SupabaseTabularStorage(client, tableName, DefaultKeyValueSchema, DefaultKeyValueKey) ); }); }); diff --git a/packages/test/src/test/storage-kv/genericKvRepositoryTests.ts b/packages/test/src/test/storage-kv/genericKvRepositoryTests.ts index bdd5352c..a443a85c 100644 --- a/packages/test/src/test/storage-kv/genericKvRepositoryTests.ts +++ b/packages/test/src/test/storage-kv/genericKvRepositoryTests.ts @@ -4,15 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DefaultKeyValueSchema, IKvRepository } from "@workglow/storage"; +import { DefaultKeyValueSchema, IKvStorage } from "@workglow/storage"; import { FromSchema, JsonSchema } from "@workglow/util"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; export function runGenericKvRepositoryTests( - createRepository: (keyType: JsonSchema, valueType: JsonSchema) => Promise> + createRepository: (keyType: JsonSchema, valueType: JsonSchema) => Promise> ) { describe("with default schemas (key and value)", () => { - let repository: IKvRepository< + let repository: IKvStorage< FromSchema, FromSchema >; @@ -64,7 +64,7 @@ export function runGenericKvRepositoryTests( }); describe("with json value", () => { - let repository: IKvRepository; + let repository: IKvStorage; beforeEach(async () => { repository = (await createRepository( @@ -77,7 +77,7 @@ export function runGenericKvRepositoryTests( }, additionalProperties: false, } - )) as IKvRepository; + )) as IKvStorage; await (repository as any).setupDatabase?.(); }); diff --git a/packages/test/src/test/storage-tabular/CachedTabularRepository.test.ts b/packages/test/src/test/storage-tabular/CachedTabularRepository.test.ts index 9b4b6522..bb4f6e1a 100644 --- a/packages/test/src/test/storage-tabular/CachedTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/CachedTabularRepository.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CachedTabularRepository, InMemoryTabularRepository } from "@workglow/storage"; +import { CachedTabularStorage, InMemoryTabularStorage } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { runGenericTabularRepositorySubscriptionTests } from "./genericTabularRepositorySubscriptionTests"; import { @@ -17,15 +17,15 @@ import { const spyOn = vi.spyOn; -describe("CachedTabularRepository", () => { +describe("CachedTabularStorage", () => { describe("generic repository tests", () => { runGenericTabularRepositoryTests( async () => { - const durable = new InMemoryTabularRepository< + const durable = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); - return new CachedTabularRepository( + return new CachedTabularStorage( durable, undefined, CompoundSchema, @@ -33,7 +33,7 @@ describe("CachedTabularRepository", () => { ); }, async () => { - const durable = new InMemoryTabularRepository< + const durable = new InMemoryTabularStorage< typeof SearchSchema, typeof SearchPrimaryKeyNames >(SearchSchema, SearchPrimaryKeyNames, [ @@ -42,7 +42,7 @@ describe("CachedTabularRepository", () => { ["subcategory", "category"], "value", ]); - return new CachedTabularRepository( + return new CachedTabularStorage( durable, undefined, SearchSchema, @@ -55,11 +55,11 @@ describe("CachedTabularRepository", () => { runGenericTabularRepositorySubscriptionTests( async () => { - const durable = new InMemoryTabularRepository< + const durable = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); - return new CachedTabularRepository( + return new CachedTabularStorage( durable, undefined, CompoundSchema, @@ -70,15 +70,15 @@ describe("CachedTabularRepository", () => { ); describe("caching behavior", () => { - let durable: InMemoryTabularRepository; - let cached: CachedTabularRepository; + let durable: InMemoryTabularStorage; + let cached: CachedTabularStorage; beforeEach(() => { - durable = new InMemoryTabularRepository< + durable = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); - cached = new CachedTabularRepository( + cached = new CachedTabularStorage( durable, undefined, CompoundSchema, @@ -353,15 +353,15 @@ describe("CachedTabularRepository", () => { }); describe("cache management", () => { - let durable: InMemoryTabularRepository; - let cached: CachedTabularRepository; + let durable: InMemoryTabularStorage; + let cached: CachedTabularStorage; beforeEach(() => { - durable = new InMemoryTabularRepository< + durable = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); - cached = new CachedTabularRepository( + cached = new CachedTabularStorage( durable, undefined, CompoundSchema, @@ -428,12 +428,12 @@ describe("CachedTabularRepository", () => { it("should handle cache initialization errors gracefully", async () => { // Create a mock durable that throws on getAll - const errorDurable = new InMemoryTabularRepository< + const errorDurable = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); - const cachedWithError = new CachedTabularRepository< + const cachedWithError = new CachedTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(errorDurable, undefined, CompoundSchema, CompoundPrimaryKeyNames); @@ -450,15 +450,15 @@ describe("CachedTabularRepository", () => { }); describe("event forwarding", () => { - let durable: InMemoryTabularRepository; - let cached: CachedTabularRepository; + let durable: InMemoryTabularStorage; + let cached: CachedTabularStorage; beforeEach(() => { - durable = new InMemoryTabularRepository< + durable = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); - cached = new CachedTabularRepository( + cached = new CachedTabularStorage( durable, undefined, CompoundSchema, @@ -569,17 +569,17 @@ describe("CachedTabularRepository", () => { describe("custom cache repository", () => { it("should use provided cache repository", async () => { - const durable = new InMemoryTabularRepository< + const durable = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); - const customCache = new InMemoryTabularRepository< + const customCache = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); - const cached = new CachedTabularRepository< + const cached = new CachedTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(durable, customCache, CompoundSchema, CompoundPrimaryKeyNames); @@ -605,27 +605,27 @@ describe("CachedTabularRepository", () => { describe("constructor validation", () => { it("should throw error if schema and primaryKeyNames are not provided", () => { - const durable = new InMemoryTabularRepository< + const durable = new InMemoryTabularStorage< typeof CompoundSchema, typeof CompoundPrimaryKeyNames >(CompoundSchema, CompoundPrimaryKeyNames); expect(() => { - new CachedTabularRepository(durable); + new CachedTabularStorage(durable); }).toThrow("Schema and primaryKeyNames must be provided"); }); }); describe("deleteSearch", () => { - let durable: InMemoryTabularRepository; - let cached: CachedTabularRepository; + let durable: InMemoryTabularStorage; + let cached: CachedTabularStorage; beforeEach(() => { - durable = new InMemoryTabularRepository( + durable = new InMemoryTabularStorage( SearchSchema, SearchPrimaryKeyNames ); - cached = new CachedTabularRepository( + cached = new CachedTabularStorage( durable, undefined, SearchSchema, diff --git a/packages/test/src/test/storage-tabular/FsFolderTabularRepository.test.ts b/packages/test/src/test/storage-tabular/FsFolderTabularRepository.test.ts index a930074a..6cc0e618 100644 --- a/packages/test/src/test/storage-tabular/FsFolderTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/FsFolderTabularRepository.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { FsFolderTabularRepository } from "@workglow/storage"; +import { FsFolderTabularStorage } from "@workglow/storage"; import { mkdirSync, rmSync } from "fs"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; import { runGenericTabularRepositorySubscriptionTests } from "./genericTabularRepositorySubscriptionTests"; @@ -20,7 +20,7 @@ import { const testDir = ".cache/test/testing"; -describe("FsFolderTabularRepository", () => { +describe("FsFolderTabularStorage", () => { beforeEach(() => { try { mkdirSync(testDir, { recursive: true }); @@ -37,14 +37,14 @@ describe("FsFolderTabularRepository", () => { describe("basic functionality", () => { runGenericTabularRepositoryTests( async () => - new FsFolderTabularRepository( + new FsFolderTabularStorage( testDir, CompoundSchema, CompoundPrimaryKeyNames ), undefined, async () => - new FsFolderTabularRepository( + new FsFolderTabularStorage( testDir, AllTypesSchema, AllTypesPrimaryKeyNames @@ -54,7 +54,7 @@ describe("FsFolderTabularRepository", () => { runGenericTabularRepositorySubscriptionTests( async () => - new FsFolderTabularRepository( + new FsFolderTabularStorage( testDir, CompoundSchema, CompoundPrimaryKeyNames @@ -66,7 +66,7 @@ describe("FsFolderTabularRepository", () => { describe("search functionality", () => { test("should throw error when attempting to search", async () => { try { - const repo = new FsFolderTabularRepository< + const repo = new FsFolderTabularStorage< typeof SearchSchema, typeof SearchPrimaryKeyNames >(testDir, SearchSchema, SearchPrimaryKeyNames, [ diff --git a/packages/test/src/test/storage-tabular/InMemoryTabularRepository.test.ts b/packages/test/src/test/storage-tabular/InMemoryTabularRepository.test.ts index 6c293afa..0f7f0de3 100644 --- a/packages/test/src/test/storage-tabular/InMemoryTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/InMemoryTabularRepository.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { InMemoryTabularRepository } from "@workglow/storage"; +import { InMemoryTabularStorage } from "@workglow/storage"; import { describe } from "vitest"; import { runGenericTabularRepositorySubscriptionTests } from "./genericTabularRepositorySubscriptionTests"; import { @@ -17,21 +17,21 @@ import { SearchSchema, } from "./genericTabularRepositoryTests"; -describe("InMemoryTabularRepository", () => { +describe("InMemoryTabularStorage", () => { runGenericTabularRepositoryTests( async () => - new InMemoryTabularRepository( + new InMemoryTabularStorage( CompoundSchema, CompoundPrimaryKeyNames ), async () => - new InMemoryTabularRepository( + new InMemoryTabularStorage( SearchSchema, SearchPrimaryKeyNames, ["category", ["category", "subcategory"], ["subcategory", "category"], "value"] ), async () => - new InMemoryTabularRepository( + new InMemoryTabularStorage( AllTypesSchema, AllTypesPrimaryKeyNames ) @@ -39,7 +39,7 @@ describe("InMemoryTabularRepository", () => { runGenericTabularRepositorySubscriptionTests( async () => - new InMemoryTabularRepository( + new InMemoryTabularStorage( CompoundSchema, CompoundPrimaryKeyNames ), diff --git a/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts b/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts index 6e07ec9d..63483bae 100644 --- a/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { IndexedDbTabularRepository } from "@workglow/storage"; +import { IndexedDbTabularStorage } from "@workglow/storage"; import type { DataPortSchemaObject, FromSchema } from "@workglow/util"; import { uuid4 } from "@workglow/util"; import "fake-indexeddb/auto"; @@ -21,25 +21,25 @@ import { SearchSchema, } from "./genericTabularRepositoryTests"; -describe("IndexedDbTabularRepository", () => { +describe("IndexedDbTabularStorage", () => { const dbName = `idx_test_${uuid4().replace(/-/g, "_")}`; runGenericTabularRepositoryTests( async () => - new IndexedDbTabularRepository( + new IndexedDbTabularStorage( `${dbName}_complex`, CompoundSchema, CompoundPrimaryKeyNames ), async () => - new IndexedDbTabularRepository( + new IndexedDbTabularStorage( `${dbName}_compound`, SearchSchema, SearchPrimaryKeyNames, ["category", ["category", "subcategory"], ["subcategory", "category"], "value"] ), async () => { - const repo = new IndexedDbTabularRepository< + const repo = new IndexedDbTabularStorage< typeof AllTypesSchema, typeof AllTypesPrimaryKeyNames >(`${dbName}_alltypes`, AllTypesSchema, AllTypesPrimaryKeyNames); @@ -52,7 +52,7 @@ describe("IndexedDbTabularRepository", () => { async () => { // Use a unique database name for each test to avoid conflicts const uniqueDbName = `${dbName}_subscription_${Date.now()}_${Math.random().toString(36).slice(2)}`; - return new IndexedDbTabularRepository( + return new IndexedDbTabularStorage( uniqueDbName, CompoundSchema, CompoundPrimaryKeyNames, @@ -99,14 +99,14 @@ describe("IndexedDbTabularRepository", () => { type OptionalEntity = FromSchema; describe("with all required columns (efficient cursor-based)", () => { - let repo: IndexedDbTabularRepository< + let repo: IndexedDbTabularStorage< typeof RequiredColumnsSchema, typeof RequiredColumnsPK, RequiredEntity >; beforeEach(async () => { - repo = new IndexedDbTabularRepository( + repo = new IndexedDbTabularStorage( `${dbName}_required`, RequiredColumnsSchema, RequiredColumnsPK, @@ -192,7 +192,7 @@ describe("IndexedDbTabularRepository", () => { }); describe("with optional columns (full table scan)", () => { - let repo: IndexedDbTabularRepository< + let repo: IndexedDbTabularStorage< typeof OptionalColumnsSchema, typeof OptionalColumnsPK, OptionalEntity, @@ -200,7 +200,7 @@ describe("IndexedDbTabularRepository", () => { >; beforeEach(async () => { - repo = new IndexedDbTabularRepository( + repo = new IndexedDbTabularStorage( `${dbName}_optional`, OptionalColumnsSchema, OptionalColumnsPK, diff --git a/packages/test/src/test/storage-tabular/PostgresTabularRepository.test.ts b/packages/test/src/test/storage-tabular/PostgresTabularRepository.test.ts index 5bc46f37..4db15b82 100644 --- a/packages/test/src/test/storage-tabular/PostgresTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/PostgresTabularRepository.test.ts @@ -5,7 +5,7 @@ */ import { PGlite } from "@electric-sql/pglite"; -import { PostgresTabularRepository } from "@workglow/storage"; +import { PostgresTabularStorage } from "@workglow/storage"; import { uuid4 } from "@workglow/util"; import type { Pool } from "pg"; import { describe } from "vitest"; @@ -21,17 +21,17 @@ import { const db = new PGlite() as unknown as Pool; -describe("PostgresTabularRepository", () => { +describe("PostgresTabularStorage", () => { runGenericTabularRepositoryTests( async () => - new PostgresTabularRepository( + new PostgresTabularStorage( db, `sql_test_${uuid4().replace(/-/g, "_")}`, CompoundSchema, CompoundPrimaryKeyNames ), async () => - new PostgresTabularRepository( + new PostgresTabularStorage( db, `sql_test_${uuid4().replace(/-/g, "_")}`, SearchSchema, @@ -39,7 +39,7 @@ describe("PostgresTabularRepository", () => { ["category", ["category", "subcategory"], ["subcategory", "category"], "value"] ), async () => { - const repo = new PostgresTabularRepository< + const repo = new PostgresTabularStorage< typeof AllTypesSchema, typeof AllTypesPrimaryKeyNames >( diff --git a/packages/test/src/test/storage-tabular/SqliteTabularRepository.test.ts b/packages/test/src/test/storage-tabular/SqliteTabularRepository.test.ts index dba32ad6..673cf40e 100644 --- a/packages/test/src/test/storage-tabular/SqliteTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/SqliteTabularRepository.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { SqliteTabularRepository } from "@workglow/storage"; +import { SqliteTabularStorage } from "@workglow/storage"; import { uuid4 } from "@workglow/util"; import { describe } from "vitest"; import { @@ -17,17 +17,17 @@ import { SearchSchema, } from "./genericTabularRepositoryTests"; -describe("SqliteTabularRepository", () => { +describe("SqliteTabularStorage", () => { runGenericTabularRepositoryTests( async () => - new SqliteTabularRepository( + new SqliteTabularStorage( ":memory:", `sql_test_${uuid4().replace(/-/g, "_")}`, CompoundSchema, CompoundPrimaryKeyNames ), async () => - new SqliteTabularRepository( + new SqliteTabularStorage( ":memory:", `sql_test_${uuid4().replace(/-/g, "_")}`, SearchSchema, @@ -35,7 +35,7 @@ describe("SqliteTabularRepository", () => { ["category", ["category", "subcategory"], ["subcategory", "category"], "value"] ), async () => { - const repo = new SqliteTabularRepository< + const repo = new SqliteTabularStorage< typeof AllTypesSchema, typeof AllTypesPrimaryKeyNames >( diff --git a/packages/test/src/test/storage-tabular/SupabaseTabularRepository.test.ts b/packages/test/src/test/storage-tabular/SupabaseTabularRepository.test.ts index 3a4f507e..04a7833c 100644 --- a/packages/test/src/test/storage-tabular/SupabaseTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/SupabaseTabularRepository.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { SupabaseTabularRepository } from "@workglow/storage"; +import { SupabaseTabularStorage } from "@workglow/storage"; import { uuid4 } from "@workglow/util"; import { describe } from "vitest"; import { createSupabaseMockClient } from "../helpers/SupabaseMockClient"; @@ -20,17 +20,17 @@ import { const client = createSupabaseMockClient(); -describe("SupabaseTabularRepository", () => { +describe("SupabaseTabularStorage", () => { runGenericTabularRepositoryTests( async () => - new SupabaseTabularRepository( + new SupabaseTabularStorage( client, `supabase_test_${uuid4().replace(/-/g, "_")}`, CompoundSchema, CompoundPrimaryKeyNames ), async () => - new SupabaseTabularRepository( + new SupabaseTabularStorage( client, `supabase_test_${uuid4().replace(/-/g, "_")}`, SearchSchema, @@ -38,7 +38,7 @@ describe("SupabaseTabularRepository", () => { ["category", ["category", "subcategory"], ["subcategory", "category"], "value"] ), async () => { - const repo = new SupabaseTabularRepository< + const repo = new SupabaseTabularStorage< typeof AllTypesSchema, typeof AllTypesPrimaryKeyNames >( diff --git a/packages/test/src/test/storage-tabular/genericTabularRepositorySubscriptionTests.ts b/packages/test/src/test/storage-tabular/genericTabularRepositorySubscriptionTests.ts index 6c6ca4b6..5da22fd9 100644 --- a/packages/test/src/test/storage-tabular/genericTabularRepositorySubscriptionTests.ts +++ b/packages/test/src/test/storage-tabular/genericTabularRepositorySubscriptionTests.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ITabularRepository, TabularChangePayload } from "@workglow/storage"; +import { ITabularStorage, TabularChangePayload } from "@workglow/storage"; import { FromSchema, sleep } from "@workglow/util"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; import { CompoundPrimaryKeyNames, CompoundSchema } from "./genericTabularRepositoryTests"; @@ -15,7 +15,7 @@ import { CompoundPrimaryKeyNames, CompoundSchema } from "./genericTabularReposit export function runGenericTabularRepositorySubscriptionTests( createRepository: () => Promise< - ITabularRepository + ITabularStorage >, options?: { /** Whether this repository implementation uses polling (needs longer waits) */ @@ -37,7 +37,7 @@ export function runGenericTabularRepositorySubscriptionTests( const initWaitTime = usesPolling ? Math.max(pollingIntervalMs * 4, 150) : 10; describe("Subscription Tests", () => { - let repository: ITabularRepository; + let repository: ITabularStorage; beforeEach(async () => { repository = await createRepository(); diff --git a/packages/test/src/test/storage-tabular/genericTabularRepositoryTests.ts b/packages/test/src/test/storage-tabular/genericTabularRepositoryTests.ts index ecaeb518..6c1a3f8d 100644 --- a/packages/test/src/test/storage-tabular/genericTabularRepositoryTests.ts +++ b/packages/test/src/test/storage-tabular/genericTabularRepositoryTests.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ITabularRepository } from "@workglow/storage"; +import { ITabularStorage } from "@workglow/storage"; import { DataPortSchemaObject, FromSchema } from "@workglow/util"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; @@ -79,17 +79,17 @@ export const AllTypesSchema = { export function runGenericTabularRepositoryTests( createCompoundPkRepository: () => Promise< - ITabularRepository + ITabularStorage >, createSearchableRepository?: () => Promise< - ITabularRepository + ITabularStorage >, createAllTypesRepository?: () => Promise< - ITabularRepository + ITabularStorage > ) { describe("with compound primary keys", () => { - let repository: ITabularRepository; + let repository: ITabularStorage; beforeEach(async () => { repository = await createCompoundPkRepository(); @@ -207,7 +207,7 @@ export function runGenericTabularRepositoryTests( // Only run compound index tests if createCompoundRepository is provided if (createSearchableRepository) { describe("with searchable indexes", () => { - let searchableRepo: ITabularRepository; + let searchableRepo: ITabularStorage; beforeEach(async () => { searchableRepo = await createSearchableRepository(); @@ -343,7 +343,7 @@ export function runGenericTabularRepositoryTests( }); describe(`deleteSearch tests`, () => { - let repository: ITabularRepository; + let repository: ITabularStorage; beforeEach(async () => { repository = await createSearchableRepository(); @@ -803,7 +803,7 @@ export function runGenericTabularRepositoryTests( }); describe("return value tests with timestamps", () => { - let repository: ITabularRepository; + let repository: ITabularStorage; beforeEach(async () => { repository = await createSearchableRepository(); @@ -983,7 +983,7 @@ export function runGenericTabularRepositoryTests( if (createAllTypesRepository) { describe("data type coverage", () => { type AllTypesRecord = FromSchema; - let repository: ITabularRepository; + let repository: ITabularStorage; beforeEach(async () => { repository = await createAllTypesRepository(); diff --git a/packages/test/src/test/storage-util/IndexedDbHybridSubscription.test.ts b/packages/test/src/test/storage-util/IndexedDbHybridSubscription.test.ts index 424427b1..91194d26 100644 --- a/packages/test/src/test/storage-util/IndexedDbHybridSubscription.test.ts +++ b/packages/test/src/test/storage-util/IndexedDbHybridSubscription.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { IndexedDbQueueStorage, IndexedDbTabularRepository } from "@workglow/storage"; +import { IndexedDbQueueStorage, IndexedDbTabularStorage } from "@workglow/storage"; import { sleep, uuid4 } from "@workglow/util"; import "fake-indexeddb/auto"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; @@ -16,7 +16,7 @@ import { afterEach, beforeEach, describe, expect, it } from "vitest"; */ describe("IndexedDB Hybrid Subscription Integration", () => { - describe("IndexedDbTabularRepository with HybridSubscriptionManager", () => { + describe("IndexedDbTabularStorage with HybridSubscriptionManager", () => { const schema = { type: "object" as const, properties: { @@ -37,7 +37,7 @@ describe("IndexedDB Hybrid Subscription Integration", () => { }); it("should use HybridSubscriptionManager instead of PollingSubscriptionManager", async () => { - const repo = new IndexedDbTabularRepository(tableName, schema, ["id"] as const, [], { + const repo = new IndexedDbTabularStorage(tableName, schema, ["id"] as const, [], { useBroadcastChannel: true, backupPollingIntervalMs: 5000, }); @@ -69,7 +69,7 @@ describe("IndexedDB Hybrid Subscription Integration", () => { it("should detect changes faster than polling interval", async () => { // Set a long backup polling interval to ensure we're not relying on it - const repo = new IndexedDbTabularRepository(tableName, schema, ["id"] as const, [], { + const repo = new IndexedDbTabularStorage(tableName, schema, ["id"] as const, [], { useBroadcastChannel: false, backupPollingIntervalMs: 10000, }); @@ -104,7 +104,7 @@ describe("IndexedDB Hybrid Subscription Integration", () => { }); it("should handle multiple rapid changes efficiently", async () => { - const repo = new IndexedDbTabularRepository(tableName, schema, ["id"] as const, [], { + const repo = new IndexedDbTabularStorage(tableName, schema, ["id"] as const, [], { useBroadcastChannel: false, backupPollingIntervalMs: 5000, }); @@ -138,7 +138,7 @@ describe("IndexedDB Hybrid Subscription Integration", () => { }); it("should support disabling BroadcastChannel", async () => { - const repo = new IndexedDbTabularRepository(tableName, schema, ["id"] as const, [], { + const repo = new IndexedDbTabularStorage(tableName, schema, ["id"] as const, [], { useBroadcastChannel: false, backupPollingIntervalMs: 0, }); @@ -165,7 +165,7 @@ describe("IndexedDB Hybrid Subscription Integration", () => { }); it("should handle delete operations", async () => { - const repo = new IndexedDbTabularRepository(tableName, schema, ["id"] as const, [], { + const repo = new IndexedDbTabularStorage(tableName, schema, ["id"] as const, [], { useBroadcastChannel: false, backupPollingIntervalMs: 0, }); @@ -199,7 +199,7 @@ describe("IndexedDB Hybrid Subscription Integration", () => { }); it("should handle bulk operations", async () => { - const repo = new IndexedDbTabularRepository(tableName, schema, ["id"] as const, [], { + const repo = new IndexedDbTabularStorage(tableName, schema, ["id"] as const, [], { useBroadcastChannel: false, backupPollingIntervalMs: 0, }); @@ -405,7 +405,7 @@ describe("IndexedDB Hybrid Subscription Integration", () => { }; // Test with hybrid subscription (backup polling disabled) - const hybridRepo = new IndexedDbTabularRepository( + const hybridRepo = new IndexedDbTabularStorage( tableName + "_hybrid", schema, ["id"] as const, diff --git a/packages/test/src/test/task-graph/InputResolver.test.ts b/packages/test/src/test/task-graph/InputResolver.test.ts index 2d69920a..eac89d82 100644 --- a/packages/test/src/test/task-graph/InputResolver.test.ts +++ b/packages/test/src/test/task-graph/InputResolver.test.ts @@ -5,9 +5,9 @@ */ import { - AnyTabularRepository, + AnyTabularStorage, getGlobalTabularRepositories, - InMemoryTabularRepository, + InMemoryTabularStorage, registerTabularRepository, TypeTabularRepository, } from "@workglow/storage"; @@ -32,11 +32,11 @@ describe("InputResolver", () => { additionalProperties: false, } as const; - let testRepo: InMemoryTabularRepository; + let testRepo: InMemoryTabularStorage; beforeEach(async () => { // Create and register a test repository - testRepo = new InMemoryTabularRepository(testEntitySchema, ["id"] as const); + testRepo = new InMemoryTabularStorage(testEntitySchema, ["id"] as const); await testRepo.setupDatabase(); registerTabularRepository("test-repo", testRepo); }); @@ -217,7 +217,7 @@ describe("InputResolver", () => { } async execute( - input: { repository: AnyTabularRepository; query: string }, + input: { repository: AnyTabularStorage; query: string }, _context: IExecuteContext ): Promise<{ results: any[] }> { const { repository } = input; From 5c0c9c698fdfd4428649d983082eb8b9b13c66ee Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Wed, 14 Jan 2026 03:52:50 +0000 Subject: [PATCH 11/14] [feat] Introduce Dataset Package and Update Dependencies - Added a new `@workglow/dataset` package to manage dataset-related functionalities, including chunk vector storage and document management. - Updated various components to utilize the new dataset package, replacing references to the storage package where applicable. - Enhanced the `bun.lock` and `package.json` files to include the new dataset package as a dependency across relevant modules. - Refactored existing tasks and tests to integrate with the new dataset structure, ensuring compatibility and improved functionality. - Updated documentation to reflect the introduction of the dataset package and its features. --- TODO.md | 2 +- bun.lock | 18 + packages/ai/package.json | 5 + packages/ai/src/task/ChunkRetrievalTask.ts | 2 +- packages/ai/src/task/ChunkToVectorTask.ts | 2 +- .../src/task/ChunkVectorHybridSearchTask.ts | 2 +- packages/ai/src/task/ChunkVectorSearchTask.ts | 2 +- packages/ai/src/task/ChunkVectorUpsertTask.ts | 2 +- packages/ai/src/task/DocumentEnricherTask.ts | 2 +- .../ai/src/task/HierarchicalChunkerTask.ts | 2 +- packages/ai/src/task/HierarchyJoinTask.ts | 4 +- packages/ai/src/task/StructuralParserTask.ts | 2 +- packages/dataset/CHANGELOG.md | 0 packages/dataset/LICENSE | 201 +++ packages/dataset/README.md | 1138 +++++++++++++++++ packages/dataset/package.json | 54 + packages/dataset/src/browser.ts | 7 + packages/dataset/src/bun.ts | 7 + .../src/chunk-vector/ChunkVectorSchema.ts | 0 .../ChunkVectorStorageRegistry.ts | 0 .../src/chunk-vector/IChunkVectorStorage.ts | 105 ++ .../InMemoryChunkVectorStorage.ts | 185 +++ .../PostgresChunkVectorStorage.ts | 293 +++++ .../src/chunk-vector/README.md | 0 .../chunk-vector/SqliteChunkVectorStorage.ts | 192 +++ packages/dataset/src/common-server.ts | 7 + packages/dataset/src/common.ts | 20 + .../src/document/Document.ts | 0 .../src/document/DocumentNode.ts | 6 +- .../src/document/DocumentRepository.ts | 4 +- .../document/DocumentRepositoryRegistry.ts | 0 .../src/document/DocumentSchema.ts | 0 .../src/document/DocumentStorageSchema.ts | 0 .../src/document/StructuralParser.ts | 0 packages/dataset/src/node.ts | 7 + packages/dataset/src/types.ts | 7 + .../src/util/RepositorySchema.ts | 0 packages/dataset/tsconfig.json | 12 + packages/storage/src/common-server.ts | 4 +- packages/storage/src/common.ts | 18 +- .../storage/src/vector/ChunkVectorSchema.ts | 35 + .../src/vector/ChunkVectorStorageRegistry.ts | 83 ++ .../IChunkVectorStorage.ts | 0 .../InMemoryChunkVectorStorage.ts | 0 .../PostgresChunkVectorStorage.ts | 0 packages/storage/src/vector/README.md | 341 +++++ .../SqliteChunkVectorStorage.ts | 0 packages/test/package.json | 5 + .../test/src/test/rag/ChunkToVector.test.ts | 2 +- packages/test/src/test/rag/Document.test.ts | 4 +- .../rag/DocumentNodeRetrievalTask.test.ts | 2 +- .../rag/DocumentNodeVectorSearchTask.test.ts | 2 +- .../DocumentNodeVectorStoreUpsertTask.test.ts | 2 +- .../src/test/rag/DocumentRepository.test.ts | 4 +- packages/test/src/test/rag/EndToEnd.test.ts | 12 +- packages/test/src/test/rag/FullChain.test.ts | 34 +- .../src/test/rag/HierarchicalChunker.test.ts | 2 +- .../src/test/rag/HybridSearchTask.test.ts | 2 +- .../test/src/test/rag/RagWorkflow.test.ts | 4 +- .../src/test/rag/StructuralParser.test.ts | 53 +- .../src/test/task-graph/InputResolver.test.ts | 2 +- packages/test/src/test/util/Document.test.ts | 4 +- 62 files changed, 2777 insertions(+), 128 deletions(-) create mode 100644 packages/dataset/CHANGELOG.md create mode 100644 packages/dataset/LICENSE create mode 100644 packages/dataset/README.md create mode 100644 packages/dataset/package.json create mode 100644 packages/dataset/src/browser.ts create mode 100644 packages/dataset/src/bun.ts rename packages/{storage => dataset}/src/chunk-vector/ChunkVectorSchema.ts (100%) rename packages/{storage => dataset}/src/chunk-vector/ChunkVectorStorageRegistry.ts (100%) create mode 100644 packages/dataset/src/chunk-vector/IChunkVectorStorage.ts create mode 100644 packages/dataset/src/chunk-vector/InMemoryChunkVectorStorage.ts create mode 100644 packages/dataset/src/chunk-vector/PostgresChunkVectorStorage.ts rename packages/{storage => dataset}/src/chunk-vector/README.md (100%) create mode 100644 packages/dataset/src/chunk-vector/SqliteChunkVectorStorage.ts create mode 100644 packages/dataset/src/common-server.ts create mode 100644 packages/dataset/src/common.ts rename packages/{storage => dataset}/src/document/Document.ts (100%) rename packages/{storage => dataset}/src/document/DocumentNode.ts (96%) rename packages/{storage => dataset}/src/document/DocumentRepository.ts (97%) rename packages/{storage => dataset}/src/document/DocumentRepositoryRegistry.ts (100%) rename packages/{storage => dataset}/src/document/DocumentSchema.ts (100%) rename packages/{storage => dataset}/src/document/DocumentStorageSchema.ts (100%) rename packages/{storage => dataset}/src/document/StructuralParser.ts (100%) create mode 100644 packages/dataset/src/node.ts create mode 100644 packages/dataset/src/types.ts rename packages/{storage => dataset}/src/util/RepositorySchema.ts (100%) create mode 100644 packages/dataset/tsconfig.json create mode 100644 packages/storage/src/vector/ChunkVectorSchema.ts create mode 100644 packages/storage/src/vector/ChunkVectorStorageRegistry.ts rename packages/storage/src/{chunk-vector => vector}/IChunkVectorStorage.ts (100%) rename packages/storage/src/{chunk-vector => vector}/InMemoryChunkVectorStorage.ts (100%) rename packages/storage/src/{chunk-vector => vector}/PostgresChunkVectorStorage.ts (100%) create mode 100644 packages/storage/src/vector/README.md rename packages/storage/src/{chunk-vector => vector}/SqliteChunkVectorStorage.ts (100%) diff --git a/TODO.md b/TODO.md index 16a7a810..4edd6c9d 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,6 @@ TODO.md -- [ ] Rename repositories in the packages/storage to use the word Storage instead of Repository. +- [x] Rename repositories in the packages/storage to use the word Storage instead of Repository. - [ ] Vector Storage (not chunk storage) - [ ] Rename the files from packages/storage/src/vector-storage to packages/storage/src/vector - [ ] No fixed column names, use the schema to define the columns. diff --git a/bun.lock b/bun.lock index 62fb6c0f..41aaac7f 100644 --- a/bun.lock +++ b/bun.lock @@ -98,12 +98,14 @@ "name": "@workglow/ai", "version": "0.0.85", "devDependencies": { + "@workglow/dataset": "workspace:*", "@workglow/job-queue": "workspace:*", "@workglow/storage": "workspace:*", "@workglow/task-graph": "workspace:*", "@workglow/util": "workspace:*", }, "peerDependencies": { + "@workglow/dataset": "workspace:*", "@workglow/job-queue": "workspace:*", "@workglow/storage": "workspace:*", "@workglow/task-graph": "workspace:*", @@ -138,6 +140,18 @@ "@workglow/util": "workspace:*", }, }, + "packages/dataset": { + "name": "@workglow/dataset", + "version": "0.0.85", + "devDependencies": { + "@workglow/storage": "workspace:*", + "@workglow/util": "workspace:*", + }, + "peerDependencies": { + "@workglow/storage": "workspace:*", + "@workglow/util": "workspace:*", + }, + }, "packages/debug": { "name": "@workglow/debug", "version": "0.0.85", @@ -229,6 +243,7 @@ "@types/pg": "^8.15.5", "@workglow/ai": "workspace:*", "@workglow/ai-provider": "workspace:*", + "@workglow/dataset": "workspace:*", "@workglow/job-queue": "workspace:*", "@workglow/sqlite": "workspace:*", "@workglow/storage": "workspace:*", @@ -244,6 +259,7 @@ "@supabase/supabase-js": "^2.89.0", "@workglow/ai": "workspace:*", "@workglow/ai-provider": "workspace:*", + "@workglow/dataset": "workspace:*", "@workglow/job-queue": "workspace:*", "@workglow/sqlite": "workspace:*", "@workglow/storage": "workspace:*", @@ -737,6 +753,8 @@ "@workglow/cli": ["@workglow/cli@workspace:examples/cli"], + "@workglow/dataset": ["@workglow/dataset@workspace:packages/dataset"], + "@workglow/debug": ["@workglow/debug@workspace:packages/debug"], "@workglow/job-queue": ["@workglow/job-queue@workspace:packages/job-queue"], diff --git a/packages/ai/package.json b/packages/ai/package.json index 0ec91bf3..66c27682 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -36,12 +36,16 @@ "access": "public" }, "peerDependencies": { + "@workglow/dataset": "workspace:*", "@workglow/job-queue": "workspace:*", "@workglow/storage": "workspace:*", "@workglow/task-graph": "workspace:*", "@workglow/util": "workspace:*" }, "peerDependenciesMeta": { + "@workglow/dataset": { + "optional": false + }, "@workglow/job-queue": { "optional": false }, @@ -56,6 +60,7 @@ } }, "devDependencies": { + "@workglow/dataset": "workspace:*", "@workglow/job-queue": "workspace:*", "@workglow/storage": "workspace:*", "@workglow/task-graph": "workspace:*", diff --git a/packages/ai/src/task/ChunkRetrievalTask.ts b/packages/ai/src/task/ChunkRetrievalTask.ts index ee2ff5f7..5a6b0d49 100644 --- a/packages/ai/src/task/ChunkRetrievalTask.ts +++ b/packages/ai/src/task/ChunkRetrievalTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/storage"; +import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/ai/src/task/ChunkToVectorTask.ts b/packages/ai/src/task/ChunkToVectorTask.ts index c6a149c1..1e2640bd 100644 --- a/packages/ai/src/task/ChunkToVectorTask.ts +++ b/packages/ai/src/task/ChunkToVectorTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ChunkNodeSchema, type ChunkNode } from "@workglow/storage"; +import { ChunkNodeSchema, type ChunkNode } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/ai/src/task/ChunkVectorHybridSearchTask.ts b/packages/ai/src/task/ChunkVectorHybridSearchTask.ts index 61a8948f..e8b5a6d6 100644 --- a/packages/ai/src/task/ChunkVectorHybridSearchTask.ts +++ b/packages/ai/src/task/ChunkVectorHybridSearchTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/storage"; +import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/ai/src/task/ChunkVectorSearchTask.ts b/packages/ai/src/task/ChunkVectorSearchTask.ts index 45c6d3f6..c94263ec 100644 --- a/packages/ai/src/task/ChunkVectorSearchTask.ts +++ b/packages/ai/src/task/ChunkVectorSearchTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/storage"; +import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/ai/src/task/ChunkVectorUpsertTask.ts b/packages/ai/src/task/ChunkVectorUpsertTask.ts index f069efc8..bb5d1152 100644 --- a/packages/ai/src/task/ChunkVectorUpsertTask.ts +++ b/packages/ai/src/task/ChunkVectorUpsertTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/storage"; +import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/ai/src/task/DocumentEnricherTask.ts b/packages/ai/src/task/DocumentEnricherTask.ts index 12937f4d..c1900b0b 100644 --- a/packages/ai/src/task/DocumentEnricherTask.ts +++ b/packages/ai/src/task/DocumentEnricherTask.ts @@ -10,7 +10,7 @@ import { type DocumentNode, type Entity, type NodeEnrichment, -} from "@workglow/storage"; +} from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/ai/src/task/HierarchicalChunkerTask.ts b/packages/ai/src/task/HierarchicalChunkerTask.ts index 1a3d70fa..ed5b07ce 100644 --- a/packages/ai/src/task/HierarchicalChunkerTask.ts +++ b/packages/ai/src/task/HierarchicalChunkerTask.ts @@ -13,7 +13,7 @@ import { type ChunkNode, type DocumentNode, type TokenBudget, -} from "@workglow/storage"; +} from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/ai/src/task/HierarchyJoinTask.ts b/packages/ai/src/task/HierarchyJoinTask.ts index 09f0ccb5..2a44638b 100644 --- a/packages/ai/src/task/HierarchyJoinTask.ts +++ b/packages/ai/src/task/HierarchyJoinTask.ts @@ -4,12 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { DocumentRepository } from "@workglow/storage"; import { type ChunkMetadata, ChunkMetadataArraySchema, EnrichedChunkMetadataArraySchema, -} from "@workglow/storage"; + type DocumentRepository, +} from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/ai/src/task/StructuralParserTask.ts b/packages/ai/src/task/StructuralParserTask.ts index afb77abc..cf80ccf5 100644 --- a/packages/ai/src/task/StructuralParserTask.ts +++ b/packages/ai/src/task/StructuralParserTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DocumentNode, NodeIdGenerator, StructuralParser } from "@workglow/storage"; +import { DocumentNode, NodeIdGenerator, StructuralParser } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, diff --git a/packages/dataset/CHANGELOG.md b/packages/dataset/CHANGELOG.md new file mode 100644 index 00000000..e69de29b diff --git a/packages/dataset/LICENSE b/packages/dataset/LICENSE new file mode 100644 index 00000000..c9745e3b --- /dev/null +++ b/packages/dataset/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright 2025 Steven Roussey + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/packages/dataset/README.md b/packages/dataset/README.md new file mode 100644 index 00000000..f86bfa63 --- /dev/null +++ b/packages/dataset/README.md @@ -0,0 +1,1138 @@ +# @workglow/storage + +Modular storage solutions for Workglow.AI platform with multiple backend implementations. Provides consistent interfaces for key-value storage, tabular data storage, and job queue persistence. + +- [Quick Start](#quick-start) +- [Installation](#installation) +- [Core Concepts](#core-concepts) + - [Type Safety](#type-safety) + - [Environment Compatibility](#environment-compatibility) + - [Import Patterns](#import-patterns) +- [Storage Types](#storage-types) + - [Key-Value Storage](#key-value-storage) + - [Basic Usage](#basic-usage) + - [Environment-Specific Examples](#environment-specific-examples) + - [Bulk Operations](#bulk-operations) + - [Event Handling](#event-handling) + - [Tabular Storage](#tabular-storage) + - [Schema Definition](#schema-definition) + - [CRUD Operations](#crud-operations) + - [Bulk Operations](#bulk-operations-1) + - [Searching and Filtering](#searching-and-filtering) + - [Environment-Specific Tabular Storage](#environment-specific-tabular-storage) + - [Queue Storage](#queue-storage) + - [Basic Job Queue Operations](#basic-job-queue-operations) + - [Job Management](#job-management) +- [Environment-Specific Usage](#environment-specific-usage) + - [Browser Environment](#browser-environment) + - [Node.js Environment](#nodejs-environment) + - [Bun Environment](#bun-environment) +- [Advanced Features](#advanced-features) + - [Repository Registry](#repository-registry) + - [Event-Driven Architecture](#event-driven-architecture) + - [Compound Primary Keys](#compound-primary-keys) + - [Custom File Layout (KV on filesystem)](#custom-file-layout-kv-on-filesystem) +- [API Reference](#api-reference) + - [IKvStorage\](#ikvrepositorykey-value) + - [ITabularStorage\](#itabularrepositoryschema-primarykeynames) + - [IQueueStorage\](#iqueuestorageinput-output) +- [Examples](#examples) + - [User Management System](#user-management-system) + - [Configuration Management](#configuration-management) +- [Testing](#testing) + - [Writing Tests for Your Storage Usage](#writing-tests-for-your-storage-usage) +- [License](#license) + +## Quick Start + +```typescript +// Key-Value Storage (simple data) +import { InMemoryKvStorage } from "@workglow/storage"; + +const kvStore = new InMemoryKvStorage(); +await kvStore.put("user:123", { name: "Alice", age: 30 }); +const kvUser = await kvStore.get("user:123"); // { name: "Alice", age: 30 } +``` + +```typescript +// Tabular Storage (structured data with schemas) +import { InMemoryTabularStorage } from "@workglow/storage"; +import { JsonSchema } from "@workglow/util"; + +const userSchema = { + type: "object", + properties: { + id: { type: "string" }, + name: { type: "string" }, + email: { type: "string" }, + age: { type: "number" }, + }, + required: ["id", "name", "email", "age"], + additionalProperties: false, +} as const satisfies JsonSchema; + +const userRepo = new InMemoryTabularStorage( + userSchema, + ["id"], // primary key + ["email"] // additional indexes +); + +await userRepo.put({ id: "123", name: "Alice", email: "alice@example.com", age: 30 }); +const user = await userRepo.get({ id: "123" }); +``` + +## Installation + +```bash +# Using bun (recommended) +bun install @workglow/storage + +# Using npm +npm install @workglow/storage + +# Using yarn +yarn add @workglow/storage +``` + +## Core Concepts + +### Type Safety + +All storage implementations are fully typed using TypeScript and JSON Schema for runtime validation: + +```typescript +import { JsonSchema, FromSchema } from "@workglow/util"; + +// Define your data structure +const ProductSchema = { + type: "object", + properties: { + id: { type: "string" }, + name: { type: "string" }, + price: { type: "number" }, + category: { type: "string" }, + inStock: { type: "boolean" }, + }, + required: ["id", "name", "price", "category", "inStock"], + additionalProperties: false, +} as const satisfies JsonSchema; + +// TypeScript automatically infers: +// Entity = FromSchema +// PrimaryKey = { id: string } +``` + +### Environment Compatibility + +| Storage Type | Node.js | Bun | Browser | Persistence | +| ------------ | ------- | --- | ------- | ----------- | +| InMemory | ✅ | ✅ | ✅ | ❌ | +| IndexedDB | ❌ | ❌ | ✅ | ✅ | +| SQLite | ✅ | ✅ | ❌ | ✅ | +| PostgreSQL | ✅ | ✅ | ❌ | ✅ | +| Supabase | ✅ | ✅ | ✅ | ✅ | +| FileSystem | ✅ | ✅ | ❌ | ✅ | + +### Import Patterns + +The package uses conditional exports, so importing from `@workglow/storage` automatically selects the right build for your runtime (browser, Node.js, or Bun). + +```typescript +// Import from the top-level package; it resolves to the correct target per environment +import { InMemoryKvStorage, SqliteTabularStorage } from "@workglow/storage"; +``` + +## Storage Types + +### Key-Value Storage + +Simple key-value storage for unstructured or semi-structured data. + +#### Basic Usage + +```typescript +import { InMemoryKvStorage, FsFolderJsonKvRepository } from "@workglow/storage"; + +// In-memory (for testing/caching) +const cache = new InMemoryKvStorage(); +await cache.put("config", { theme: "dark", language: "en" }); + +// File-based JSON (persistent) +const settings = new FsFolderJsonKvRepository("./data/settings"); +await settings.put("user:preferences", { notifications: true }); +``` + +#### Environment-Specific Examples + +```typescript +// Browser (using IndexedDB) +import { IndexedDbKvRepository } from "@workglow/storage"; +const browserStore = new IndexedDbKvRepository("my-app-storage"); + +// Node.js/Bun (using SQLite) +import { SqliteKvRepository } from "@workglow/storage"; +// Pass a file path or a Database instance (see @workglow/sqlite) +const sqliteStore = new SqliteKvRepository("./data.db", "config_table"); + +// PostgreSQL (Node.js/Bun) +import { PostgresKvRepository } from "@workglow/storage"; +import { Pool } from "pg"; +const pool = new Pool({ connectionString: "postgresql://..." }); +const pgStore = new PostgresKvRepository(pool, "settings"); + +// Supabase (Node.js/Bun) +import { SupabaseKvRepository } from "@workglow/storage"; +import { createClient } from "@supabase/supabase-js"; +const supabase = createClient("https://your-project.supabase.co", "your-anon-key"); +const supabaseStore = new SupabaseKvRepository(supabase, "settings"); +``` + +#### Bulk Operations + +```typescript +const store = new InMemoryKvStorage(); + +// Bulk insert +await store.putBulk([ + { key: "player1", value: { name: "Alice", score: 100 } }, + { key: "player2", value: { name: "Bob", score: 85 } }, +]); + +// Get all data +const allPlayers = await store.getAll(); +// Result: [{ key: "player1", value: { name: "Alice", score: 100 } }, ...] + +// Get size +const count = await store.size(); // 2 +``` + +#### Event Handling + +```typescript +const store = new InMemoryKvStorage(); + +// Listen to storage events +store.on("put", (key, value) => { + console.log(`Stored: ${key} = ${JSON.stringify(value)}`); +}); + +store.on("get", (key, value) => { + console.log(`Retrieved: ${key} = ${value ? "found" : "not found"}`); +}); + +await store.put("test", { data: "example" }); // Triggers 'put' event +await store.get("test"); // Triggers 'get' event +``` + +### Tabular Storage + +Structured storage with schemas, primary keys, and indexing for complex data relationships. + +#### Schema Definition + +```typescript +import { JsonSchema } from "@workglow/util"; +import { InMemoryTabularStorage } from "@workglow/storage"; + +// Define your entity schema +const UserSchema = { + type: "object", + properties: { + id: { type: "string" }, + email: { type: "string" }, + name: { type: "string" }, + age: { type: "number" }, + department: { type: "string" }, + createdAt: { type: "string" }, + }, + required: ["id", "email", "name", "age", "department", "createdAt"], + additionalProperties: false, +} as const satisfies JsonSchema; + +// Create repository with primary key and indexes +const userRepo = new InMemoryTabularStorage( + UserSchema, + ["id"], // Primary key (can be compound: ["dept", "id"]) + ["email", "department", ["department", "age"]] // Indexes for fast lookups +); +``` + +#### CRUD Operations + +```typescript +// Create +await userRepo.put({ + id: "user_123", + email: "alice@company.com", + name: "Alice Johnson", + age: 28, + department: "Engineering", + createdAt: new Date().toISOString(), +}); + +// Read by primary key +const user = await userRepo.get({ id: "user_123" }); + +// Update (put with same primary key) +await userRepo.put({ + ...user!, + age: 29, // Birthday! +}); + +// Delete +await userRepo.delete({ id: "user_123" }); +``` + +#### Bulk Operations + +```typescript +// Bulk insert +await userRepo.putBulk([ + { + id: "1", + email: "alice@co.com", + name: "Alice", + age: 28, + department: "Engineering", + createdAt: "2024-01-01", + }, + { + id: "2", + email: "bob@co.com", + name: "Bob", + age: 32, + department: "Sales", + createdAt: "2024-01-02", + }, + { + id: "3", + email: "carol@co.com", + name: "Carol", + age: 26, + department: "Engineering", + createdAt: "2024-01-03", + }, +]); + +// Get all records +const allUsers = await userRepo.getAll(); + +// Get repository size +const userCount = await userRepo.size(); +``` + +#### Searching and Filtering + +```typescript +// Search by partial match (uses indexes when available) +const engineeringUsers = await userRepo.search({ department: "Engineering" }); +const adultUsers = await userRepo.search({ age: 25 }); // Exact match + +// Delete by search criteria (supports multiple columns) +await userRepo.deleteSearch({ department: "Sales" }); // Equality +await userRepo.deleteSearch({ age: { value: 65, operator: ">=" } }); // Delete users 65 and older + +// Multiple criteria (AND logic) +await userRepo.deleteSearch({ + department: "Sales", + age: { value: 30, operator: "<" }, +}); // Delete young Sales employees +``` + +#### Environment-Specific Tabular Storage + +```typescript +// SQLite (Node.js/Bun) +import { SqliteTabularStorage } from "@workglow/storage"; + +const sqliteUsers = new SqliteTabularStorage( + "./users.db", + "users", + UserSchema, + ["id"], + ["email"] +); + +// PostgreSQL (Node.js/Bun) +import { PostgresTabularStorage } from "@workglow/storage"; +import { Pool } from "pg"; + +const pool = new Pool({ connectionString: "postgresql://..." }); +const pgUsers = new PostgresTabularStorage( + pool, + "users", + UserSchema, + ["id"], + ["email"] +); + +// Supabase (Node.js/Bun) +import { SupabaseTabularRepository } from "@workglow/storage"; +import { createClient } from "@supabase/supabase-js"; + +const supabase = createClient("https://your-project.supabase.co", "your-anon-key"); +const supabaseUsers = new SupabaseTabularRepository( + supabase, + "users", + UserSchema, + ["id"], + ["email"] +); + +// IndexedDB (Browser) +import { IndexedDbTabularRepository } from "@workglow/storage"; +const browserUsers = new IndexedDbTabularRepository( + "users", + UserSchema, + ["id"], + ["email"] +); + +// File-based (Node.js/Bun) +import { FsFolderTabularRepository } from "@workglow/storage"; +const fileUsers = new FsFolderTabularRepository( + "./data/users", + UserSchema, + ["id"], + ["email"] +); +``` + +### Queue Storage + +Persistent job queue storage for background processing and task management. + +> **Note**: Queue storage is primarily used internally by the job queue system. Direct usage is for advanced scenarios. + +#### Basic Job Queue Operations + +```typescript +import { InMemoryQueueStorage, JobStatus } from "@workglow/storage"; + +// Define job input/output types +type ProcessingInput = { text: string; options: any }; +type ProcessingOutput = { result: string; metadata: any }; + +const jobQueue = new InMemoryQueueStorage(); + +// Add job to queue +const jobId = await jobQueue.add({ + input: { text: "Hello world", options: { uppercase: true } }, + run_after: null, // Run immediately + max_retries: 3, +}); + +// Get next job for processing +const job = await jobQueue.next(); +if (job) { + // Process the job... + const result = { result: "HELLO WORLD", metadata: { processed: true } }; + + // Mark as complete + await jobQueue.complete({ + ...job, + output: result, + status: JobStatus.COMPLETED, + }); +} +``` + +#### Job Management + +```typescript +// Check queue status +const pendingCount = await jobQueue.size(JobStatus.PENDING); +const processingCount = await jobQueue.size(JobStatus.PROCESSING); + +// Peek at jobs without removing them +const nextJobs = await jobQueue.peek(JobStatus.PENDING, 5); + +// Progress tracking +await jobQueue.saveProgress(jobId, 50, "Processing...", { step: 1 }); + +// Handle job failures +await jobQueue.abort(jobId); + +// Cleanup old completed jobs +await jobQueue.deleteJobsByStatusAndAge(JobStatus.COMPLETED, 24 * 60 * 60 * 1000); // 24 hours +``` + +## Environment-Specific Usage + +### Browser Environment + +```typescript +import { + IndexedDbKvRepository, + IndexedDbTabularRepository, + IndexedDbQueueStorage, + SupabaseKvRepository, + SupabaseTabularRepository, + SupabaseQueueStorage, +} from "@workglow/storage"; +import { createClient } from "@supabase/supabase-js"; + +// Local browser storage with IndexedDB +const settings = new IndexedDbKvRepository("app-settings"); +const userData = new IndexedDbTabularRepository("users", UserSchema, ["id"]); +const jobQueue = new IndexedDbQueueStorage("background-jobs"); + +// Or use Supabase for cloud storage from the browser +const supabase = createClient("https://your-project.supabase.co", "your-anon-key"); +const cloudSettings = new SupabaseKvRepository(supabase, "app-settings"); +const cloudUserData = new SupabaseTabularRepository(supabase, "users", UserSchema, ["id"]); +const cloudJobQueue = new SupabaseQueueStorage(supabase, "background-jobs"); +``` + +### Node.js Environment + +```typescript +import { + SqliteKvRepository, + PostgresTabularStorage, + FsFolderJsonKvRepository, +} from "@workglow/storage"; + +// Mix and match storage backends +const cache = new FsFolderJsonKvRepository("./cache"); +const users = new PostgresTabularStorage(pool, "users", UserSchema, ["id"]); +``` + +### Bun Environment + +```typescript +// Bun has access to all implementations +import { + SqliteTabularStorage, + FsFolderJsonKvRepository, + PostgresQueueStorage, + SupabaseTabularRepository, +} from "@workglow/storage"; + +import { Database } from "bun:sqlite"; +import { createClient } from "@supabase/supabase-js"; + +const db = new Database("./app.db"); +const data = new SqliteTabularStorage(db, "items", ItemSchema, ["id"]); + +// Or use Supabase for cloud storage +const supabase = createClient("https://your-project.supabase.co", "your-anon-key"); +const cloudData = new SupabaseTabularRepository(supabase, "items", ItemSchema, ["id"]); +``` + +## Advanced Features + +### Repository Registry + +Repositories can be registered globally by ID, allowing tasks to reference them by name rather than passing direct instances. This is useful for configuring repositories once at application startup and referencing them throughout your task graphs. + +#### Registering Repositories + +```typescript +import { + registerTabularRepository, + getTabularRepository, + InMemoryTabularStorage, +} from "@workglow/storage"; + +// Define your schema +const userSchema = { + type: "object", + properties: { + id: { type: "string" }, + name: { type: "string" }, + email: { type: "string" }, + }, + required: ["id", "name", "email"], + additionalProperties: false, +} as const; + +// Create and register a repository +const userRepo = new InMemoryTabularStorage(userSchema, ["id"] as const); +registerTabularRepository("users", userRepo); + +// Later, retrieve the repository by ID +const repo = getTabularRepository("users"); +``` + +#### Using Repositories in Tasks + +When using repositories with tasks, you can pass either the repository ID or a direct instance. The TaskRunner automatically resolves string IDs using the registry. + +```typescript +import { TypeTabularRepository } from "@workglow/storage"; + +// In your task's input schema, use TypeTabularRepository +static inputSchema() { + return { + type: "object", + properties: { + dataSource: TypeTabularRepository({ + title: "User Repository", + description: "Repository containing user records", + }), + }, + required: ["dataSource"], + }; +} + +// Both approaches work: +await task.run({ dataSource: "users" }); // Resolved from registry +await task.run({ dataSource: userRepoInstance }); // Direct instance +``` + +#### Schema Helper Functions + +The package provides schema helper functions for defining repository inputs with proper format annotations: + +```typescript +import { + TypeTabularRepository, + TypeVectorRepository, + TypeDocumentRepository, +} from "@workglow/storage"; + +// Tabular repository (format: "repository:tabular") +const tabularSchema = TypeTabularRepository({ + title: "Data Source", + description: "Tabular data repository", +}); + +// Vector repository (format: "repository:document-node-vector") +const vectorSchema = TypeVectorRepository({ + title: "Embeddings Store", + description: "Vector embeddings repository", +}); + +// Document repository (format: "repository:document") +const docSchema = TypeDocumentRepository({ + title: "Document Store", + description: "Document storage repository", +}); +``` + +### Event-Driven Architecture + +All storage implementations support event emission for monitoring and reactive programming: + +```typescript +const store = new InMemoryTabularStorage(UserSchema, ["id"]); + +// Monitor all operations +store.on("put", (entity) => console.log("User created/updated:", entity)); +store.on("delete", (key) => console.log("User deleted:", key)); +store.on("get", (key, entity) => console.log("User accessed:", entity ? "found" : "not found")); + +// Wait for specific events +const [entity] = await store.waitOn("put"); // Waits for next put operation +``` + +### Compound Primary Keys + +```typescript +import { JsonSchema } from "@workglow/util"; + +const OrderLineSchema = { + type: "object", + properties: { + orderId: { type: "string" }, + lineNumber: { type: "number" }, + productId: { type: "string" }, + quantity: { type: "number" }, + price: { type: "number" }, + }, + required: ["orderId", "lineNumber", "productId", "quantity", "price"], + additionalProperties: false, +} as const satisfies JsonSchema; + +const orderLines = new InMemoryTabularStorage( + OrderLineSchema, + ["orderId", "lineNumber"], // Compound primary key + ["productId"] // Additional index +); + +// Use compound keys +await orderLines.put({ + orderId: "ORD-123", + lineNumber: 1, + productId: "PROD-A", + quantity: 2, + price: 19.99, +}); +const line = await orderLines.get({ orderId: "ORD-123", lineNumber: 1 }); +``` + +### Custom File Layout (KV on filesystem) + +```typescript +import { FsFolderKvRepository } from "@workglow/storage"; +import { JsonSchema } from "@workglow/util"; + +// Control how keys map to file paths and value encoding via schemas +const keySchema = { type: "string" } as const satisfies JsonSchema; +const valueSchema = { type: "string" } as const satisfies JsonSchema; + +const files = new FsFolderKvRepository( + "./data/files", + (key) => `${key}.txt`, + keySchema, + valueSchema +); + +await files.put("note-1", "Hello world"); +``` + +## API Reference + +### IKvStorage + +Core interface for key-value storage: + +```typescript +interface IKvStorage { + // Core operations + put(key: Key, value: Value): Promise; + putBulk(items: Array<{ key: Key; value: Value }>): Promise; + get(key: Key): Promise; + delete(key: Key): Promise; + getAll(): Promise | undefined>; + deleteAll(): Promise; + size(): Promise; + + // Event handling + on(event: "put" | "get" | "getAll" | "delete" | "deleteall", callback: Function): void; + off(event: string, callback: Function): void; + once(event: string, callback: Function): void; + waitOn(event: string): Promise; + emit(event: string, ...args: any[]): void; +} +``` + +### ITabularStorage + +Core interface for tabular storage: + +```typescript +interface ITabularStorage { + // Core operations + put(entity: Entity): Promise; + putBulk(entities: Entity[]): Promise; + get(key: PrimaryKey): Promise; + delete(key: PrimaryKey | Entity): Promise; + getAll(): Promise; + deleteAll(): Promise; + size(): Promise; + + // Search operations + search(criteria: Partial): Promise; + deleteSearch(criteria: DeleteSearchCriteria): Promise; + + // Event handling + on(event: "put" | "get" | "search" | "delete" | "clearall", callback: Function): void; + off(event: string, callback: Function): void; + once(event: string, callback: Function): void; + waitOn(event: string): Promise; + emit(event: string, ...args: any[]): void; +} +``` + +#### DeleteSearchCriteria + +The `deleteSearch` method accepts a criteria object that supports multiple columns with optional comparison operators: + +```typescript +// Type definitions +type SearchOperator = "=" | "<" | "<=" | ">" | ">="; + +interface SearchCondition { + readonly value: T; + readonly operator: SearchOperator; +} + +type DeleteSearchCriteria = { + readonly [K in keyof Entity]?: Entity[K] | SearchCondition; +}; + +// Usage examples +// Equality match (direct value) +await repo.deleteSearch({ category: "electronics" }); + +// With comparison operator +await repo.deleteSearch({ createdAt: { value: date, operator: "<" } }); + +// Multiple criteria (AND logic) +await repo.deleteSearch({ + category: "electronics", + value: { value: 100, operator: ">=" }, +}); +``` + +### IQueueStorage + +Core interface for job queue storage: + +```typescript +interface IQueueStorage { + add(job: JobStorageFormat): Promise; + get(id: unknown): Promise | undefined>; + next(): Promise | undefined>; + complete(job: JobStorageFormat): Promise; + peek(status?: JobStatus, num?: number): Promise[]>; + size(status?: JobStatus): Promise; + abort(id: unknown): Promise; + saveProgress(id: unknown, progress: number, message: string, details: any): Promise; + deleteAll(): Promise; + getByRunId(runId: string): Promise>>; + outputForInput(input: Input): Promise; + delete(id: unknown): Promise; + deleteJobsByStatusAndAge(status: JobStatus, olderThanMs: number): Promise; +} +``` + +## Examples + +### User Management System + +```typescript +import { JsonSchema, FromSchema } from "@workglow/util"; +import { InMemoryTabularStorage, InMemoryKvStorage } from "@workglow/storage"; + +// User profile with tabular storage +const UserSchema = { + type: "object", + properties: { + id: { type: "string" }, + username: { type: "string" }, + email: { type: "string" }, + firstName: { type: "string" }, + lastName: { type: "string" }, + role: { + type: "string", + enum: ["admin", "user", "guest"], + }, + createdAt: { type: "string" }, + lastLoginAt: { type: "string" }, + }, + required: ["id", "username", "email", "firstName", "lastName", "role", "createdAt"], + additionalProperties: false, +} as const satisfies JsonSchema; + +const userRepo = new InMemoryTabularStorage( + UserSchema, + ["id"], + ["email", "username"] +); + +// User sessions with KV storage +const sessionStore = new InMemoryKvStorage(); + +// User management class +class UserManager { + constructor( + private userRepo: typeof userRepo, + private sessionStore: typeof sessionStore + ) {} + + async createUser(userData: Omit, "id" | "createdAt">) { + const user = { + ...userData, + id: crypto.randomUUID(), + createdAt: new Date().toISOString(), + }; + await this.userRepo.put(user); + return user; + } + + async loginUser(email: string): Promise { + const users = await this.userRepo.search({ email }); + if (!users?.length) throw new Error("User not found"); + + const sessionId = crypto.randomUUID(); + await this.sessionStore.put(sessionId, { + userId: users[0].id, + expiresAt: new Date(Date.now() + 24 * 60 * 60 * 1000).toISOString(), + }); + + // Update last login + await this.userRepo.put({ + ...users[0], + lastLoginAt: new Date().toISOString(), + }); + + return sessionId; + } + + async getSessionUser(sessionId: string) { + const session = await this.sessionStore.get(sessionId); + if (!session || new Date(session.expiresAt) < new Date()) { + return null; + } + return this.userRepo.get({ id: session.userId }); + } +} +``` + +### Configuration Management + +```typescript +// Application settings with typed configuration +type AppConfig = { + database: { + host: string; + port: number; + name: string; + }; + features: { + enableNewUI: boolean; + maxUploadSize: number; + }; + integrations: { + stripe: { apiKey: string; webhook: string }; + sendgrid: { apiKey: string }; + }; +}; + +const configStore = new FsFolderJsonKvRepository("./config"); + +class ConfigManager { + private cache = new Map(); + + constructor(private store: typeof configStore) { + // Listen for config changes + store.on("put", (key, value) => { + this.cache.set(key, value); + console.log(`Configuration updated: ${key}`); + }); + } + + async getConfig(environment: string): Promise { + if (this.cache.has(environment)) { + return this.cache.get(environment)!; + } + + const config = await this.store.get(environment); + if (!config) throw new Error(`No configuration for environment: ${environment}`); + + this.cache.set(environment, config); + return config; + } + + async updateConfig(environment: string, updates: Partial) { + const current = await this.getConfig(environment); + const updated = { ...current, ...updates }; + await this.store.put(environment, updated); + } +} +``` + +### Supabase Integration Example + +```typescript +import { createClient } from "@supabase/supabase-js"; +import { JsonSchema } from "@workglow/util"; +import { + SupabaseTabularRepository, + SupabaseKvRepository, + SupabaseQueueStorage, +} from "@workglow/storage"; + +// Initialize Supabase client +const supabase = createClient(process.env.SUPABASE_URL!, process.env.SUPABASE_ANON_KEY!); + +// Define schemas +const ProductSchema = { + type: "object", + properties: { + id: { type: "string" }, + name: { type: "string" }, + price: { type: "number" }, + category: { type: "string" }, + stock: { type: "number", minimum: 0 }, + createdAt: { type: "string", format: "date-time" }, + }, + required: ["id", "name", "price", "category", "stock", "createdAt"], + additionalProperties: false, +} as const satisfies JsonSchema; + +const OrderSchema = { + type: "object", + properties: { + id: { type: "string" }, + customerId: { type: "string" }, + productId: { type: "string" }, + quantity: { type: "number", minimum: 1 }, + status: { + type: "string", + enum: ["pending", "processing", "completed", "cancelled"], + }, + createdAt: { type: "string", format: "date-time" }, + }, + required: ["id", "customerId", "productId", "quantity", "status", "createdAt"], + additionalProperties: false, +} as const satisfies JsonSchema; + +// Create repositories +const products = new SupabaseTabularRepository( + supabase, + "products", + ProductSchema, + ["id"], + ["category", "name"] // Indexed columns for fast searching +); + +const orders = new SupabaseTabularRepository( + supabase, + "orders", + OrderSchema, + ["id"], + ["customerId", "status", ["customerId", "status"]] // Compound index +); + +// Use KV for caching +const cache = new SupabaseKvRepository(supabase, "cache"); + +// Use queue for background processing +type EmailJob = { to: string; subject: string; body: string }; +const emailQueue = new SupabaseQueueStorage(supabase, "emails"); + +// Example usage +async function createOrder(customerId: string, productId: string, quantity: number) { + // Check product availability + const product = await products.get({ id: productId }); + if (!product || product.stock < quantity) { + throw new Error("Insufficient stock"); + } + + // Create order + const order = { + id: crypto.randomUUID(), + customerId, + productId, + quantity, + status: "pending" as const, + createdAt: new Date().toISOString(), + }; + await orders.put(order); + + // Update stock + await products.put({ + ...product, + stock: product.stock - quantity, + }); + + // Queue email notification + await emailQueue.add({ + input: { + to: customerId, + subject: "Order Confirmation", + body: `Your order ${order.id} has been confirmed!`, + }, + run_after: null, + max_retries: 3, + }); + + return order; +} + +// Get customer's orders +async function getCustomerOrders(customerId: string) { + return await orders.search({ customerId }); +} + +// Get orders by status +async function getOrdersByStatus(status: string) { + return await orders.search({ status }); +} +``` + +**Important Note** +The implementations assume you have an exec_sql RPC function in your Supabase database for table creation, or that you've created the tables through Supabase migrations. For production use, it's recommended to: + +- Create tables using Supabase migrations rather than runtime table creation +- Set up proper Row Level Security (RLS) policies in Supabase +- Use service role keys for server-side operations that need elevated permissions + +## Testing + +The package includes comprehensive test suites for all storage implementations: + +```bash +# Run all tests +bun test + +# Run specific test suites +bun test --grep "KvRepository" +bun test --grep "TabularRepository" +bun test --grep "QueueStorage" + +# Test specific environments +bun test --grep "InMemory" # Cross-platform tests +bun test --grep "IndexedDb" # Browser tests +bun test --grep "Sqlite" # Native tests +``` + +### Writing Tests for Your Storage Usage + +```typescript +import { describe, test, expect, beforeEach } from "vitest"; +import { InMemoryTabularStorage } from "@workglow/storage"; + +describe("UserRepository", () => { + let userRepo: InMemoryTabularStorage; + + beforeEach(() => { + userRepo = new InMemoryTabularStorage( + UserSchema, + ["id"], + ["email"] + ); + }); + + test("should create and retrieve user", async () => { + const user = { + id: "test-123", + email: "test@example.com", + name: "Test User", + age: 25, + department: "Engineering", + createdAt: new Date().toISOString(), + }; + + await userRepo.put(user); + const retrieved = await userRepo.get({ id: "test-123" }); + + expect(retrieved).toEqual(user); + }); + + test("should find users by department", async () => { + const users = [ + { + id: "1", + email: "alice@co.com", + name: "Alice", + age: 28, + department: "Engineering", + createdAt: "2024-01-01", + }, + { + id: "2", + email: "bob@co.com", + name: "Bob", + age: 32, + department: "Sales", + createdAt: "2024-01-02", + }, + ]; + + await userRepo.putBulk(users); + const engineers = await userRepo.search({ department: "Engineering" }); + + expect(engineers).toHaveLength(1); + expect(engineers![0].name).toBe("Alice"); + }); +}); +``` + +## License + +Apache 2.0 - See [LICENSE](./LICENSE) for details diff --git a/packages/dataset/package.json b/packages/dataset/package.json new file mode 100644 index 00000000..9bd8d372 --- /dev/null +++ b/packages/dataset/package.json @@ -0,0 +1,54 @@ +{ + "name": "@workglow/dataset", + "type": "module", + "version": "0.0.85", + "description": "Dataset package for Workglow.", + "scripts": { + "watch": "concurrently -c 'auto' 'bun:watch-*'", + "watch-browser": "bun build --watch --no-clear-screen --target=browser --sourcemap=external --packages=external --outdir ./dist ./src/browser.ts", + "watch-node": "bun build --watch --no-clear-screen --target=node --sourcemap=external --packages=external --outdir ./dist ./src/node.ts", + "watch-bun": "bun build --watch --no-clear-screen --target=bun --sourcemap=external --packages=external --outdir ./dist ./src/bun.ts", + "watch-types": "tsc --watch --preserveWatchOutput", + "build-package": "bun run build-clean && concurrently -c 'auto' -n 'browser,node,bun,types' 'bun run build-browser' 'bun run build-node' 'bun run build-bun' 'bun run build-types'", + "build-clean": "rm -fr dist/*", + "build-browser": "bun build --target=browser --sourcemap=external --packages=external --outdir ./dist ./src/browser.ts", + "build-node": "bun build --target=node --sourcemap=external --packages=external --outdir ./dist ./src/node.ts", + "build-bun": "bun build --target=bun --sourcemap=external --packages=external --outdir ./dist ./src/bun.ts", + "build-types": "rm -f tsconfig.tsbuildinfo && tsc", + "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", + "test": "bun test", + "prepare": "node -e \"const pkg=require('./package.json');pkg.exports['.'].bun='./dist/bun.js';pkg.exports['.'].types='./dist/types.d.ts';require('fs').writeFileSync('package.json',JSON.stringify(pkg,null,2))\"" + }, + "peerDependencies": { + "@workglow/storage": "workspace:*", + "@workglow/util": "workspace:*" + }, + "peerDependenciesMeta": { + "@workglow/storage": { + "optional": false + }, + "@workglow/util": { + "optional": false + } + }, + "devDependencies": { + "@workglow/storage": "workspace:*", + "@workglow/util": "workspace:*" + }, + "exports": { + ".": { + "react-native": "./dist/browser.js", + "browser": "./dist/browser.js", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", + "import": "./dist/node.js" + } + }, + "files": [ + "dist", + "src/**/*.md" + ], + "publishConfig": { + "access": "public" + } +} \ No newline at end of file diff --git a/packages/dataset/src/browser.ts b/packages/dataset/src/browser.ts new file mode 100644 index 00000000..56ad4a60 --- /dev/null +++ b/packages/dataset/src/browser.ts @@ -0,0 +1,7 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from "./common"; diff --git a/packages/dataset/src/bun.ts b/packages/dataset/src/bun.ts new file mode 100644 index 00000000..66ce35c0 --- /dev/null +++ b/packages/dataset/src/bun.ts @@ -0,0 +1,7 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from "./common-server"; diff --git a/packages/storage/src/chunk-vector/ChunkVectorSchema.ts b/packages/dataset/src/chunk-vector/ChunkVectorSchema.ts similarity index 100% rename from packages/storage/src/chunk-vector/ChunkVectorSchema.ts rename to packages/dataset/src/chunk-vector/ChunkVectorSchema.ts diff --git a/packages/storage/src/chunk-vector/ChunkVectorStorageRegistry.ts b/packages/dataset/src/chunk-vector/ChunkVectorStorageRegistry.ts similarity index 100% rename from packages/storage/src/chunk-vector/ChunkVectorStorageRegistry.ts rename to packages/dataset/src/chunk-vector/ChunkVectorStorageRegistry.ts diff --git a/packages/dataset/src/chunk-vector/IChunkVectorStorage.ts b/packages/dataset/src/chunk-vector/IChunkVectorStorage.ts new file mode 100644 index 00000000..02ffd364 --- /dev/null +++ b/packages/dataset/src/chunk-vector/IChunkVectorStorage.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + DataPortSchemaObject, + EventParameters, + FromSchema, + TypedArray, + TypedArraySchemaOptions, +} from "@workglow/util"; +import type { ITabularStorage, TabularEventListeners } from "@workglow/storage"; + +export type AnyChunkVectorStorage = IChunkVectorStorage; + +/** + * Options for vector search operations + */ +export interface VectorSearchOptions> { + readonly topK?: number; + readonly filter?: Partial; + readonly scoreThreshold?: number; +} + +/** + * Options for hybrid search (vector + full-text) + */ +export interface HybridSearchOptions< + Metadata = Record, +> extends VectorSearchOptions { + readonly textQuery: string; + readonly vectorWeight?: number; +} + +/** + * Type definitions for document chunk vector repository events + */ +export interface VectorChunkEventListeners extends TabularEventListeners< + PrimaryKey, + Entity +> { + similaritySearch: (query: TypedArray, results: (Entity & { score: number })[]) => void; + hybridSearch: (query: TypedArray, results: (Entity & { score: number })[]) => void; +} + +export type VectorChunkEventName = keyof VectorChunkEventListeners; +export type VectorChunkEventListener< + Event extends VectorChunkEventName, + PrimaryKey, + Entity, +> = VectorChunkEventListeners[Event]; + +export type VectorChunkEventParameters< + Event extends VectorChunkEventName, + PrimaryKey, + Entity, +> = EventParameters, Event>; + +/** + * Interface defining the contract for document chunk vector storage repositories. + * These repositories store vector embeddings with metadata for decument chunks. + * Extends ITabularRepository to provide standard storage operations, + * plus vector-specific similarity search capabilities. + * Supports various vector types including quantized formats. + * + * @typeParam Schema - The schema definition for the entity using JSON Schema + * @typeParam PrimaryKeyNames - Array of property names that form the primary key + * @typeParam Entity - The entity type + */ +export interface IChunkVectorStorage< + Schema extends DataPortSchemaObject, + PrimaryKeyNames extends ReadonlyArray, + Entity = FromSchema, +> extends ITabularStorage { + /** + * Get the vector dimension + * @returns The vector dimension + */ + getVectorDimensions(): number; + + /** + * Search for similar vectors using similarity scoring + * @param query - Query vector to compare against + * @param options - Search options (topK, filter, scoreThreshold) + * @returns Array of search results sorted by similarity (highest first) + */ + similaritySearch( + query: TypedArray, + options?: VectorSearchOptions> + ): Promise<(Entity & { score: number })[]>; + + /** + * Hybrid search combining vector similarity with full-text search + * This is optional and may not be supported by all implementations + * @param query - Query vector to compare against + * @param options - Hybrid search options including text query + * @returns Array of search results sorted by combined relevance + */ + hybridSearch?( + query: TypedArray, + options: HybridSearchOptions> + ): Promise<(Entity & { score: number })[]>; +} diff --git a/packages/dataset/src/chunk-vector/InMemoryChunkVectorStorage.ts b/packages/dataset/src/chunk-vector/InMemoryChunkVectorStorage.ts new file mode 100644 index 00000000..46afb138 --- /dev/null +++ b/packages/dataset/src/chunk-vector/InMemoryChunkVectorStorage.ts @@ -0,0 +1,185 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TypedArray } from "@workglow/util"; +import { cosineSimilarity } from "@workglow/util"; +import { InMemoryTabularStorage } from "@workglow/storage"; +import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; +import type { + HybridSearchOptions, + IChunkVectorStorage, + VectorSearchOptions, +} from "./IChunkVectorStorage"; + +/** + * Check if metadata matches filter + */ +function matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; +} + +/** + * Simple full-text search scoring (keyword matching) + */ +function textRelevance(text: string, query: string): number { + const textLower = text.toLowerCase(); + const queryLower = query.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + if (queryWords.length === 0) { + return 0; + } + let matches = 0; + for (const word of queryWords) { + if (textLower.includes(word)) { + matches++; + } + } + return matches / queryWords.length; +} + +/** + * In-memory document chunk vector repository implementation. + * Extends InMemoryTabularRepository for storage. + * Suitable for testing and small-scale browser applications. + * Supports all vector types including quantized formats. + * + * @template Metadata - The metadata type for the document chunk + * @template Vector - The vector type for the document chunk + */ +export class InMemoryChunkVectorStorage< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> + extends InMemoryTabularStorage< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector + > + implements + IChunkVectorStorage< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector + > +{ + private vectorDimensions: number; + private VectorType: new (array: number[]) => TypedArray; + + /** + * Creates a new in-memory document chunk vector repository + * @param dimensions - The number of dimensions of the vector + * @param VectorType - The type of vector to use (defaults to Float32Array) + */ + constructor(dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array) { + super(ChunkVectorSchema, ChunkVectorKey); + + this.vectorDimensions = dimensions; + this.VectorType = VectorType; + } + + /** + * Get the vector dimensions + * @returns The vector dimensions + */ + getVectorDimensions(): number { + return this.vectorDimensions; + } + + async similaritySearch( + query: TypedArray, + options: VectorSearchOptions> = {} + ) { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const results: Array & { score: number }> = []; + + const allEntities = (await this.getAll()) || []; + + for (const entity of allEntities) { + const vector = entity.vector; + const metadata = entity.metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate similarity + const score = cosineSimilarity(query, vector); + + // Apply threshold + if (score < scoreThreshold) { + continue; + } + + results.push({ + ...entity, + vector, + score, + }); + } + + // Sort by score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } + + async hybridSearch(query: TypedArray, options: HybridSearchOptions>) { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + // Fall back to regular vector search if no text query + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + const results: Array & { score: number }> = []; + const allEntities = (await this.getAll()) || []; + + for (const entity of allEntities) { + // In memory, vectors are stored as TypedArrays directly (not serialized) + const vector = entity.vector; + const metadata = entity.metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate vector similarity + const vectorScore = cosineSimilarity(query, vector); + + // Calculate text relevance (simple keyword matching) + const metadataText = Object.values(metadata).join(" ").toLowerCase(); + const textScore = textRelevance(metadataText, textQuery); + + // Combine scores + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + // Apply threshold + if (combinedScore < scoreThreshold) { + continue; + } + + results.push({ + ...entity, + vector, + score: combinedScore, + }); + } + + // Sort by combined score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } +} diff --git a/packages/dataset/src/chunk-vector/PostgresChunkVectorStorage.ts b/packages/dataset/src/chunk-vector/PostgresChunkVectorStorage.ts new file mode 100644 index 00000000..5f97b5ae --- /dev/null +++ b/packages/dataset/src/chunk-vector/PostgresChunkVectorStorage.ts @@ -0,0 +1,293 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { cosineSimilarity, type TypedArray } from "@workglow/util"; +import type { Pool } from "pg"; +import { PostgresTabularStorage } from "@workglow/storage"; +import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; +import type { + HybridSearchOptions, + IChunkVectorStorage, + VectorSearchOptions, +} from "./IChunkVectorStorage"; + +/** + * PostgreSQL document chunk vector repository implementation using pgvector extension. + * Extends PostgresTabularRepository for storage. + * Provides efficient vector similarity search with native database support. + * + * Requirements: + * - PostgreSQL database with pgvector extension installed + * - CREATE EXTENSION vector; + * + * @template Metadata - The metadata type for the document chunk + * @template Vector - The vector type for the document chunk + */ +export class PostgresChunkVectorStorage< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> + extends PostgresTabularStorage< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector + > + implements + IChunkVectorStorage< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector + > +{ + private vectorDimensions: number; + private VectorType: new (array: number[]) => TypedArray; + /** + * Creates a new PostgreSQL document chunk vector repository + * @param db - PostgreSQL connection pool + * @param table - The name of the table to use for storage + * @param dimensions - The number of dimensions of the vector + * @param VectorType - The type of vector to use (defaults to Float32Array) + */ + constructor( + db: Pool, + table: string, + dimensions: number, + VectorType: new (array: number[]) => TypedArray = Float32Array + ) { + super(db, table, ChunkVectorSchema, ChunkVectorKey); + + this.vectorDimensions = dimensions; + this.VectorType = VectorType; + } + + getVectorDimensions(): number { + return this.vectorDimensions; + } + + async similaritySearch( + query: TypedArray, + options: VectorSearchOptions = {} + ): Promise & { score: number }>> { + const { topK = 10, filter, scoreThreshold = 0 } = options; + + try { + // Try native pgvector search first + const queryVector = `[${Array.from(query).join(",")}]`; + let sql = ` + SELECT + *, + 1 - (vector <=> $1::vector) as score + FROM "${this.table}" + `; + + const params: any[] = [queryVector]; + let paramIndex = 2; + + if (filter && Object.keys(filter).length > 0) { + const conditions: string[] = []; + for (const [key, value] of Object.entries(filter)) { + conditions.push(`metadata->>'${key}' = $${paramIndex}`); + params.push(String(value)); + paramIndex++; + } + sql += ` WHERE ${conditions.join(" AND ")}`; + } + + if (scoreThreshold > 0) { + sql += filter ? " AND" : " WHERE"; + sql += ` (1 - (vector <=> $1::vector)) >= $${paramIndex}`; + params.push(scoreThreshold); + paramIndex++; + } + + sql += ` ORDER BY vector <=> $1::vector LIMIT $${paramIndex}`; + params.push(topK); + + const result = await this.db.query(sql, params); + + // Fetch vectors separately for each result + const results: Array & { score: number }> = []; + for (const row of result.rows) { + const vectorResult = await this.db.query( + `SELECT vector::text FROM "${this.table}" WHERE id = $1`, + [row.id] + ); + const vectorStr = vectorResult.rows[0]?.vector || "[]"; + const vectorArray = JSON.parse(vectorStr); + + results.push({ + ...row, + vector: new this.VectorType(vectorArray), + score: parseFloat(row.score), + } as any); + } + + return results; + } catch (error) { + // Fall back to in-memory similarity calculation if pgvector is not available + console.warn("pgvector query failed, falling back to in-memory search:", error); + return this.searchFallback(query, options); + } + } + + async hybridSearch(query: TypedArray, options: HybridSearchOptions) { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + try { + // Try native hybrid search with pgvector + full-text + const queryVector = `[${Array.from(query).join(",")}]`; + const tsQuery = textQuery.split(/\s+/).join(" & "); + + let sql = ` + SELECT + *, + ( + $2 * (1 - (vector <=> $1::vector)) + + $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) + ) as score + FROM "${this.table}" + `; + + const params: any[] = [queryVector, vectorWeight, 1 - vectorWeight, tsQuery]; + let paramIndex = 5; + + if (filter && Object.keys(filter).length > 0) { + const conditions: string[] = []; + for (const [key, value] of Object.entries(filter)) { + conditions.push(`metadata->>'${key}' = $${paramIndex}`); + params.push(String(value)); + paramIndex++; + } + sql += ` WHERE ${conditions.join(" AND ")}`; + } + + if (scoreThreshold > 0) { + sql += filter ? " AND" : " WHERE"; + sql += ` ( + $2 * (1 - (vector <=> $1::vector)) + + $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) + ) >= $${paramIndex}`; + params.push(scoreThreshold); + paramIndex++; + } + + sql += ` ORDER BY score DESC LIMIT $${paramIndex}`; + params.push(topK); + + const result = await this.db.query(sql, params); + + // Fetch vectors separately for each result + const results: Array & { score: number }> = []; + for (const row of result.rows) { + const vectorResult = await this.db.query( + `SELECT vector::text FROM "${this.table}" WHERE id = $1`, + [row.id] + ); + const vectorStr = vectorResult.rows[0]?.vector || "[]"; + const vectorArray = JSON.parse(vectorStr); + + results.push({ + ...row, + vector: new this.VectorType(vectorArray), + score: parseFloat(row.score), + } as any); + } + + return results; + } catch (error) { + // Fall back to in-memory hybrid search + console.warn("pgvector hybrid query failed, falling back to in-memory search:", error); + return this.hybridSearchFallback(query, options); + } + } + + /** + * Fallback search using in-memory cosine similarity + */ + private async searchFallback(query: TypedArray, options: VectorSearchOptions) { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const allRows = (await this.getAll()) || []; + const results: Array & { score: number }> = []; + + for (const row of allRows) { + const vector = row.vector; + const metadata = row.metadata; + + if (filter && !this.matchesFilter(metadata, filter)) { + continue; + } + + const score = cosineSimilarity(query, vector); + + if (score >= scoreThreshold) { + results.push({ ...row, vector, score }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } + + /** + * Fallback hybrid search + */ + private async hybridSearchFallback(query: TypedArray, options: HybridSearchOptions) { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + const allRows = (await this.getAll()) || []; + const results: Array & { score: number }> = []; + const queryLower = textQuery.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + + for (const row of allRows) { + const vector = row.vector; + const metadata = row.metadata; + + if (filter && !this.matchesFilter(metadata, filter)) { + continue; + } + + const vectorScore = cosineSimilarity(query, vector); + const metadataText = JSON.stringify(metadata).toLowerCase(); + let textScore = 0; + if (queryWords.length > 0) { + let matches = 0; + for (const word of queryWords) { + if (metadataText.includes(word)) { + matches++; + } + } + textScore = matches / queryWords.length; + } + + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + if (combinedScore >= scoreThreshold) { + results.push({ ...row, vector, score: combinedScore }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } + + private matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; + } +} diff --git a/packages/storage/src/chunk-vector/README.md b/packages/dataset/src/chunk-vector/README.md similarity index 100% rename from packages/storage/src/chunk-vector/README.md rename to packages/dataset/src/chunk-vector/README.md diff --git a/packages/dataset/src/chunk-vector/SqliteChunkVectorStorage.ts b/packages/dataset/src/chunk-vector/SqliteChunkVectorStorage.ts new file mode 100644 index 00000000..a4bcbb15 --- /dev/null +++ b/packages/dataset/src/chunk-vector/SqliteChunkVectorStorage.ts @@ -0,0 +1,192 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Sqlite } from "@workglow/sqlite"; +import type { TypedArray } from "@workglow/util"; +import { cosineSimilarity } from "@workglow/util"; +import { SqliteTabularStorage } from "@workglow/storage"; +import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; +import type { + HybridSearchOptions, + IChunkVectorStorage, + VectorSearchOptions, +} from "./IChunkVectorStorage"; + +/** + * Check if metadata matches filter + */ +function matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; +} + +/** + * SQLite document chunk vector repository implementation using tabular storage underneath. + * Stores vectors as JSON-encoded arrays with metadata. + * + * @template Metadata - The metadata type for the document chunk + * @template Vector - The vector type for the document chunk + */ +export class SqliteChunkVectorStorage< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> + extends SqliteTabularStorage< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector + > + implements + IChunkVectorStorage< + typeof ChunkVectorSchema, + typeof ChunkVectorKey, + ChunkVector + > +{ + private vectorDimensions: number; + private VectorType: new (array: number[]) => TypedArray; + + /** + * Creates a new SQLite document chunk vector repository + * @param dbOrPath - Either a Database instance or a path to the SQLite database file + * @param table - The name of the table to use for storage (defaults to 'vectors') + * @param dimensions - The number of dimensions of the vector + * @param VectorType - The type of vector to use (defaults to Float32Array) + */ + constructor( + dbOrPath: string | Sqlite.Database, + table: string = "vectors", + dimensions: number, + VectorType: new (array: number[]) => TypedArray = Float32Array + ) { + super(dbOrPath, table, ChunkVectorSchema, ChunkVectorKey); + + this.vectorDimensions = dimensions; + this.VectorType = VectorType; + } + + getVectorDimensions(): number { + return this.vectorDimensions; + } + + /** + * Deserialize vector from JSON string + * Defaults to Float32Array for compatibility with typical embedding vectors + */ + private deserializeVector(vectorJson: string): TypedArray { + const array = JSON.parse(vectorJson); + // Default to Float32Array for typical use case (embeddings) + return new this.VectorType(array); + } + + async similaritySearch(query: TypedArray, options: VectorSearchOptions = {}) { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const results: Array & { score: number }> = []; + + const allEntities = (await this.getAll()) || []; + + for (const entity of allEntities) { + // SQLite stores vectors as JSON strings, need to deserialize + const vectorRaw = entity.vector as unknown as string; + const vector = this.deserializeVector(vectorRaw); + const metadata = entity.metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate similarity + const score = cosineSimilarity(query, vector); + + // Apply threshold + if (score < scoreThreshold) { + continue; + } + + results.push({ + ...entity, + vector, + score, + } as any); + } + + // Sort by score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } + + async hybridSearch(query: TypedArray, options: HybridSearchOptions) { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + // Fall back to regular vector search if no text query + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + const results: Array & { score: number }> = []; + const allEntities = (await this.getAll()) || []; + const queryLower = textQuery.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + + for (const entity of allEntities) { + // SQLite stores vectors as JSON strings, need to deserialize + const vectorRaw = entity.vector as unknown as string; + const vector = + typeof vectorRaw === "string" + ? this.deserializeVector(vectorRaw) + : (vectorRaw as TypedArray); + const metadata = entity.metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate vector similarity + const vectorScore = cosineSimilarity(query, vector); + + // Calculate text relevance (simple keyword matching) + const metadataText = JSON.stringify(metadata).toLowerCase(); + let textScore = 0; + if (queryWords.length > 0) { + let matches = 0; + for (const word of queryWords) { + if (metadataText.includes(word)) { + matches++; + } + } + textScore = matches / queryWords.length; + } + + // Combine scores + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + // Apply threshold + if (combinedScore < scoreThreshold) { + continue; + } + + results.push({ + ...entity, + vector, + score: combinedScore, + } as any); + } + + // Sort by combined score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + return topResults; + } +} diff --git a/packages/dataset/src/common-server.ts b/packages/dataset/src/common-server.ts new file mode 100644 index 00000000..56ad4a60 --- /dev/null +++ b/packages/dataset/src/common-server.ts @@ -0,0 +1,7 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from "./common"; diff --git a/packages/dataset/src/common.ts b/packages/dataset/src/common.ts new file mode 100644 index 00000000..2cd1e8a7 --- /dev/null +++ b/packages/dataset/src/common.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from "./util/RepositorySchema"; + +export * from "./document/Document"; +export * from "./document/DocumentNode"; +export * from "./document/DocumentRepository"; +export * from "./document/DocumentRepositoryRegistry"; +export * from "./document/DocumentSchema"; +export * from "./document/DocumentStorageSchema"; +export * from "./document/StructuralParser"; + +export * from "./chunk-vector/ChunkVectorSchema"; +export * from "./chunk-vector/ChunkVectorStorageRegistry"; +export * from "./chunk-vector/IChunkVectorStorage"; +export * from "./chunk-vector/InMemoryChunkVectorStorage"; diff --git a/packages/storage/src/document/Document.ts b/packages/dataset/src/document/Document.ts similarity index 100% rename from packages/storage/src/document/Document.ts rename to packages/dataset/src/document/Document.ts diff --git a/packages/storage/src/document/DocumentNode.ts b/packages/dataset/src/document/DocumentNode.ts similarity index 96% rename from packages/storage/src/document/DocumentNode.ts rename to packages/dataset/src/document/DocumentNode.ts index 6fccdd2b..dd68a896 100644 --- a/packages/storage/src/document/DocumentNode.ts +++ b/packages/dataset/src/document/DocumentNode.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { sha256 } from "@workglow/util"; +import { sha256, uuid4 } from "@workglow/util"; import { NodeKind, @@ -24,6 +24,7 @@ export class NodeIdGenerator { * Generate doc_id from source URI and content hash */ static async generateDocId(sourceUri: string, content: string): Promise { + return uuid4(); const contentHash = await sha256(content); const combined = `${sourceUri}|${contentHash}`; const hash = await sha256(combined); @@ -38,6 +39,7 @@ export class NodeIdGenerator { kind: NodeKindType, range: NodeRange ): Promise { + return uuid4(); const combined = `${doc_id}|${kind}|${range.startOffset}:${range.endOffset}`; const hash = await sha256(combined); return `node_${hash.substring(0, 16)}`; @@ -47,6 +49,7 @@ export class NodeIdGenerator { * Generate nodeId for child nodes (paragraph, topic) */ static async generateChildNodeId(parentNodeId: string, ordinal: number): Promise { + return uuid4(); const combined = `${parentNodeId}|${ordinal}`; const hash = await sha256(combined); return `node_${hash.substring(0, 16)}`; @@ -60,6 +63,7 @@ export class NodeIdGenerator { leafNodeId: string, chunkOrdinal: number ): Promise { + return uuid4(); const combined = `${doc_id}|${leafNodeId}|${chunkOrdinal}`; const hash = await sha256(combined); return `chunk_${hash.substring(0, 16)}`; diff --git a/packages/storage/src/document/DocumentRepository.ts b/packages/dataset/src/document/DocumentRepository.ts similarity index 97% rename from packages/storage/src/document/DocumentRepository.ts rename to packages/dataset/src/document/DocumentRepository.ts index 746a16f6..bc92afbb 100644 --- a/packages/storage/src/document/DocumentRepository.ts +++ b/packages/dataset/src/document/DocumentRepository.ts @@ -10,7 +10,7 @@ import type { AnyChunkVectorStorage, VectorSearchOptions, } from "../chunk-vector/IChunkVectorStorage"; -import type { ITabularStorage } from "../tabular/ITabularStorage"; +import type { ITabularStorage } from "@workglow/storage"; import { Document } from "./Document"; import { ChunkNode, DocumentNode } from "./DocumentSchema"; import { @@ -204,7 +204,7 @@ export class DocumentRepository { if (!entities) { return []; } - return entities.map((e) => e.doc_id); + return entities.map((e: DocumentStorageEntity) => e.doc_id); } /** diff --git a/packages/storage/src/document/DocumentRepositoryRegistry.ts b/packages/dataset/src/document/DocumentRepositoryRegistry.ts similarity index 100% rename from packages/storage/src/document/DocumentRepositoryRegistry.ts rename to packages/dataset/src/document/DocumentRepositoryRegistry.ts diff --git a/packages/storage/src/document/DocumentSchema.ts b/packages/dataset/src/document/DocumentSchema.ts similarity index 100% rename from packages/storage/src/document/DocumentSchema.ts rename to packages/dataset/src/document/DocumentSchema.ts diff --git a/packages/storage/src/document/DocumentStorageSchema.ts b/packages/dataset/src/document/DocumentStorageSchema.ts similarity index 100% rename from packages/storage/src/document/DocumentStorageSchema.ts rename to packages/dataset/src/document/DocumentStorageSchema.ts diff --git a/packages/storage/src/document/StructuralParser.ts b/packages/dataset/src/document/StructuralParser.ts similarity index 100% rename from packages/storage/src/document/StructuralParser.ts rename to packages/dataset/src/document/StructuralParser.ts diff --git a/packages/dataset/src/node.ts b/packages/dataset/src/node.ts new file mode 100644 index 00000000..66ce35c0 --- /dev/null +++ b/packages/dataset/src/node.ts @@ -0,0 +1,7 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from "./common-server"; diff --git a/packages/dataset/src/types.ts b/packages/dataset/src/types.ts new file mode 100644 index 00000000..66ce35c0 --- /dev/null +++ b/packages/dataset/src/types.ts @@ -0,0 +1,7 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from "./common-server"; diff --git a/packages/storage/src/util/RepositorySchema.ts b/packages/dataset/src/util/RepositorySchema.ts similarity index 100% rename from packages/storage/src/util/RepositorySchema.ts rename to packages/dataset/src/util/RepositorySchema.ts diff --git a/packages/dataset/tsconfig.json b/packages/dataset/tsconfig.json new file mode 100644 index 00000000..b4562c70 --- /dev/null +++ b/packages/dataset/tsconfig.json @@ -0,0 +1,12 @@ +{ + "extends": "../../tsconfig.json", + "include": ["src/common.ts", "src/common-server.ts", "src/*/**/*"], + "files": ["./src/types.ts"], + "exclude": ["dist", "src/chunk-vector/PostgresChunkVectorStorage.ts", "src/chunk-vector/SqliteChunkVectorStorage.ts"], + "compilerOptions": { + "composite": true, + "outDir": "./dist", + "baseUrl": "./src", + "rootDir": "./src" + } +} diff --git a/packages/storage/src/common-server.ts b/packages/storage/src/common-server.ts index 118159b9..0480e6a5 100644 --- a/packages/storage/src/common-server.ts +++ b/packages/storage/src/common-server.ts @@ -25,8 +25,8 @@ export * from "./queue-limiter/PostgresRateLimiterStorage"; export * from "./queue-limiter/SqliteRateLimiterStorage"; export * from "./queue-limiter/SupabaseRateLimiterStorage"; -export * from "./chunk-vector/PostgresChunkVectorStorage"; -export * from "./chunk-vector/SqliteChunkVectorStorage"; +export * from "./vector/PostgresChunkVectorStorage"; +export * from "./vector/SqliteChunkVectorStorage"; // testing export * from "./kv/IndexedDbKvStorage"; diff --git a/packages/storage/src/common.ts b/packages/storage/src/common.ts index 4971fca5..2c31740a 100644 --- a/packages/storage/src/common.ts +++ b/packages/storage/src/common.ts @@ -10,8 +10,6 @@ export * from "./tabular/InMemoryTabularStorage"; export * from "./tabular/ITabularStorage"; export * from "./tabular/TabularStorageRegistry"; -export * from "./util/RepositorySchema"; - export * from "./kv/IKvStorage"; export * from "./kv/InMemoryKvStorage"; export * from "./kv/KvStorage"; @@ -26,15 +24,7 @@ export * from "./queue-limiter/IRateLimiterStorage"; export * from "./util/HybridSubscriptionManager"; export * from "./util/PollingSubscriptionManager"; -export * from "./document/Document"; -export * from "./document/DocumentNode"; -export * from "./document/DocumentRepository"; -export * from "./document/DocumentRepositoryRegistry"; -export * from "./document/DocumentSchema"; -export * from "./document/DocumentStorageSchema"; -export * from "./document/StructuralParser"; - -export * from "./chunk-vector/ChunkVectorStorageRegistry"; -export * from "./chunk-vector/ChunkVectorSchema"; -export * from "./chunk-vector/IChunkVectorStorage"; -export * from "./chunk-vector/InMemoryChunkVectorStorage"; +export * from "./vector/ChunkVectorSchema"; +export * from "./vector/ChunkVectorStorageRegistry"; +export * from "./vector/IChunkVectorStorage"; +export * from "./vector/InMemoryChunkVectorStorage"; diff --git a/packages/storage/src/vector/ChunkVectorSchema.ts b/packages/storage/src/vector/ChunkVectorSchema.ts new file mode 100644 index 00000000..8ce95438 --- /dev/null +++ b/packages/storage/src/vector/ChunkVectorSchema.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { TypedArraySchema, type DataPortSchemaObject, type TypedArray } from "@workglow/util"; + +/** + * Default schema for document chunk storage with vector embeddings + */ +export const ChunkVectorSchema = { + type: "object", + properties: { + chunk_id: { type: "string" }, + doc_id: { type: "string" }, + vector: TypedArraySchema(), + metadata: { type: "object", additionalProperties: true }, + }, + additionalProperties: false, +} as const satisfies DataPortSchemaObject; +export type ChunkVectorSchema = typeof ChunkVectorSchema; + +export const ChunkVectorKey = ["chunk_id"] as const; +export type ChunkVectorKey = typeof ChunkVectorKey; + +export interface ChunkVector< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> { + chunk_id: string; + doc_id: string; + vector: Vector; + metadata: Metadata; +} diff --git a/packages/storage/src/vector/ChunkVectorStorageRegistry.ts b/packages/storage/src/vector/ChunkVectorStorageRegistry.ts new file mode 100644 index 00000000..7c51b929 --- /dev/null +++ b/packages/storage/src/vector/ChunkVectorStorageRegistry.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import { AnyChunkVectorStorage } from "./IChunkVectorStorage"; + +/** + * Service token for the documenbt chunk vector repository registry + * Maps repository IDs to IVectorChunkRepository instances + */ +export const DOCUMENT_CHUNK_VECTOR_REPOSITORIES = createServiceToken< + Map +>("storage.document-node-vector.repositories"); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES)) { + globalServiceRegistry.register( + DOCUMENT_CHUNK_VECTOR_REPOSITORIES, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global document chunk vector repository registry + * @returns Map of document chunk vector repository ID to instance + */ +export function getGlobalChunkVectorRepositories(): Map { + return globalServiceRegistry.get(DOCUMENT_CHUNK_VECTOR_REPOSITORIES); +} + +/** + * Registers a vector repository globally by ID + * @param id The unique identifier for this repository + * @param repository The repository instance to register + */ +export function registerChunkVectorRepository( + id: string, + repository: AnyChunkVectorStorage +): void { + const repos = getGlobalChunkVectorRepositories(); + repos.set(id, repository); +} + +/** + * Gets a document chunk vector repository by ID from the global registry + * @param id The repository identifier + * @returns The repository instance or undefined if not found + */ +export function getChunkVectorRepository(id: string): AnyChunkVectorStorage | undefined { + return getGlobalChunkVectorRepositories().get(id); +} + +/** + * Resolves a repository ID to an IVectorChunkRepository from the registry. + * Used by the input resolver system. + */ +async function resolveChunkVectorRepositoryFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise { + const repos = registry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) + ? registry.get>(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) + : getGlobalChunkVectorRepositories(); + + const repo = repos.get(id); + if (!repo) { + throw new Error(`Document chunk vector repository "${id}" not found in registry`); + } + return repo; +} + +// Register the repository resolver for format: "repository:document-node-vector" +registerInputResolver("repository:document-node-vector", resolveChunkVectorRepositoryFromRegistry); diff --git a/packages/storage/src/chunk-vector/IChunkVectorStorage.ts b/packages/storage/src/vector/IChunkVectorStorage.ts similarity index 100% rename from packages/storage/src/chunk-vector/IChunkVectorStorage.ts rename to packages/storage/src/vector/IChunkVectorStorage.ts diff --git a/packages/storage/src/chunk-vector/InMemoryChunkVectorStorage.ts b/packages/storage/src/vector/InMemoryChunkVectorStorage.ts similarity index 100% rename from packages/storage/src/chunk-vector/InMemoryChunkVectorStorage.ts rename to packages/storage/src/vector/InMemoryChunkVectorStorage.ts diff --git a/packages/storage/src/chunk-vector/PostgresChunkVectorStorage.ts b/packages/storage/src/vector/PostgresChunkVectorStorage.ts similarity index 100% rename from packages/storage/src/chunk-vector/PostgresChunkVectorStorage.ts rename to packages/storage/src/vector/PostgresChunkVectorStorage.ts diff --git a/packages/storage/src/vector/README.md b/packages/storage/src/vector/README.md new file mode 100644 index 00000000..f64c8ca0 --- /dev/null +++ b/packages/storage/src/vector/README.md @@ -0,0 +1,341 @@ +# Chunk Vector Storage Module + +Storage for document chunk embeddings with vector similarity search capabilities. Extends the tabular repository pattern to add vector search functionality for RAG (Retrieval-Augmented Generation) pipelines. + +## Features + +- **Multiple Storage Backends:** + - 🧠 `InMemoryChunkVectorStorage` - Fast in-memory storage for testing and small datasets + - 📁 `SqliteChunkVectorStorage` - Persistent SQLite storage for local applications + - 🐘 `PostgresChunkVectorStorage` - PostgreSQL with pgvector extension for production + +- **Quantized Vector Support:** + - Float32Array (standard 32-bit floating point) + - Float16Array (16-bit floating point) + - Float64Array (64-bit high precision) + - Int8Array (8-bit signed - binary quantization) + - Uint8Array (8-bit unsigned - quantization) + - Int16Array (16-bit signed - quantization) + - Uint16Array (16-bit unsigned - quantization) + +- **Search Capabilities:** + - Vector similarity search (cosine similarity) + - Hybrid search (vector + full-text keyword matching) + - Metadata filtering + - Top-K retrieval with score thresholds + +- **Built on Tabular Repositories:** + - Extends `ITabularStorage` for standard CRUD operations + - Inherits event emitter pattern for monitoring + - Type-safe schema-based storage + +## Installation + +```bash +bun install @workglow/storage +``` + +## Usage + +### In-Memory Repository (Testing/Development) + +```typescript +import { InMemoryChunkVectorStorage } from "@workglow/storage"; + +// Create repository with 384 dimensions +const repo = new InMemoryChunkVectorStorage(384); +await repo.setupDatabase(); + +// Store a chunk with its embedding +await repo.put({ + chunk_id: "chunk-001", + doc_id: "doc-001", + vector: new Float32Array([0.1, 0.2, 0.3 /* ... 384 dims */]), + metadata: { text: "Hello world", source: "example.txt" }, +}); + +// Search for similar chunks +const results = await repo.similaritySearch(new Float32Array([0.15, 0.25, 0.35 /* ... */]), { + topK: 5, + scoreThreshold: 0.7, +}); +``` + +### Quantized Vectors (Reduced Storage) + +```typescript +import { InMemoryChunkVectorStorage } from "@workglow/storage"; + +// Use Int8Array for 4x smaller storage (binary quantization) +const repo = new InMemoryChunkVectorStorage<{ text: string }, Int8Array>(384, Int8Array); +await repo.setupDatabase(); + +// Store quantized vectors +await repo.put({ + chunk_id: "chunk-001", + doc_id: "doc-001", + vector: new Int8Array([127, -128, 64 /* ... */]), + metadata: { category: "ai" }, +}); + +// Search with quantized query +const results = await repo.similaritySearch(new Int8Array([100, -50, 75 /* ... */]), { topK: 5 }); +``` + +### SQLite Repository (Local Persistence) + +```typescript +import { SqliteChunkVectorStorage } from "@workglow/storage"; + +const repo = new SqliteChunkVectorStorage<{ text: string }>( + "./vectors.db", // database path + "chunks", // table name + 768 // vector dimension +); +await repo.setupDatabase(); + +// Bulk insert using inherited tabular methods +await repo.putMany([ + { chunk_id: "1", doc_id: "doc1", vector: new Float32Array([...]), metadata: { text: "..." } }, + { chunk_id: "2", doc_id: "doc1", vector: new Float32Array([...]), metadata: { text: "..." } }, +]); +``` + +### PostgreSQL with pgvector + +```typescript +import { Pool } from "pg"; +import { PostgresChunkVectorStorage } from "@workglow/storage"; + +const pool = new Pool({ connectionString: "postgresql://..." }); +const repo = new PostgresChunkVectorStorage<{ text: string; category: string }>( + pool, + "chunks", + 384 // vector dimension +); +await repo.setupDatabase(); + +// Native pgvector similarity search with filter +const results = await repo.similaritySearch(queryVector, { + topK: 10, + filter: { category: "ai" }, + scoreThreshold: 0.5, +}); + +// Hybrid search (vector + full-text) +const hybridResults = await repo.hybridSearch(queryVector, { + textQuery: "machine learning", + topK: 10, + vectorWeight: 0.7, + filter: { category: "ai" }, +}); +``` + +## Data Model + +### ChunkVector Schema + +Each chunk vector entry contains: + +```typescript +interface ChunkVector< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> { + chunk_id: string; // Unique identifier for the chunk + doc_id: string; // Parent document identifier + vector: Vector; // Embedding vector + metadata: Metadata; // Custom metadata (text content, entities, etc.) +} +``` + +### Default Schema + +```typescript +const ChunkVectorSchema = { + type: "object", + properties: { + chunk_id: { type: "string" }, + doc_id: { type: "string" }, + vector: TypedArraySchema(), + metadata: { type: "object", additionalProperties: true }, + }, + additionalProperties: false, +} as const; + +const ChunkVectorKey = ["chunk_id"] as const; +``` + +## API Reference + +### IChunkVectorStorage Interface + +Extends `ITabularStorage` with vector-specific methods: + +```typescript +interface IChunkVectorStorage extends ITabularStorage< + Schema, + PrimaryKeyNames, + Entity +> { + // Get the vector dimension + getVectorDimensions(): number; + + // Vector similarity search + similaritySearch( + query: TypedArray, + options?: VectorSearchOptions + ): Promise<(Entity & { score: number })[]>; + + // Hybrid search (optional - not all implementations support it) + hybridSearch?( + query: TypedArray, + options: HybridSearchOptions + ): Promise<(Entity & { score: number })[]>; +} +``` + +### Inherited Tabular Methods + +From `ITabularStorage`: + +```typescript +// Setup +setupDatabase(): Promise; + +// CRUD Operations +put(entity: Entity): Promise; +putMany(entities: Entity[]): Promise; +get(key: PrimaryKey): Promise; +getAll(): Promise; +delete(key: PrimaryKey): Promise; +deleteMany(keys: PrimaryKey[]): Promise; + +// Utility +size(): Promise; +clear(): Promise; +destroy(): void; +``` + +### Search Options + +```typescript +interface VectorSearchOptions> { + readonly topK?: number; // Number of results (default: 10) + readonly filter?: Partial; // Filter by metadata fields + readonly scoreThreshold?: number; // Minimum score 0-1 (default: 0) +} + +interface HybridSearchOptions extends VectorSearchOptions { + readonly textQuery: string; // Full-text query keywords + readonly vectorWeight?: number; // Vector weight 0-1 (default: 0.7) +} +``` + +## Global Registry + +Register and retrieve chunk vector repositories globally: + +```typescript +import { + registerChunkVectorRepository, + getChunkVectorRepository, + getGlobalChunkVectorRepositories, +} from "@workglow/storage"; + +// Register a repository +registerChunkVectorRepository("my-chunks", repo); + +// Retrieve by ID +const repo = getChunkVectorRepository("my-chunks"); + +// Get all registered repositories +const allRepos = getGlobalChunkVectorRepositories(); +``` + +## Quantization Benefits + +Quantized vectors reduce storage and can improve performance: + +| Vector Type | Bytes/Dim | Storage vs Float32 | Use Case | +| ------------ | --------- | ------------------ | ------------------------------------ | +| Float32Array | 4 | 100% (baseline) | Standard embeddings | +| Float64Array | 8 | 200% | High precision needed | +| Float16Array | 2 | 50% | Great precision/size tradeoff | +| Int16Array | 2 | 50% | Good precision/size tradeoff | +| Int8Array | 1 | 25% | Binary quantization, max compression | +| Uint8Array | 1 | 25% | Quantized embeddings [0-255] | + +**Example:** A 768-dimensional embedding: + +- Float32: 3,072 bytes +- Int8: 768 bytes (75% reduction!) + +## Performance Considerations + +### InMemory + +- **Best for:** Testing, small datasets (<10K vectors), development +- **Pros:** Fastest, no dependencies, supports all vector types +- **Cons:** No persistence, memory limited + +### SQLite + +- **Best for:** Local apps, medium datasets (<100K vectors) +- **Pros:** Persistent, single file, no server +- **Cons:** No native vector indexing (linear scan), slower for large datasets + +### PostgreSQL + pgvector + +- **Best for:** Production, large datasets (>100K vectors) +- **Pros:** Native HNSW/IVFFlat indexing, efficient similarity search, scalable +- **Cons:** Requires PostgreSQL server and pgvector extension +- **Setup:** `CREATE EXTENSION vector;` + +## Integration with DocumentRepository + +The chunk vector repository works alongside `DocumentRepository` for hierarchical document storage: + +```typescript +import { + DocumentRepository, + InMemoryChunkVectorStorage, + InMemoryTabularStorage, +} from "@workglow/storage"; +import { DocumentStorageSchema } from "@workglow/storage"; + +// Initialize storage backends +const tabularStorage = new InMemoryTabularStorage(DocumentStorageSchema, ["doc_id"]); +await tabularStorage.setupDatabase(); + +const vectorStorage = new InMemoryChunkVectorStorage(384); +await vectorStorage.setupDatabase(); + +// Create document repository with both storages +const docRepo = new DocumentRepository(tabularStorage, vectorStorage); + +// Store document structure in tabular, chunks in vector +await docRepo.upsert(document); + +// Search chunks by vector similarity +const results = await docRepo.search(queryVector, { topK: 5 }); +``` + +### Chunk Metadata for Hierarchical Documents + +When using hierarchical chunking, chunk metadata typically includes: + +```typescript +metadata: { + text: string; // Chunk text content + leafNodeId?: string; // Reference to document tree node + depth?: number; // Hierarchy depth + nodePath?: string[]; // Node IDs from root to leaf + summary?: string; // Summary of the chunk content + entities?: Entity[]; // Named entities extracted from the chunk +} +``` + +## License + +Apache 2.0 diff --git a/packages/storage/src/chunk-vector/SqliteChunkVectorStorage.ts b/packages/storage/src/vector/SqliteChunkVectorStorage.ts similarity index 100% rename from packages/storage/src/chunk-vector/SqliteChunkVectorStorage.ts rename to packages/storage/src/vector/SqliteChunkVectorStorage.ts diff --git a/packages/test/package.json b/packages/test/package.json index 19cf0a83..8c4077e3 100644 --- a/packages/test/package.json +++ b/packages/test/package.json @@ -35,6 +35,7 @@ "peerDependencies": { "@workglow/ai": "workspace:*", "@workglow/ai-provider": "workspace:*", + "@workglow/dataset": "workspace:*", "@workglow/job-queue": "workspace:*", "@workglow/storage": "workspace:*", "@workglow/task-graph": "workspace:*", @@ -52,6 +53,9 @@ "@workglow/ai-provider": { "optional": false }, + "@workglow/dataset": { + "optional": false + }, "@workglow/job-queue": { "optional": false }, @@ -75,6 +79,7 @@ "@electric-sql/pglite": "^0.3.11", "@workglow/ai": "workspace:*", "@workglow/ai-provider": "workspace:*", + "@workglow/dataset": "workspace:*", "@workglow/job-queue": "workspace:*", "@workglow/sqlite": "workspace:*", "@workglow/storage": "workspace:*", diff --git a/packages/test/src/test/rag/ChunkToVector.test.ts b/packages/test/src/test/rag/ChunkToVector.test.ts index e1d00ea1..eea38722 100644 --- a/packages/test/src/test/rag/ChunkToVector.test.ts +++ b/packages/test/src/test/rag/ChunkToVector.test.ts @@ -6,7 +6,7 @@ import "@workglow/ai"; // Trigger Workflow prototype extensions import type { ChunkToVectorTaskOutput, HierarchicalChunkerTaskOutput } from "@workglow/ai"; -import { type ChunkNode, NodeIdGenerator, StructuralParser } from "@workglow/storage"; +import { type ChunkNode, NodeIdGenerator, StructuralParser } from "@workglow/dataset"; import { Workflow } from "@workglow/task-graph"; import { describe, expect, it } from "vitest"; diff --git a/packages/test/src/test/rag/Document.test.ts b/packages/test/src/test/rag/Document.test.ts index e89195dd..cc6e841a 100644 --- a/packages/test/src/test/rag/Document.test.ts +++ b/packages/test/src/test/rag/Document.test.ts @@ -4,8 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ChunkNode, DocumentNode } from "@workglow/storage"; -import { Document, NodeKind } from "@workglow/storage"; +import type { ChunkNode, DocumentNode } from "@workglow/dataset"; +import { Document, NodeKind } from "@workglow/dataset"; import { describe, expect, test } from "vitest"; describe("Document", () => { diff --git a/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts b/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts index 0a5aaf72..e4fb9ea2 100644 --- a/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts +++ b/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts @@ -5,7 +5,7 @@ */ import { retrieval } from "@workglow/ai"; -import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/dataset"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("DocumentNodeRetrievalTask", () => { diff --git a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts b/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts index e1fd3526..1b61392d 100644 --- a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts +++ b/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts @@ -5,7 +5,7 @@ */ import { ChunkVectorSearchTask } from "@workglow/ai"; -import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/dataset"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorSearchTask", () => { diff --git a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts b/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts index 72c5c6b8..50bf53e2 100644 --- a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts +++ b/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts @@ -5,7 +5,7 @@ */ import { ChunkVectorUpsertTask } from "@workglow/ai"; -import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/dataset"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorUpsertTask", () => { diff --git a/packages/test/src/test/rag/DocumentRepository.test.ts b/packages/test/src/test/rag/DocumentRepository.test.ts index 64f4ec92..15bb6530 100644 --- a/packages/test/src/test/rag/DocumentRepository.test.ts +++ b/packages/test/src/test/rag/DocumentRepository.test.ts @@ -4,17 +4,17 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { InMemoryTabularStorage } from "@workglow/storage"; import { Document, DocumentRepository, DocumentStorageKey, DocumentStorageSchema, InMemoryChunkVectorStorage, - InMemoryTabularStorage, NodeIdGenerator, NodeKind, StructuralParser, -} from "@workglow/storage"; +} from "@workglow/dataset"; import { beforeEach, describe, expect, it } from "vitest"; describe("DocumentRepository", () => { diff --git a/packages/test/src/test/rag/EndToEnd.test.ts b/packages/test/src/test/rag/EndToEnd.test.ts index 3fac3586..fbfb3afe 100644 --- a/packages/test/src/test/rag/EndToEnd.test.ts +++ b/packages/test/src/test/rag/EndToEnd.test.ts @@ -11,13 +11,18 @@ import { DocumentStorageKey, DocumentStorageSchema, InMemoryChunkVectorStorage, - InMemoryTabularStorage, NodeIdGenerator, StructuralParser, -} from "@workglow/storage"; -import { describe, expect, it } from "vitest"; +} from "@workglow/dataset"; +import { InMemoryTabularStorage } from "@workglow/storage"; +import { beforeAll, describe, expect, it } from "vitest"; +import { registerTasks } from "../../binding/RegisterTasks"; describe("End-to-end hierarchical RAG", () => { + beforeAll(async () => { + registerTasks(); + }); + it("should demonstrate chainable design (chunks → text array)", async () => { // Sample markdown document const markdown = `# Machine Learning @@ -36,7 +41,6 @@ Finds patterns in data.`; const doc_id = await NodeIdGenerator.generateDocId("ml-guide", markdown); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "ML Guide"); - // CHAINABLE DESIGN TEST - Use workflow to verify chaining const chunkResult = await hierarchicalChunker({ doc_id, documentTree: root, diff --git a/packages/test/src/test/rag/FullChain.test.ts b/packages/test/src/test/rag/FullChain.test.ts index 0ee22153..d7635b39 100644 --- a/packages/test/src/test/rag/FullChain.test.ts +++ b/packages/test/src/test/rag/FullChain.test.ts @@ -5,11 +5,16 @@ */ import { HierarchicalChunkerTaskOutput } from "@workglow/ai"; -import { ChunkNode, InMemoryChunkVectorStorage, NodeIdGenerator } from "@workglow/storage"; +import { ChunkNode, InMemoryChunkVectorStorage, NodeIdGenerator } from "@workglow/dataset"; import { Workflow } from "@workglow/task-graph"; -import { describe, expect, it } from "vitest"; +import { beforeAll, describe, expect, it } from "vitest"; +import { registerTasks } from "../../binding/RegisterTasks"; describe("Complete chainable workflow", () => { + beforeAll(async () => { + registerTasks(); + }); + it("should chain from parsing to storage without loops", async () => { const vectorRepo = new InMemoryChunkVectorStorage(3); await vectorRepo.setupDatabase(); @@ -45,7 +50,6 @@ This is the second section with more content.`; // Verify the chain worked - final output from hierarchicalChunker expect(result.doc_id).toBeDefined(); - expect(result.doc_id).toMatch(/^doc_[0-9a-f]{16}$/); expect(result.chunks).toBeDefined(); expect(result.text).toBeDefined(); expect(result.count).toBeGreaterThan(0); @@ -90,30 +94,6 @@ This is the second section with more content.`; } }); - it("should generate consistent doc_id across chains", async () => { - const markdown = "# Test\n\nContent."; - - // Run twice with same content - const result1 = await new Workflow() - .structuralParser({ - text: markdown, - title: "Test", - sourceUri: "test.md", - }) - .run(); - - const result2 = await new Workflow() - .structuralParser({ - text: markdown, - title: "Test", - sourceUri: "test.md", - }) - .run(); - - // Should generate same doc_id (deterministic) - expect(result1.doc_id).toBe(result2.doc_id); - }); - it("should allow doc_id override for variant creation", async () => { const markdown = "# Test\n\nContent."; const customId = await NodeIdGenerator.generateDocId("custom", markdown); diff --git a/packages/test/src/test/rag/HierarchicalChunker.test.ts b/packages/test/src/test/rag/HierarchicalChunker.test.ts index c7f943dc..a1fc8c3a 100644 --- a/packages/test/src/test/rag/HierarchicalChunker.test.ts +++ b/packages/test/src/test/rag/HierarchicalChunker.test.ts @@ -5,7 +5,7 @@ */ import { hierarchicalChunker } from "@workglow/ai"; -import { estimateTokens, NodeIdGenerator, StructuralParser } from "@workglow/storage"; +import { estimateTokens, NodeIdGenerator, StructuralParser } from "@workglow/dataset"; import { Workflow } from "@workglow/task-graph"; import { describe, expect, it } from "vitest"; diff --git a/packages/test/src/test/rag/HybridSearchTask.test.ts b/packages/test/src/test/rag/HybridSearchTask.test.ts index 79803922..7d1ac428 100644 --- a/packages/test/src/test/rag/HybridSearchTask.test.ts +++ b/packages/test/src/test/rag/HybridSearchTask.test.ts @@ -5,7 +5,7 @@ */ import { hybridSearch } from "@workglow/ai"; -import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/storage"; +import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/dataset"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorHybridSearchTask", () => { diff --git a/packages/test/src/test/rag/RagWorkflow.test.ts b/packages/test/src/test/rag/RagWorkflow.test.ts index 412033db..c3a55c3f 100644 --- a/packages/test/src/test/rag/RagWorkflow.test.ts +++ b/packages/test/src/test/rag/RagWorkflow.test.ts @@ -46,14 +46,14 @@ import { VectorStoreUpsertTaskOutput, } from "@workglow/ai"; import { register_HFT_InlineJobFns } from "@workglow/ai-provider"; +import { InMemoryTabularStorage } from "@workglow/storage"; import { DocumentRepository, DocumentStorageKey, DocumentStorageSchema, InMemoryChunkVectorStorage, - InMemoryTabularStorage, registerChunkVectorRepository, -} from "@workglow/storage"; +} from "@workglow/dataset"; import { getTaskQueueRegistry, setTaskQueueRegistry, Workflow } from "@workglow/task-graph"; import { readdirSync } from "fs"; import { join } from "path"; diff --git a/packages/test/src/test/rag/StructuralParser.test.ts b/packages/test/src/test/rag/StructuralParser.test.ts index 51d3b8e9..43d0059a 100644 --- a/packages/test/src/test/rag/StructuralParser.test.ts +++ b/packages/test/src/test/rag/StructuralParser.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { NodeIdGenerator, NodeKind, StructuralParser } from "@workglow/storage"; +import { NodeKind, StructuralParser } from "@workglow/dataset"; import { describe, expect, it } from "vitest"; describe("StructuralParser", () => { @@ -150,55 +150,4 @@ Third paragraph here.`; expect(root.children[0].kind).toBe(NodeKind.PARAGRAPH); }); }); - - describe("NodeIdGenerator", () => { - it("should generate consistent docIds", async () => { - const id1 = await NodeIdGenerator.generateDocId("source1", "content"); - const id2 = await NodeIdGenerator.generateDocId("source1", "content"); - - expect(id1).toBe(id2); - expect(id1).toMatch(/^doc_[0-9a-f]{16}$/); - }); - - it("should generate different IDs for different content", () => { - const id1 = NodeIdGenerator.generateDocId("source", "content1"); - const id2 = NodeIdGenerator.generateDocId("source", "content2"); - - expect(id1).not.toBe(id2); - }); - - it("should generate consistent structural node IDs", async () => { - const doc_id = "doc_test"; - const range = { startOffset: 0, endOffset: 100 }; - - const id1 = await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.SECTION, range); - const id2 = await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.SECTION, range); - - expect(id1).toBe(id2); - expect(id1).toMatch(/^node_[0-9a-f]{16}$/); - }); - - it("should generate consistent child node IDs", async () => { - const parentId = "node_parent"; - const ordinal = 2; - - const id1 = await NodeIdGenerator.generateChildNodeId(parentId, ordinal); - const id2 = await NodeIdGenerator.generateChildNodeId(parentId, ordinal); - - expect(id1).toBe(id2); - expect(id1).toMatch(/^node_[0-9a-f]{16}$/); - }); - - it("should generate consistent chunk IDs", async () => { - const doc_id = "doc_test"; - const leafNodeId = "node_leaf"; - const ordinal = 0; - - const id1 = await NodeIdGenerator.generateChunkId(doc_id, leafNodeId, ordinal); - const id2 = await NodeIdGenerator.generateChunkId(doc_id, leafNodeId, ordinal); - - expect(id1).toBe(id2); - expect(id1).toMatch(/^chunk_[0-9a-f]{16}$/); - }); - }); }); diff --git a/packages/test/src/test/task-graph/InputResolver.test.ts b/packages/test/src/test/task-graph/InputResolver.test.ts index eac89d82..d52ddbf3 100644 --- a/packages/test/src/test/task-graph/InputResolver.test.ts +++ b/packages/test/src/test/task-graph/InputResolver.test.ts @@ -9,8 +9,8 @@ import { getGlobalTabularRepositories, InMemoryTabularStorage, registerTabularRepository, - TypeTabularRepository, } from "@workglow/storage"; +import { TypeTabularRepository } from "@workglow/dataset"; import { IExecuteContext, resolveSchemaInputs, Task, TaskRegistry } from "@workglow/task-graph"; import { getInputResolvers, diff --git a/packages/test/src/test/util/Document.test.ts b/packages/test/src/test/util/Document.test.ts index 85cc0635..c81c3b15 100644 --- a/packages/test/src/test/util/Document.test.ts +++ b/packages/test/src/test/util/Document.test.ts @@ -4,8 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ChunkNode, DocumentNode } from "@workglow/storage"; -import { Document, NodeKind } from "@workglow/storage"; +import type { ChunkNode, DocumentNode } from "@workglow/dataset"; +import { Document, NodeKind } from "@workglow/dataset"; import { describe, expect, test } from "vitest"; describe("Document", () => { From 1d889436db84b3a27b1d05ade0de85e1c644218c Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Wed, 14 Jan 2026 07:24:32 +0000 Subject: [PATCH 12/14] [refactor] Transition from Repository to Storage or Dataset Naming Convention - Updated the codebase to replace references from `repository` to `dataset`, enhancing clarity in data management terminology. - Refactored various tasks and schemas to align with the new dataset structure, including `DocumentChunkDataset` and `DocumentDataset`. - Removed deprecated chunk vector storage components and introduced new vector storage implementations for PostgreSQL and SQLite. - Enhanced the TODO list with new items related to dataset management and improved documentation to reflect these changes. - Updated tests to ensure compatibility with the new dataset naming and structure. --- TODO.md | 9 +- docs/developers/03_extending.md | 18 +- packages/ai-provider/package.json | 4 +- packages/ai/package.json | 4 +- packages/ai/src/model/ModelRegistry.ts | 2 +- packages/ai/src/task/ChunkRetrievalTask.ts | 12 +- .../src/task/ChunkVectorHybridSearchTask.ts | 12 +- packages/ai/src/task/ChunkVectorSearchTask.ts | 14 +- packages/ai/src/task/ChunkVectorUpsertTask.ts | 12 +- .../ai/src/task/HierarchicalChunkerTask.ts | 7 +- packages/ai/src/task/HierarchyJoinTask.ts | 2 +- packages/ai/src/task/StructuralParserTask.ts | 7 +- packages/dataset/README.md | 18 +- packages/dataset/package.json | 4 +- .../ChunkVectorStorageRegistry.ts | 83 ----- .../src/chunk-vector/IChunkVectorStorage.ts | 105 ------- .../InMemoryChunkVectorStorage.ts | 185 ----------- .../PostgresChunkVectorStorage.ts | 293 ------------------ .../chunk-vector/SqliteChunkVectorStorage.ts | 192 ------------ packages/dataset/src/common.ts | 12 +- .../document-chunk/DocumentChunkDataset.ts | 127 ++++++++ .../DocumentChunkDatasetRegistry.ts | 79 +++++ .../DocumentChunkSchema.ts} | 21 +- .../README.md | 177 ++++++----- .../dataset/src/document/DocumentDataset.ts | 204 ++++++++++++ .../src/document/DocumentDatasetRegistry.ts | 79 +++++ packages/dataset/src/document/DocumentNode.ts | 57 ---- .../src/document/DocumentRepository.ts | 18 +- .../document/DocumentRepositoryRegistry.ts | 79 ----- .../src/document/DocumentStorageSchema.ts | 7 + .../dataset/src/document/StructuralParser.ts | 26 +- packages/dataset/src/util/DatasetSchema.ts | 89 ++++++ packages/dataset/src/util/RepositorySchema.ts | 96 ------ packages/dataset/tsconfig.json | 2 +- packages/debug/package.json | 4 +- packages/job-queue/package.json | 4 +- packages/sqlite/package.json | 8 +- packages/storage/README.md | 8 +- packages/storage/package.json | 4 +- packages/storage/src/common-server.ts | 4 +- packages/storage/src/common.ts | 6 +- .../src/tabular/TabularStorageRegistry.ts | 6 +- .../storage/src/vector/ChunkVectorSchema.ts | 35 --- .../src/vector/ChunkVectorStorageRegistry.ts | 83 ----- ...hunkVectorStorage.ts => IVectorStorage.ts} | 75 +++-- ...torStorage.ts => InMemoryVectorStorage.ts} | 80 +++-- ...torStorage.ts => PostgresVectorStorage.ts} | 152 +++++---- packages/storage/src/vector/README.md | 172 ++++++---- ...ectorStorage.ts => SqliteVectorStorage.ts} | 92 +++--- packages/task-graph/package.json | 4 +- packages/task-graph/src/task/InputResolver.ts | 4 +- packages/task-graph/src/task/README.md | 2 +- packages/tasks/package.json | 4 +- packages/test/package.json | 4 +- .../test/src/test/rag/ChunkToVector.test.ts | 5 +- ....ts => DocumentChunkRetrievalTask.test.ts} | 67 ++-- ...est.ts => DocumentChunkSearchTask.test.ts} | 72 +++-- ...est.ts => DocumentChunkUpsertTask.test.ts} | 58 ++-- .../src/test/rag/DocumentRepository.test.ts | 154 ++++----- packages/test/src/test/rag/EndToEnd.test.ts | 28 +- packages/test/src/test/rag/FullChain.test.ts | 15 +- .../src/test/rag/HierarchicalChunker.test.ts | 13 +- .../src/test/rag/HybridSearchTask.test.ts | 69 +++-- .../test/src/test/rag/RagWorkflow.test.ts | 45 ++- .../src/test/task-graph/InputResolver.test.ts | 80 ++--- packages/util/package.json | 4 +- packages/util/src/di/InputResolverRegistry.ts | 29 +- 67 files changed, 1554 insertions(+), 1892 deletions(-) delete mode 100644 packages/dataset/src/chunk-vector/ChunkVectorStorageRegistry.ts delete mode 100644 packages/dataset/src/chunk-vector/IChunkVectorStorage.ts delete mode 100644 packages/dataset/src/chunk-vector/InMemoryChunkVectorStorage.ts delete mode 100644 packages/dataset/src/chunk-vector/PostgresChunkVectorStorage.ts delete mode 100644 packages/dataset/src/chunk-vector/SqliteChunkVectorStorage.ts create mode 100644 packages/dataset/src/document-chunk/DocumentChunkDataset.ts create mode 100644 packages/dataset/src/document-chunk/DocumentChunkDatasetRegistry.ts rename packages/dataset/src/{chunk-vector/ChunkVectorSchema.ts => document-chunk/DocumentChunkSchema.ts} (53%) rename packages/dataset/src/{chunk-vector => document-chunk}/README.md (65%) create mode 100644 packages/dataset/src/document/DocumentDataset.ts create mode 100644 packages/dataset/src/document/DocumentDatasetRegistry.ts delete mode 100644 packages/dataset/src/document/DocumentRepositoryRegistry.ts create mode 100644 packages/dataset/src/util/DatasetSchema.ts delete mode 100644 packages/dataset/src/util/RepositorySchema.ts delete mode 100644 packages/storage/src/vector/ChunkVectorSchema.ts delete mode 100644 packages/storage/src/vector/ChunkVectorStorageRegistry.ts rename packages/storage/src/vector/{IChunkVectorStorage.ts => IVectorStorage.ts} (53%) rename packages/storage/src/vector/{InMemoryChunkVectorStorage.ts => InMemoryVectorStorage.ts} (68%) rename packages/storage/src/vector/{PostgresChunkVectorStorage.ts => PostgresVectorStorage.ts} (58%) rename packages/storage/src/vector/{SqliteChunkVectorStorage.ts => SqliteVectorStorage.ts} (65%) rename packages/test/src/test/rag/{DocumentNodeRetrievalTask.test.ts => DocumentChunkRetrievalTask.test.ts} (86%) rename packages/test/src/test/rag/{DocumentNodeVectorSearchTask.test.ts => DocumentChunkSearchTask.test.ts} (81%) rename packages/test/src/test/rag/{DocumentNodeVectorStoreUpsertTask.test.ts => DocumentChunkUpsertTask.test.ts} (81%) diff --git a/TODO.md b/TODO.md index 4edd6c9d..bab50c98 100644 --- a/TODO.md +++ b/TODO.md @@ -2,13 +2,13 @@ TODO.md - [x] Rename repositories in the packages/storage to use the word Storage instead of Repository. - [ ] Vector Storage (not chunk storage) - - [ ] Rename the files from packages/storage/src/vector-storage to packages/storage/src/vector - - [ ] No fixed column names, use the schema to define the columns. + - [x] Rename the files from packages/storage/src/vector-storage to packages/storage/src/vector + - [x] No fixed column names, use the schema to define the columns. - [ ] Option for which column to use if there are multiple, default to the first one. - [ ] Use @mceachen/sqlite-vec for sqlite storage. - [ ] Datasets Package - - [ ] Documents repository (mabye rename to DocumentDataset) - - [ ] Chunks repository (maybe rename to ChunkDataset) or DocumentChunksDataset? Or just part of DocumentDataset? Or is it a new thing? + - [x] Documents dataset (mabye rename to DocumentDataset) + - [ ] Chunks Package (or part of DocumentDataset?) - [ ] Move Model repository to datasets package. - [ ] Chunk Repository - [ ] Add to packages/tasks or packages/ai @@ -17,6 +17,7 @@ TODO.md - [ ] Chunks and nodes are not always the same. - [ ] And we may need to save the chunk's node path. Or paths? or document range? Standard metadata? - [ ] Use Repository to always envelope the storage operations (for transactions, dealing with IDs, etc). +- [ ] Instead of passing doc_id around, pass a document key that is unknonwn (string or object) - [ ] Get a better model for question answering. - [ ] Get a better model for named entity recognition, the current one recognized everything as a token, not helpful. diff --git a/docs/developers/03_extending.md b/docs/developers/03_extending.md index 875ccb76..f0e27797 100644 --- a/docs/developers/03_extending.md +++ b/docs/developers/03_extending.md @@ -138,13 +138,13 @@ When defining task input schemas, you can use `format` annotations to enable aut The system supports several format annotations out of the box: -| Format | Description | Helper Function | -| --------------------------------- | ----------------------------------- | ----------------------------- | -| `model` | Any AI model configuration | `TypeModel()` | -| `model:TaskName` | Model compatible with specific task | — | -| `repository:tabular` | Tabular data repository | `TypeTabularRepository()` | -| `repository:document-node-vector` | Vector storage repository | `TypeChunkVectorRepository()` | -| `repository:document` | Document repository | `TypeDocumentRepository()` | +| Format | Description | Helper Function | +| ------------------------------ | ----------------------------------- | ----------------------------- | +| `model` | Any AI model configuration | `TypeModel()` | +| `model:TaskName` | Model compatible with specific task | — | +| `storage:tabular` | Tabular data dataset | `TypeTabularRepository()` | +| `dataset:document-node-vector` | Vector storage dataset | `TypeChunkVectorRepository()` | +| `dataset:document` | Document dataset | `TypeDocumentRepository()` | ### Example: Using Format Annotations @@ -316,7 +316,7 @@ await new Workflow() }) .chunkToVector() .vectorStoreUpsert({ - repository: vectorRepo, + dataset: vectorDataset, }) .run(); ``` @@ -330,7 +330,7 @@ const answer = await new Workflow() model: "Xenova/all-MiniLM-L6-v2", }) .vectorStoreSearch({ - repository: vectorRepo, + dataset: vectorDataset, topK: 10, }) .reranker({ diff --git a/packages/ai-provider/package.json b/packages/ai-provider/package.json index e13702e7..3db05d11 100644 --- a/packages/ai-provider/package.json +++ b/packages/ai-provider/package.json @@ -17,8 +17,8 @@ }, "exports": { ".": { - "bun": "./dist/index.js", - "types": "./dist/types.d.ts", + "bun": "./src/index.ts", + "types": "./src/types.ts", "import": "./dist/index.js" } }, diff --git a/packages/ai/package.json b/packages/ai/package.json index 66c27682..218a0f7f 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" } }, diff --git a/packages/ai/src/model/ModelRegistry.ts b/packages/ai/src/model/ModelRegistry.ts index 2a7a3deb..f570dcf8 100644 --- a/packages/ai/src/model/ModelRegistry.ts +++ b/packages/ai/src/model/ModelRegistry.ts @@ -59,7 +59,7 @@ async function resolveModelFromRegistry( if (Array.isArray(id)) { const results = await Promise.all(id.map((i) => modelRepo.findByName(i))); - return results.filter((model) => model !== undefined) as ModelConfig[]; + return results.filter((model): model is NonNullable => model !== undefined); } const model = await modelRepo.findByName(id); diff --git a/packages/ai/src/task/ChunkRetrievalTask.ts b/packages/ai/src/task/ChunkRetrievalTask.ts index 5a6b0d49..7b16587b 100644 --- a/packages/ai/src/task/ChunkRetrievalTask.ts +++ b/packages/ai/src/task/ChunkRetrievalTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/dataset"; +import { DocumentChunkDataset, TypeDocumentChunkDataset } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, @@ -25,7 +25,7 @@ import { TextEmbeddingTask } from "./TextEmbeddingTask"; const inputSchema = { type: "object", properties: { - repository: TypeChunkVectorRepository({ + dataset: TypeDocumentChunkDataset({ title: "Document Chunk Vector Repository", description: "The document chunk vector repository instance to search in", }), @@ -72,14 +72,14 @@ const inputSchema = { default: false, }, }, - required: ["repository", "query"], + required: ["dataset", "query"], if: { properties: { query: { type: "string" }, }, }, then: { - required: ["repository", "query", "model"], + required: ["dataset", "query", "model"], }, additionalProperties: false, } as const satisfies DataPortSchema; @@ -162,7 +162,7 @@ export class DocumentNodeRetrievalTask extends Task< async execute(input: RetrievalTaskInput, context: IExecuteContext): Promise { const { - repository, + dataset, query, topK = 5, filter, @@ -172,7 +172,7 @@ export class DocumentNodeRetrievalTask extends Task< } = input; // Repository is resolved by input resolver system before execution - const repo = repository as AnyChunkVectorStorage; + const repo = dataset as DocumentChunkDataset; // Determine query vector let queryVector: TypedArray; diff --git a/packages/ai/src/task/ChunkVectorHybridSearchTask.ts b/packages/ai/src/task/ChunkVectorHybridSearchTask.ts index e8b5a6d6..486287da 100644 --- a/packages/ai/src/task/ChunkVectorHybridSearchTask.ts +++ b/packages/ai/src/task/ChunkVectorHybridSearchTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/dataset"; +import { DocumentChunkDataset, TypeDocumentChunkDataset } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, @@ -22,7 +22,7 @@ import { const inputSchema = { type: "object", properties: { - repository: TypeChunkVectorRepository({ + dataset: TypeDocumentChunkDataset({ title: "Document Chunk Vector Repository", description: "The document chunk vector repository instance to search in (must support hybridSearch)", @@ -71,7 +71,7 @@ const inputSchema = { default: false, }, }, - required: ["repository", "queryVector", "queryText"], + required: ["dataset", "queryVector", "queryText"], additionalProperties: false, } as const satisfies DataPortSchema; @@ -160,7 +160,7 @@ export class ChunkVectorHybridSearchTask extends Task< context: IExecuteContext ): Promise { const { - repository, + dataset, queryVector, queryText, topK = 10, @@ -171,11 +171,11 @@ export class ChunkVectorHybridSearchTask extends Task< } = input; // Repository is resolved by input resolver system before execution - const repo = repository as AnyChunkVectorStorage; + const repo = dataset as DocumentChunkDataset; // Check if repository supports hybrid search if (!repo.hybridSearch) { - throw new Error("Repository does not support hybrid search."); + throw new Error("Dataset does not support hybrid search."); } // Convert to Float32Array for repository search (repo expects Float32Array by default) diff --git a/packages/ai/src/task/ChunkVectorSearchTask.ts b/packages/ai/src/task/ChunkVectorSearchTask.ts index c94263ec..b27e366c 100644 --- a/packages/ai/src/task/ChunkVectorSearchTask.ts +++ b/packages/ai/src/task/ChunkVectorSearchTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/dataset"; +import { DocumentChunkDataset, TypeDocumentChunkDataset } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, @@ -22,7 +22,7 @@ import { const inputSchema = { type: "object", properties: { - repository: TypeChunkVectorRepository({ + dataset: TypeDocumentChunkDataset({ title: "Vector Repository", description: "The vector repository instance to search in", }), @@ -51,7 +51,7 @@ const inputSchema = { default: 0, }, }, - required: ["repository", "query"], + required: ["dataset", "query"], additionalProperties: false, } as const satisfies DataPortSchema; @@ -103,7 +103,7 @@ export type VectorStoreSearchTaskInput = FromSchema; /** - * Task for searching similar vectors in a vector repository. + * Task for searching similar vectors in a document chunk dataset. * Returns top-K most similar vectors with their metadata and scores. */ export class ChunkVectorSearchTask extends Task< @@ -114,7 +114,7 @@ export class ChunkVectorSearchTask extends Task< public static type = "ChunkVectorSearchTask"; public static category = "Vector Store"; public static title = "Vector Store Search"; - public static description = "Search for similar vectors in a vector repository"; + public static description = "Search for similar vectors in a document chunk dataset"; public static cacheable = true; public static inputSchema(): DataPortSchema { @@ -129,9 +129,9 @@ export class ChunkVectorSearchTask extends Task< input: VectorStoreSearchTaskInput, context: IExecuteContext ): Promise { - const { repository, query, topK = 10, filter, scoreThreshold = 0 } = input; + const { dataset, query, topK = 10, filter, scoreThreshold = 0 } = input; - const repo = repository as AnyChunkVectorStorage; + const repo = dataset as DocumentChunkDataset; const results = await repo.similaritySearch(query, { topK, diff --git a/packages/ai/src/task/ChunkVectorUpsertTask.ts b/packages/ai/src/task/ChunkVectorUpsertTask.ts index bb5d1152..880c89a3 100644 --- a/packages/ai/src/task/ChunkVectorUpsertTask.ts +++ b/packages/ai/src/task/ChunkVectorUpsertTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AnyChunkVectorStorage, TypeChunkVectorRepository } from "@workglow/dataset"; +import { DocumentChunkDataset, TypeDocumentChunkDataset } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, @@ -23,7 +23,7 @@ import { TypeSingleOrArray } from "./base/AiTaskSchemas"; const inputSchema = { type: "object", properties: { - repository: TypeChunkVectorRepository({ + dataset: TypeDocumentChunkDataset({ title: "Document Chunk Vector Repository", description: "The document chunk vector repository instance to store vectors in", }), @@ -45,7 +45,7 @@ const inputSchema = { additionalProperties: true, }), }, - required: ["repository", "doc_id", "vectors", "metadata"], + required: ["dataset", "doc_id", "vectors", "metadata"], additionalProperties: false, } as const satisfies DataPortSchema; @@ -91,7 +91,7 @@ export class ChunkVectorUpsertTask extends Task< public static type = "ChunkVectorUpsertTask"; public static category = "Vector Store"; public static title = "Vector Store Upsert"; - public static description = "Store vector embeddings with metadata in a vector repository"; + public static description = "Store vector embeddings with metadata in a document chunk dataset"; public static cacheable = false; // Has side effects public static inputSchema(): DataPortSchema { @@ -106,7 +106,7 @@ export class ChunkVectorUpsertTask extends Task< input: VectorStoreUpsertTaskInput, context: IExecuteContext ): Promise { - const { repository, doc_id, vectors, metadata } = input; + const { dataset, doc_id, vectors, metadata } = input; // Normalize inputs to arrays const vectorArray = Array.isArray(vectors) ? vectors : [vectors]; @@ -114,7 +114,7 @@ export class ChunkVectorUpsertTask extends Task< ? metadata : Array(vectorArray.length).fill(metadata); - const repo = repository as AnyChunkVectorStorage; + const repo = dataset as DocumentChunkDataset; await context.updateProgress(1, "Upserting vectors"); diff --git a/packages/ai/src/task/HierarchicalChunkerTask.ts b/packages/ai/src/task/HierarchicalChunkerTask.ts index ed5b07ce..96ef41e3 100644 --- a/packages/ai/src/task/HierarchicalChunkerTask.ts +++ b/packages/ai/src/task/HierarchicalChunkerTask.ts @@ -9,7 +9,6 @@ import { estimateTokens, getChildren, hasChildren, - NodeIdGenerator, type ChunkNode, type DocumentNode, type TokenBudget, @@ -21,7 +20,7 @@ import { Task, Workflow, } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; +import { DataPortSchema, FromSchema, uuid4 } from "@workglow/util"; const inputSchema = { type: "object", @@ -207,7 +206,7 @@ export class HierarchicalChunkerTask extends Task< if (estimateTokens(text) <= tokenBudget.maxTokensPerChunk - tokenBudget.reservedTokens) { // Text fits in one chunk - const chunkId = await NodeIdGenerator.generateChunkId(doc_id, leafNodeId, 0); + const chunkId = uuid4(); chunks.push({ chunkId, doc_id, @@ -226,7 +225,7 @@ export class HierarchicalChunkerTask extends Task< const endOffset = Math.min(startOffset + maxChars, text.length); const chunkText = text.substring(startOffset, endOffset); - const chunkId = await NodeIdGenerator.generateChunkId(doc_id, leafNodeId, chunkOrdinal); + const chunkId = uuid4(); chunks.push({ chunkId, diff --git a/packages/ai/src/task/HierarchyJoinTask.ts b/packages/ai/src/task/HierarchyJoinTask.ts index 2a44638b..9d13f0a6 100644 --- a/packages/ai/src/task/HierarchyJoinTask.ts +++ b/packages/ai/src/task/HierarchyJoinTask.ts @@ -5,9 +5,9 @@ */ import { - type ChunkMetadata, ChunkMetadataArraySchema, EnrichedChunkMetadataArraySchema, + type ChunkMetadata, type DocumentRepository, } from "@workglow/dataset"; import { diff --git a/packages/ai/src/task/StructuralParserTask.ts b/packages/ai/src/task/StructuralParserTask.ts index cf80ccf5..48934885 100644 --- a/packages/ai/src/task/StructuralParserTask.ts +++ b/packages/ai/src/task/StructuralParserTask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DocumentNode, NodeIdGenerator, StructuralParser } from "@workglow/dataset"; +import { DocumentNode, StructuralParser } from "@workglow/dataset"; import { CreateWorkflow, IExecuteContext, @@ -12,7 +12,7 @@ import { Task, Workflow, } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; +import { DataPortSchema, FromSchema, uuid4 } from "@workglow/util"; const inputSchema = { type: "object", @@ -104,8 +104,7 @@ export class StructuralParserTask extends Task< const { text, title, format = "auto", sourceUri, doc_id: providedDocId } = input; // Generate or use provided doc_id - const doc_id = - providedDocId || (await NodeIdGenerator.generateDocId(sourceUri || "document", text)); + const doc_id = providedDocId || uuid4(); // Parse based on format let documentTree: DocumentNode; diff --git a/packages/dataset/README.md b/packages/dataset/README.md index f86bfa63..e9ad476b 100644 --- a/packages/dataset/README.md +++ b/packages/dataset/README.md @@ -592,22 +592,22 @@ import { TypeDocumentRepository, } from "@workglow/storage"; -// Tabular repository (format: "repository:tabular") +// Tabular repository (format: "storage:tabular") const tabularSchema = TypeTabularRepository({ title: "Data Source", - description: "Tabular data repository", + description: "Tabular data storage", }); -// Vector repository (format: "repository:document-node-vector") +// Vector repository (format: "dataset:document-chunk") const vectorSchema = TypeVectorRepository({ title: "Embeddings Store", - description: "Vector embeddings repository", + description: "Vector embeddings dataset", }); -// Document repository (format: "repository:document") +// Document repository (format: "dataset:document") const docSchema = TypeDocumentRepository({ title: "Document Store", - description: "Document storage repository", + description: "Document storage dataset", }); ``` @@ -1081,11 +1081,7 @@ describe("UserRepository", () => { let userRepo: InMemoryTabularStorage; beforeEach(() => { - userRepo = new InMemoryTabularStorage( - UserSchema, - ["id"], - ["email"] - ); + userRepo = new InMemoryTabularStorage(UserSchema, ["id"], ["email"]); }); test("should create and retrieve user", async () => { diff --git a/packages/dataset/package.json b/packages/dataset/package.json index 9bd8d372..ce883eab 100644 --- a/packages/dataset/package.json +++ b/packages/dataset/package.json @@ -39,8 +39,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" } }, diff --git a/packages/dataset/src/chunk-vector/ChunkVectorStorageRegistry.ts b/packages/dataset/src/chunk-vector/ChunkVectorStorageRegistry.ts deleted file mode 100644 index 7c51b929..00000000 --- a/packages/dataset/src/chunk-vector/ChunkVectorStorageRegistry.ts +++ /dev/null @@ -1,83 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - createServiceToken, - globalServiceRegistry, - registerInputResolver, - ServiceRegistry, -} from "@workglow/util"; -import { AnyChunkVectorStorage } from "./IChunkVectorStorage"; - -/** - * Service token for the documenbt chunk vector repository registry - * Maps repository IDs to IVectorChunkRepository instances - */ -export const DOCUMENT_CHUNK_VECTOR_REPOSITORIES = createServiceToken< - Map ->("storage.document-node-vector.repositories"); - -// Register default factory if not already registered -if (!globalServiceRegistry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES)) { - globalServiceRegistry.register( - DOCUMENT_CHUNK_VECTOR_REPOSITORIES, - (): Map => new Map(), - true - ); -} - -/** - * Gets the global document chunk vector repository registry - * @returns Map of document chunk vector repository ID to instance - */ -export function getGlobalChunkVectorRepositories(): Map { - return globalServiceRegistry.get(DOCUMENT_CHUNK_VECTOR_REPOSITORIES); -} - -/** - * Registers a vector repository globally by ID - * @param id The unique identifier for this repository - * @param repository The repository instance to register - */ -export function registerChunkVectorRepository( - id: string, - repository: AnyChunkVectorStorage -): void { - const repos = getGlobalChunkVectorRepositories(); - repos.set(id, repository); -} - -/** - * Gets a document chunk vector repository by ID from the global registry - * @param id The repository identifier - * @returns The repository instance or undefined if not found - */ -export function getChunkVectorRepository(id: string): AnyChunkVectorStorage | undefined { - return getGlobalChunkVectorRepositories().get(id); -} - -/** - * Resolves a repository ID to an IVectorChunkRepository from the registry. - * Used by the input resolver system. - */ -async function resolveChunkVectorRepositoryFromRegistry( - id: string, - format: string, - registry: ServiceRegistry -): Promise { - const repos = registry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) - ? registry.get>(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) - : getGlobalChunkVectorRepositories(); - - const repo = repos.get(id); - if (!repo) { - throw new Error(`Document chunk vector repository "${id}" not found in registry`); - } - return repo; -} - -// Register the repository resolver for format: "repository:document-node-vector" -registerInputResolver("repository:document-node-vector", resolveChunkVectorRepositoryFromRegistry); diff --git a/packages/dataset/src/chunk-vector/IChunkVectorStorage.ts b/packages/dataset/src/chunk-vector/IChunkVectorStorage.ts deleted file mode 100644 index 02ffd364..00000000 --- a/packages/dataset/src/chunk-vector/IChunkVectorStorage.ts +++ /dev/null @@ -1,105 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { - DataPortSchemaObject, - EventParameters, - FromSchema, - TypedArray, - TypedArraySchemaOptions, -} from "@workglow/util"; -import type { ITabularStorage, TabularEventListeners } from "@workglow/storage"; - -export type AnyChunkVectorStorage = IChunkVectorStorage; - -/** - * Options for vector search operations - */ -export interface VectorSearchOptions> { - readonly topK?: number; - readonly filter?: Partial; - readonly scoreThreshold?: number; -} - -/** - * Options for hybrid search (vector + full-text) - */ -export interface HybridSearchOptions< - Metadata = Record, -> extends VectorSearchOptions { - readonly textQuery: string; - readonly vectorWeight?: number; -} - -/** - * Type definitions for document chunk vector repository events - */ -export interface VectorChunkEventListeners extends TabularEventListeners< - PrimaryKey, - Entity -> { - similaritySearch: (query: TypedArray, results: (Entity & { score: number })[]) => void; - hybridSearch: (query: TypedArray, results: (Entity & { score: number })[]) => void; -} - -export type VectorChunkEventName = keyof VectorChunkEventListeners; -export type VectorChunkEventListener< - Event extends VectorChunkEventName, - PrimaryKey, - Entity, -> = VectorChunkEventListeners[Event]; - -export type VectorChunkEventParameters< - Event extends VectorChunkEventName, - PrimaryKey, - Entity, -> = EventParameters, Event>; - -/** - * Interface defining the contract for document chunk vector storage repositories. - * These repositories store vector embeddings with metadata for decument chunks. - * Extends ITabularRepository to provide standard storage operations, - * plus vector-specific similarity search capabilities. - * Supports various vector types including quantized formats. - * - * @typeParam Schema - The schema definition for the entity using JSON Schema - * @typeParam PrimaryKeyNames - Array of property names that form the primary key - * @typeParam Entity - The entity type - */ -export interface IChunkVectorStorage< - Schema extends DataPortSchemaObject, - PrimaryKeyNames extends ReadonlyArray, - Entity = FromSchema, -> extends ITabularStorage { - /** - * Get the vector dimension - * @returns The vector dimension - */ - getVectorDimensions(): number; - - /** - * Search for similar vectors using similarity scoring - * @param query - Query vector to compare against - * @param options - Search options (topK, filter, scoreThreshold) - * @returns Array of search results sorted by similarity (highest first) - */ - similaritySearch( - query: TypedArray, - options?: VectorSearchOptions> - ): Promise<(Entity & { score: number })[]>; - - /** - * Hybrid search combining vector similarity with full-text search - * This is optional and may not be supported by all implementations - * @param query - Query vector to compare against - * @param options - Hybrid search options including text query - * @returns Array of search results sorted by combined relevance - */ - hybridSearch?( - query: TypedArray, - options: HybridSearchOptions> - ): Promise<(Entity & { score: number })[]>; -} diff --git a/packages/dataset/src/chunk-vector/InMemoryChunkVectorStorage.ts b/packages/dataset/src/chunk-vector/InMemoryChunkVectorStorage.ts deleted file mode 100644 index 46afb138..00000000 --- a/packages/dataset/src/chunk-vector/InMemoryChunkVectorStorage.ts +++ /dev/null @@ -1,185 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { TypedArray } from "@workglow/util"; -import { cosineSimilarity } from "@workglow/util"; -import { InMemoryTabularStorage } from "@workglow/storage"; -import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; -import type { - HybridSearchOptions, - IChunkVectorStorage, - VectorSearchOptions, -} from "./IChunkVectorStorage"; - -/** - * Check if metadata matches filter - */ -function matchesFilter(metadata: Metadata, filter: Partial): boolean { - for (const [key, value] of Object.entries(filter)) { - if (metadata[key as keyof Metadata] !== value) { - return false; - } - } - return true; -} - -/** - * Simple full-text search scoring (keyword matching) - */ -function textRelevance(text: string, query: string): number { - const textLower = text.toLowerCase(); - const queryLower = query.toLowerCase(); - const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); - if (queryWords.length === 0) { - return 0; - } - let matches = 0; - for (const word of queryWords) { - if (textLower.includes(word)) { - matches++; - } - } - return matches / queryWords.length; -} - -/** - * In-memory document chunk vector repository implementation. - * Extends InMemoryTabularRepository for storage. - * Suitable for testing and small-scale browser applications. - * Supports all vector types including quantized formats. - * - * @template Metadata - The metadata type for the document chunk - * @template Vector - The vector type for the document chunk - */ -export class InMemoryChunkVectorStorage< - Metadata extends Record = Record, - Vector extends TypedArray = Float32Array, -> - extends InMemoryTabularStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > - implements - IChunkVectorStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > -{ - private vectorDimensions: number; - private VectorType: new (array: number[]) => TypedArray; - - /** - * Creates a new in-memory document chunk vector repository - * @param dimensions - The number of dimensions of the vector - * @param VectorType - The type of vector to use (defaults to Float32Array) - */ - constructor(dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array) { - super(ChunkVectorSchema, ChunkVectorKey); - - this.vectorDimensions = dimensions; - this.VectorType = VectorType; - } - - /** - * Get the vector dimensions - * @returns The vector dimensions - */ - getVectorDimensions(): number { - return this.vectorDimensions; - } - - async similaritySearch( - query: TypedArray, - options: VectorSearchOptions> = {} - ) { - const { topK = 10, filter, scoreThreshold = 0 } = options; - const results: Array & { score: number }> = []; - - const allEntities = (await this.getAll()) || []; - - for (const entity of allEntities) { - const vector = entity.vector; - const metadata = entity.metadata; - - // Apply filter if provided - if (filter && !matchesFilter(metadata, filter)) { - continue; - } - - // Calculate similarity - const score = cosineSimilarity(query, vector); - - // Apply threshold - if (score < scoreThreshold) { - continue; - } - - results.push({ - ...entity, - vector, - score, - }); - } - - // Sort by score descending and take top K - results.sort((a, b) => b.score - a.score); - const topResults = results.slice(0, topK); - - return topResults; - } - - async hybridSearch(query: TypedArray, options: HybridSearchOptions>) { - const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; - - if (!textQuery || textQuery.trim().length === 0) { - // Fall back to regular vector search if no text query - return this.similaritySearch(query, { topK, filter, scoreThreshold }); - } - - const results: Array & { score: number }> = []; - const allEntities = (await this.getAll()) || []; - - for (const entity of allEntities) { - // In memory, vectors are stored as TypedArrays directly (not serialized) - const vector = entity.vector; - const metadata = entity.metadata; - - // Apply filter if provided - if (filter && !matchesFilter(metadata, filter)) { - continue; - } - - // Calculate vector similarity - const vectorScore = cosineSimilarity(query, vector); - - // Calculate text relevance (simple keyword matching) - const metadataText = Object.values(metadata).join(" ").toLowerCase(); - const textScore = textRelevance(metadataText, textQuery); - - // Combine scores - const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; - - // Apply threshold - if (combinedScore < scoreThreshold) { - continue; - } - - results.push({ - ...entity, - vector, - score: combinedScore, - }); - } - - // Sort by combined score descending and take top K - results.sort((a, b) => b.score - a.score); - const topResults = results.slice(0, topK); - - return topResults; - } -} diff --git a/packages/dataset/src/chunk-vector/PostgresChunkVectorStorage.ts b/packages/dataset/src/chunk-vector/PostgresChunkVectorStorage.ts deleted file mode 100644 index 5f97b5ae..00000000 --- a/packages/dataset/src/chunk-vector/PostgresChunkVectorStorage.ts +++ /dev/null @@ -1,293 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { cosineSimilarity, type TypedArray } from "@workglow/util"; -import type { Pool } from "pg"; -import { PostgresTabularStorage } from "@workglow/storage"; -import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; -import type { - HybridSearchOptions, - IChunkVectorStorage, - VectorSearchOptions, -} from "./IChunkVectorStorage"; - -/** - * PostgreSQL document chunk vector repository implementation using pgvector extension. - * Extends PostgresTabularRepository for storage. - * Provides efficient vector similarity search with native database support. - * - * Requirements: - * - PostgreSQL database with pgvector extension installed - * - CREATE EXTENSION vector; - * - * @template Metadata - The metadata type for the document chunk - * @template Vector - The vector type for the document chunk - */ -export class PostgresChunkVectorStorage< - Metadata extends Record = Record, - Vector extends TypedArray = Float32Array, -> - extends PostgresTabularStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > - implements - IChunkVectorStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > -{ - private vectorDimensions: number; - private VectorType: new (array: number[]) => TypedArray; - /** - * Creates a new PostgreSQL document chunk vector repository - * @param db - PostgreSQL connection pool - * @param table - The name of the table to use for storage - * @param dimensions - The number of dimensions of the vector - * @param VectorType - The type of vector to use (defaults to Float32Array) - */ - constructor( - db: Pool, - table: string, - dimensions: number, - VectorType: new (array: number[]) => TypedArray = Float32Array - ) { - super(db, table, ChunkVectorSchema, ChunkVectorKey); - - this.vectorDimensions = dimensions; - this.VectorType = VectorType; - } - - getVectorDimensions(): number { - return this.vectorDimensions; - } - - async similaritySearch( - query: TypedArray, - options: VectorSearchOptions = {} - ): Promise & { score: number }>> { - const { topK = 10, filter, scoreThreshold = 0 } = options; - - try { - // Try native pgvector search first - const queryVector = `[${Array.from(query).join(",")}]`; - let sql = ` - SELECT - *, - 1 - (vector <=> $1::vector) as score - FROM "${this.table}" - `; - - const params: any[] = [queryVector]; - let paramIndex = 2; - - if (filter && Object.keys(filter).length > 0) { - const conditions: string[] = []; - for (const [key, value] of Object.entries(filter)) { - conditions.push(`metadata->>'${key}' = $${paramIndex}`); - params.push(String(value)); - paramIndex++; - } - sql += ` WHERE ${conditions.join(" AND ")}`; - } - - if (scoreThreshold > 0) { - sql += filter ? " AND" : " WHERE"; - sql += ` (1 - (vector <=> $1::vector)) >= $${paramIndex}`; - params.push(scoreThreshold); - paramIndex++; - } - - sql += ` ORDER BY vector <=> $1::vector LIMIT $${paramIndex}`; - params.push(topK); - - const result = await this.db.query(sql, params); - - // Fetch vectors separately for each result - const results: Array & { score: number }> = []; - for (const row of result.rows) { - const vectorResult = await this.db.query( - `SELECT vector::text FROM "${this.table}" WHERE id = $1`, - [row.id] - ); - const vectorStr = vectorResult.rows[0]?.vector || "[]"; - const vectorArray = JSON.parse(vectorStr); - - results.push({ - ...row, - vector: new this.VectorType(vectorArray), - score: parseFloat(row.score), - } as any); - } - - return results; - } catch (error) { - // Fall back to in-memory similarity calculation if pgvector is not available - console.warn("pgvector query failed, falling back to in-memory search:", error); - return this.searchFallback(query, options); - } - } - - async hybridSearch(query: TypedArray, options: HybridSearchOptions) { - const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; - - if (!textQuery || textQuery.trim().length === 0) { - return this.similaritySearch(query, { topK, filter, scoreThreshold }); - } - - try { - // Try native hybrid search with pgvector + full-text - const queryVector = `[${Array.from(query).join(",")}]`; - const tsQuery = textQuery.split(/\s+/).join(" & "); - - let sql = ` - SELECT - *, - ( - $2 * (1 - (vector <=> $1::vector)) + - $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) - ) as score - FROM "${this.table}" - `; - - const params: any[] = [queryVector, vectorWeight, 1 - vectorWeight, tsQuery]; - let paramIndex = 5; - - if (filter && Object.keys(filter).length > 0) { - const conditions: string[] = []; - for (const [key, value] of Object.entries(filter)) { - conditions.push(`metadata->>'${key}' = $${paramIndex}`); - params.push(String(value)); - paramIndex++; - } - sql += ` WHERE ${conditions.join(" AND ")}`; - } - - if (scoreThreshold > 0) { - sql += filter ? " AND" : " WHERE"; - sql += ` ( - $2 * (1 - (vector <=> $1::vector)) + - $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) - ) >= $${paramIndex}`; - params.push(scoreThreshold); - paramIndex++; - } - - sql += ` ORDER BY score DESC LIMIT $${paramIndex}`; - params.push(topK); - - const result = await this.db.query(sql, params); - - // Fetch vectors separately for each result - const results: Array & { score: number }> = []; - for (const row of result.rows) { - const vectorResult = await this.db.query( - `SELECT vector::text FROM "${this.table}" WHERE id = $1`, - [row.id] - ); - const vectorStr = vectorResult.rows[0]?.vector || "[]"; - const vectorArray = JSON.parse(vectorStr); - - results.push({ - ...row, - vector: new this.VectorType(vectorArray), - score: parseFloat(row.score), - } as any); - } - - return results; - } catch (error) { - // Fall back to in-memory hybrid search - console.warn("pgvector hybrid query failed, falling back to in-memory search:", error); - return this.hybridSearchFallback(query, options); - } - } - - /** - * Fallback search using in-memory cosine similarity - */ - private async searchFallback(query: TypedArray, options: VectorSearchOptions) { - const { topK = 10, filter, scoreThreshold = 0 } = options; - const allRows = (await this.getAll()) || []; - const results: Array & { score: number }> = []; - - for (const row of allRows) { - const vector = row.vector; - const metadata = row.metadata; - - if (filter && !this.matchesFilter(metadata, filter)) { - continue; - } - - const score = cosineSimilarity(query, vector); - - if (score >= scoreThreshold) { - results.push({ ...row, vector, score }); - } - } - - results.sort((a, b) => b.score - a.score); - const topResults = results.slice(0, topK); - - return topResults; - } - - /** - * Fallback hybrid search - */ - private async hybridSearchFallback(query: TypedArray, options: HybridSearchOptions) { - const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; - - const allRows = (await this.getAll()) || []; - const results: Array & { score: number }> = []; - const queryLower = textQuery.toLowerCase(); - const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); - - for (const row of allRows) { - const vector = row.vector; - const metadata = row.metadata; - - if (filter && !this.matchesFilter(metadata, filter)) { - continue; - } - - const vectorScore = cosineSimilarity(query, vector); - const metadataText = JSON.stringify(metadata).toLowerCase(); - let textScore = 0; - if (queryWords.length > 0) { - let matches = 0; - for (const word of queryWords) { - if (metadataText.includes(word)) { - matches++; - } - } - textScore = matches / queryWords.length; - } - - const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; - - if (combinedScore >= scoreThreshold) { - results.push({ ...row, vector, score: combinedScore }); - } - } - - results.sort((a, b) => b.score - a.score); - const topResults = results.slice(0, topK); - - return topResults; - } - - private matchesFilter(metadata: Metadata, filter: Partial): boolean { - for (const [key, value] of Object.entries(filter)) { - if (metadata[key as keyof Metadata] !== value) { - return false; - } - } - return true; - } -} diff --git a/packages/dataset/src/chunk-vector/SqliteChunkVectorStorage.ts b/packages/dataset/src/chunk-vector/SqliteChunkVectorStorage.ts deleted file mode 100644 index a4bcbb15..00000000 --- a/packages/dataset/src/chunk-vector/SqliteChunkVectorStorage.ts +++ /dev/null @@ -1,192 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Sqlite } from "@workglow/sqlite"; -import type { TypedArray } from "@workglow/util"; -import { cosineSimilarity } from "@workglow/util"; -import { SqliteTabularStorage } from "@workglow/storage"; -import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; -import type { - HybridSearchOptions, - IChunkVectorStorage, - VectorSearchOptions, -} from "./IChunkVectorStorage"; - -/** - * Check if metadata matches filter - */ -function matchesFilter(metadata: Metadata, filter: Partial): boolean { - for (const [key, value] of Object.entries(filter)) { - if (metadata[key as keyof Metadata] !== value) { - return false; - } - } - return true; -} - -/** - * SQLite document chunk vector repository implementation using tabular storage underneath. - * Stores vectors as JSON-encoded arrays with metadata. - * - * @template Metadata - The metadata type for the document chunk - * @template Vector - The vector type for the document chunk - */ -export class SqliteChunkVectorStorage< - Metadata extends Record = Record, - Vector extends TypedArray = Float32Array, -> - extends SqliteTabularStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > - implements - IChunkVectorStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > -{ - private vectorDimensions: number; - private VectorType: new (array: number[]) => TypedArray; - - /** - * Creates a new SQLite document chunk vector repository - * @param dbOrPath - Either a Database instance or a path to the SQLite database file - * @param table - The name of the table to use for storage (defaults to 'vectors') - * @param dimensions - The number of dimensions of the vector - * @param VectorType - The type of vector to use (defaults to Float32Array) - */ - constructor( - dbOrPath: string | Sqlite.Database, - table: string = "vectors", - dimensions: number, - VectorType: new (array: number[]) => TypedArray = Float32Array - ) { - super(dbOrPath, table, ChunkVectorSchema, ChunkVectorKey); - - this.vectorDimensions = dimensions; - this.VectorType = VectorType; - } - - getVectorDimensions(): number { - return this.vectorDimensions; - } - - /** - * Deserialize vector from JSON string - * Defaults to Float32Array for compatibility with typical embedding vectors - */ - private deserializeVector(vectorJson: string): TypedArray { - const array = JSON.parse(vectorJson); - // Default to Float32Array for typical use case (embeddings) - return new this.VectorType(array); - } - - async similaritySearch(query: TypedArray, options: VectorSearchOptions = {}) { - const { topK = 10, filter, scoreThreshold = 0 } = options; - const results: Array & { score: number }> = []; - - const allEntities = (await this.getAll()) || []; - - for (const entity of allEntities) { - // SQLite stores vectors as JSON strings, need to deserialize - const vectorRaw = entity.vector as unknown as string; - const vector = this.deserializeVector(vectorRaw); - const metadata = entity.metadata; - - // Apply filter if provided - if (filter && !matchesFilter(metadata, filter)) { - continue; - } - - // Calculate similarity - const score = cosineSimilarity(query, vector); - - // Apply threshold - if (score < scoreThreshold) { - continue; - } - - results.push({ - ...entity, - vector, - score, - } as any); - } - - // Sort by score descending and take top K - results.sort((a, b) => b.score - a.score); - const topResults = results.slice(0, topK); - - return topResults; - } - - async hybridSearch(query: TypedArray, options: HybridSearchOptions) { - const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; - - if (!textQuery || textQuery.trim().length === 0) { - // Fall back to regular vector search if no text query - return this.similaritySearch(query, { topK, filter, scoreThreshold }); - } - - const results: Array & { score: number }> = []; - const allEntities = (await this.getAll()) || []; - const queryLower = textQuery.toLowerCase(); - const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); - - for (const entity of allEntities) { - // SQLite stores vectors as JSON strings, need to deserialize - const vectorRaw = entity.vector as unknown as string; - const vector = - typeof vectorRaw === "string" - ? this.deserializeVector(vectorRaw) - : (vectorRaw as TypedArray); - const metadata = entity.metadata; - - // Apply filter if provided - if (filter && !matchesFilter(metadata, filter)) { - continue; - } - - // Calculate vector similarity - const vectorScore = cosineSimilarity(query, vector); - - // Calculate text relevance (simple keyword matching) - const metadataText = JSON.stringify(metadata).toLowerCase(); - let textScore = 0; - if (queryWords.length > 0) { - let matches = 0; - for (const word of queryWords) { - if (metadataText.includes(word)) { - matches++; - } - } - textScore = matches / queryWords.length; - } - - // Combine scores - const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; - - // Apply threshold - if (combinedScore < scoreThreshold) { - continue; - } - - results.push({ - ...entity, - vector, - score: combinedScore, - } as any); - } - - // Sort by combined score descending and take top K - results.sort((a, b) => b.score - a.score); - const topResults = results.slice(0, topK); - - return topResults; - } -} diff --git a/packages/dataset/src/common.ts b/packages/dataset/src/common.ts index 2cd1e8a7..a07aafe3 100644 --- a/packages/dataset/src/common.ts +++ b/packages/dataset/src/common.ts @@ -4,17 +4,17 @@ * SPDX-License-Identifier: Apache-2.0 */ -export * from "./util/RepositorySchema"; +export * from "./util/DatasetSchema"; export * from "./document/Document"; +export * from "./document/DocumentDataset"; +export * from "./document/DocumentDatasetRegistry"; export * from "./document/DocumentNode"; export * from "./document/DocumentRepository"; -export * from "./document/DocumentRepositoryRegistry"; export * from "./document/DocumentSchema"; export * from "./document/DocumentStorageSchema"; export * from "./document/StructuralParser"; -export * from "./chunk-vector/ChunkVectorSchema"; -export * from "./chunk-vector/ChunkVectorStorageRegistry"; -export * from "./chunk-vector/IChunkVectorStorage"; -export * from "./chunk-vector/InMemoryChunkVectorStorage"; +export * from "./document-chunk/DocumentChunkDataset"; +export * from "./document-chunk/DocumentChunkDatasetRegistry"; +export * from "./document-chunk/DocumentChunkSchema"; diff --git a/packages/dataset/src/document-chunk/DocumentChunkDataset.ts b/packages/dataset/src/document-chunk/DocumentChunkDataset.ts new file mode 100644 index 00000000..c99da309 --- /dev/null +++ b/packages/dataset/src/document-chunk/DocumentChunkDataset.ts @@ -0,0 +1,127 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { VectorSearchOptions } from "@workglow/storage"; +import type { TypedArray } from "@workglow/util"; +import type { DocumentChunk, DocumentChunkStorage } from "./DocumentChunkSchema"; + +/** + * Document Chunk Dataset + * + * A dataset-specific wrapper around vector storage for document chunks. + * This provides a domain-specific API for working with document chunk embeddings + * in RAG pipelines. + */ +export class DocumentChunkDataset { + private storage: DocumentChunkStorage; + + constructor(storage: DocumentChunkStorage) { + this.storage = storage; + } + + /** + * Get the underlying storage instance + */ + getStorage(): DocumentChunkStorage { + return this.storage; + } + + /** + * Store a document chunk + */ + async put(chunk: DocumentChunk): Promise { + return this.storage.put(chunk); + } + + /** + * Store multiple document chunks + */ + async putBulk(chunks: DocumentChunk[]): Promise { + return this.storage.putBulk(chunks); + } + + /** + * Get a document chunk by ID + */ + async get(chunk_id: string): Promise { + return this.storage.get({ chunk_id } as any); + } + + /** + * Delete a document chunk + */ + async delete(chunk_id: string): Promise { + return this.storage.delete({ chunk_id } as any); + } + + /** + * Search for similar chunks using vector similarity + */ + async similaritySearch( + query: TypedArray, + options?: VectorSearchOptions> + ): Promise> { + return this.storage.similaritySearch(query, options); + } + + /** + * Hybrid search (vector + full-text) + */ + async hybridSearch( + query: TypedArray, + options: VectorSearchOptions> & { + textQuery: string; + vectorWeight?: number; + } + ): Promise> { + if (this.storage.hybridSearch) { + return this.storage.hybridSearch(query, options); + } + throw new Error("Hybrid search not supported by this storage backend"); + } + + /** + * Get all chunks + */ + async getAll(): Promise { + return this.storage.getAll(); + } + + /** + * Get the count of stored chunks + */ + async size(): Promise { + return this.storage.size(); + } + + /** + * Clear all chunks + */ + async clear(): Promise { + return (this.storage as any).clear(); + } + + /** + * Destroy the storage + */ + destroy(): void { + return this.storage.destroy(); + } + + /** + * Setup the database/storage + */ + async setupDatabase(): Promise { + return this.storage.setupDatabase(); + } + + /** + * Get the vector dimensions + */ + getVectorDimensions(): number { + return this.storage.getVectorDimensions(); + } +} diff --git a/packages/dataset/src/document-chunk/DocumentChunkDatasetRegistry.ts b/packages/dataset/src/document-chunk/DocumentChunkDatasetRegistry.ts new file mode 100644 index 00000000..76b1b90d --- /dev/null +++ b/packages/dataset/src/document-chunk/DocumentChunkDatasetRegistry.ts @@ -0,0 +1,79 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import type { DocumentChunkDataset } from "./DocumentChunkDataset"; + +/** + * Service token for the document chunk dataset registry + * Maps dataset IDs to DocumentChunkDataset instances + */ +export const DOCUMENT_CHUNK_DATASET = + createServiceToken>("dataset.document-chunk"); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(DOCUMENT_CHUNK_DATASET)) { + globalServiceRegistry.register( + DOCUMENT_CHUNK_DATASET, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global document chunk dataset registry + * @returns Map of document chunk dataset ID to instance + */ +export function getGlobalDocumentChunkDataset(): Map { + return globalServiceRegistry.get(DOCUMENT_CHUNK_DATASET); +} + +/** + * Registers a document chunk dataset globally by ID + * @param id The unique identifier for this dataset + * @param dataset The dataset instance to register + */ +export function registerDocumentChunkDataset(id: string, dataset: DocumentChunkDataset): void { + const datasets = getGlobalDocumentChunkDataset(); + datasets.set(id, dataset); +} + +/** + * Gets a document chunk dataset by ID from the global registry + * @param id The dataset identifier + * @returns The dataset instance or undefined if not found + */ +export function getDocumentChunkDataset(id: string): DocumentChunkDataset | undefined { + return getGlobalDocumentChunkDataset().get(id); +} + +/** + * Resolves a dataset ID to a DocumentChunkDataset from the registry. + * Used by the input resolver system. + */ +async function resolveDocumentChunkDatasetFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise { + const datasets = registry.has(DOCUMENT_CHUNK_DATASET) + ? registry.get>(DOCUMENT_CHUNK_DATASET) + : getGlobalDocumentChunkDataset(); + + const dataset = datasets.get(id); + if (!dataset) { + throw new Error(`Document chunk dataset "${id}" not found in registry`); + } + return dataset; +} + +// Register the dataset resolver for format: "dataset:document-chunk" +registerInputResolver("dataset:document-chunk", resolveDocumentChunkDatasetFromRegistry); diff --git a/packages/dataset/src/chunk-vector/ChunkVectorSchema.ts b/packages/dataset/src/document-chunk/DocumentChunkSchema.ts similarity index 53% rename from packages/dataset/src/chunk-vector/ChunkVectorSchema.ts rename to packages/dataset/src/document-chunk/DocumentChunkSchema.ts index 8ce95438..c490f5c5 100644 --- a/packages/dataset/src/chunk-vector/ChunkVectorSchema.ts +++ b/packages/dataset/src/document-chunk/DocumentChunkSchema.ts @@ -4,32 +4,39 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { IVectorStorage } from "@workglow/storage"; import { TypedArraySchema, type DataPortSchemaObject, type TypedArray } from "@workglow/util"; /** * Default schema for document chunk storage with vector embeddings */ -export const ChunkVectorSchema = { +export const DocumentChunkSchema = { type: "object", properties: { chunk_id: { type: "string" }, doc_id: { type: "string" }, vector: TypedArraySchema(), - metadata: { type: "object", additionalProperties: true }, + metadata: { type: "object", format: "metadata", additionalProperties: true }, }, additionalProperties: false, } as const satisfies DataPortSchemaObject; -export type ChunkVectorSchema = typeof ChunkVectorSchema; +export type DocumentChunkSchema = typeof DocumentChunkSchema; -export const ChunkVectorKey = ["chunk_id"] as const; -export type ChunkVectorKey = typeof ChunkVectorKey; +export const DocumentChunkPrimaryKey = ["chunk_id"] as const; +export type DocumentChunkPrimaryKey = typeof DocumentChunkPrimaryKey; -export interface ChunkVector< +export interface DocumentChunk< Metadata extends Record = Record, - Vector extends TypedArray = Float32Array, + Vector extends TypedArray = TypedArray, > { chunk_id: string; doc_id: string; vector: Vector; metadata: Metadata; } + +export type DocumentChunkStorage = IVectorStorage< + Record, + typeof DocumentChunkSchema, + DocumentChunk +>; diff --git a/packages/dataset/src/chunk-vector/README.md b/packages/dataset/src/document-chunk/README.md similarity index 65% rename from packages/dataset/src/chunk-vector/README.md rename to packages/dataset/src/document-chunk/README.md index f64c8ca0..b2d2ef04 100644 --- a/packages/dataset/src/chunk-vector/README.md +++ b/packages/dataset/src/document-chunk/README.md @@ -1,52 +1,38 @@ -# Chunk Vector Storage Module +# Document Chunk Dataset -Storage for document chunk embeddings with vector similarity search capabilities. Extends the tabular repository pattern to add vector search functionality for RAG (Retrieval-Augmented Generation) pipelines. +Document-specific schema and utilities for storing document chunk embeddings. Uses the general-purpose vector storage from `@workglow/storage` with a predefined schema for document chunks in RAG (Retrieval-Augmented Generation) pipelines. ## Features -- **Multiple Storage Backends:** - - 🧠 `InMemoryChunkVectorStorage` - Fast in-memory storage for testing and small datasets - - 📁 `SqliteChunkVectorStorage` - Persistent SQLite storage for local applications - - 🐘 `PostgresChunkVectorStorage` - PostgreSQL with pgvector extension for production - -- **Quantized Vector Support:** - - Float32Array (standard 32-bit floating point) - - Float16Array (16-bit floating point) - - Float64Array (64-bit high precision) - - Int8Array (8-bit signed - binary quantization) - - Uint8Array (8-bit unsigned - quantization) - - Int16Array (16-bit signed - quantization) - - Uint16Array (16-bit unsigned - quantization) - -- **Search Capabilities:** - - Vector similarity search (cosine similarity) - - Hybrid search (vector + full-text keyword matching) - - Metadata filtering - - Top-K retrieval with score thresholds - -- **Built on Tabular Repositories:** - - Extends `ITabularStorage` for standard CRUD operations - - Inherits event emitter pattern for monitoring - - Type-safe schema-based storage +- **Predefined Schema**: `DocumentChunkSchema` with fields for chunk_id, doc_id, vector, and metadata +- **Registry Pattern**: Register and retrieve chunk storage instances globally +- **Type Safety**: Full TypeScript type definitions for document chunks +- **Storage Agnostic**: Works with any vector storage backend (InMemory, SQLite, PostgreSQL) ## Installation ```bash -bun install @workglow/storage +bun install @workglow/dataset @workglow/storage ``` ## Usage -### In-Memory Repository (Testing/Development) +### Basic Usage with InMemoryVectorStorage ```typescript -import { InMemoryChunkVectorStorage } from "@workglow/storage"; - -// Create repository with 384 dimensions -const repo = new InMemoryChunkVectorStorage(384); +import { DocumentChunkSchema, DocumentChunkPrimaryKey } from "@workglow/dataset"; +import { InMemoryVectorStorage } from "@workglow/storage"; + +// Create storage using the DocumentChunkSchema +const repo = new InMemoryVectorStorage( + DocumentChunkSchema, + DocumentChunkPrimaryKey, + [], // indexes (optional) + 384 // vector dimensions +); await repo.setupDatabase(); -// Store a chunk with its embedding +// Store a document chunk with its embedding await repo.put({ chunk_id: "chunk-001", doc_id: "doc-001", @@ -64,10 +50,17 @@ const results = await repo.similaritySearch(new Float32Array([0.15, 0.25, 0.35 / ### Quantized Vectors (Reduced Storage) ```typescript -import { InMemoryChunkVectorStorage } from "@workglow/storage"; +import { DocumentChunkSchema, DocumentChunkPrimaryKey } from "@workglow/dataset"; +import { InMemoryVectorStorage } from "@workglow/storage"; // Use Int8Array for 4x smaller storage (binary quantization) -const repo = new InMemoryChunkVectorStorage<{ text: string }, Int8Array>(384, Int8Array); +const repo = new InMemoryVectorStorage( + DocumentChunkSchema, + DocumentChunkPrimaryKey, + [], + 384, + Int8Array // Specify vector type +); await repo.setupDatabase(); // Store quantized vectors @@ -82,15 +75,19 @@ await repo.put({ const results = await repo.similaritySearch(new Int8Array([100, -50, 75 /* ... */]), { topK: 5 }); ``` -### SQLite Repository (Local Persistence) +### SQLite Storage (Local Persistence) ```typescript -import { SqliteChunkVectorStorage } from "@workglow/storage"; - -const repo = new SqliteChunkVectorStorage<{ text: string }>( - "./vectors.db", // database path - "chunks", // table name - 768 // vector dimension +import { DocumentChunkSchema, DocumentChunkPrimaryKey } from "@workglow/dataset"; +import { SqliteVectorStorage } from "@workglow/storage"; + +const repo = new SqliteVectorStorage( + "./vectors.db", // database path + "chunks", // table name + DocumentChunkSchema, + DocumentChunkPrimaryKey, + [], // indexes + 768 // vector dimension ); await repo.setupDatabase(); @@ -105,12 +102,16 @@ await repo.putMany([ ```typescript import { Pool } from "pg"; -import { PostgresChunkVectorStorage } from "@workglow/storage"; +import { DocumentChunkSchema, DocumentChunkPrimaryKey } from "@workglow/dataset"; +import { PostgresVectorStorage } from "@workglow/storage"; const pool = new Pool({ connectionString: "postgresql://..." }); -const repo = new PostgresChunkVectorStorage<{ text: string; category: string }>( +const repo = new PostgresVectorStorage( pool, "chunks", + DocumentChunkSchema, + DocumentChunkPrimaryKey, + [], 384 // vector dimension ); await repo.setupDatabase(); @@ -131,39 +132,47 @@ const hybridResults = await repo.hybridSearch(queryVector, { }); ``` -## Data Model +## Schema Definition -### ChunkVector Schema +### DocumentChunkSchema -Each chunk vector entry contains: +The predefined schema for document chunks: ```typescript -interface ChunkVector< - Metadata extends Record = Record, - Vector extends TypedArray = Float32Array, -> { - chunk_id: string; // Unique identifier for the chunk - doc_id: string; // Parent document identifier - vector: Vector; // Embedding vector - metadata: Metadata; // Custom metadata (text content, entities, etc.) -} -``` - -### Default Schema +import { TypedArraySchema } from "@workglow/util"; -```typescript -const ChunkVectorSchema = { +export const DocumentChunkSchema = { type: "object", properties: { chunk_id: { type: "string" }, doc_id: { type: "string" }, - vector: TypedArraySchema(), - metadata: { type: "object", additionalProperties: true }, + vector: TypedArraySchema(), // Automatically detected as vector column + metadata: { + type: "object", + format: "metadata", // Marked for filtering support + additionalProperties: true, + }, }, additionalProperties: false, } as const; -const ChunkVectorKey = ["chunk_id"] as const; +export const DocumentChunkPrimaryKey = ["chunk_id"] as const; +``` + +### DocumentChunk Type + +TypeScript interface for document chunks: + +```typescript +interface DocumentChunk< + Metadata extends Record = Record, + Vector extends TypedArray = Float32Array, +> { + chunk_id: string; // Unique identifier for the chunk + doc_id: string; // Parent document identifier + vector: Vector; // Embedding vector + metadata: Metadata; // Custom metadata (text content, entities, etc.) +} ``` ## API Reference @@ -234,23 +243,29 @@ interface HybridSearchOptions extends VectorSearchOptions { ## Global Registry -Register and retrieve chunk vector repositories globally: +Register and retrieve chunk vector storage instances globally: ```typescript import { + DocumentChunkSchema, + DocumentChunkPrimaryKey, registerChunkVectorRepository, - getChunkVectorRepository, - getGlobalChunkVectorRepositories, -} from "@workglow/storage"; + getDocumentChunkDataset, + getGlobalDocumentChunkDataset, +} from "@workglow/dataset"; +import { InMemoryVectorStorage } from "@workglow/storage"; + +// Create and register a storage instance +const repo = new InMemoryVectorStorage(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 384); +await repo.setupDatabase(); -// Register a repository registerChunkVectorRepository("my-chunks", repo); // Retrieve by ID -const repo = getChunkVectorRepository("my-chunks"); +const retrievedRepo = getDocumentChunkDataset("my-chunks"); -// Get all registered repositories -const allRepos = getGlobalChunkVectorRepositories(); +// Get all registered storage instances +const allRepos = getGlobalDocumentChunkDataset(); ``` ## Quantization Benefits @@ -294,21 +309,27 @@ Quantized vectors reduce storage and can improve performance: ## Integration with DocumentRepository -The chunk vector repository works alongside `DocumentRepository` for hierarchical document storage: +Document chunk storage works alongside `DocumentRepository` for hierarchical document management: ```typescript import { DocumentRepository, - InMemoryChunkVectorStorage, - InMemoryTabularStorage, -} from "@workglow/storage"; -import { DocumentStorageSchema } from "@workglow/storage"; + DocumentStorageSchema, + DocumentChunkSchema, + DocumentChunkPrimaryKey, +} from "@workglow/dataset"; +import { InMemoryTabularStorage, InMemoryVectorStorage } from "@workglow/storage"; // Initialize storage backends const tabularStorage = new InMemoryTabularStorage(DocumentStorageSchema, ["doc_id"]); await tabularStorage.setupDatabase(); -const vectorStorage = new InMemoryChunkVectorStorage(384); +const vectorStorage = new InMemoryVectorStorage( + DocumentChunkSchema, + DocumentChunkPrimaryKey, + [], + 384 +); await vectorStorage.setupDatabase(); // Create document repository with both storages diff --git a/packages/dataset/src/document/DocumentDataset.ts b/packages/dataset/src/document/DocumentDataset.ts new file mode 100644 index 00000000..763dbb44 --- /dev/null +++ b/packages/dataset/src/document/DocumentDataset.ts @@ -0,0 +1,204 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { VectorSearchOptions } from "@workglow/storage"; +import type { TypedArray } from "@workglow/util"; +import type { DocumentChunk, DocumentChunkStorage } from "../document-chunk/DocumentChunkSchema"; +import { Document } from "./Document"; +import { ChunkNode, DocumentNode } from "./DocumentSchema"; +import { DocumentStorageEntity, DocumentTabularStorage } from "./DocumentStorageSchema"; + +/** + * Document dataset that uses TabularStorage for document persistence and VectorStorage for chunk persistence and similarity search. + * This is a unified implementation that composes storage backends rather than using + * inheritance/interface patterns. + */ +export class DocumentDataset { + private tabularStorage: DocumentTabularStorage; + private vectorStorage?: DocumentChunkStorage; + + /** + * Creates a new DocumentDataset instance. + * + * @param tabularStorage - Pre-initialized tabular storage for document persistence + * @param vectorStorage - Pre-initialized vector storage for chunk similarity search + * + * @example + * ```typescript + * const tabularStorage = new InMemoryTabularStorage(DocumentStorageSchema, ["doc_id"]); + * await tabularStorage.setupDatabase(); + * + * const vectorStorage = new InMemoryVectorStorage(); + * await vectorStorage.setupDatabase(); + * + * const docDataset = new DocumentDataset(tabularStorage, vectorStorage); + * ``` + */ + constructor(tabularStorage: DocumentTabularStorage, vectorStorage?: DocumentChunkStorage) { + this.tabularStorage = tabularStorage; + this.vectorStorage = vectorStorage; + } + + /** + * Upsert a document + */ + async upsert(document: Document): Promise { + const serialized = JSON.stringify(document.toJSON ? document.toJSON() : document); + await this.tabularStorage.put({ + doc_id: document.doc_id, + data: serialized, + }); + } + + /** + * Get a document by ID + */ + async get(doc_id: string): Promise { + const entity = await this.tabularStorage.get({ doc_id: doc_id }); + if (!entity) { + return undefined; + } + return Document.fromJSON(entity.data); + } + + /** + * Delete a document + */ + async delete(doc_id: string): Promise { + await this.tabularStorage.delete({ doc_id: doc_id }); + } + + /** + * Get a specific node by ID + */ + async getNode(doc_id: string, nodeId: string): Promise { + const doc = await this.get(doc_id); + if (!doc) { + return undefined; + } + + // Traverse tree to find node + const traverse = (node: any): any => { + if (node.nodeId === nodeId) { + return node; + } + if (node.children && Array.isArray(node.children)) { + for (const child of node.children) { + const found = traverse(child); + if (found) return found; + } + } + return undefined; + }; + + return traverse(doc.root); + } + + /** + * Get ancestors of a node (from root to node) + */ + async getAncestors(doc_id: string, nodeId: string): Promise { + const doc = await this.get(doc_id); + if (!doc) { + return []; + } + + // Get path from root to target node + const path: string[] = []; + const findPath = (node: any): boolean => { + path.push(node.nodeId); + if (node.nodeId === nodeId) { + return true; + } + if (node.children && Array.isArray(node.children)) { + for (const child of node.children) { + if (findPath(child)) { + return true; + } + } + } + path.pop(); + return false; + }; + + if (!findPath(doc.root)) { + return []; + } + + // Collect nodes along the path + const ancestors: any[] = []; + let currentNode: any = doc.root; + ancestors.push(currentNode); + + for (let i = 1; i < path.length; i++) { + const targetId = path[i]; + if (currentNode.children && Array.isArray(currentNode.children)) { + const found = currentNode.children.find((child: any) => child.nodeId === targetId); + if (found) { + currentNode = found; + ancestors.push(currentNode); + } else { + break; + } + } else { + break; + } + } + + return ancestors; + } + + /** + * Get chunks for a document + */ + async getChunks(doc_id: string): Promise { + const doc = await this.get(doc_id); + if (!doc) { + return []; + } + return doc.getChunks(); + } + + /** + * Find chunks that contain a specific nodeId in their path + */ + async findChunksByNodeId(doc_id: string, nodeId: string): Promise { + const doc = await this.get(doc_id); + if (!doc) { + return []; + } + if (doc.findChunksByNodeId) { + return doc.findChunksByNodeId(nodeId); + } + // Fallback implementation + const chunks = doc.getChunks(); + return chunks.filter((chunk) => chunk.nodePath && chunk.nodePath.includes(nodeId)); + } + + /** + * List all document IDs + */ + async list(): Promise { + const entities = await this.tabularStorage.getAll(); + if (!entities) { + return []; + } + return entities.map((e: DocumentStorageEntity) => e.doc_id); + } + + /** + * Search for similar vectors using the vector storage + * @param query - Query vector to search for + * @param options - Search options (topK, filter, scoreThreshold) + * @returns Array of search results sorted by similarity + */ + async search( + query: TypedArray, + options?: VectorSearchOptions> + ): Promise, TypedArray>>> { + return this.vectorStorage?.similaritySearch(query, options) || []; + } +} diff --git a/packages/dataset/src/document/DocumentDatasetRegistry.ts b/packages/dataset/src/document/DocumentDatasetRegistry.ts new file mode 100644 index 00000000..2e395da4 --- /dev/null +++ b/packages/dataset/src/document/DocumentDatasetRegistry.ts @@ -0,0 +1,79 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import type { DocumentDataset } from "./DocumentDataset"; + +/** + * Service token for the document dataset registry + * Maps dataset IDs to DocumentDataset instances + */ +export const DOCUMENT_DATASETS = + createServiceToken>("dataset.documents"); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(DOCUMENT_DATASETS)) { + globalServiceRegistry.register( + DOCUMENT_DATASETS, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global document dataset registry + * @returns Map of document dataset ID to instance + */ +export function getGlobalDocumentDatasets(): Map { + return globalServiceRegistry.get(DOCUMENT_DATASETS); +} + +/** + * Registers a document dataset globally by ID + * @param id The unique identifier for this dataset + * @param dataset The dataset instance to register + */ +export function registerDocumentDataset(id: string, dataset: DocumentDataset): void { + const datasets = getGlobalDocumentDatasets(); + datasets.set(id, dataset); +} + +/** + * Gets a document dataset by ID from the global registry + * @param id The dataset identifier + * @returns The dataset instance or undefined if not found + */ +export function getDocumentDataset(id: string): DocumentDataset | undefined { + return getGlobalDocumentDatasets().get(id); +} + +/** + * Resolves a dataset ID to a DocumentDataset from the registry. + * Used by the input resolver system. + */ +async function resolveDocumentDatasetFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise { + const datasets = registry.has(DOCUMENT_DATASETS) + ? registry.get>(DOCUMENT_DATASETS) + : getGlobalDocumentDatasets(); + + const dataset = datasets.get(id); + if (!dataset) { + throw new Error(`Document dataset "${id}" not found in registry`); + } + return dataset; +} + +// Register the dataset resolver for format: "dataset:document" +registerInputResolver("dataset:document", resolveDocumentDatasetFromRegistry); diff --git a/packages/dataset/src/document/DocumentNode.ts b/packages/dataset/src/document/DocumentNode.ts index dd68a896..2d11a25e 100644 --- a/packages/dataset/src/document/DocumentNode.ts +++ b/packages/dataset/src/document/DocumentNode.ts @@ -4,72 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { sha256, uuid4 } from "@workglow/util"; - import { NodeKind, type DocumentNode, type DocumentRootNode, - type NodeKind as NodeKindType, type NodeRange, type SectionNode, type TopicNode, } from "./DocumentSchema"; -/** - * Utility functions for ID generation - */ -export class NodeIdGenerator { - /** - * Generate doc_id from source URI and content hash - */ - static async generateDocId(sourceUri: string, content: string): Promise { - return uuid4(); - const contentHash = await sha256(content); - const combined = `${sourceUri}|${contentHash}`; - const hash = await sha256(combined); - return `doc_${hash.substring(0, 16)}`; - } - - /** - * Generate nodeId for structural nodes (document, section) - */ - static async generateStructuralNodeId( - doc_id: string, - kind: NodeKindType, - range: NodeRange - ): Promise { - return uuid4(); - const combined = `${doc_id}|${kind}|${range.startOffset}:${range.endOffset}`; - const hash = await sha256(combined); - return `node_${hash.substring(0, 16)}`; - } - - /** - * Generate nodeId for child nodes (paragraph, topic) - */ - static async generateChildNodeId(parentNodeId: string, ordinal: number): Promise { - return uuid4(); - const combined = `${parentNodeId}|${ordinal}`; - const hash = await sha256(combined); - return `node_${hash.substring(0, 16)}`; - } - - /** - * Generate chunkId - */ - static async generateChunkId( - doc_id: string, - leafNodeId: string, - chunkOrdinal: number - ): Promise { - return uuid4(); - const combined = `${doc_id}|${leafNodeId}|${chunkOrdinal}`; - const hash = await sha256(combined); - return `chunk_${hash.substring(0, 16)}`; - } -} - /** * Approximate token counting (v1) */ diff --git a/packages/dataset/src/document/DocumentRepository.ts b/packages/dataset/src/document/DocumentRepository.ts index bc92afbb..b26a0620 100644 --- a/packages/dataset/src/document/DocumentRepository.ts +++ b/packages/dataset/src/document/DocumentRepository.ts @@ -4,13 +4,9 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { AnyVectorStorage, ITabularStorage, VectorSearchOptions } from "@workglow/storage"; import type { TypedArray } from "@workglow/util"; -import { ChunkVector } from "../chunk-vector/ChunkVectorSchema"; -import type { - AnyChunkVectorStorage, - VectorSearchOptions, -} from "../chunk-vector/IChunkVectorStorage"; -import type { ITabularStorage } from "@workglow/storage"; +import type { DocumentChunk } from "../document-chunk/DocumentChunkSchema"; import { Document } from "./Document"; import { ChunkNode, DocumentNode } from "./DocumentSchema"; import { @@ -29,7 +25,7 @@ export class DocumentRepository { DocumentStorageKey, DocumentStorageEntity >; - private vectorStorage?: AnyChunkVectorStorage; + private vectorStorage?: AnyVectorStorage; /** * Creates a new DocumentRepository instance. @@ -54,7 +50,7 @@ export class DocumentRepository { ["doc_id"], DocumentStorageEntity >, - vectorStorage?: AnyChunkVectorStorage + vectorStorage?: AnyVectorStorage ) { this.tabularStorage = tabularStorage; this.vectorStorage = vectorStorage; @@ -204,7 +200,7 @@ export class DocumentRepository { if (!entities) { return []; } - return entities.map((e: DocumentStorageEntity) => e.doc_id); + return entities.map((e) => e.doc_id); } /** @@ -216,7 +212,7 @@ export class DocumentRepository { async search( query: TypedArray, options?: VectorSearchOptions> - ): Promise, TypedArray>>> { - return this.vectorStorage?.similaritySearch(query, options) || []; + ): Promise, TypedArray> & { score: number }>> { + return (this.vectorStorage?.similaritySearch(query, options) || []) as any; } } diff --git a/packages/dataset/src/document/DocumentRepositoryRegistry.ts b/packages/dataset/src/document/DocumentRepositoryRegistry.ts deleted file mode 100644 index 0f011539..00000000 --- a/packages/dataset/src/document/DocumentRepositoryRegistry.ts +++ /dev/null @@ -1,79 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - createServiceToken, - globalServiceRegistry, - registerInputResolver, - ServiceRegistry, -} from "@workglow/util"; -import type { DocumentRepository } from "./DocumentRepository"; - -/** - * Service token for the document repository registry - * Maps repository IDs to DocumentRepository instances - */ -export const DOCUMENT_REPOSITORIES = - createServiceToken>("document.repositories"); - -// Register default factory if not already registered -if (!globalServiceRegistry.has(DOCUMENT_REPOSITORIES)) { - globalServiceRegistry.register( - DOCUMENT_REPOSITORIES, - (): Map => new Map(), - true - ); -} - -/** - * Gets the global document repository registry - * @returns Map of document repository ID to instance - */ -export function getGlobalDocumentRepositories(): Map { - return globalServiceRegistry.get(DOCUMENT_REPOSITORIES); -} - -/** - * Registers a document repository globally by ID - * @param id The unique identifier for this repository - * @param repository The repository instance to register - */ -export function registerDocumentRepository(id: string, repository: DocumentRepository): void { - const repos = getGlobalDocumentRepositories(); - repos.set(id, repository); -} - -/** - * Gets a document repository by ID from the global registry - * @param id The repository identifier - * @returns The repository instance or undefined if not found - */ -export function getDocumentRepository(id: string): DocumentRepository | undefined { - return getGlobalDocumentRepositories().get(id); -} - -/** - * Resolves a repository ID to a DocumentRepository from the registry. - * Used by the input resolver system. - */ -async function resolveDocumentRepositoryFromRegistry( - id: string, - format: string, - registry: ServiceRegistry -): Promise { - const repos = registry.has(DOCUMENT_REPOSITORIES) - ? registry.get>(DOCUMENT_REPOSITORIES) - : getGlobalDocumentRepositories(); - - const repo = repos.get(id); - if (!repo) { - throw new Error(`Document repository "${id}" not found in registry`); - } - return repo; -} - -// Register the repository resolver for format: "repository:document" -registerInputResolver("repository:document", resolveDocumentRepositoryFromRegistry); diff --git a/packages/dataset/src/document/DocumentStorageSchema.ts b/packages/dataset/src/document/DocumentStorageSchema.ts index 40e518c6..d65eb454 100644 --- a/packages/dataset/src/document/DocumentStorageSchema.ts +++ b/packages/dataset/src/document/DocumentStorageSchema.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { ITabularStorage } from "@workglow/storage"; import { TypedArraySchemaOptions, type DataPortSchemaObject, @@ -41,3 +42,9 @@ export const DocumentStorageKey = ["doc_id"] as const; export type DocumentStorageKey = typeof DocumentStorageKey; export type DocumentStorageEntity = FromSchema; + +export type DocumentTabularStorage = ITabularStorage< + typeof DocumentStorageSchema, + DocumentStorageKey, + DocumentStorageEntity +>; diff --git a/packages/dataset/src/document/StructuralParser.ts b/packages/dataset/src/document/StructuralParser.ts index 3f66033b..b272f2eb 100644 --- a/packages/dataset/src/document/StructuralParser.ts +++ b/packages/dataset/src/document/StructuralParser.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { NodeIdGenerator } from "./DocumentNode"; +import { uuid4 } from "@workglow/util"; import { type DocumentRootNode, NodeKind, @@ -28,10 +28,7 @@ export class StructuralParser { let currentOffset = 0; const root: DocumentRootNode = { - nodeId: await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.DOCUMENT, { - startOffset: 0, - endOffset: text.length, - }), + nodeId: uuid4(), kind: NodeKind.DOCUMENT, range: { startOffset: 0, endOffset: text.length }, text: title, @@ -51,10 +48,7 @@ export class StructuralParser { const paragraphEndOffset = currentOffset; const paragraph: ParagraphNode = { - nodeId: await NodeIdGenerator.generateChildNodeId( - currentParentStack[currentParentStack.length - 1].nodeId, - currentParentStack[currentParentStack.length - 1].children.length - ), + nodeId: uuid4(), kind: NodeKind.PARAGRAPH, range: { startOffset: paragraphStartOffset, @@ -102,10 +96,7 @@ export class StructuralParser { const sectionStartOffset = currentOffset; const section: SectionNode = { - nodeId: await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.SECTION, { - startOffset: sectionStartOffset, - endOffset: text.length, // Will be updated when section closes - }), + nodeId: uuid4(), kind: NodeKind.SECTION, level, title: headerTitle, @@ -159,10 +150,7 @@ export class StructuralParser { title: string ): Promise { const root: DocumentRootNode = { - nodeId: await NodeIdGenerator.generateStructuralNodeId(doc_id, NodeKind.DOCUMENT, { - startOffset: 0, - endOffset: text.length, - }), + nodeId: uuid4(), kind: NodeKind.DOCUMENT, range: { startOffset: 0, endOffset: text.length }, text: title, @@ -186,7 +174,7 @@ export class StructuralParser { const endOffset = startOffset + paragraphText.length; const paragraph: ParagraphNode = { - nodeId: await NodeIdGenerator.generateChildNodeId(root.nodeId, paragraphIndex), + nodeId: uuid4(), kind: NodeKind.PARAGRAPH, range: { startOffset, @@ -213,7 +201,7 @@ export class StructuralParser { const endOffset = startOffset + paragraphText.length; const paragraph: ParagraphNode = { - nodeId: await NodeIdGenerator.generateChildNodeId(root.nodeId, paragraphIndex), + nodeId: uuid4(), kind: NodeKind.PARAGRAPH, range: { startOffset, diff --git a/packages/dataset/src/util/DatasetSchema.ts b/packages/dataset/src/util/DatasetSchema.ts new file mode 100644 index 00000000..d0a75f8f --- /dev/null +++ b/packages/dataset/src/util/DatasetSchema.ts @@ -0,0 +1,89 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { JsonSchema } from "@workglow/util"; + +/** + * Semantic format types for dataset schema annotations. + * These are used by the InputResolver to determine how to resolve string IDs. + */ +export type DatasetSemantic = "dataset:tabular" | "dataset:document-chunk" | "dataset:document"; + +/** + * Creates a JSON schema for a tabular dataset input. + * The schema accepts either a string ID (resolved from registry) or a direct dataset instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for tabular dataset input + * + * @example + * ```typescript + * const inputSchema = { + * type: "object", + * properties: { + * dataSource: TypeTabularRepository({ + * title: "User Database", + * description: "Dataset containing user records", + * }), + * }, + * required: ["dataSource"], + * } as const; + * ``` + */ +export function TypeTabularStorage = {}>(options: O = {} as O) { + return { + title: "Tabular Storage", + description: "Storage ID or instance for tabular data storage", + ...options, + format: "storage:tabular" as const, + oneOf: [ + { type: "string" as const, title: "Storage ID" }, + { title: "Storage Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} + +/** + * Creates a JSON schema for a document chunk dataset input. + * The schema accepts either a string ID (resolved from registry) or a direct dataset instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for document chunk dataset input + */ +export function TypeDocumentChunkDataset = {}>( + options: O = {} as O +) { + return { + title: "Document Chunk Dataset", + description: "Dataset ID or instance for document chunk data storage", + ...options, + format: "dataset:document-chunk" as const, + anyOf: [ + { type: "string" as const, title: "Dataset ID" }, + { title: "Dataset Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} + +/** + * Creates a JSON schema for a document dataset input. + * The schema accepts either a string ID (resolved from registry) or a direct dataset instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for document dataset input + */ +export function TypeDocumentDataset = {}>(options: O = {} as O) { + return { + title: "Document Dataset", + description: "Dataset ID or instance for document data storage", + ...options, + format: "dataset:document" as const, + anyOf: [ + { type: "string" as const, title: "Dataset ID" }, + { title: "Dataset Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} diff --git a/packages/dataset/src/util/RepositorySchema.ts b/packages/dataset/src/util/RepositorySchema.ts deleted file mode 100644 index b7ccc3f1..00000000 --- a/packages/dataset/src/util/RepositorySchema.ts +++ /dev/null @@ -1,96 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { JsonSchema } from "@workglow/util"; - -/** - * Semantic format types for repository schema annotations. - * These are used by the InputResolver to determine how to resolve string IDs. - */ -export type RepositorySemantic = - | "repository:tabular" - | "repository:document-node-vector" - | "repository:document"; - -/** - * Creates a JSON schema for a tabular repository input. - * The schema accepts either a string ID (resolved from registry) or a direct repository instance. - * - * @param options Additional schema options to merge - * @returns JSON schema for tabular repository input - * - * @example - * ```typescript - * const inputSchema = { - * type: "object", - * properties: { - * dataSource: TypeTabularRepository({ - * title: "User Database", - * description: "Repository containing user records", - * }), - * }, - * required: ["dataSource"], - * } as const; - * ``` - */ -export function TypeTabularRepository = {}>( - options: O = {} as O -) { - return { - title: "Tabular Repository", - description: "Repository ID or instance for tabular data storage", - ...options, - format: "repository:tabular" as const, - oneOf: [ - { type: "string" as const, title: "Repository ID" }, - { title: "Repository Instance", additionalProperties: true }, - ], - } as const satisfies JsonSchema; -} - -/** - * Creates a JSON schema for a vector repository input. - * The schema accepts either a string ID (resolved from registry) or a direct repository instance. - * - * @param options Additional schema options to merge - * @returns JSON schema for vector repository input - */ -export function TypeChunkVectorRepository = {}>( - options: O = {} as O -) { - return { - title: "Document Chunk Vector Repository", - description: "Repository ID or instance for document chunk vector data storage", - ...options, - format: "repository:document-node-vector" as const, - anyOf: [ - { type: "string" as const, title: "Repository ID" }, - { title: "Repository Instance", additionalProperties: true }, - ], - } as const satisfies JsonSchema; -} - -/** - * Creates a JSON schema for a document repository input. - * The schema accepts either a string ID (resolved from registry) or a direct repository instance. - * - * @param options Additional schema options to merge - * @returns JSON schema for document repository input - */ -export function TypeDocumentRepository = {}>( - options: O = {} as O -) { - return { - title: "Document Repository", - description: "Repository ID or instance for document data storage", - ...options, - format: "repository:document" as const, - anyOf: [ - { type: "string" as const, title: "Repository ID" }, - { title: "Repository Instance", additionalProperties: true }, - ], - } as const satisfies JsonSchema; -} diff --git a/packages/dataset/tsconfig.json b/packages/dataset/tsconfig.json index b4562c70..5aab665e 100644 --- a/packages/dataset/tsconfig.json +++ b/packages/dataset/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../tsconfig.json", "include": ["src/common.ts", "src/common-server.ts", "src/*/**/*"], "files": ["./src/types.ts"], - "exclude": ["dist", "src/chunk-vector/PostgresChunkVectorStorage.ts", "src/chunk-vector/SqliteChunkVectorStorage.ts"], + "exclude": ["dist"], "compilerOptions": { "composite": true, "outDir": "./dist", diff --git a/packages/debug/package.json b/packages/debug/package.json index f5f48dd4..ad0577e6 100644 --- a/packages/debug/package.json +++ b/packages/debug/package.json @@ -19,8 +19,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/browser.js", - "types": "./dist/types.d.ts", + "bun": "./src/browser.ts", + "types": "./src/types.ts", "import": "./dist/browser.js" } }, diff --git a/packages/job-queue/package.json b/packages/job-queue/package.json index be6b080e..85a4fe08 100644 --- a/packages/job-queue/package.json +++ b/packages/job-queue/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" } }, diff --git a/packages/sqlite/package.json b/packages/sqlite/package.json index 73f530a8..2518526f 100644 --- a/packages/sqlite/package.json +++ b/packages/sqlite/package.json @@ -38,13 +38,13 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" }, "./bun": { - "types": "./dist/bun.d.ts", - "import": "./dist/bun.js" + "types": "./src/bun.ts", + "import": "./src/bun.ts" }, "./node": { "types": "./dist/node.d.ts", diff --git a/packages/storage/README.md b/packages/storage/README.md index f86bfa63..f65c45de 100644 --- a/packages/storage/README.md +++ b/packages/storage/README.md @@ -592,7 +592,7 @@ import { TypeDocumentRepository, } from "@workglow/storage"; -// Tabular repository (format: "repository:tabular") +// Tabular repository (format: "storage:tabular") const tabularSchema = TypeTabularRepository({ title: "Data Source", description: "Tabular data repository", @@ -1081,11 +1081,7 @@ describe("UserRepository", () => { let userRepo: InMemoryTabularStorage; beforeEach(() => { - userRepo = new InMemoryTabularStorage( - UserSchema, - ["id"], - ["email"] - ); + userRepo = new InMemoryTabularStorage(UserSchema, ["id"], ["email"]); }); test("should create and retrieve user", async () => { diff --git a/packages/storage/package.json b/packages/storage/package.json index 237807c4..d09c4bd4 100644 --- a/packages/storage/package.json +++ b/packages/storage/package.json @@ -45,8 +45,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" } }, diff --git a/packages/storage/src/common-server.ts b/packages/storage/src/common-server.ts index 0480e6a5..96db5266 100644 --- a/packages/storage/src/common-server.ts +++ b/packages/storage/src/common-server.ts @@ -25,8 +25,8 @@ export * from "./queue-limiter/PostgresRateLimiterStorage"; export * from "./queue-limiter/SqliteRateLimiterStorage"; export * from "./queue-limiter/SupabaseRateLimiterStorage"; -export * from "./vector/PostgresChunkVectorStorage"; -export * from "./vector/SqliteChunkVectorStorage"; +export * from "./vector/PostgresVectorStorage"; +export * from "./vector/SqliteVectorStorage"; // testing export * from "./kv/IndexedDbKvStorage"; diff --git a/packages/storage/src/common.ts b/packages/storage/src/common.ts index 2c31740a..45036734 100644 --- a/packages/storage/src/common.ts +++ b/packages/storage/src/common.ts @@ -24,7 +24,5 @@ export * from "./queue-limiter/IRateLimiterStorage"; export * from "./util/HybridSubscriptionManager"; export * from "./util/PollingSubscriptionManager"; -export * from "./vector/ChunkVectorSchema"; -export * from "./vector/ChunkVectorStorageRegistry"; -export * from "./vector/IChunkVectorStorage"; -export * from "./vector/InMemoryChunkVectorStorage"; +export * from "./vector/InMemoryVectorStorage"; +export * from "./vector/IVectorStorage"; diff --git a/packages/storage/src/tabular/TabularStorageRegistry.ts b/packages/storage/src/tabular/TabularStorageRegistry.ts index c00ea82d..1d84edfc 100644 --- a/packages/storage/src/tabular/TabularStorageRegistry.ts +++ b/packages/storage/src/tabular/TabularStorageRegistry.ts @@ -70,10 +70,10 @@ function resolveRepositoryFromRegistry( : getGlobalTabularRepositories(); const repo = repos.get(id); if (!repo) { - throw new Error(`Tabular repository "${id}" not found in registry`); + throw new Error(`Tabular storage "${id}" not found in registry`); } return repo; } -// Register the repository resolver for format: "repository:tabular" -registerInputResolver("repository:tabular", resolveRepositoryFromRegistry); +// Register the repository resolver for format: "storage:tabular" +registerInputResolver("storage:tabular", resolveRepositoryFromRegistry); diff --git a/packages/storage/src/vector/ChunkVectorSchema.ts b/packages/storage/src/vector/ChunkVectorSchema.ts deleted file mode 100644 index 8ce95438..00000000 --- a/packages/storage/src/vector/ChunkVectorSchema.ts +++ /dev/null @@ -1,35 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { TypedArraySchema, type DataPortSchemaObject, type TypedArray } from "@workglow/util"; - -/** - * Default schema for document chunk storage with vector embeddings - */ -export const ChunkVectorSchema = { - type: "object", - properties: { - chunk_id: { type: "string" }, - doc_id: { type: "string" }, - vector: TypedArraySchema(), - metadata: { type: "object", additionalProperties: true }, - }, - additionalProperties: false, -} as const satisfies DataPortSchemaObject; -export type ChunkVectorSchema = typeof ChunkVectorSchema; - -export const ChunkVectorKey = ["chunk_id"] as const; -export type ChunkVectorKey = typeof ChunkVectorKey; - -export interface ChunkVector< - Metadata extends Record = Record, - Vector extends TypedArray = Float32Array, -> { - chunk_id: string; - doc_id: string; - vector: Vector; - metadata: Metadata; -} diff --git a/packages/storage/src/vector/ChunkVectorStorageRegistry.ts b/packages/storage/src/vector/ChunkVectorStorageRegistry.ts deleted file mode 100644 index 7c51b929..00000000 --- a/packages/storage/src/vector/ChunkVectorStorageRegistry.ts +++ /dev/null @@ -1,83 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - createServiceToken, - globalServiceRegistry, - registerInputResolver, - ServiceRegistry, -} from "@workglow/util"; -import { AnyChunkVectorStorage } from "./IChunkVectorStorage"; - -/** - * Service token for the documenbt chunk vector repository registry - * Maps repository IDs to IVectorChunkRepository instances - */ -export const DOCUMENT_CHUNK_VECTOR_REPOSITORIES = createServiceToken< - Map ->("storage.document-node-vector.repositories"); - -// Register default factory if not already registered -if (!globalServiceRegistry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES)) { - globalServiceRegistry.register( - DOCUMENT_CHUNK_VECTOR_REPOSITORIES, - (): Map => new Map(), - true - ); -} - -/** - * Gets the global document chunk vector repository registry - * @returns Map of document chunk vector repository ID to instance - */ -export function getGlobalChunkVectorRepositories(): Map { - return globalServiceRegistry.get(DOCUMENT_CHUNK_VECTOR_REPOSITORIES); -} - -/** - * Registers a vector repository globally by ID - * @param id The unique identifier for this repository - * @param repository The repository instance to register - */ -export function registerChunkVectorRepository( - id: string, - repository: AnyChunkVectorStorage -): void { - const repos = getGlobalChunkVectorRepositories(); - repos.set(id, repository); -} - -/** - * Gets a document chunk vector repository by ID from the global registry - * @param id The repository identifier - * @returns The repository instance or undefined if not found - */ -export function getChunkVectorRepository(id: string): AnyChunkVectorStorage | undefined { - return getGlobalChunkVectorRepositories().get(id); -} - -/** - * Resolves a repository ID to an IVectorChunkRepository from the registry. - * Used by the input resolver system. - */ -async function resolveChunkVectorRepositoryFromRegistry( - id: string, - format: string, - registry: ServiceRegistry -): Promise { - const repos = registry.has(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) - ? registry.get>(DOCUMENT_CHUNK_VECTOR_REPOSITORIES) - : getGlobalChunkVectorRepositories(); - - const repo = repos.get(id); - if (!repo) { - throw new Error(`Document chunk vector repository "${id}" not found in registry`); - } - return repo; -} - -// Register the repository resolver for format: "repository:document-node-vector" -registerInputResolver("repository:document-node-vector", resolveChunkVectorRepositoryFromRegistry); diff --git a/packages/storage/src/vector/IChunkVectorStorage.ts b/packages/storage/src/vector/IVectorStorage.ts similarity index 53% rename from packages/storage/src/vector/IChunkVectorStorage.ts rename to packages/storage/src/vector/IVectorStorage.ts index 0b68d414..2634d361 100644 --- a/packages/storage/src/vector/IChunkVectorStorage.ts +++ b/packages/storage/src/vector/IVectorStorage.ts @@ -8,17 +8,20 @@ import type { DataPortSchemaObject, EventParameters, FromSchema, + JsonSchema, TypedArray, TypedArraySchemaOptions, } from "@workglow/util"; import type { ITabularStorage, TabularEventListeners } from "../tabular/ITabularStorage"; -export type AnyChunkVectorStorage = IChunkVectorStorage; +export type AnyVectorStorage = IVectorStorage; /** * Options for vector search operations */ -export interface VectorSearchOptions> { +export interface VectorSearchOptions< + Metadata extends Record | undefined = Record, +> { readonly topK?: number; readonly filter?: Partial; readonly scoreThreshold?: number; @@ -28,7 +31,7 @@ export interface VectorSearchOptions> { * Options for hybrid search (vector + full-text) */ export interface HybridSearchOptions< - Metadata = Record, + Metadata extends Record | undefined = Record, > extends VectorSearchOptions { readonly textQuery: string; readonly vectorWeight?: number; @@ -37,7 +40,7 @@ export interface HybridSearchOptions< /** * Type definitions for document chunk vector repository events */ -export interface VectorChunkEventListeners extends TabularEventListeners< +export interface VectorEventListeners extends TabularEventListeners< PrimaryKey, Entity > { @@ -45,23 +48,23 @@ export interface VectorChunkEventListeners extends TabularEv hybridSearch: (query: TypedArray, results: (Entity & { score: number })[]) => void; } -export type VectorChunkEventName = keyof VectorChunkEventListeners; -export type VectorChunkEventListener< - Event extends VectorChunkEventName, +export type VectorEventName = keyof VectorEventListeners; +export type VectorEventListener< + Event extends VectorEventName, PrimaryKey, Entity, -> = VectorChunkEventListeners[Event]; +> = VectorEventListeners[Event]; -export type VectorChunkEventParameters< - Event extends VectorChunkEventName, +export type VectorEventParameters< + Event extends VectorEventName, PrimaryKey, Entity, -> = EventParameters, Event>; +> = EventParameters, Event>; /** - * Interface defining the contract for document chunk vector storage repositories. - * These repositories store vector embeddings with metadata for decument chunks. - * Extends ITabularRepository to provide standard storage operations, + * Interface defining the contract for vector storage repositories. + * These repositories store vector embeddings with metadata. + * Extends ITabularStorage to provide standard storage operations, * plus vector-specific similarity search capabilities. * Supports various vector types including quantized formats. * @@ -69,10 +72,13 @@ export type VectorChunkEventParameters< * @typeParam PrimaryKeyNames - Array of property names that form the primary key * @typeParam Entity - The entity type */ -export interface IChunkVectorStorage< +export interface IVectorStorage< + Metadata extends Record | undefined, Schema extends DataPortSchemaObject, - PrimaryKeyNames extends ReadonlyArray, Entity = FromSchema, + PrimaryKeyNames extends ReadonlyArray = ReadonlyArray< + keyof Schema["properties"] + >, > extends ITabularStorage { /** * Get the vector dimension @@ -88,7 +94,7 @@ export interface IChunkVectorStorage< */ similaritySearch( query: TypedArray, - options?: VectorSearchOptions> + options?: VectorSearchOptions ): Promise<(Entity & { score: number })[]>; /** @@ -100,6 +106,39 @@ export interface IChunkVectorStorage< */ hybridSearch?( query: TypedArray, - options: HybridSearchOptions> + options: HybridSearchOptions ): Promise<(Entity & { score: number })[]>; } + +/** + * TODO: Given a schema, return the vector column by searching for a property with a TypedArray format (or TypedArray:xxx format) + */ + +export function getVectorProperty( + schema: Schema +): keyof Schema["properties"] | undefined { + for (const [key, value] of Object.entries(schema.properties)) { + if ( + typeof value !== "boolean" && + value.type === "array" && + (value.format === "TypedArray" || value.format?.startsWith("TypedArray:")) + ) { + return key; + } + } + return undefined; +} + +/** + * Given a schema, return the property which is an object with format "metadata" + */ +export function getMetadataProperty( + schema: Schema +): keyof Schema["properties"] | undefined { + for (const [key, value] of Object.entries(schema.properties)) { + if (typeof value !== "boolean" && value.type === "object" && value.format === "metadata") { + return key; + } + } + return undefined; +} diff --git a/packages/storage/src/vector/InMemoryChunkVectorStorage.ts b/packages/storage/src/vector/InMemoryVectorStorage.ts similarity index 68% rename from packages/storage/src/vector/InMemoryChunkVectorStorage.ts rename to packages/storage/src/vector/InMemoryVectorStorage.ts index 009c5b44..4ef376f3 100644 --- a/packages/storage/src/vector/InMemoryChunkVectorStorage.ts +++ b/packages/storage/src/vector/InMemoryVectorStorage.ts @@ -4,15 +4,21 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { TypedArray } from "@workglow/util"; +import type { + DataPortSchemaObject, + FromSchema, + TypedArray, + TypedArraySchemaOptions, +} from "@workglow/util"; import { cosineSimilarity } from "@workglow/util"; import { InMemoryTabularStorage } from "../tabular/InMemoryTabularStorage"; -import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; -import type { - HybridSearchOptions, - IChunkVectorStorage, - VectorSearchOptions, -} from "./IChunkVectorStorage"; +import { + getMetadataProperty, + getVectorProperty, + type HybridSearchOptions, + type IVectorStorage, + type VectorSearchOptions, +} from "./IVectorStorage"; /** * Check if metadata matches filter @@ -54,35 +60,45 @@ function textRelevance(text: string, query: string): number { * @template Metadata - The metadata type for the document chunk * @template Vector - The vector type for the document chunk */ -export class InMemoryChunkVectorStorage< +export class InMemoryVectorStorage< + Schema extends DataPortSchemaObject, + PrimaryKeyNames extends ReadonlyArray, Metadata extends Record = Record, Vector extends TypedArray = Float32Array, + Entity = FromSchema, > - extends InMemoryTabularStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > - implements - IChunkVectorStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > + extends InMemoryTabularStorage + implements IVectorStorage { private vectorDimensions: number; private VectorType: new (array: number[]) => TypedArray; + private vectorPropertyName: keyof Entity; + private metadataPropertyName: keyof Entity | undefined; /** * Creates a new in-memory document chunk vector repository * @param dimensions - The number of dimensions of the vector * @param VectorType - The type of vector to use (defaults to Float32Array) */ - constructor(dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array) { - super(ChunkVectorSchema, ChunkVectorKey); + constructor( + schema: Schema, + primaryKeyNames: PrimaryKeyNames, + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + dimensions: number, + VectorType: new (array: number[]) => TypedArray = Float32Array + ) { + super(schema, primaryKeyNames, indexes); this.vectorDimensions = dimensions; this.VectorType = VectorType; + + // Cache vector and metadata property names from schema + const vectorProp = getVectorProperty(schema); + if (!vectorProp) { + throw new Error("Schema must have a property with type array and format TypedArray"); + } + this.vectorPropertyName = vectorProp as keyof Entity; + this.metadataPropertyName = getMetadataProperty(schema) as keyof Entity | undefined; } /** @@ -98,13 +114,15 @@ export class InMemoryChunkVectorStorage< options: VectorSearchOptions> = {} ) { const { topK = 10, filter, scoreThreshold = 0 } = options; - const results: Array & { score: number }> = []; + const results: Array = []; const allEntities = (await this.getAll()) || []; for (const entity of allEntities) { - const vector = entity.vector; - const metadata = entity.metadata; + const vector = entity[this.vectorPropertyName] as TypedArray; + const metadata = this.metadataPropertyName + ? (entity[this.metadataPropertyName] as Metadata) + : ({} as Metadata); // Apply filter if provided if (filter && !matchesFilter(metadata, filter)) { @@ -121,9 +139,8 @@ export class InMemoryChunkVectorStorage< results.push({ ...entity, - vector, score, - }); + } as Entity & { score: number }); } // Sort by score descending and take top K @@ -141,13 +158,15 @@ export class InMemoryChunkVectorStorage< return this.similaritySearch(query, { topK, filter, scoreThreshold }); } - const results: Array & { score: number }> = []; + const results: Array = []; const allEntities = (await this.getAll()) || []; for (const entity of allEntities) { // In memory, vectors are stored as TypedArrays directly (not serialized) - const vector = entity.vector; - const metadata = entity.metadata; + const vector = entity[this.vectorPropertyName] as TypedArray; + const metadata = this.metadataPropertyName + ? (entity[this.metadataPropertyName] as Metadata) + : ({} as Metadata); // Apply filter if provided if (filter && !matchesFilter(metadata, filter)) { @@ -171,9 +190,8 @@ export class InMemoryChunkVectorStorage< results.push({ ...entity, - vector, score: combinedScore, - }); + } as Entity & { score: number }); } // Sort by combined score descending and take top K diff --git a/packages/storage/src/vector/PostgresChunkVectorStorage.ts b/packages/storage/src/vector/PostgresVectorStorage.ts similarity index 58% rename from packages/storage/src/vector/PostgresChunkVectorStorage.ts rename to packages/storage/src/vector/PostgresVectorStorage.ts index afd68081..62f93b8c 100644 --- a/packages/storage/src/vector/PostgresChunkVectorStorage.ts +++ b/packages/storage/src/vector/PostgresVectorStorage.ts @@ -4,63 +4,81 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { cosineSimilarity, type TypedArray } from "@workglow/util"; +import type { + DataPortSchemaObject, + FromSchema, + TypedArray, + TypedArraySchemaOptions, +} from "@workglow/util"; +import { cosineSimilarity } from "@workglow/util"; import type { Pool } from "pg"; import { PostgresTabularStorage } from "../tabular/PostgresTabularStorage"; -import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; -import type { - HybridSearchOptions, - IChunkVectorStorage, - VectorSearchOptions, -} from "./IChunkVectorStorage"; +import { + getMetadataProperty, + getVectorProperty, + type HybridSearchOptions, + type IVectorStorage, + type VectorSearchOptions, +} from "./IVectorStorage"; /** - * PostgreSQL document chunk vector repository implementation using pgvector extension. - * Extends PostgresTabularRepository for storage. + * PostgreSQL vector repository implementation using pgvector extension. + * Extends PostgresTabularStorage for storage. * Provides efficient vector similarity search with native database support. * * Requirements: * - PostgreSQL database with pgvector extension installed * - CREATE EXTENSION vector; * - * @template Metadata - The metadata type for the document chunk - * @template Vector - The vector type for the document chunk + * @template Metadata - The metadata type + * @template Vector - The vector type */ -export class PostgresChunkVectorStorage< +export class PostgresVectorStorage< + Schema extends DataPortSchemaObject, + PrimaryKeyNames extends ReadonlyArray, Metadata extends Record = Record, Vector extends TypedArray = Float32Array, + Entity = FromSchema, > - extends PostgresTabularStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > - implements - IChunkVectorStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > + extends PostgresTabularStorage + implements IVectorStorage { private vectorDimensions: number; private VectorType: new (array: number[]) => TypedArray; + private vectorPropertyName: keyof Entity; + private metadataPropertyName: keyof Entity | undefined; + /** - * Creates a new PostgreSQL document chunk vector repository + * Creates a new PostgreSQL vector repository * @param db - PostgreSQL connection pool * @param table - The name of the table to use for storage + * @param schema - The schema definition for the entity + * @param primaryKeyNames - Array of property names that form the primary key + * @param indexes - Array of columns or column arrays to make searchable * @param dimensions - The number of dimensions of the vector * @param VectorType - The type of vector to use (defaults to Float32Array) */ constructor( db: Pool, table: string, + schema: Schema, + primaryKeyNames: PrimaryKeyNames, + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array ) { - super(db, table, ChunkVectorSchema, ChunkVectorKey); + super(db, table, schema, primaryKeyNames, indexes); this.vectorDimensions = dimensions; this.VectorType = VectorType; + + // Cache vector and metadata property names from schema + const vectorProp = getVectorProperty(schema); + if (!vectorProp) { + throw new Error("Schema must have a property with type array and format TypedArray"); + } + this.vectorPropertyName = vectorProp as keyof Entity; + this.metadataPropertyName = getMetadataProperty(schema) as keyof Entity | undefined; } getVectorDimensions(): number { @@ -70,26 +88,29 @@ export class PostgresChunkVectorStorage< async similaritySearch( query: TypedArray, options: VectorSearchOptions = {} - ): Promise & { score: number }>> { + ): Promise> { const { topK = 10, filter, scoreThreshold = 0 } = options; try { // Try native pgvector search first const queryVector = `[${Array.from(query).join(",")}]`; + const vectorCol = String(this.vectorPropertyName); + const metadataCol = this.metadataPropertyName ? String(this.metadataPropertyName) : null; + let sql = ` SELECT *, - 1 - (vector <=> $1::vector) as score + 1 - (${vectorCol} <=> $1::vector) as score FROM "${this.table}" `; const params: any[] = [queryVector]; let paramIndex = 2; - if (filter && Object.keys(filter).length > 0) { + if (filter && Object.keys(filter).length > 0 && metadataCol) { const conditions: string[] = []; for (const [key, value] of Object.entries(filter)) { - conditions.push(`metadata->>'${key}' = $${paramIndex}`); + conditions.push(`${metadataCol}->>'${key}' = $${paramIndex}`); params.push(String(value)); paramIndex++; } @@ -98,31 +119,31 @@ export class PostgresChunkVectorStorage< if (scoreThreshold > 0) { sql += filter ? " AND" : " WHERE"; - sql += ` (1 - (vector <=> $1::vector)) >= $${paramIndex}`; + sql += ` (1 - (${vectorCol} <=> $1::vector)) >= $${paramIndex}`; params.push(scoreThreshold); paramIndex++; } - sql += ` ORDER BY vector <=> $1::vector LIMIT $${paramIndex}`; + sql += ` ORDER BY ${vectorCol} <=> $1::vector LIMIT $${paramIndex}`; params.push(topK); const result = await this.db.query(sql, params); // Fetch vectors separately for each result - const results: Array & { score: number }> = []; + const results: Array = []; for (const row of result.rows) { const vectorResult = await this.db.query( - `SELECT vector::text FROM "${this.table}" WHERE id = $1`, - [row.id] + `SELECT ${vectorCol}::text FROM "${this.table}" WHERE ${this.getPrimaryKeyWhereClause(row)}`, + this.getPrimaryKeyValues(row) ); - const vectorStr = vectorResult.rows[0]?.vector || "[]"; + const vectorStr = vectorResult.rows[0]?.[vectorCol] || "[]"; const vectorArray = JSON.parse(vectorStr); results.push({ ...row, - vector: new this.VectorType(vectorArray), + [this.vectorPropertyName]: new this.VectorType(vectorArray), score: parseFloat(row.score), - } as any); + } as Entity & { score: number }); } return results; @@ -144,13 +165,15 @@ export class PostgresChunkVectorStorage< // Try native hybrid search with pgvector + full-text const queryVector = `[${Array.from(query).join(",")}]`; const tsQuery = textQuery.split(/\s+/).join(" & "); + const vectorCol = String(this.vectorPropertyName); + const metadataCol = this.metadataPropertyName ? String(this.metadataPropertyName) : null; let sql = ` SELECT *, ( - $2 * (1 - (vector <=> $1::vector)) + - $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) + $2 * (1 - (${vectorCol} <=> $1::vector)) + + $3 * ts_rank(to_tsvector('english', ${metadataCol || "''"}::text), to_tsquery('english', $4)) ) as score FROM "${this.table}" `; @@ -158,10 +181,10 @@ export class PostgresChunkVectorStorage< const params: any[] = [queryVector, vectorWeight, 1 - vectorWeight, tsQuery]; let paramIndex = 5; - if (filter && Object.keys(filter).length > 0) { + if (filter && Object.keys(filter).length > 0 && metadataCol) { const conditions: string[] = []; for (const [key, value] of Object.entries(filter)) { - conditions.push(`metadata->>'${key}' = $${paramIndex}`); + conditions.push(`${metadataCol}->>'${key}' = $${paramIndex}`); params.push(String(value)); paramIndex++; } @@ -171,8 +194,8 @@ export class PostgresChunkVectorStorage< if (scoreThreshold > 0) { sql += filter ? " AND" : " WHERE"; sql += ` ( - $2 * (1 - (vector <=> $1::vector)) + - $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) + $2 * (1 - (${vectorCol} <=> $1::vector)) + + $3 * ts_rank(to_tsvector('english', ${metadataCol || "''"}::text), to_tsquery('english', $4)) ) >= $${paramIndex}`; params.push(scoreThreshold); paramIndex++; @@ -184,20 +207,20 @@ export class PostgresChunkVectorStorage< const result = await this.db.query(sql, params); // Fetch vectors separately for each result - const results: Array & { score: number }> = []; + const results: Array = []; for (const row of result.rows) { const vectorResult = await this.db.query( - `SELECT vector::text FROM "${this.table}" WHERE id = $1`, - [row.id] + `SELECT ${vectorCol}::text FROM "${this.table}" WHERE ${this.getPrimaryKeyWhereClause(row)}`, + this.getPrimaryKeyValues(row) ); - const vectorStr = vectorResult.rows[0]?.vector || "[]"; + const vectorStr = vectorResult.rows[0]?.[vectorCol] || "[]"; const vectorArray = JSON.parse(vectorStr); results.push({ ...row, - vector: new this.VectorType(vectorArray), + [this.vectorPropertyName]: new this.VectorType(vectorArray), score: parseFloat(row.score), - } as any); + } as Entity & { score: number }); } return results; @@ -214,11 +237,13 @@ export class PostgresChunkVectorStorage< private async searchFallback(query: TypedArray, options: VectorSearchOptions) { const { topK = 10, filter, scoreThreshold = 0 } = options; const allRows = (await this.getAll()) || []; - const results: Array & { score: number }> = []; + const results: Array = []; for (const row of allRows) { - const vector = row.vector; - const metadata = row.metadata; + const vector = row[this.vectorPropertyName] as TypedArray; + const metadata = this.metadataPropertyName + ? (row[this.metadataPropertyName] as Metadata) + : ({} as Metadata); if (filter && !this.matchesFilter(metadata, filter)) { continue; @@ -227,7 +252,7 @@ export class PostgresChunkVectorStorage< const score = cosineSimilarity(query, vector); if (score >= scoreThreshold) { - results.push({ ...row, vector, score }); + results.push({ ...row, score } as Entity & { score: number }); } } @@ -244,13 +269,15 @@ export class PostgresChunkVectorStorage< const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; const allRows = (await this.getAll()) || []; - const results: Array & { score: number }> = []; + const results: Array = []; const queryLower = textQuery.toLowerCase(); const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); for (const row of allRows) { - const vector = row.vector; - const metadata = row.metadata; + const vector = row[this.vectorPropertyName] as TypedArray; + const metadata = this.metadataPropertyName + ? (row[this.metadataPropertyName] as Metadata) + : ({} as Metadata); if (filter && !this.matchesFilter(metadata, filter)) { continue; @@ -272,7 +299,7 @@ export class PostgresChunkVectorStorage< const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; if (combinedScore >= scoreThreshold) { - results.push({ ...row, vector, score: combinedScore }); + results.push({ ...row, score: combinedScore } as Entity & { score: number }); } } @@ -282,6 +309,17 @@ export class PostgresChunkVectorStorage< return topResults; } + private getPrimaryKeyWhereClause(row: any): string { + const conditions = this.primaryKeyNames.map( + (key, idx) => `${String(key)} = $${idx + 1}` + ); + return conditions.join(" AND "); + } + + private getPrimaryKeyValues(row: any): any[] { + return this.primaryKeyNames.map((key) => row[key]); + } + private matchesFilter(metadata: Metadata, filter: Partial): boolean { for (const [key, value] of Object.entries(filter)) { if (metadata[key as keyof Metadata] !== value) { diff --git a/packages/storage/src/vector/README.md b/packages/storage/src/vector/README.md index f64c8ca0..c452c372 100644 --- a/packages/storage/src/vector/README.md +++ b/packages/storage/src/vector/README.md @@ -1,13 +1,13 @@ -# Chunk Vector Storage Module +# Vector Storage Module -Storage for document chunk embeddings with vector similarity search capabilities. Extends the tabular repository pattern to add vector search functionality for RAG (Retrieval-Augmented Generation) pipelines. +General-purpose vector storage with similarity search capabilities. Schema-driven approach that automatically detects vector and metadata columns. Extends the tabular storage pattern to add vector search functionality. ## Features - **Multiple Storage Backends:** - - 🧠 `InMemoryChunkVectorStorage` - Fast in-memory storage for testing and small datasets - - 📁 `SqliteChunkVectorStorage` - Persistent SQLite storage for local applications - - 🐘 `PostgresChunkVectorStorage` - PostgreSQL with pgvector extension for production + - 🧠 `InMemoryVectorStorage` - Fast in-memory storage for testing and small datasets + - 📁 `SqliteVectorStorage` - Persistent SQLite storage for local applications + - 🐘 `PostgresVectorStorage` - PostgreSQL with pgvector extension for production - **Quantized Vector Support:** - Float32Array (standard 32-bit floating point) @@ -37,24 +37,40 @@ bun install @workglow/storage ## Usage -### In-Memory Repository (Testing/Development) +### In-Memory Storage (Testing/Development) ```typescript -import { InMemoryChunkVectorStorage } from "@workglow/storage"; +import { InMemoryVectorStorage } from "@workglow/storage"; +import { TypedArraySchema } from "@workglow/util"; -// Create repository with 384 dimensions -const repo = new InMemoryChunkVectorStorage(384); +// Define your schema with a vector column +const MyVectorSchema = { + type: "object", + properties: { + id: { type: "string" }, + embedding: TypedArraySchema(), // Vector column (automatically detected) + metadata: { type: "object", format: "metadata", additionalProperties: true }, + }, + additionalProperties: false, +} as const; + +// Create repository with schema +const repo = new InMemoryVectorStorage( + MyVectorSchema, + ["id"], // Primary key + [], // Indexes (optional) + 384 // Vector dimensions +); await repo.setupDatabase(); -// Store a chunk with its embedding +// Store entities with embeddings await repo.put({ - chunk_id: "chunk-001", - doc_id: "doc-001", - vector: new Float32Array([0.1, 0.2, 0.3 /* ... 384 dims */]), + id: "item-001", + embedding: new Float32Array([0.1, 0.2, 0.3 /* ... 384 dims */]), metadata: { text: "Hello world", source: "example.txt" }, }); -// Search for similar chunks +// Search for similar vectors const results = await repo.similaritySearch(new Float32Array([0.15, 0.25, 0.35 /* ... */]), { topK: 5, scoreThreshold: 0.7, @@ -64,40 +80,70 @@ const results = await repo.similaritySearch(new Float32Array([0.15, 0.25, 0.35 / ### Quantized Vectors (Reduced Storage) ```typescript -import { InMemoryChunkVectorStorage } from "@workglow/storage"; +import { InMemoryVectorStorage } from "@workglow/storage"; +import { TypedArraySchema } from "@workglow/util"; + +const QuantizedSchema = { + type: "object", + properties: { + id: { type: "string" }, + embedding: TypedArraySchema(), + tags: { type: "object", format: "metadata", additionalProperties: true }, + }, + additionalProperties: false, +} as const; // Use Int8Array for 4x smaller storage (binary quantization) -const repo = new InMemoryChunkVectorStorage<{ text: string }, Int8Array>(384, Int8Array); +const repo = new InMemoryVectorStorage( + QuantizedSchema, + ["id"], + [], + 384, + Int8Array // Specify vector type +); await repo.setupDatabase(); // Store quantized vectors await repo.put({ - chunk_id: "chunk-001", - doc_id: "doc-001", - vector: new Int8Array([127, -128, 64 /* ... */]), - metadata: { category: "ai" }, + id: "item-001", + embedding: new Int8Array([127, -128, 64 /* ... */]), + tags: { category: "ai" }, }); // Search with quantized query const results = await repo.similaritySearch(new Int8Array([100, -50, 75 /* ... */]), { topK: 5 }); ``` -### SQLite Repository (Local Persistence) +### SQLite Storage (Local Persistence) ```typescript -import { SqliteChunkVectorStorage } from "@workglow/storage"; +import { SqliteVectorStorage } from "@workglow/storage"; +import { TypedArraySchema } from "@workglow/util"; + +const MySchema = { + type: "object", + properties: { + id: { type: "string" }, + vector: TypedArraySchema(), + data: { type: "object", format: "metadata", additionalProperties: true }, + }, + additionalProperties: false, +} as const; -const repo = new SqliteChunkVectorStorage<{ text: string }>( - "./vectors.db", // database path - "chunks", // table name - 768 // vector dimension +const repo = new SqliteVectorStorage( + "./vectors.db", // database path + "vectors", // table name + MySchema, + ["id"], + [], + 768 // vector dimension ); await repo.setupDatabase(); // Bulk insert using inherited tabular methods await repo.putMany([ - { chunk_id: "1", doc_id: "doc1", vector: new Float32Array([...]), metadata: { text: "..." } }, - { chunk_id: "2", doc_id: "doc1", vector: new Float32Array([...]), metadata: { text: "..." } }, + { id: "1", vector: new Float32Array([...]), data: { text: "..." } }, + { id: "2", vector: new Float32Array([...]), data: { text: "..." } }, ]); ``` @@ -105,12 +151,26 @@ await repo.putMany([ ```typescript import { Pool } from "pg"; -import { PostgresChunkVectorStorage } from "@workglow/storage"; +import { PostgresVectorStorage } from "@workglow/storage"; +import { TypedArraySchema } from "@workglow/util"; + +const MySchema = { + type: "object", + properties: { + id: { type: "string" }, + vector: TypedArraySchema(), + info: { type: "object", format: "metadata", additionalProperties: true }, + }, + additionalProperties: false, +} as const; const pool = new Pool({ connectionString: "postgresql://..." }); -const repo = new PostgresChunkVectorStorage<{ text: string; category: string }>( +const repo = new PostgresVectorStorage( pool, - "chunks", + "vectors", + MySchema, + ["id"], + [], 384 // vector dimension ); await repo.setupDatabase(); @@ -131,41 +191,36 @@ const hybridResults = await repo.hybridSearch(queryVector, { }); ``` -## Data Model +## Schema-Driven Design -### ChunkVector Schema - -Each chunk vector entry contains: +The vector storage automatically detects which column contains the vector by looking for properties with `format: "TypedArray"` in your schema: ```typescript -interface ChunkVector< - Metadata extends Record = Record, - Vector extends TypedArray = Float32Array, -> { - chunk_id: string; // Unique identifier for the chunk - doc_id: string; // Parent document identifier - vector: Vector; // Embedding vector - metadata: Metadata; // Custom metadata (text content, entities, etc.) -} -``` - -### Default Schema +import { TypedArraySchema } from "@workglow/util"; -```typescript -const ChunkVectorSchema = { +// Vector column is automatically detected by the storage implementation +const MySchema = { type: "object", properties: { - chunk_id: { type: "string" }, - doc_id: { type: "string" }, - vector: TypedArraySchema(), - metadata: { type: "object", additionalProperties: true }, + id: { type: "string" }, + embedding: TypedArraySchema(), // ← Detected as vector column + metadata: { + type: "object", + format: "metadata", // ← Detected as metadata column (optional) + additionalProperties: true, + }, + created_at: { type: "string" }, }, additionalProperties: false, } as const; - -const ChunkVectorKey = ["chunk_id"] as const; ``` +**Key Points:** + +- **Vector Column**: Any property with `type: "array"` and `format: "TypedArray"` (or `format: "TypedArray:*"`) +- **Metadata Column**: Any property with `type: "object"` and `format: "metadata"` (optional, used for filtering) +- **Flexible Schema**: Add any additional properties you need - the storage will work with your schema + ## API Reference ### IChunkVectorStorage Interface @@ -237,11 +292,8 @@ interface HybridSearchOptions extends VectorSearchOptions { Register and retrieve chunk vector repositories globally: ```typescript -import { - registerChunkVectorRepository, - getChunkVectorRepository, - getGlobalChunkVectorRepositories, -} from "@workglow/storage"; +import { getChunkVectorRepository, getGlobalChunkVectorRepositories } from "@workglow/storage"; +import { registerChunkVectorRepository, getGlobalChunkVectorRepositories } from "@workglow/dataset"; // Register a repository registerChunkVectorRepository("my-chunks", repo); diff --git a/packages/storage/src/vector/SqliteChunkVectorStorage.ts b/packages/storage/src/vector/SqliteVectorStorage.ts similarity index 65% rename from packages/storage/src/vector/SqliteChunkVectorStorage.ts rename to packages/storage/src/vector/SqliteVectorStorage.ts index a23a4fb2..384a73dc 100644 --- a/packages/storage/src/vector/SqliteChunkVectorStorage.ts +++ b/packages/storage/src/vector/SqliteVectorStorage.ts @@ -5,15 +5,21 @@ */ import { Sqlite } from "@workglow/sqlite"; -import type { TypedArray } from "@workglow/util"; +import type { + DataPortSchemaObject, + FromSchema, + TypedArray, + TypedArraySchemaOptions, +} from "@workglow/util"; import { cosineSimilarity } from "@workglow/util"; import { SqliteTabularStorage } from "../tabular/SqliteTabularStorage"; -import { ChunkVector, ChunkVectorKey, ChunkVectorSchema } from "./ChunkVectorSchema"; -import type { - HybridSearchOptions, - IChunkVectorStorage, - VectorSearchOptions, -} from "./IChunkVectorStorage"; +import { + getMetadataProperty, + getVectorProperty, + type HybridSearchOptions, + type IVectorStorage, + type VectorSearchOptions, +} from "./IVectorStorage"; /** * Check if metadata matches filter @@ -28,33 +34,31 @@ function matchesFilter(metadata: Metadata, filter: Partial): } /** - * SQLite document chunk vector repository implementation using tabular storage underneath. + * SQLite vector repository implementation using tabular storage underneath. * Stores vectors as JSON-encoded arrays with metadata. * - * @template Metadata - The metadata type for the document chunk - * @template Vector - The vector type for the document chunk + * @template Vector - The vector type for the vector + * @template Metadata - The metadata type for the vector + * @template Schema - The schema for the vector + * @template PrimaryKeyNames - The primary key names for the vector */ -export class SqliteChunkVectorStorage< - Metadata extends Record = Record, +export class SqliteVectorStorage< + Schema extends DataPortSchemaObject, + PrimaryKeyNames extends ReadonlyArray, Vector extends TypedArray = Float32Array, + Metadata extends Record | undefined = Record, + Entity = FromSchema, > - extends SqliteTabularStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > - implements - IChunkVectorStorage< - typeof ChunkVectorSchema, - typeof ChunkVectorKey, - ChunkVector - > + extends SqliteTabularStorage + implements IVectorStorage { private vectorDimensions: number; private VectorType: new (array: number[]) => TypedArray; + private vectorPropertyName: keyof Entity; + private metadataPropertyName: keyof Entity | undefined; /** - * Creates a new SQLite document chunk vector repository + * Creates a new SQLite vector repository * @param dbOrPath - Either a Database instance or a path to the SQLite database file * @param table - The name of the table to use for storage (defaults to 'vectors') * @param dimensions - The number of dimensions of the vector @@ -63,13 +67,24 @@ export class SqliteChunkVectorStorage< constructor( dbOrPath: string | Sqlite.Database, table: string = "vectors", + schema: Schema, + primaryKeyNames: PrimaryKeyNames, + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], dimensions: number, VectorType: new (array: number[]) => TypedArray = Float32Array ) { - super(dbOrPath, table, ChunkVectorSchema, ChunkVectorKey); + super(dbOrPath, table, schema, primaryKeyNames, indexes); this.vectorDimensions = dimensions; this.VectorType = VectorType; + + // Cache vector and metadata property names from schema + const vectorProp = getVectorProperty(schema); + if (!vectorProp) { + throw new Error("Schema must have a property with type array and format TypedArray"); + } + this.vectorPropertyName = vectorProp as keyof Entity; + this.metadataPropertyName = getMetadataProperty(schema) as keyof Entity | undefined; } getVectorDimensions(): number { @@ -88,15 +103,17 @@ export class SqliteChunkVectorStorage< async similaritySearch(query: TypedArray, options: VectorSearchOptions = {}) { const { topK = 10, filter, scoreThreshold = 0 } = options; - const results: Array & { score: number }> = []; + const results: Array = []; const allEntities = (await this.getAll()) || []; for (const entity of allEntities) { // SQLite stores vectors as JSON strings, need to deserialize - const vectorRaw = entity.vector as unknown as string; + const vectorRaw = entity[this.vectorPropertyName] as unknown as string; const vector = this.deserializeVector(vectorRaw); - const metadata = entity.metadata; + const metadata = this.metadataPropertyName + ? (entity[this.metadataPropertyName] as Metadata) + : ({} as Metadata); // Apply filter if provided if (filter && !matchesFilter(metadata, filter)) { @@ -113,9 +130,8 @@ export class SqliteChunkVectorStorage< results.push({ ...entity, - vector, score, - } as any); + } as Entity & { score: number }); } // Sort by score descending and take top K @@ -133,19 +149,18 @@ export class SqliteChunkVectorStorage< return this.similaritySearch(query, { topK, filter, scoreThreshold }); } - const results: Array & { score: number }> = []; + const results: Array = []; const allEntities = (await this.getAll()) || []; const queryLower = textQuery.toLowerCase(); const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); for (const entity of allEntities) { // SQLite stores vectors as JSON strings, need to deserialize - const vectorRaw = entity.vector as unknown as string; - const vector = - typeof vectorRaw === "string" - ? this.deserializeVector(vectorRaw) - : (vectorRaw as TypedArray); - const metadata = entity.metadata; + const vectorRaw = entity[this.vectorPropertyName] as unknown as string; + const vector = this.deserializeVector(vectorRaw); + const metadata = this.metadataPropertyName + ? (entity[this.metadataPropertyName] as Metadata) + : ({} as Metadata); // Apply filter if provided if (filter && !matchesFilter(metadata, filter)) { @@ -178,9 +193,8 @@ export class SqliteChunkVectorStorage< results.push({ ...entity, - vector, score: combinedScore, - } as any); + } as Entity & { score: number }); } // Sort by combined score descending and take top K diff --git a/packages/task-graph/package.json b/packages/task-graph/package.json index 61e0f400..86955350 100644 --- a/packages/task-graph/package.json +++ b/packages/task-graph/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" } }, diff --git a/packages/task-graph/src/task/InputResolver.ts b/packages/task-graph/src/task/InputResolver.ts index 282a519a..87439126 100644 --- a/packages/task-graph/src/task/InputResolver.ts +++ b/packages/task-graph/src/task/InputResolver.ts @@ -42,7 +42,7 @@ function getSchemaFormat(schema: unknown): string | undefined { /** * Gets the format prefix from a format string. * For "model:TextEmbedding" returns "model" - * For "repository:tabular" returns "repository" + * For "storage:tabular" returns "storage" */ function getFormatPrefix(format: string): string { const colonIndex = format.indexOf(":"); @@ -88,7 +88,7 @@ export async function resolveSchemaInputs>( const format = getSchemaFormat(propSchema); if (!format) continue; - // Try full format first (e.g., "repository:document-node-vector"), then fall back to prefix (e.g., "repository") + // Try full format first (e.g., "dataset:document-chunk"), then fall back to prefix (e.g., "dataset") let resolver = resolvers.get(format); if (!resolver) { const prefix = getFormatPrefix(format); diff --git a/packages/task-graph/src/task/README.md b/packages/task-graph/src/task/README.md index 13a88491..e83ff8c5 100644 --- a/packages/task-graph/src/task/README.md +++ b/packages/task-graph/src/task/README.md @@ -239,7 +239,7 @@ The TaskRunner automatically resolves schema-annotated string inputs to their co ### How It Works -When a task's input schema includes properties with `format` annotations (such as `"model"`, `"model:TaskName"`, or `"repository:tabular"`), the TaskRunner inspects each input property: +When a task's input schema includes properties with `format` annotations (such as `"model"`, `"model:TaskName"`, or `"storage:tabular"`), the TaskRunner inspects each input property: - **String values** are looked up in the appropriate registry and resolved to instances - **Object values** (already instances) pass through unchanged diff --git a/packages/tasks/package.json b/packages/tasks/package.json index 335e494a..5062a565 100644 --- a/packages/tasks/package.json +++ b/packages/tasks/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" } }, diff --git a/packages/test/package.json b/packages/test/package.json index 8c4077e3..985c87e1 100644 --- a/packages/test/package.json +++ b/packages/test/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" } }, diff --git a/packages/test/src/test/rag/ChunkToVector.test.ts b/packages/test/src/test/rag/ChunkToVector.test.ts index eea38722..e4f7a7d5 100644 --- a/packages/test/src/test/rag/ChunkToVector.test.ts +++ b/packages/test/src/test/rag/ChunkToVector.test.ts @@ -6,14 +6,15 @@ import "@workglow/ai"; // Trigger Workflow prototype extensions import type { ChunkToVectorTaskOutput, HierarchicalChunkerTaskOutput } from "@workglow/ai"; -import { type ChunkNode, NodeIdGenerator, StructuralParser } from "@workglow/dataset"; +import { type ChunkNode, StructuralParser } from "@workglow/dataset"; import { Workflow } from "@workglow/task-graph"; +import { uuid4 } from "@workglow/util"; import { describe, expect, it } from "vitest"; describe("ChunkToVectorTask", () => { it("should transform chunks and vectors to vector store format", async () => { const markdown = "# Test\n\nContent."; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); // Generate chunks using workflow diff --git a/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts b/packages/test/src/test/rag/DocumentChunkRetrievalTask.test.ts similarity index 86% rename from packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts rename to packages/test/src/test/rag/DocumentChunkRetrievalTask.test.ts index e4fb9ea2..453331ef 100644 --- a/packages/test/src/test/rag/DocumentNodeRetrievalTask.test.ts +++ b/packages/test/src/test/rag/DocumentChunkRetrievalTask.test.ts @@ -5,15 +5,36 @@ */ import { retrieval } from "@workglow/ai"; -import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/dataset"; +import { + DocumentChunk, + DocumentChunkDataset, + DocumentChunkPrimaryKey, + DocumentChunkSchema, + registerDocumentChunkDataset, +} from "@workglow/dataset"; +import { InMemoryVectorStorage } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("DocumentNodeRetrievalTask", () => { - let repo: InMemoryChunkVectorStorage; + let storage: InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >; + let dataset: DocumentChunkDataset; beforeEach(async () => { - repo = new InMemoryChunkVectorStorage(3); - await repo.setupDatabase(); + storage = new InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 3, Float32Array); + await storage.setupDatabase(); + dataset = new DocumentChunkDataset(storage); // Populate repository with test data const vectors = [ @@ -34,7 +55,7 @@ describe("DocumentNodeRetrievalTask", () => { for (let i = 0; i < vectors.length; i++) { const doc_id = `doc${i + 1}`; - await repo.put({ + await dataset.put({ chunk_id: `${doc_id}_0`, doc_id, vector: vectors[i], @@ -44,14 +65,14 @@ describe("DocumentNodeRetrievalTask", () => { }); afterEach(() => { - repo.destroy(); + storage.destroy(); }); test("should retrieve chunks with query vector", async () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 3, }); @@ -71,7 +92,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 5, }); @@ -93,7 +114,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([0.0, 1.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 5, }); @@ -109,7 +130,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([0.0, 0.0, 1.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 5, }); @@ -125,7 +146,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 3, returnVectors: true, @@ -140,7 +161,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 3, returnVectors: false, @@ -153,7 +174,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 2, }); @@ -164,7 +185,7 @@ describe("DocumentNodeRetrievalTask", () => { test("should apply metadata filter", async () => { // Add a document with specific metadata for filtering - await repo.put({ + await dataset.put({ chunk_id: "filtered_doc_0", doc_id: "filtered_doc", vector: new Float32Array([1.0, 0.0, 0.0]), @@ -177,7 +198,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 10, filter: { category: "test" }, @@ -191,7 +212,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 10, scoreThreshold: 0.9, @@ -206,7 +227,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryEmbedding = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryEmbedding, topK: 3, }); @@ -219,7 +240,7 @@ describe("DocumentNodeRetrievalTask", () => { await expect( // @ts-expect-error - query is string but no model is provided retrieval({ - repository: repo, + dataset, query: "test query string", topK: 3, }) @@ -230,7 +251,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, }); @@ -241,7 +262,7 @@ describe("DocumentNodeRetrievalTask", () => { test("should JSON.stringify metadata when no text/content/chunk fields", async () => { // Add document with only non-standard metadata - await repo.put({ + await dataset.put({ chunk_id: "json_doc_0", doc_id: "json_doc", vector: new Float32Array([1.0, 0.0, 0.0]), @@ -254,7 +275,7 @@ describe("DocumentNodeRetrievalTask", () => { const queryVector = new Float32Array([1.0, 0.0, 0.0]); const result = await retrieval({ - repository: repo, + dataset, query: queryVector, topK: 10, }); @@ -268,13 +289,13 @@ describe("DocumentNodeRetrievalTask", () => { test("should resolve repository from string ID", async () => { // Register repository by ID - registerChunkVectorRepository("test-retrieval-repo", repo); + registerDocumentChunkDataset("test-retrieval-repo", dataset); const queryVector = new Float32Array([1.0, 0.0, 0.0]); // Pass repository as string ID instead of instance const result = await retrieval({ - repository: "test-retrieval-repo" as any, + dataset: "test-retrieval-repo", query: queryVector, topK: 3, }); diff --git a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts b/packages/test/src/test/rag/DocumentChunkSearchTask.test.ts similarity index 81% rename from packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts rename to packages/test/src/test/rag/DocumentChunkSearchTask.test.ts index 1b61392d..180fe118 100644 --- a/packages/test/src/test/rag/DocumentNodeVectorSearchTask.test.ts +++ b/packages/test/src/test/rag/DocumentChunkSearchTask.test.ts @@ -5,15 +5,36 @@ */ import { ChunkVectorSearchTask } from "@workglow/ai"; -import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/dataset"; +import { + DocumentChunk, + DocumentChunkDataset, + DocumentChunkPrimaryKey, + DocumentChunkSchema, + registerDocumentChunkDataset, +} from "@workglow/dataset"; +import { InMemoryVectorStorage } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorSearchTask", () => { - let repo: InMemoryChunkVectorStorage; + let storage: InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >; + let dataset: DocumentChunkDataset; beforeEach(async () => { - repo = new InMemoryChunkVectorStorage(3); - await repo.setupDatabase(); + storage = new InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 3, Float32Array); + await storage.setupDatabase(); + dataset = new DocumentChunkDataset(storage); // Populate repository with test data const vectors = [ @@ -34,7 +55,7 @@ describe("ChunkVectorSearchTask", () => { for (let i = 0; i < vectors.length; i++) { const doc_id = `doc${i + 1}`; - await repo.put({ + await dataset.put({ chunk_id: `${doc_id}_0`, doc_id, vector: vectors[i], @@ -44,7 +65,7 @@ describe("ChunkVectorSearchTask", () => { }); afterEach(() => { - repo.destroy(); + storage.destroy(); }); test("should search and return top K results", async () => { @@ -52,7 +73,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset: dataset, query: queryVector, topK: 3, }); @@ -77,7 +98,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset, query: queryVector, topK: 2, }); @@ -91,7 +112,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset, query: queryVector, topK: 10, filter: { category: "tech" }, @@ -109,7 +130,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset, query: queryVector, topK: 10, scoreThreshold: 0.9, @@ -126,7 +147,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset, query: queryVector, topK: 10, filter: { category: "nonexistent" }, @@ -144,7 +165,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset, query: queryVector, }); @@ -158,7 +179,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset, query: queryVector, topK: 3, }); @@ -173,7 +194,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset, query: queryVector, topK: 5, }); @@ -185,14 +206,21 @@ describe("ChunkVectorSearchTask", () => { }); test("should handle empty repository", async () => { - const emptyRepo = new InMemoryChunkVectorStorage(3); - await emptyRepo.setupDatabase(); + const emptyStorage = new InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 3, Float32Array); + await emptyStorage.setupDatabase(); + const emptyDataset = new DocumentChunkDataset(emptyStorage); const queryVector = new Float32Array([1.0, 0.0, 0.0]); const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: emptyRepo, + dataset: emptyDataset, query: queryVector, topK: 10, }); @@ -201,7 +229,7 @@ describe("ChunkVectorSearchTask", () => { expect(result.ids).toHaveLength(0); expect(result.scores).toHaveLength(0); - emptyRepo.destroy(); + emptyStorage.destroy(); }); test("should combine filter and score threshold", async () => { @@ -209,7 +237,7 @@ describe("ChunkVectorSearchTask", () => { const task = new ChunkVectorSearchTask(); const result = await task.run({ - repository: repo, + dataset, query: queryVector, topK: 10, filter: { category: "tech" }, @@ -226,15 +254,15 @@ describe("ChunkVectorSearchTask", () => { }); test("should resolve repository from string ID", async () => { - // Register repository by ID - registerChunkVectorRepository("test-vector-repo", repo); + // Register dataset by ID + registerDocumentChunkDataset("test-vector-repo", dataset); const queryVector = new Float32Array([1.0, 0.0, 0.0]); const task = new ChunkVectorSearchTask(); // Pass repository as string ID instead of instance const result = await task.run({ - repository: "test-vector-repo" as any, + dataset: "test-vector-repo", query: queryVector, topK: 3, }); diff --git a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts b/packages/test/src/test/rag/DocumentChunkUpsertTask.test.ts similarity index 81% rename from packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts rename to packages/test/src/test/rag/DocumentChunkUpsertTask.test.ts index 50bf53e2..af96859c 100644 --- a/packages/test/src/test/rag/DocumentNodeVectorStoreUpsertTask.test.ts +++ b/packages/test/src/test/rag/DocumentChunkUpsertTask.test.ts @@ -5,19 +5,27 @@ */ import { ChunkVectorUpsertTask } from "@workglow/ai"; -import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/dataset"; +import { + DocumentChunkDataset, + DocumentChunkPrimaryKey, + DocumentChunkSchema, + registerDocumentChunkDataset, +} from "@workglow/dataset"; +import { InMemoryVectorStorage } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorUpsertTask", () => { - let repo: InMemoryChunkVectorStorage; + let storage: InMemoryVectorStorage; + let dataset: DocumentChunkDataset; beforeEach(async () => { - repo = new InMemoryChunkVectorStorage(3); - await repo.setupDatabase(); + storage = new InMemoryVectorStorage(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 3); + await storage.setupDatabase(); + dataset = new DocumentChunkDataset(storage as any); }); afterEach(() => { - repo.destroy(); + storage.destroy(); }); test("should upsert a single vector", async () => { @@ -26,7 +34,7 @@ describe("ChunkVectorUpsertTask", () => { const task = new ChunkVectorUpsertTask(); const result = await task.run({ - repository: repo, + dataset, doc_id: "doc1", vectors: vector, metadata: metadata, @@ -37,7 +45,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.chunk_ids).toHaveLength(1); // Verify vector was stored - const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); + const retrieved = await dataset.get(result.chunk_ids[0]); expect(retrieved).toBeDefined(); expect(retrieved?.doc_id).toBe("doc1"); expect(retrieved!.metadata).toEqual(metadata); @@ -53,7 +61,7 @@ describe("ChunkVectorUpsertTask", () => { const task = new ChunkVectorUpsertTask(); const result = await task.run({ - repository: repo, + dataset, doc_id: "doc1", vectors: vectors, metadata: metadata, @@ -65,7 +73,7 @@ describe("ChunkVectorUpsertTask", () => { // Verify all vectors were stored for (let i = 0; i < 3; i++) { - const retrieved = await repo.get({ chunk_id: result.chunk_ids[i] }); + const retrieved = await dataset.get(result.chunk_ids[i]); expect(retrieved).toBeDefined(); expect(retrieved?.doc_id).toBe("doc1"); expect(retrieved!.metadata).toEqual(metadata); @@ -78,7 +86,7 @@ describe("ChunkVectorUpsertTask", () => { const task = new ChunkVectorUpsertTask(); const result = await task.run({ - repository: repo, + dataset, doc_id: "doc1", vectors: vector, metadata: metadata, @@ -87,7 +95,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(1); expect(result.doc_id).toBe("doc1"); - const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); + const retrieved = await dataset.get(result.chunk_ids[0]); expect(retrieved).toBeDefined(); expect(retrieved!.metadata).toEqual(metadata); }); @@ -101,7 +109,7 @@ describe("ChunkVectorUpsertTask", () => { // First upsert const task1 = new ChunkVectorUpsertTask(); const result1 = await task1.run({ - repository: repo, + dataset, doc_id: "doc1", vectors: vector1, metadata: metadata1, @@ -110,13 +118,13 @@ describe("ChunkVectorUpsertTask", () => { // Update with same ID const task2 = new ChunkVectorUpsertTask(); const result2 = await task2.run({ - repository: repo, + dataset, doc_id: "doc1", vectors: vector2, metadata: metadata2, }); - const retrieved = await repo.get({ chunk_id: result2.chunk_ids[0] }); + const retrieved = await dataset.get(result2.chunk_ids[0]); expect(retrieved).toBeDefined(); expect(retrieved!.metadata).toEqual(metadata2); }); @@ -127,7 +135,7 @@ describe("ChunkVectorUpsertTask", () => { const task = new ChunkVectorUpsertTask(); const result = await task.run({ - repository: repo, + dataset, doc_id: "doc1", vectors: vectors, metadata: metadata, @@ -143,7 +151,7 @@ describe("ChunkVectorUpsertTask", () => { const task = new ChunkVectorUpsertTask(); const result = await task.run({ - repository: repo, + dataset, doc_id: "doc1", vectors: vector, metadata: metadata, @@ -151,7 +159,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(1); - const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); + const retrieved = await dataset.get(result.chunk_ids[0]); expect(retrieved).toBeDefined(); expect(retrieved?.vector).toBeInstanceOf(Int8Array); }); @@ -162,7 +170,7 @@ describe("ChunkVectorUpsertTask", () => { const task = new ChunkVectorUpsertTask(); const result = await task.run({ - repository: repo, + dataset, doc_id: "doc1", vectors: vector, metadata: metadata, @@ -170,7 +178,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(1); - const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); + const retrieved = await dataset.get(result.chunk_ids[0]); expect(retrieved!.metadata).toEqual(metadata); }); @@ -184,7 +192,7 @@ describe("ChunkVectorUpsertTask", () => { const task = new ChunkVectorUpsertTask(); const result = await task.run({ - repository: repo, + dataset, doc_id: "batch-doc", vectors: vectors, metadata: metadata, @@ -193,13 +201,13 @@ describe("ChunkVectorUpsertTask", () => { expect(result.count).toBe(count); expect(result.chunk_ids).toHaveLength(count); - const size = await repo.size(); + const size = await dataset.size(); expect(size).toBe(count); }); test("should resolve repository from string ID", async () => { - // Register repository by ID - registerChunkVectorRepository("test-upsert-repo", repo); + // Register dataset by ID + registerDocumentChunkDataset("test-upsert-repo", dataset); const vector = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]); const metadata = { text: "Test document", source: "test.txt" }; @@ -207,7 +215,7 @@ describe("ChunkVectorUpsertTask", () => { const task = new ChunkVectorUpsertTask(); // Pass repository as string ID instead of instance const result = await task.run({ - repository: "test-upsert-repo" as any, + dataset: "test-upsert-repo", doc_id: "doc1", vectors: vector, metadata: metadata, @@ -217,7 +225,7 @@ describe("ChunkVectorUpsertTask", () => { expect(result.doc_id).toBe("doc1"); // Verify vector was stored - const retrieved = await repo.get({ chunk_id: result.chunk_ids[0] }); + const retrieved = await dataset.get(result.chunk_ids[0]); expect(retrieved).toBeDefined(); expect(retrieved?.doc_id).toBe("doc1"); expect(retrieved!.metadata).toEqual(metadata); diff --git a/packages/test/src/test/rag/DocumentRepository.test.ts b/packages/test/src/test/rag/DocumentRepository.test.ts index 15bb6530..aa61bfca 100644 --- a/packages/test/src/test/rag/DocumentRepository.test.ts +++ b/packages/test/src/test/rag/DocumentRepository.test.ts @@ -4,22 +4,25 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { InMemoryTabularStorage } from "@workglow/storage"; import { Document, - DocumentRepository, + DocumentChunk, + DocumentChunkDataset, + DocumentChunkPrimaryKey, + DocumentChunkSchema, + DocumentDataset, DocumentStorageKey, DocumentStorageSchema, - InMemoryChunkVectorStorage, - NodeIdGenerator, NodeKind, StructuralParser, } from "@workglow/dataset"; +import { InMemoryTabularStorage, InMemoryVectorStorage } from "@workglow/storage"; +import { uuid4 } from "@workglow/util"; import { beforeEach, describe, expect, it } from "vitest"; -describe("DocumentRepository", () => { - let repo: DocumentRepository; - let vectorStorage: InMemoryChunkVectorStorage; +describe("DocumentDataset", () => { + let dataset: DocumentDataset; + let vectorDataset: DocumentChunkDataset; beforeEach(async () => { const tabularStorage = new InMemoryTabularStorage( @@ -28,21 +31,28 @@ describe("DocumentRepository", () => { ); await tabularStorage.setupDatabase(); - vectorStorage = new InMemoryChunkVectorStorage(3); - await vectorStorage.setupDatabase(); - - repo = new DocumentRepository(tabularStorage, vectorStorage); + const storage = new InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 3, Float32Array); + await storage.setupDatabase(); + vectorDataset = new DocumentChunkDataset(storage); + + dataset = new DocumentDataset(tabularStorage, vectorDataset as any); }); it("should store and retrieve documents", async () => { const markdown = "# Test\n\nContent."; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test Document" }); - await repo.upsert(doc); - const retrieved = await repo.get(doc_id); + await dataset.upsert(doc); + const retrieved = await dataset.get(doc_id); expect(retrieved).toBeDefined(); expect(retrieved?.doc_id).toBe(doc_id); @@ -51,15 +61,15 @@ describe("DocumentRepository", () => { it("should retrieve nodes by ID", async () => { const markdown = "# Section\n\nParagraph."; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); - await repo.upsert(doc); + await dataset.upsert(doc); // Get a child node const firstChild = root.children[0]; - const retrieved = await repo.getNode(doc_id, firstChild.nodeId); + const retrieved = await dataset.getNode(doc_id, firstChild.nodeId); expect(retrieved).toBeDefined(); expect(retrieved?.nodeId).toBe(firstChild.nodeId); @@ -72,11 +82,11 @@ describe("DocumentRepository", () => { Paragraph.`; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); - await repo.upsert(doc); + await dataset.upsert(doc); // Find a deeply nested node const section = root.children.find((c) => c.kind === NodeKind.SECTION); @@ -85,7 +95,7 @@ Paragraph.`; const subsection = (section as any).children.find((c: any) => c.kind === NodeKind.SECTION); expect(subsection).toBeDefined(); - const ancestors = await repo.getAncestors(doc_id, subsection.nodeId); + const ancestors = await dataset.getAncestors(doc_id, subsection.nodeId); // Should include root, section, and subsection expect(ancestors.length).toBeGreaterThanOrEqual(3); @@ -96,7 +106,7 @@ Paragraph.`; it("should handle chunks", async () => { const markdown = "# Test\n\nContent."; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); @@ -114,10 +124,10 @@ Paragraph.`; doc.setChunks(chunks); - await repo.upsert(doc); + await dataset.upsert(doc); // Retrieve chunks - const retrievedChunks = await repo.getChunks(doc_id); + const retrievedChunks = await dataset.getChunks(doc_id); expect(retrievedChunks).toBeDefined(); expect(retrievedChunks.length).toBe(1); }); @@ -126,8 +136,8 @@ Paragraph.`; const markdown1 = "# Doc 1"; const markdown2 = "# Doc 2"; - const id1 = await NodeIdGenerator.generateDocId("test1", markdown1); - const id2 = await NodeIdGenerator.generateDocId("test2", markdown2); + const id1 = uuid4(); + const id2 = uuid4(); const root1 = await StructuralParser.parseMarkdown(id1, markdown1, "Doc 1"); const root2 = await StructuralParser.parseMarkdown(id2, markdown2, "Doc 2"); @@ -135,10 +145,10 @@ Paragraph.`; const doc1 = new Document(id1, root1, { title: "Doc 1" }); const doc2 = new Document(id2, root2, { title: "Doc 2" }); - await repo.upsert(doc1); - await repo.upsert(doc2); + await dataset.upsert(doc1); + await dataset.upsert(doc2); - const list = await repo.list(); + const list = await dataset.list(); expect(list.length).toBe(2); expect(list).toContain(id1); expect(list).toContain(id2); @@ -146,60 +156,60 @@ Paragraph.`; it("should delete documents", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); - await repo.upsert(doc); + await dataset.upsert(doc); - expect(await repo.get(doc_id)).toBeDefined(); + expect(await dataset.get(doc_id)).toBeDefined(); - await repo.delete(doc_id); + await dataset.delete(doc_id); - expect(await repo.get(doc_id)).toBeUndefined(); + expect(await dataset.get(doc_id)).toBeUndefined(); }); it("should return undefined for non-existent document", async () => { - const result = await repo.get("non-existent-doc-id"); + const result = await dataset.get("non-existent-doc-id"); expect(result).toBeUndefined(); }); it("should return undefined for node in non-existent document", async () => { - const result = await repo.getNode("non-existent-doc-id", "some-node-id"); + const result = await dataset.getNode("non-existent-doc-id", "some-node-id"); expect(result).toBeUndefined(); }); it("should return undefined for non-existent node", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); - await repo.upsert(doc); + await dataset.upsert(doc); - const result = await repo.getNode(doc_id, "non-existent-node-id"); + const result = await dataset.getNode(doc_id, "non-existent-node-id"); expect(result).toBeUndefined(); }); it("should return empty array for ancestors of non-existent document", async () => { - const result = await repo.getAncestors("non-existent-doc-id", "some-node-id"); + const result = await dataset.getAncestors("non-existent-doc-id", "some-node-id"); expect(result).toEqual([]); }); it("should return empty array for ancestors of non-existent node", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); - await repo.upsert(doc); + await dataset.upsert(doc); - const result = await repo.getAncestors(doc_id, "non-existent-node-id"); + const result = await dataset.getAncestors(doc_id, "non-existent-node-id"); expect(result).toEqual([]); }); it("should return empty array for chunks of non-existent document", async () => { - const result = await repo.getChunks("non-existent-doc-id"); + const result = await dataset.getChunks("non-existent-doc-id"); expect(result).toEqual([]); }); @@ -210,40 +220,40 @@ Paragraph.`; DocumentStorageKey ); await tabularStorage.setupDatabase(); - const emptyRepo = new DocumentRepository(tabularStorage); + const emptyDataset = new DocumentDataset(tabularStorage); - const result = await emptyRepo.list(); + const result = await emptyDataset.list(); expect(result).toEqual([]); }); it("should not throw when deleting non-existent document", async () => { // Just verify delete completes without error - await repo.delete("non-existent-doc-id"); + await dataset.delete("non-existent-doc-id"); // If we get here, it didn't throw expect(true).toBe(true); }); it("should update existing document on upsert", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc1 = new Document(doc_id, root, { title: "Original Title" }); - await repo.upsert(doc1); + await dataset.upsert(doc1); const doc2 = new Document(doc_id, root, { title: "Updated Title" }); - await repo.upsert(doc2); + await dataset.upsert(doc2); - const retrieved = await repo.get(doc_id); + const retrieved = await dataset.get(doc_id); expect(retrieved?.metadata.title).toBe("Updated Title"); - const list = await repo.list(); + const list = await dataset.list(); expect(list.length).toBe(1); }); it("should find chunks by node ID", async () => { const markdown = "# Test\n\nContent."; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); @@ -265,45 +275,45 @@ Paragraph.`; }, ]; doc.setChunks(chunks); - await repo.upsert(doc); + await dataset.upsert(doc); - const result = await repo.findChunksByNodeId(doc_id, root.nodeId); + const result = await dataset.findChunksByNodeId(doc_id, root.nodeId); expect(result.length).toBe(2); }); it("should return empty array for findChunksByNodeId with no matches", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); doc.setChunks([]); - await repo.upsert(doc); + await dataset.upsert(doc); - const result = await repo.findChunksByNodeId(doc_id, "non-matching-node"); + const result = await dataset.findChunksByNodeId(doc_id, "non-matching-node"); expect(result).toEqual([]); }); it("should return empty array for findChunksByNodeId with non-existent document", async () => { - const result = await repo.findChunksByNodeId("non-existent-doc", "some-node"); + const result = await dataset.findChunksByNodeId("non-existent-doc", "some-node"); expect(result).toEqual([]); }); it("should search with vector storage", async () => { // Add vectors to vector storage - await vectorStorage.put({ + await vectorDataset.put({ chunk_id: "chunk_1", doc_id: "doc1", vector: new Float32Array([1.0, 0.0, 0.0]), metadata: { text: "First chunk" }, }); - await vectorStorage.put({ + await vectorDataset.put({ chunk_id: "chunk_2", doc_id: "doc1", vector: new Float32Array([0.8, 0.2, 0.0]), metadata: { text: "Second chunk" }, }); - await vectorStorage.put({ + await vectorDataset.put({ chunk_id: "chunk_3", doc_id: "doc2", vector: new Float32Array([0.0, 1.0, 0.0]), @@ -311,20 +321,20 @@ Paragraph.`; }); const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const results = await repo.search(queryVector, { topK: 2 }); + const results = await dataset.search(queryVector, { topK: 2 }); expect(results.length).toBe(2); expect(results[0].chunk_id).toBe("chunk_1"); }); it("should search with score threshold", async () => { - await vectorStorage.put({ + await vectorDataset.put({ chunk_id: "chunk_1", doc_id: "doc1", vector: new Float32Array([1.0, 0.0, 0.0]), metadata: { text: "Matching chunk" }, }); - await vectorStorage.put({ + await vectorDataset.put({ chunk_id: "chunk_2", doc_id: "doc1", vector: new Float32Array([0.0, 1.0, 0.0]), @@ -332,7 +342,7 @@ Paragraph.`; }); const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const results = await repo.search(queryVector, { topK: 10, scoreThreshold: 0.9 }); + const results = await dataset.search(queryVector, { topK: 10, scoreThreshold: 0.9 }); expect(results.length).toBeGreaterThanOrEqual(1); results.forEach((r: any) => { @@ -347,10 +357,10 @@ Paragraph.`; ); await tabularStorage.setupDatabase(); - const repoWithoutVector = new DocumentRepository(tabularStorage); + const datasetWithoutVector = new DocumentDataset(tabularStorage); const queryVector = new Float32Array([1.0, 0.0, 0.0]); - const results = await repoWithoutVector.search(queryVector); + const results = await datasetWithoutVector.search(queryVector); expect(results).toEqual([]); }); @@ -359,7 +369,7 @@ Paragraph.`; describe("Document", () => { it("should manage chunks", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); @@ -382,7 +392,7 @@ describe("Document", () => { it("should serialize and deserialize", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); @@ -411,7 +421,7 @@ describe("Document", () => { it("should find chunks by nodeId", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); @@ -450,7 +460,7 @@ describe("Document", () => { it("should return empty array when no chunks match nodeId", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); @@ -472,7 +482,7 @@ describe("Document", () => { it("should handle empty chunks in findChunksByNodeId", async () => { const markdown = "# Test"; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); diff --git a/packages/test/src/test/rag/EndToEnd.test.ts b/packages/test/src/test/rag/EndToEnd.test.ts index fbfb3afe..bcbaf6b7 100644 --- a/packages/test/src/test/rag/EndToEnd.test.ts +++ b/packages/test/src/test/rag/EndToEnd.test.ts @@ -7,14 +7,17 @@ import { hierarchicalChunker } from "@workglow/ai"; import { Document, + DocumentChunk, + DocumentChunkDataset, + DocumentChunkPrimaryKey, + DocumentChunkSchema, DocumentRepository, DocumentStorageKey, DocumentStorageSchema, - InMemoryChunkVectorStorage, - NodeIdGenerator, StructuralParser, } from "@workglow/dataset"; -import { InMemoryTabularStorage } from "@workglow/storage"; +import { InMemoryTabularStorage, InMemoryVectorStorage } from "@workglow/storage"; +import { uuid4 } from "@workglow/util"; import { beforeAll, describe, expect, it } from "vitest"; import { registerTasks } from "../../binding/RegisterTasks"; @@ -38,7 +41,7 @@ Uses labeled data. Finds patterns in data.`; // Parse into hierarchical tree - const doc_id = await NodeIdGenerator.generateDocId("ml-guide", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "ML Guide"); const chunkResult = await hierarchicalChunker({ @@ -62,7 +65,7 @@ Finds patterns in data.`; it("should manage document chunks", async () => { const markdown = "# Test Document\n\nThis is test content."; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const doc = new Document(doc_id, root, { title: "Test" }); @@ -92,10 +95,17 @@ Finds patterns in data.`; ); await tabularStorage.setupDatabase(); - const vectorStorage = new InMemoryChunkVectorStorage(3); - await vectorStorage.setupDatabase(); + const storage = new InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 3, Float32Array); + await storage.setupDatabase(); + const vectorDataset = new DocumentChunkDataset(storage); - const docRepo = new DocumentRepository(tabularStorage, vectorStorage); + const docRepo = new DocumentRepository(tabularStorage, vectorDataset as any); // Create document with enriched hierarchy const markdown = `# Guide @@ -108,7 +118,7 @@ Content about topic A. Content about topic B.`; - const doc_id = await NodeIdGenerator.generateDocId("guide", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Guide"); const doc = new Document(doc_id, root, { title: "Guide" }); diff --git a/packages/test/src/test/rag/FullChain.test.ts b/packages/test/src/test/rag/FullChain.test.ts index d7635b39..71108063 100644 --- a/packages/test/src/test/rag/FullChain.test.ts +++ b/packages/test/src/test/rag/FullChain.test.ts @@ -5,7 +5,9 @@ */ import { HierarchicalChunkerTaskOutput } from "@workglow/ai"; -import { ChunkNode, InMemoryChunkVectorStorage, NodeIdGenerator } from "@workglow/dataset"; +import { ChunkNode, DocumentChunkPrimaryKey, DocumentChunkSchema } from "@workglow/dataset"; +import { uuid4 } from "@workglow/util"; +import { InMemoryVectorStorage } from "@workglow/storage"; import { Workflow } from "@workglow/task-graph"; import { beforeAll, describe, expect, it } from "vitest"; import { registerTasks } from "../../binding/RegisterTasks"; @@ -16,8 +18,13 @@ describe("Complete chainable workflow", () => { }); it("should chain from parsing to storage without loops", async () => { - const vectorRepo = new InMemoryChunkVectorStorage(3); - await vectorRepo.setupDatabase(); + const storage = new InMemoryVectorStorage( + DocumentChunkSchema, + DocumentChunkPrimaryKey, + [], + 3 + ); + await storage.setupDatabase(); const markdown = `# Test Document @@ -96,7 +103,7 @@ This is the second section with more content.`; it("should allow doc_id override for variant creation", async () => { const markdown = "# Test\n\nContent."; - const customId = await NodeIdGenerator.generateDocId("custom", markdown); + const customId = uuid4(); const result = (await new Workflow() .structuralParser({ diff --git a/packages/test/src/test/rag/HierarchicalChunker.test.ts b/packages/test/src/test/rag/HierarchicalChunker.test.ts index a1fc8c3a..1cc87eb5 100644 --- a/packages/test/src/test/rag/HierarchicalChunker.test.ts +++ b/packages/test/src/test/rag/HierarchicalChunker.test.ts @@ -5,8 +5,9 @@ */ import { hierarchicalChunker } from "@workglow/ai"; -import { estimateTokens, NodeIdGenerator, StructuralParser } from "@workglow/dataset"; +import { estimateTokens, StructuralParser } from "@workglow/dataset"; import { Workflow } from "@workglow/task-graph"; +import { uuid4 } from "@workglow/util"; import { describe, expect, it } from "vitest"; describe("HierarchicalChunkerTask", () => { @@ -19,7 +20,7 @@ This is a paragraph that should fit in one chunk. This is another paragraph.`; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); const result = await hierarchicalChunker({ @@ -52,7 +53,7 @@ This is another paragraph.`; const longText = "Lorem ipsum dolor sit amet. ".repeat(100); const markdown = `# Section\n\n${longText}`; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Long"); const maxTokens = 100; @@ -78,7 +79,7 @@ This is another paragraph.`; const text = "Word ".repeat(200); const markdown = `# Section\n\n${text}`; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Overlap"); const maxTokens = 50; @@ -117,7 +118,7 @@ Paragraph 1. Paragraph 2.`; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Flat"); const result = await new Workflow() @@ -141,7 +142,7 @@ Paragraph 2.`; Paragraph content.`; - const doc_id = await NodeIdGenerator.generateDocId("test", markdown); + const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Paths"); const result = await hierarchicalChunker({ diff --git a/packages/test/src/test/rag/HybridSearchTask.test.ts b/packages/test/src/test/rag/HybridSearchTask.test.ts index 7d1ac428..1ff96898 100644 --- a/packages/test/src/test/rag/HybridSearchTask.test.ts +++ b/packages/test/src/test/rag/HybridSearchTask.test.ts @@ -5,15 +5,36 @@ */ import { hybridSearch } from "@workglow/ai"; -import { InMemoryChunkVectorStorage, registerChunkVectorRepository } from "@workglow/dataset"; +import { + DocumentChunk, + DocumentChunkDataset, + DocumentChunkPrimaryKey, + DocumentChunkSchema, + registerDocumentChunkDataset, +} from "@workglow/dataset"; +import { InMemoryVectorStorage } from "@workglow/storage"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; describe("ChunkVectorHybridSearchTask", () => { - let repo: InMemoryChunkVectorStorage; + let storage: InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >; + let dataset: DocumentChunkDataset; beforeEach(async () => { - repo = new InMemoryChunkVectorStorage(3); - await repo.setupDatabase(); + storage = new InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 3, Float32Array); + await storage.setupDatabase(); + dataset = new DocumentChunkDataset(storage); // Populate repository with test data const vectors = [ @@ -34,17 +55,17 @@ describe("ChunkVectorHybridSearchTask", () => { for (let i = 0; i < vectors.length; i++) { const doc_id = `doc${i + 1}`; - await repo.put({ - id: `${doc_id}_0`, + await dataset.put({ + chunk_id: `${doc_id}_0`, doc_id, - vector: vectors[i] as any, + vector: vectors[i], metadata: metadata[i], - } as any); + }); } }); afterEach(() => { - repo.destroy(); + storage.destroy(); }); test("should perform hybrid search with vector and text query", async () => { @@ -52,7 +73,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "machine learning"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 3, @@ -75,7 +96,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "machine"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 5, @@ -95,7 +116,7 @@ describe("ChunkVectorHybridSearchTask", () => { // Test with high vector weight const resultHighVector = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 5, @@ -104,7 +125,7 @@ describe("ChunkVectorHybridSearchTask", () => { // Test with low vector weight (high text weight) const resultHighText = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 5, @@ -121,7 +142,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "machine"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 3, @@ -138,7 +159,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "machine"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 3, @@ -153,7 +174,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "learning"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 10, @@ -171,7 +192,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "machine"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 10, @@ -189,7 +210,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "document"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 2, @@ -204,7 +225,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "learning"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, }); @@ -219,7 +240,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "machine"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 5, @@ -236,7 +257,7 @@ describe("ChunkVectorHybridSearchTask", () => { const queryText = "machine"; const result = await hybridSearch({ - repository: repo, + dataset, queryVector: queryVector, queryText: queryText, topK: 3, @@ -247,15 +268,15 @@ describe("ChunkVectorHybridSearchTask", () => { }); test("should resolve repository from string ID", async () => { - // Register repository by ID - registerChunkVectorRepository("test-hybrid-repo", repo); + // Register dataset by ID + registerDocumentChunkDataset("test-hybrid-repo", dataset); const queryVector = new Float32Array([1.0, 0.0, 0.0]); const queryText = "machine learning"; // Pass repository as string ID instead of instance const result = await hybridSearch({ - repository: "test-hybrid-repo" as any, + dataset: "test-hybrid-repo", queryVector: queryVector, queryText: queryText, topK: 3, diff --git a/packages/test/src/test/rag/RagWorkflow.test.ts b/packages/test/src/test/rag/RagWorkflow.test.ts index c3a55c3f..6390f81a 100644 --- a/packages/test/src/test/rag/RagWorkflow.test.ts +++ b/packages/test/src/test/rag/RagWorkflow.test.ts @@ -46,14 +46,17 @@ import { VectorStoreUpsertTaskOutput, } from "@workglow/ai"; import { register_HFT_InlineJobFns } from "@workglow/ai-provider"; -import { InMemoryTabularStorage } from "@workglow/storage"; import { + DocumentChunk, + DocumentChunkDataset, + DocumentChunkPrimaryKey, + DocumentChunkSchema, DocumentRepository, DocumentStorageKey, DocumentStorageSchema, - InMemoryChunkVectorStorage, - registerChunkVectorRepository, + registerDocumentChunkDataset, } from "@workglow/dataset"; +import { InMemoryTabularStorage, InMemoryVectorStorage } from "@workglow/storage"; import { getTaskQueueRegistry, setTaskQueueRegistry, Workflow } from "@workglow/task-graph"; import { readdirSync } from "fs"; import { join } from "path"; @@ -62,7 +65,14 @@ import { registerHuggingfaceLocalModels } from "../../samples"; export { FileLoaderTask } from "@workglow/tasks"; describe("RAG Workflow End-to-End", () => { - let vectorRepo: InMemoryChunkVectorStorage; + let storage: InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >; + let vectorDataset: DocumentChunkDataset; let docRepo: DocumentRepository; const vectorRepoName = "rag-test-vector-repo"; const embeddingModel = "onnx:Xenova/all-MiniLM-L6-v2:q8"; @@ -79,16 +89,23 @@ describe("RAG Workflow End-to-End", () => { await registerHuggingfaceLocalModels(); // Setup repositories - vectorRepo = new InMemoryChunkVectorStorage(3); - await vectorRepo.setupDatabase(); - - // Register vector repository for use in workflows - registerChunkVectorRepository(vectorRepoName, vectorRepo); + storage = new InMemoryVectorStorage< + typeof DocumentChunkSchema, + typeof DocumentChunkPrimaryKey, + Record, + Float32Array, + DocumentChunk + >(DocumentChunkSchema, DocumentChunkPrimaryKey, [], 3, Float32Array); + await storage.setupDatabase(); + vectorDataset = new DocumentChunkDataset(storage); + + // Register vector dataset for use in workflows + registerDocumentChunkDataset(vectorRepoName, vectorDataset); const tabularRepo = new InMemoryTabularStorage(DocumentStorageSchema, DocumentStorageKey); await tabularRepo.setupDatabase(); - docRepo = new DocumentRepository(tabularRepo, vectorRepo); + docRepo = new DocumentRepository(tabularRepo, vectorDataset as any); }); afterAll(async () => { @@ -133,7 +150,7 @@ describe("RAG Workflow End-to-End", () => { model: embeddingModel, }) .vectorStoreUpsert({ - repository: vectorRepoName, + dataset: vectorRepoName, }); const result = (await ingestionWorkflow.run()) as VectorStoreUpsertTaskOutput; @@ -156,7 +173,7 @@ describe("RAG Workflow End-to-End", () => { const searchWorkflow = new Workflow(); searchWorkflow.retrieval({ - repository: vectorRepoName, + dataset: vectorRepoName, query, model: embeddingModel, topK: 5, @@ -192,7 +209,7 @@ describe("RAG Workflow End-to-End", () => { console.log(`\nAnswering question: "${question}"`); const retrievalResult = await retrieval({ - repository: vectorRepoName, + dataset: vectorRepoName, query: question, model: embeddingModel, topK: 3, @@ -235,7 +252,7 @@ describe("RAG Workflow End-to-End", () => { // Step 1: Retrieve context const retrievalWorkflow = new Workflow(); retrievalWorkflow.retrieval({ - repository: vectorRepoName, + dataset: vectorRepoName, query: question, model: embeddingModel, topK: 3, diff --git a/packages/test/src/test/task-graph/InputResolver.test.ts b/packages/test/src/test/task-graph/InputResolver.test.ts index d52ddbf3..87be1759 100644 --- a/packages/test/src/test/task-graph/InputResolver.test.ts +++ b/packages/test/src/test/task-graph/InputResolver.test.ts @@ -4,13 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { TypeTabularStorage } from "@workglow/dataset"; import { AnyTabularStorage, getGlobalTabularRepositories, InMemoryTabularStorage, registerTabularRepository, } from "@workglow/storage"; -import { TypeTabularRepository } from "@workglow/dataset"; import { IExecuteContext, resolveSchemaInputs, Task, TaskRegistry } from "@workglow/task-graph"; import { getInputResolvers, @@ -32,19 +32,19 @@ describe("InputResolver", () => { additionalProperties: false, } as const; - let testRepo: InMemoryTabularStorage; + let testDataset: InMemoryTabularStorage; beforeEach(async () => { // Create and register a test repository - testRepo = new InMemoryTabularStorage(testEntitySchema, ["id"] as const); - await testRepo.setupDatabase(); - registerTabularRepository("test-repo", testRepo); + testDataset = new InMemoryTabularStorage(testEntitySchema, ["id"] as const); + await testDataset.setupDatabase(); + registerTabularRepository("test-dataset", testDataset); }); afterEach(() => { // Clean up the registry - getGlobalTabularRepositories().delete("test-repo"); - testRepo.destroy(); + getGlobalTabularRepositories().delete("test-dataset"); + testDataset.destroy(); }); describe("resolveSchemaInputs", () => { @@ -52,47 +52,47 @@ describe("InputResolver", () => { const schema: DataPortSchema = { type: "object", properties: { - repository: TypeTabularRepository(), + dataset: TypeTabularStorage(), }, }; - const input = { repository: testRepo }; + const input = { dataset: testDataset }; const resolved = await resolveSchemaInputs(input, schema, { registry: globalServiceRegistry, }); - expect(resolved.repository).toBe(testRepo); + expect(resolved.dataset).toBe(testDataset); }); - test("should resolve string repository ID to instance", async () => { + test("should resolve string dataset ID to instance", async () => { const schema: DataPortSchema = { type: "object", properties: { - repository: TypeTabularRepository(), + dataset: TypeTabularStorage(), }, }; - const input = { repository: "test-repo" }; + const input = { dataset: "test-dataset" }; const resolved = await resolveSchemaInputs(input, schema, { registry: globalServiceRegistry, }); - expect(resolved.repository).toBe(testRepo); + expect(resolved.dataset).toBe(testDataset); }); - test("should throw error for unknown repository ID", async () => { + test("should throw error for unknown dataset ID", async () => { const schema: DataPortSchema = { type: "object", properties: { - repository: TypeTabularRepository(), + dataset: TypeTabularStorage(), }, }; - const input = { repository: "non-existent-repo" }; + const input = { dataset: "non-existent-dataset" }; await expect( resolveSchemaInputs(input, schema, { registry: globalServiceRegistry }) - ).rejects.toThrow('Tabular repository "non-existent-repo" not found'); + ).rejects.toThrow('Tabular storage "non-existent-dataset" not found'); }); test("should not resolve properties without format annotation", async () => { @@ -183,24 +183,24 @@ describe("InputResolver", () => { }); describe("Integration with Task", () => { - // Define a test task that uses a repository - class RepositoryConsumerTask extends Task< - { repository: any; query: string }, + // Define a test task that uses a dataset + class DatasetConsumerTask extends Task< + { dataset: AnyTabularStorage | string; query: string }, { results: any[] } > { - public static type = "RepositoryConsumerTask"; + public static type = "DatasetConsumerTask"; public static inputSchema(): DataPortSchema { return { type: "object", properties: { - repository: TypeTabularRepository({ - title: "Data Repository", - description: "Repository to query", + dataset: TypeTabularStorage({ + title: "Data Storage", + description: "Storage to query", }), query: { type: "string", title: "Query" }, }, - required: ["repository", "query"], + required: ["dataset", "query"], additionalProperties: false, }; } @@ -217,31 +217,31 @@ describe("InputResolver", () => { } async execute( - input: { repository: AnyTabularStorage; query: string }, + input: { dataset: AnyTabularStorage; query: string }, _context: IExecuteContext ): Promise<{ results: any[] }> { - const { repository } = input; - // In a real task, we'd search the repository - const results = await repository.getAll(); + const { dataset } = input; + // In a real task, we'd search the dataset + const results = await dataset.getAll(); return { results: results ?? [] }; } } beforeEach(() => { - TaskRegistry.registerTask(RepositoryConsumerTask); + TaskRegistry.registerTask(DatasetConsumerTask); }); afterEach(() => { - TaskRegistry.all.delete(RepositoryConsumerTask.type); + TaskRegistry.all.delete(DatasetConsumerTask.type); }); - test("should resolve repository when running task with string ID", async () => { + test("should resolve dataset when running task with string ID", async () => { // Add some test data - await testRepo.put({ id: "1", name: "Test Item" }); + await testDataset.put({ id: "1", name: "Test Item" }); - const task = new RepositoryConsumerTask(); + const task = new DatasetConsumerTask(); const result = await task.run({ - repository: "test-repo", + dataset: "test-dataset", query: "test", }); @@ -249,12 +249,12 @@ describe("InputResolver", () => { expect(result.results[0]).toEqual({ id: "1", name: "Test Item" }); }); - test("should work with direct repository instance", async () => { - await testRepo.put({ id: "2", name: "Direct Item" }); + test("should work with direct dataset instance", async () => { + await testDataset.put({ id: "2", name: "Direct Item" }); - const task = new RepositoryConsumerTask(); + const task = new DatasetConsumerTask(); const result = await task.run({ - repository: testRepo, + dataset: testDataset, query: "test", }); diff --git a/packages/util/package.json b/packages/util/package.json index bb841966..4cb1bb5f 100644 --- a/packages/util/package.json +++ b/packages/util/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./dist/bun.js", - "types": "./dist/types.d.ts", + "bun": "./src/bun.ts", + "types": "./src/types.ts", "import": "./dist/node.js" } }, diff --git a/packages/util/src/di/InputResolverRegistry.ts b/packages/util/src/di/InputResolverRegistry.ts index 064fb9f9..fa8c8ac3 100644 --- a/packages/util/src/di/InputResolverRegistry.ts +++ b/packages/util/src/di/InputResolverRegistry.ts @@ -4,8 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { createServiceToken, globalServiceRegistry } from "./ServiceRegistry"; import type { ServiceRegistry } from "./ServiceRegistry"; +import { createServiceToken, globalServiceRegistry } from "./ServiceRegistry"; /** * A resolver function that converts a string ID to an instance. @@ -13,7 +13,7 @@ import type { ServiceRegistry } from "./ServiceRegistry"; * Throws an error if the ID is not found. * * @param id The string ID to resolve - * @param format The full format string (e.g., "model:TextEmbedding", "repository:tabular") + * @param format The full format string (e.g., "model:TextEmbedding", "storage:tabular") * @param registry The service registry to use for lookups */ export type InputResolverFn = ( @@ -26,9 +26,8 @@ export type InputResolverFn = ( * Service token for the input resolver registry. * Maps format prefixes to resolver functions. */ -export const INPUT_RESOLVERS = createServiceToken>( - "task.input.resolvers" -); +export const INPUT_RESOLVERS = + createServiceToken>("task.input.resolvers"); // Register default factory if not already registered if (!globalServiceRegistry.has(INPUT_RESOLVERS)) { @@ -51,7 +50,7 @@ export function getInputResolvers(): Map { * Registers an input resolver for a format prefix. * The resolver will be called for any format that starts with this prefix. * - * @param formatPrefix The format prefix to match (e.g., "model", "repository") + * @param formatPrefix The format prefix to match (e.g., "model", "dataset") * @param resolver The resolver function * * @example @@ -64,16 +63,16 @@ export function getInputResolvers(): Map { * return model; * }); * - * // Register repository resolver - * registerInputResolver("repository", (id, format, registry) => { - * const repoType = format.split(":")[1]; // "tabular", "vector", etc. - * if (repoType === "tabular") { - * const repos = registry.get(TABULAR_REPOSITORIES); - * const repo = repos.get(id); - * if (!repo) throw new Error(`Repository "${id}" not found`); - * return repo; + * // Register dataset resolver + * registerInputResolver("dataset", (id, format, registry) => { + * const datasetType = format.split(":")[1]; // "tabular", "vector", etc. + * if (datasetType === "tabular") { + * const datasets = registry.get(TABULAR_DATASETS); + * const dataset = datasets.get(id); + * if (!dataset) throw new Error(`Dataset "${id}" not found`); + * return dataset; * } - * throw new Error(`Unknown repository type: ${repoType}`); + * throw new Error(`Unknown dataset type: ${datasetType}`); * }); * ``` */ From 965d9121647e3e1c965eb9518a4c9ad715c03e8e Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 18 Jan 2026 06:53:50 +0000 Subject: [PATCH 13/14] [feat] Implement Auto-Generated Primary Keys in Tabular Storage - Introduced support for auto-generated primary keys across all TabularStorage implementations, enhancing security and simplifying client interactions. - Updated schemas to include `x-auto-generated: true` for primary key fields, allowing automatic ID generation during entity insertion. - Refactored `put` and `putBulk` methods to handle entities with optional auto-generated keys, ensuring compatibility with existing data structures. - Enhanced documentation and README files to reflect the new auto-generated key features and usage examples. - Added comprehensive tests to validate the functionality of auto-generated keys across various storage backends, including InMemory, SQLite, Postgres, and IndexedDB. - Updated existing tests to ensure they align with the new auto-generated key logic and structure. --- TODO.md | 18 +- packages/ai/src/task/ChunkVectorUpsertTask.ts | 33 +-- packages/ai/src/task/HierarchyJoinTask.ts | 17 +- packages/dataset/src/common.ts | 1 - .../document-chunk/DocumentChunkDataset.ts | 17 +- .../src/document-chunk/DocumentChunkSchema.ts | 19 +- packages/dataset/src/document-chunk/README.md | 14 +- packages/dataset/src/document/Document.ts | 21 +- .../dataset/src/document/DocumentDataset.ts | 25 +- .../src/document/DocumentRepository.ts | 218 ------------------ .../src/document/DocumentStorageSchema.ts | 7 + packages/storage/README.md | 41 ++++ .../src/tabular/BaseSqlTabularStorage.ts | 21 +- .../storage/src/tabular/BaseTabularStorage.ts | 137 ++++++++++- .../src/tabular/CachedTabularStorage.ts | 24 +- .../src/tabular/FsFolderTabularStorage.ts | 76 +++++- .../storage/src/tabular/ITabularStorage.ts | 24 +- .../src/tabular/InMemoryTabularStorage.ts | 85 +++++-- .../src/tabular/IndexedDbTabularStorage.ts | 158 +++++++++---- .../src/tabular/PostgresTabularStorage.ts | 145 +++++++++--- packages/storage/src/tabular/README.md | 172 ++++++++++++++ .../tabular/SharedInMemoryTabularStorage.ts | 24 +- .../src/tabular/SqliteTabularStorage.ts | 182 +++++++++++++-- .../src/tabular/SupabaseTabularStorage.ts | 108 +++++---- packages/storage/src/util/IndexedDbTable.ts | 21 +- packages/storage/src/vector/README.md | 14 +- packages/test/src/test/rag/Document.test.ts | 4 +- .../src/test/rag/DocumentRepository.test.ts | 107 ++++----- packages/test/src/test/rag/EndToEnd.test.ts | 15 +- .../test/src/test/rag/RagWorkflow.test.ts | 8 +- .../InMemoryTabularRepository.test.ts | 18 ++ .../IndexedDbTabularRepository.test.ts | 20 ++ .../PostgresTabularRepository.test.ts | 22 ++ .../SqliteTabularRepository.test.ts | 22 ++ .../genericTabularRepositoryTests.ts | 179 ++++++++++++++ packages/test/src/test/util/Document.test.ts | 4 +- packages/util/src/json-schema/JsonSchema.ts | 1 + 37 files changed, 1467 insertions(+), 555 deletions(-) delete mode 100644 packages/dataset/src/document/DocumentRepository.ts diff --git a/TODO.md b/TODO.md index bab50c98..9c5e90eb 100644 --- a/TODO.md +++ b/TODO.md @@ -10,13 +10,8 @@ TODO.md - [x] Documents dataset (mabye rename to DocumentDataset) - [ ] Chunks Package (or part of DocumentDataset?) - [ ] Move Model repository to datasets package. -- [ ] Chunk Repository - - [ ] Add to packages/tasks or packages/ai - - [ ] Model like Model repository (although that just has one) - - [ ] Model even closer to Document repositories - [ ] Chunks and nodes are not always the same. - [ ] And we may need to save the chunk's node path. Or paths? or document range? Standard metadata? -- [ ] Use Repository to always envelope the storage operations (for transactions, dealing with IDs, etc). - [ ] Instead of passing doc_id around, pass a document key that is unknonwn (string or object) - [ ] Get a better model for question answering. @@ -31,3 +26,16 @@ TODO.md - [ ] fix image transferables onnx-community/ModernBERT-finetuned-squad-ONNX - summarization + +- [x] Auto-generated primary keys for TabularStorage + - [x] Schema annotation with `x-auto-generated: true` + - [x] Type system with `InsertEntity` for optional auto-generated keys + - [x] Support for autoincrement (integer) and UUID (string) strategies + - [x] Configurable client-provided keys: "never", "if-missing", "always" + - [x] Implementations: InMemory, SQLite, Postgres, Supabase, IndexedDB, FsFolder + - [x] Comprehensive test suite (342 tests pass) + - [x] Documentation updated + +Rework the Document Dataset. Currently there is a Document storage of tabular storage type, and that should be registered as a "dataset:document:source" meaning the source material in node format. And there is already a "dataset:document-chunk" for the chunk/vector storage which should be registered as a "dataset:document:chunk" with a well defined metadata schema. The two combined should be registered as a "dataset:document" which is the complete document with its source and all its chunks and metadata. This is for convenience but not used by tasks or ai tasks. + +The sqlitevectorstorage currently does not use a built in vector search. Use @mceachen/sqlite-vec for sqlite storage vector indexing. diff --git a/packages/ai/src/task/ChunkVectorUpsertTask.ts b/packages/ai/src/task/ChunkVectorUpsertTask.ts index 880c89a3..cbe69869 100644 --- a/packages/ai/src/task/ChunkVectorUpsertTask.ts +++ b/packages/ai/src/task/ChunkVectorUpsertTask.ts @@ -118,47 +118,50 @@ export class ChunkVectorUpsertTask extends Task< await context.updateProgress(1, "Upserting vectors"); - const chunk_ids: string[] = []; - // Bulk upsert if multiple items if (vectorArray.length > 1) { if (vectorArray.length !== metadataArray.length) { throw new Error("Mismatch: vectors and metadata arrays must have the same length"); } const entities = vectorArray.map((vector, i) => { - const chunk_id = `${doc_id}_${i}`; const metadataItem = metadataArray[i]; - chunk_ids.push(chunk_id); return { - chunk_id, doc_id, vector, metadata: metadataItem, }; }); - await repo.putBulk(entities); + const results = await repo.putBulk(entities); + const chunk_ids = results.map((r) => r.chunk_id); + return { + doc_id, + chunk_ids, + count: chunk_ids.length, + }; } else if (vectorArray.length === 1) { // Single upsert - const chunk_id = `${doc_id}_0`; const metadataItem = metadataArray[0]; - chunk_ids.push(chunk_id); - await repo.put({ - chunk_id, + const result = await repo.put({ doc_id, vector: vectorArray[0], metadata: metadataItem, }); + return { + doc_id, + chunk_ids: [result.chunk_id], + count: 1, + }; } return { doc_id, - chunk_ids, - count: chunk_ids.length, + chunk_ids: [], + count: 0, }; } } -export const vectorStoreUpsert = ( +export const chunkVectorUpsert = ( input: VectorStoreUpsertTaskInput, config?: JobQueueTaskConfig ) => { @@ -167,7 +170,7 @@ export const vectorStoreUpsert = ( declare module "@workglow/task-graph" { interface Workflow { - vectorStoreUpsert: CreateWorkflow< + chunkVectorUpsert: CreateWorkflow< VectorStoreUpsertTaskInput, VectorStoreUpsertTaskOutput, JobQueueTaskConfig @@ -175,4 +178,4 @@ declare module "@workglow/task-graph" { } } -Workflow.prototype.vectorStoreUpsert = CreateWorkflow(ChunkVectorUpsertTask); +Workflow.prototype.chunkVectorUpsert = CreateWorkflow(ChunkVectorUpsertTask); diff --git a/packages/ai/src/task/HierarchyJoinTask.ts b/packages/ai/src/task/HierarchyJoinTask.ts index 9d13f0a6..87d693f1 100644 --- a/packages/ai/src/task/HierarchyJoinTask.ts +++ b/packages/ai/src/task/HierarchyJoinTask.ts @@ -7,8 +7,9 @@ import { ChunkMetadataArraySchema, EnrichedChunkMetadataArraySchema, + TypeDocumentDataset, type ChunkMetadata, - type DocumentRepository, + type DocumentDataset, } from "@workglow/dataset"; import { CreateWorkflow, @@ -22,10 +23,10 @@ import { DataPortSchema, FromSchema } from "@workglow/util"; const inputSchema = { type: "object", properties: { - documentRepository: { - title: "Document Repository", - description: "The document repository to query for hierarchy", - }, + documents: TypeDocumentDataset({ + title: "Documents", + description: "The documents dataset to query for hierarchy", + }), chunks: { type: "array", items: { type: "string" }, @@ -58,7 +59,7 @@ const inputSchema = { default: true, }, }, - required: ["documentRepository", "chunks", "ids", "metadata", "scores"], + required: ["documents", "chunks", "ids", "metadata", "scores"], additionalProperties: false, } as const satisfies DataPortSchema; @@ -125,7 +126,7 @@ export class HierarchyJoinTask extends Task< context: IExecuteContext ): Promise { const { - documentRepository, + documents, chunks, ids, metadata, @@ -134,7 +135,7 @@ export class HierarchyJoinTask extends Task< includeEntities = true, } = input; - const repo = documentRepository as DocumentRepository; + const repo = documents as DocumentDataset; const enrichedMetadata: any[] = []; for (let i = 0; i < ids.length; i++) { diff --git a/packages/dataset/src/common.ts b/packages/dataset/src/common.ts index a07aafe3..4152fe23 100644 --- a/packages/dataset/src/common.ts +++ b/packages/dataset/src/common.ts @@ -10,7 +10,6 @@ export * from "./document/Document"; export * from "./document/DocumentDataset"; export * from "./document/DocumentDatasetRegistry"; export * from "./document/DocumentNode"; -export * from "./document/DocumentRepository"; export * from "./document/DocumentSchema"; export * from "./document/DocumentStorageSchema"; export * from "./document/StructuralParser"; diff --git a/packages/dataset/src/document-chunk/DocumentChunkDataset.ts b/packages/dataset/src/document-chunk/DocumentChunkDataset.ts index c99da309..5e6bf09f 100644 --- a/packages/dataset/src/document-chunk/DocumentChunkDataset.ts +++ b/packages/dataset/src/document-chunk/DocumentChunkDataset.ts @@ -6,7 +6,12 @@ import type { VectorSearchOptions } from "@workglow/storage"; import type { TypedArray } from "@workglow/util"; -import type { DocumentChunk, DocumentChunkStorage } from "./DocumentChunkSchema"; +import type { + DocumentChunk, + DocumentChunkKey, + DocumentChunkStorage, + InsertDocumentChunk, +} from "./DocumentChunkSchema"; /** * Document Chunk Dataset @@ -32,14 +37,14 @@ export class DocumentChunkDataset { /** * Store a document chunk */ - async put(chunk: DocumentChunk): Promise { + async put(chunk: InsertDocumentChunk): Promise { return this.storage.put(chunk); } /** * Store multiple document chunks */ - async putBulk(chunks: DocumentChunk[]): Promise { + async putBulk(chunks: InsertDocumentChunk[]): Promise { return this.storage.putBulk(chunks); } @@ -47,14 +52,16 @@ export class DocumentChunkDataset { * Get a document chunk by ID */ async get(chunk_id: string): Promise { - return this.storage.get({ chunk_id } as any); + const key: DocumentChunkKey = { chunk_id }; + return this.storage.get(key); } /** * Delete a document chunk */ async delete(chunk_id: string): Promise { - return this.storage.delete({ chunk_id } as any); + const key: DocumentChunkKey = { chunk_id }; + return this.storage.delete(key); } /** diff --git a/packages/dataset/src/document-chunk/DocumentChunkSchema.ts b/packages/dataset/src/document-chunk/DocumentChunkSchema.ts index c490f5c5..e59b72b3 100644 --- a/packages/dataset/src/document-chunk/DocumentChunkSchema.ts +++ b/packages/dataset/src/document-chunk/DocumentChunkSchema.ts @@ -13,7 +13,7 @@ import { TypedArraySchema, type DataPortSchemaObject, type TypedArray } from "@w export const DocumentChunkSchema = { type: "object", properties: { - chunk_id: { type: "string" }, + chunk_id: { type: "string", "x-auto-generated": true }, doc_id: { type: "string" }, vector: TypedArraySchema(), metadata: { type: "object", format: "metadata", additionalProperties: true }, @@ -35,8 +35,23 @@ export interface DocumentChunk< metadata: Metadata; } +/** + * Type for inserting document chunks - chunk_id is optional (auto-generated) + */ +export type InsertDocumentChunk< + Metadata extends Record = Record, + Vector extends TypedArray = TypedArray, +> = Omit, "chunk_id"> & + Partial, "chunk_id">>; + +/** + * Type for the primary key of document chunks + */ +export type DocumentChunkKey = { chunk_id: string }; + export type DocumentChunkStorage = IVectorStorage< Record, typeof DocumentChunkSchema, - DocumentChunk + DocumentChunk, + DocumentChunkPrimaryKey >; diff --git a/packages/dataset/src/document-chunk/README.md b/packages/dataset/src/document-chunk/README.md index b2d2ef04..65725148 100644 --- a/packages/dataset/src/document-chunk/README.md +++ b/packages/dataset/src/document-chunk/README.md @@ -307,13 +307,13 @@ Quantized vectors reduce storage and can improve performance: - **Cons:** Requires PostgreSQL server and pgvector extension - **Setup:** `CREATE EXTENSION vector;` -## Integration with DocumentRepository +## Integration with DocumentDataset -Document chunk storage works alongside `DocumentRepository` for hierarchical document management: +Document chunk storage works alongside `DocumentDataset` for hierarchical document management: ```typescript import { - DocumentRepository, + DocumentDataset, DocumentStorageSchema, DocumentChunkSchema, DocumentChunkPrimaryKey, @@ -332,14 +332,14 @@ const vectorStorage = new InMemoryVectorStorage( ); await vectorStorage.setupDatabase(); -// Create document repository with both storages -const docRepo = new DocumentRepository(tabularStorage, vectorStorage); +// Create document dataset with both storages +const docDataset = new DocumentDataset(tabularStorage, vectorStorage); // Store document structure in tabular, chunks in vector -await docRepo.upsert(document); +await docDataset.upsert(document); // Search chunks by vector similarity -const results = await docRepo.search(queryVector, { topK: 5 }); +const results = await docDataset.search(queryVector, { topK: 5 }); ``` ### Chunk Metadata for Hierarchical Documents diff --git a/packages/dataset/src/document/Document.ts b/packages/dataset/src/document/Document.ts index d0351f4d..4edecbfa 100644 --- a/packages/dataset/src/document/Document.ts +++ b/packages/dataset/src/document/Document.ts @@ -15,16 +15,16 @@ import type { ChunkNode, DocumentMetadata, DocumentNode } from "./DocumentSchema * - Separate persistence for document structure vs vectors */ export class Document { - public readonly doc_id: string; + public doc_id: string | undefined; public readonly metadata: DocumentMetadata; public readonly root: DocumentNode; private chunks: ChunkNode[]; constructor( - doc_id: string, root: DocumentNode, metadata: DocumentMetadata, - chunks: ChunkNode[] = [] + chunks: ChunkNode[] = [], + doc_id?: string ) { this.doc_id = doc_id; this.root = root; @@ -46,6 +46,13 @@ export class Document { return this.chunks; } + /** + * Set the document ID + */ + setDocId(doc_id: string): void { + this.doc_id = doc_id; + } + /** * Find chunks by nodeId */ @@ -57,13 +64,11 @@ export class Document { * Serialize to JSON */ toJSON(): { - doc_id: string; metadata: DocumentMetadata; root: DocumentNode; chunks: ChunkNode[]; } { return { - doc_id: this.doc_id, metadata: this.metadata, root: this.root, chunks: this.chunks, @@ -73,9 +78,9 @@ export class Document { /** * Deserialize from JSON */ - static fromJSON(json: string): Document { + static fromJSON(json: string, doc_id?: string): Document { const obj = JSON.parse(json); - const doc = new Document(obj.doc_id, obj.root, obj.metadata, obj.chunks); - return doc; + return new Document(obj.root, obj.metadata, obj.chunks, doc_id); } + } diff --git a/packages/dataset/src/document/DocumentDataset.ts b/packages/dataset/src/document/DocumentDataset.ts index 763dbb44..eb4ceac9 100644 --- a/packages/dataset/src/document/DocumentDataset.ts +++ b/packages/dataset/src/document/DocumentDataset.ts @@ -9,7 +9,11 @@ import type { TypedArray } from "@workglow/util"; import type { DocumentChunk, DocumentChunkStorage } from "../document-chunk/DocumentChunkSchema"; import { Document } from "./Document"; import { ChunkNode, DocumentNode } from "./DocumentSchema"; -import { DocumentStorageEntity, DocumentTabularStorage } from "./DocumentStorageSchema"; +import { + DocumentStorageEntity, + DocumentTabularStorage, + InsertDocumentStorageEntity, +} from "./DocumentStorageSchema"; /** * Document dataset that uses TabularStorage for document persistence and VectorStorage for chunk persistence and similarity search. @@ -44,13 +48,22 @@ export class DocumentDataset { /** * Upsert a document + * @returns The document with the generated doc_id if it was auto-generated */ - async upsert(document: Document): Promise { - const serialized = JSON.stringify(document.toJSON ? document.toJSON() : document); - await this.tabularStorage.put({ + async upsert(document: Document): Promise { + const serialized = JSON.stringify(document.toJSON()); + + const insertEntity: InsertDocumentStorageEntity = { doc_id: document.doc_id, data: serialized, - }); + }; + const entity = await this.tabularStorage.put(insertEntity); + + // If doc_id was auto-generated, return document with the generated ID + if (document.doc_id !== entity.doc_id) { + document.setDocId(entity.doc_id); + } + return document; } /** @@ -61,7 +74,7 @@ export class DocumentDataset { if (!entity) { return undefined; } - return Document.fromJSON(entity.data); + return Document.fromJSON(entity.data, entity.doc_id); } /** diff --git a/packages/dataset/src/document/DocumentRepository.ts b/packages/dataset/src/document/DocumentRepository.ts deleted file mode 100644 index b26a0620..00000000 --- a/packages/dataset/src/document/DocumentRepository.ts +++ /dev/null @@ -1,218 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { AnyVectorStorage, ITabularStorage, VectorSearchOptions } from "@workglow/storage"; -import type { TypedArray } from "@workglow/util"; -import type { DocumentChunk } from "../document-chunk/DocumentChunkSchema"; -import { Document } from "./Document"; -import { ChunkNode, DocumentNode } from "./DocumentSchema"; -import { - DocumentStorageEntity, - DocumentStorageKey, - DocumentStorageSchema, -} from "./DocumentStorageSchema"; -/** - * Document repository that uses TabularStorage for persistence and VectorStorage for search. - * This is a unified implementation that composes storage backends rather than using - * inheritance/interface patterns. - */ -export class DocumentRepository { - private tabularStorage: ITabularStorage< - DocumentStorageSchema, - DocumentStorageKey, - DocumentStorageEntity - >; - private vectorStorage?: AnyVectorStorage; - - /** - * Creates a new DocumentRepository instance. - * - * @param tabularStorage - Pre-initialized tabular storage for document persistence - * @param vectorStorage - Pre-initialized vector storage for chunk similarity search - * - * @example - * ```typescript - * const tabularStorage = new InMemoryTabularStorage(DocumentStorageSchema, ["doc_id"]); - * await tabularStorage.setupDatabase(); - * - * const vectorStorage = new InMemoryVectorStorage(); - * await vectorStorage.setupDatabase(); - * - * const docRepo = new DocumentRepository(tabularStorage, vectorStorage); - * ``` - */ - constructor( - tabularStorage: ITabularStorage< - typeof DocumentStorageSchema, - ["doc_id"], - DocumentStorageEntity - >, - vectorStorage?: AnyVectorStorage - ) { - this.tabularStorage = tabularStorage; - this.vectorStorage = vectorStorage; - } - - /** - * Upsert a document - */ - async upsert(document: Document): Promise { - const serialized = JSON.stringify(document.toJSON ? document.toJSON() : document); - await this.tabularStorage.put({ - doc_id: document.doc_id, - data: serialized, - }); - } - - /** - * Get a document by ID - */ - async get(doc_id: string): Promise { - const entity = await this.tabularStorage.get({ doc_id: doc_id }); - if (!entity) { - return undefined; - } - return Document.fromJSON(entity.data); - } - - /** - * Delete a document - */ - async delete(doc_id: string): Promise { - await this.tabularStorage.delete({ doc_id: doc_id }); - } - - /** - * Get a specific node by ID - */ - async getNode(doc_id: string, nodeId: string): Promise { - const doc = await this.get(doc_id); - if (!doc) { - return undefined; - } - - // Traverse tree to find node - const traverse = (node: any): any => { - if (node.nodeId === nodeId) { - return node; - } - if (node.children && Array.isArray(node.children)) { - for (const child of node.children) { - const found = traverse(child); - if (found) return found; - } - } - return undefined; - }; - - return traverse(doc.root); - } - - /** - * Get ancestors of a node (from root to node) - */ - async getAncestors(doc_id: string, nodeId: string): Promise { - const doc = await this.get(doc_id); - if (!doc) { - return []; - } - - // Get path from root to target node - const path: string[] = []; - const findPath = (node: any): boolean => { - path.push(node.nodeId); - if (node.nodeId === nodeId) { - return true; - } - if (node.children && Array.isArray(node.children)) { - for (const child of node.children) { - if (findPath(child)) { - return true; - } - } - } - path.pop(); - return false; - }; - - if (!findPath(doc.root)) { - return []; - } - - // Collect nodes along the path - const ancestors: any[] = []; - let currentNode: any = doc.root; - ancestors.push(currentNode); - - for (let i = 1; i < path.length; i++) { - const targetId = path[i]; - if (currentNode.children && Array.isArray(currentNode.children)) { - const found = currentNode.children.find((child: any) => child.nodeId === targetId); - if (found) { - currentNode = found; - ancestors.push(currentNode); - } else { - break; - } - } else { - break; - } - } - - return ancestors; - } - - /** - * Get chunks for a document - */ - async getChunks(doc_id: string): Promise { - const doc = await this.get(doc_id); - if (!doc) { - return []; - } - return doc.getChunks(); - } - - /** - * Find chunks that contain a specific nodeId in their path - */ - async findChunksByNodeId(doc_id: string, nodeId: string): Promise { - const doc = await this.get(doc_id); - if (!doc) { - return []; - } - if (doc.findChunksByNodeId) { - return doc.findChunksByNodeId(nodeId); - } - // Fallback implementation - const chunks = doc.getChunks(); - return chunks.filter((chunk) => chunk.nodePath && chunk.nodePath.includes(nodeId)); - } - - /** - * List all document IDs - */ - async list(): Promise { - const entities = await this.tabularStorage.getAll(); - if (!entities) { - return []; - } - return entities.map((e) => e.doc_id); - } - - /** - * Search for similar vectors using the vector storage - * @param query - Query vector to search for - * @param options - Search options (topK, filter, scoreThreshold) - * @returns Array of search results sorted by similarity - */ - async search( - query: TypedArray, - options?: VectorSearchOptions> - ): Promise, TypedArray> & { score: number }>> { - return (this.vectorStorage?.similaritySearch(query, options) || []) as any; - } -} diff --git a/packages/dataset/src/document/DocumentStorageSchema.ts b/packages/dataset/src/document/DocumentStorageSchema.ts index d65eb454..73598491 100644 --- a/packages/dataset/src/document/DocumentStorageSchema.ts +++ b/packages/dataset/src/document/DocumentStorageSchema.ts @@ -19,6 +19,7 @@ export const DocumentStorageSchema = { properties: { doc_id: { type: "string", + "x-auto-generated": true, title: "Document ID", description: "Unique identifier for the document", }, @@ -43,6 +44,12 @@ export type DocumentStorageKey = typeof DocumentStorageKey; export type DocumentStorageEntity = FromSchema; +/** + * Type for inserting documents - doc_id is optional (auto-generated) + */ +export type InsertDocumentStorageEntity = Omit & + Partial>; + export type DocumentTabularStorage = ITabularStorage< typeof DocumentStorageSchema, DocumentStorageKey, diff --git a/packages/storage/README.md b/packages/storage/README.md index f65c45de..a22cecb8 100644 --- a/packages/storage/README.md +++ b/packages/storage/README.md @@ -257,6 +257,47 @@ const userRepo = new InMemoryTabularStorage( ); ``` +#### Auto-Generated Primary Keys + +TabularStorage supports automatic ID generation by marking schema properties with `x-auto-generated: true`: + +```typescript +const UserSchema = { + type: "object", + properties: { + id: { type: "integer", "x-auto-generated": true }, // Auto-increment + name: { type: "string" }, + email: { type: "string" }, + }, + required: ["id", "name", "email"], +} as const; + +const storage = new PostgresTabularStorage(db, "users", UserSchema, ["id"] as const); + +// Insert without providing ID - database generates it +const user = await storage.put({ name: "Alice", email: "alice@example.com" }); +console.log(user.id); // 1 (auto-generated) +``` + +**Generation Strategy** (inferred from type): +- `type: "integer"` → Auto-increment (SERIAL, INTEGER PRIMARY KEY, counter) +- `type: "string"` → UUID via `uuid4()` from `@workglow/util` + +**Configuration Options:** + +```typescript +new PostgresTabularStorage( + db, "users", UserSchema, ["id"], [], + { clientProvidedKeys: "if-missing" } // "never" | "if-missing" | "always" +); +``` + +- `"if-missing"` (default): Use client value if provided, generate otherwise +- `"never"`: Always generate (most secure) +- `"always"`: Require client value (for testing) + +See [Tabular Storage README](./src/tabular/README.md) for detailed documentation. + #### CRUD Operations ```typescript diff --git a/packages/storage/src/tabular/BaseSqlTabularStorage.ts b/packages/storage/src/tabular/BaseSqlTabularStorage.ts index 1f21bfdd..2a3a89a1 100644 --- a/packages/storage/src/tabular/BaseSqlTabularStorage.ts +++ b/packages/storage/src/tabular/BaseSqlTabularStorage.ts @@ -10,8 +10,13 @@ import { JsonSchema, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseTabularStorage } from "./BaseTabularStorage"; -import { SimplifyPrimaryKey, ValueOptionType } from "./ITabularStorage"; +import { BaseTabularStorage, ClientProvidedKeysOption } from "./BaseTabularStorage"; +import { + AutoGeneratedKeys, + InsertEntity, + SimplifyPrimaryKey, + ValueOptionType, +} from "./ITabularStorage"; // BaseTabularStorage is a tabular store that uses SQLite and Postgres use as common code @@ -29,7 +34,11 @@ export abstract class BaseSqlTabularStorage< Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, Value = Omit, -> extends BaseTabularStorage { + InsertType extends InsertEntity> = InsertEntity< + Entity, + AutoGeneratedKeys + >, +> extends BaseTabularStorage { /** * Creates a new instance of BaseSqlTabularStorage * @param table - The name of the database table to use for storage @@ -37,14 +46,16 @@ export abstract class BaseSqlTabularStorage< * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( protected readonly table: string = "tabular_store", schema: Schema, primaryKeyNames: PrimaryKeyNames, - indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [] + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { - super(schema, primaryKeyNames, indexes); + super(schema, primaryKeyNames, indexes, clientProvidedKeys); this.validateTableAndSchema(); } diff --git a/packages/storage/src/tabular/BaseTabularStorage.ts b/packages/storage/src/tabular/BaseTabularStorage.ts index 16beec5b..582ed370 100644 --- a/packages/storage/src/tabular/BaseTabularStorage.ts +++ b/packages/storage/src/tabular/BaseTabularStorage.ts @@ -14,7 +14,9 @@ import { } from "@workglow/util"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, ITabularStorage, SimplifyPrimaryKey, TabularChangePayload, @@ -30,6 +32,16 @@ export const TABULAR_REPOSITORY = createServiceToken( "storage.tabularRepository" ); +/** + * Options for controlling how client-provided values for auto-generated keys are handled + */ +export type ClientProvidedKeysOption = "never" | "if-missing" | "always"; + +/** + * Generation strategy for auto-generated keys + */ +export type KeyGenerationStrategy = "autoincrement" | "uuid"; + /** * Abstract base class for tabular storage repositories. * Provides functionality for storing and retrieving data with typed @@ -46,7 +58,8 @@ export abstract class BaseTabularStorage< Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, Value = Omit, -> implements ITabularStorage { + InsertType = InsertEntity>, +> implements ITabularStorage { /** Event emitter for repository events */ protected events = new EventEmitter>(); @@ -54,18 +67,28 @@ export abstract class BaseTabularStorage< protected primaryKeySchema: DataPortSchemaObject; protected valueSchema: DataPortSchemaObject; + /** Name of the auto-generated key column (only first primary key column can be auto-generated) */ + protected autoGeneratedKeyName: keyof Entity | null = null; + /** Strategy for generating the auto-generated key */ + protected autoGeneratedKeyStrategy: KeyGenerationStrategy | null = null; + /** How to handle client-provided values for auto-generated keys */ + protected clientProvidedKeys: ClientProvidedKeysOption; + /** * Creates a new BaseTabularStorage instance * @param schema - Schema defining the structure of the entity * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( protected schema: Schema, protected primaryKeyNames: PrimaryKeyNames, - indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [] + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { + this.clientProvidedKeys = clientProvidedKeys; const primaryKeyProps: Record = {}; const valueProps: Record = {}; const primaryKeySet = new Set(primaryKeyNames); @@ -138,6 +161,44 @@ export abstract class BaseTabularStorage< } } } + + // Detect and validate auto-generated keys + // Only the first primary key column can be auto-generated + const autoGeneratedKeys: string[] = []; + for (const key of primaryKeyNames) { + const keyStr = String(key); + const propDef = (schema.properties as any)[keyStr]; + if (propDef && typeof propDef === "object" && "x-auto-generated" in propDef) { + if (propDef["x-auto-generated"] === true) { + autoGeneratedKeys.push(keyStr); + } + } + } + + if (autoGeneratedKeys.length > 1) { + throw new Error( + `Multiple auto-generated keys detected: ${autoGeneratedKeys.join(", ")}. ` + + `Only the first primary key column can be auto-generated.` + ); + } + + if (autoGeneratedKeys.length > 0) { + const autoGenKeyName = autoGeneratedKeys[0]; + const firstPrimaryKey = String(primaryKeyNames[0]); + + if (autoGenKeyName !== firstPrimaryKey) { + throw new Error( + `Auto-generated key "${autoGenKeyName}" must be the first primary key column. ` + + `Current first primary key is "${firstPrimaryKey}".` + ); + } + + this.autoGeneratedKeyName = autoGenKeyName as keyof Entity; + this.autoGeneratedKeyStrategy = this.determineGenerationStrategy( + autoGenKeyName, + schema.properties[autoGenKeyName as keyof Schema["properties"]] + ); + } } protected filterCompoundKeys( @@ -239,8 +300,8 @@ export abstract class BaseTabularStorage< /** * Core abstract methods that must be implemented by concrete repositories */ - abstract put(value: Entity): Promise; - abstract putBulk(values: Entity[]): Promise; + abstract put(value: InsertType): Promise; + abstract putBulk(values: InsertType[]): Promise; abstract get(key: PrimaryKey): Promise; abstract delete(key: PrimaryKey | Entity): Promise; abstract getAll(): Promise; @@ -400,6 +461,74 @@ export abstract class BaseTabularStorage< return bestMatch; } + /** + * Checks if this storage has an auto-generated key configured + * @returns true if an auto-generated key is configured + */ + protected hasAutoGeneratedKey(): boolean { + return this.autoGeneratedKeyName !== null; + } + + /** + * Checks if a given column name is the auto-generated key + * @param name - Column name to check + * @returns true if the column is the auto-generated key + */ + protected isAutoGeneratedKey(name: string): boolean { + return this.autoGeneratedKeyName !== null && String(this.autoGeneratedKeyName) === name; + } + + /** + * Determines the generation strategy for an auto-generated key based on its type + * @param columnName - Name of the column + * @param typeDef - JSON Schema type definition for the column + * @returns The generation strategy to use + */ + protected determineGenerationStrategy( + columnName: string, + typeDef: any + ): KeyGenerationStrategy { + // Extract the actual type if it's a union with null + let actualType = typeDef; + if (typeDef && typeof typeDef === "object") { + if (typeDef.anyOf && Array.isArray(typeDef.anyOf)) { + actualType = typeDef.anyOf.find((t: any) => t.type !== "null") || typeDef; + } else if (typeDef.oneOf && Array.isArray(typeDef.oneOf)) { + actualType = typeDef.oneOf.find((t: any) => t.type !== "null") || typeDef; + } + } + + if (typeof actualType !== "object") { + return "uuid"; + } + + // Integer types use autoincrement + if (actualType.type === "integer") { + return "autoincrement"; + } + + // Default to UUID for strings and other types + return "uuid"; + } + + /** + * Generates a key value for client-side key generation + * Override in storage classes that generate keys client-side (InMemory, IndexedDB for UUIDs) + * SQL-based storages typically generate keys server-side + * @param columnName - Name of the column to generate a key for + * @param strategy - The generation strategy to use + * @returns The generated key value + */ + protected generateKeyValue( + columnName: string, + strategy: KeyGenerationStrategy + ): Promise | string | number { + throw new Error( + `generateKeyValue not implemented for ${this.constructor.name}. ` + + `Column: ${columnName}, Strategy: ${strategy}` + ); + } + /** * Sets up the database/storage for the repository. * Must be called before using any other methods (except for in-memory implementations). diff --git a/packages/storage/src/tabular/CachedTabularStorage.ts b/packages/storage/src/tabular/CachedTabularStorage.ts index 199a516e..429951b6 100644 --- a/packages/storage/src/tabular/CachedTabularStorage.ts +++ b/packages/storage/src/tabular/CachedTabularStorage.ts @@ -10,11 +10,13 @@ import { FromSchema, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseTabularStorage } from "./BaseTabularStorage"; +import { BaseTabularStorage, ClientProvidedKeysOption } from "./BaseTabularStorage"; import { InMemoryTabularStorage } from "./InMemoryTabularStorage"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, ITabularStorage, SimplifyPrimaryKey, TabularSubscribeOptions, @@ -38,7 +40,12 @@ export class CachedTabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularStorage { + Value = Omit, + InsertType extends InsertEntity> = InsertEntity< + Entity, + AutoGeneratedKeys + >, +> extends BaseTabularStorage { public readonly cache: ITabularStorage; private durable: ITabularStorage; private cacheInitialized = false; @@ -52,13 +59,15 @@ export class CachedTabularStorage< * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( durable: ITabularStorage, cache?: ITabularStorage, schema?: Schema, primaryKeyNames?: PrimaryKeyNames, - indexes?: readonly (keyof Entity | readonly (keyof Entity)[])[] + indexes?: readonly (keyof Entity | readonly (keyof Entity)[])[], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { // Extract schema and primaryKeyNames from durable repository if not provided // Note: This is a limitation - we can't always extract these from an interface @@ -69,7 +78,7 @@ export class CachedTabularStorage< ); } - super(schema, primaryKeyNames, indexes || []); + super(schema, primaryKeyNames, indexes || [], clientProvidedKeys); this.durable = durable; // Create cache if not provided @@ -79,7 +88,8 @@ export class CachedTabularStorage< this.cache = new InMemoryTabularStorage( schema, primaryKeyNames, - indexes || [] + indexes || [], + clientProvidedKeys ); } @@ -133,7 +143,7 @@ export class CachedTabularStorage< * @returns The stored entity * @emits 'put' event with the stored entity when successful */ - async put(value: Entity): Promise { + async put(value: InsertType): Promise { await this.initializeCache(); // Write to durable first (source of truth) @@ -151,7 +161,7 @@ export class CachedTabularStorage< * @returns Array of stored entities * @emits 'put' event for each value stored */ - async putBulk(values: Entity[]): Promise { + async putBulk(values: InsertType[]): Promise { await this.initializeCache(); // Write to durable first (source of truth) diff --git a/packages/storage/src/tabular/FsFolderTabularStorage.ts b/packages/storage/src/tabular/FsFolderTabularStorage.ts index f3b531f9..a179b0dc 100644 --- a/packages/storage/src/tabular/FsFolderTabularStorage.ts +++ b/packages/storage/src/tabular/FsFolderTabularStorage.ts @@ -11,14 +11,17 @@ import { makeFingerprint, sleep, TypedArraySchemaOptions, + uuid4, } from "@workglow/util"; import { mkdir, readdir, readFile, rm, writeFile } from "node:fs/promises"; import path from "node:path"; import { PollingSubscriptionManager } from "../util/PollingSubscriptionManager"; -import { BaseTabularStorage } from "./BaseTabularStorage"; +import { BaseTabularStorage, ClientProvidedKeysOption, KeyGenerationStrategy } from "./BaseTabularStorage"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, SimplifyPrimaryKey, TabularChangePayload, TabularSubscribeOptions, @@ -41,8 +44,12 @@ export class FsFolderTabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularStorage { + Value = Omit, + InsertType = InsertEntity>, +> extends BaseTabularStorage { private folderPath: string; + /** Counter for auto-incrementing integer keys */ + private autoIncrementCounter = 0; /** Shared polling subscription manager */ private pollingManager: PollingSubscriptionManager< Entity, @@ -57,14 +64,16 @@ export class FsFolderTabularStorage< * @param schema - Schema defining the structure of the entity * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Note: indexes are not supported in this implementation. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( folderPath: string, schema: Schema, primaryKeyNames: PrimaryKeyNames, - indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [] + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { - super(schema, primaryKeyNames, indexes); + super(schema, primaryKeyNames, indexes, clientProvidedKeys); this.folderPath = path.join(folderPath); } @@ -85,17 +94,60 @@ export class FsFolderTabularStorage< } } + /** + * Generates a key value for auto-generated keys + * @param columnName - Name of the column to generate a key for + * @param strategy - The generation strategy to use + * @returns The generated key value + */ + protected generateKeyValue(columnName: string, strategy: KeyGenerationStrategy): string | number { + if (strategy === "autoincrement") { + return ++this.autoIncrementCounter; + } else { + return uuid4(); + } + } + /** * Stores a row in the repository - * @param entity - The entity to store + * @param entity - The entity to store (may be missing auto-generated keys) * @returns The stored entity * @emits 'put' event when successful */ - async put(entity: Entity): Promise { + async put(entity: InsertType): Promise { + let entityToStore = entity as unknown as Entity; + + // Handle auto-generated keys + if (this.hasAutoGeneratedKey() && this.autoGeneratedKeyName) { + const keyName = this.autoGeneratedKeyName as string; + const clientProvidedValue = (entity as any)[keyName]; + const hasClientValue = clientProvidedValue !== undefined && clientProvidedValue !== null; + + let shouldGenerate = false; + if (this.clientProvidedKeys === "never") { + shouldGenerate = true; + } else if (this.clientProvidedKeys === "always") { + if (!hasClientValue) { + throw new Error( + `Auto-generated key "${keyName}" is required when clientProvidedKeys is "always"` + ); + } + shouldGenerate = false; + } else { + // "if-missing" + shouldGenerate = !hasClientValue; + } + + if (shouldGenerate) { + const generatedValue = this.generateKeyValue(keyName, this.autoGeneratedKeyStrategy!); + entityToStore = { ...entity, [keyName]: generatedValue } as Entity; + } + } + await this.setupDirectory(); - const filePath = await this.getFilePath(entity); + const filePath = await this.getFilePath(entityToStore); try { - await writeFile(filePath, JSON.stringify(entity)); + await writeFile(filePath, JSON.stringify(entityToStore)); } catch (error) { try { // CI system sometimes has issues temporarily @@ -105,17 +157,17 @@ export class FsFolderTabularStorage< console.error("Error writing file", filePath, error); } } - this.events.emit("put", entity); - return entity; + this.events.emit("put", entityToStore); + return entityToStore; } /** * Stores multiple rows in the repository in a bulk operation - * @param entities - Array of entities to store + * @param entities - Array of entities to store (may be missing auto-generated keys) * @returns Array of stored entities * @emits 'put' event for each entity stored */ - async putBulk(entities: Entity[]): Promise { + async putBulk(entities: InsertType[]): Promise { await this.setupDirectory(); return await Promise.all(entities.map(async (entity) => this.put(entity))); } diff --git a/packages/storage/src/tabular/ITabularStorage.ts b/packages/storage/src/tabular/ITabularStorage.ts index 126df293..0665d9dd 100644 --- a/packages/storage/src/tabular/ITabularStorage.ts +++ b/packages/storage/src/tabular/ITabularStorage.ts @@ -124,6 +124,23 @@ export type SimplifyPrimaryKey< KeyName extends ReadonlyArray, > = Entity extends any ? Pick> : never; +/** + * Extracts property names marked as auto-generated from the schema. + * Properties with `x-auto-generated: true` are considered auto-generated. + */ +export type AutoGeneratedKeys = { + [K in keyof Schema["properties"]]: Schema["properties"][K] extends { "x-auto-generated": true } + ? K + : never; +}[keyof Schema["properties"]]; + +/** + * Entity type for insertion - auto-generated keys are optional. + * This allows clients to omit auto-generated keys when inserting entities. + */ +export type InsertEntity = Omit & + Partial>; + /** * Interface defining the contract for tabular storage repositories. * Provides a flexible interface for storing and retrieving data with typed @@ -138,10 +155,11 @@ export interface ITabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, + InsertType = InsertEntity>, > { // Core methods - put(value: Entity): Promise; - putBulk(values: Entity[]): Promise; + put(value: InsertType): Promise; + putBulk(values: InsertType[]): Promise; get(key: PrimaryKey): Promise; delete(key: PrimaryKey | Entity): Promise; getAll(): Promise; @@ -212,4 +230,4 @@ export interface ITabularStorage< [Symbol.asyncDispose](): Promise; } -export type AnyTabularStorage = ITabularStorage; +export type AnyTabularStorage = ITabularStorage; diff --git a/packages/storage/src/tabular/InMemoryTabularStorage.ts b/packages/storage/src/tabular/InMemoryTabularStorage.ts index b4057017..b9fc616b 100644 --- a/packages/storage/src/tabular/InMemoryTabularStorage.ts +++ b/packages/storage/src/tabular/InMemoryTabularStorage.ts @@ -10,11 +10,14 @@ import { FromSchema, makeFingerprint, TypedArraySchemaOptions, + uuid4, } from "@workglow/util"; -import { BaseTabularStorage } from "./BaseTabularStorage"; +import { BaseTabularStorage, ClientProvidedKeysOption, KeyGenerationStrategy } from "./BaseTabularStorage"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, isSearchCondition, SimplifyPrimaryKey, TabularChangePayload, @@ -38,9 +41,16 @@ export class InMemoryTabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularStorage { + Value = Omit, + InsertType extends InsertEntity> = InsertEntity< + Entity, + AutoGeneratedKeys + >, +> extends BaseTabularStorage { /** Internal storage using a Map with fingerprint strings as keys */ values = new Map(); + /** Counter for auto-incrementing integer keys */ + private autoIncrementCounter = 0; /** * Creates a new InMemoryTabularStorage instance @@ -48,13 +58,15 @@ export class InMemoryTabularStorage< * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( schema: Schema, primaryKeyNames: PrimaryKeyNames, - indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [] + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { - super(schema, primaryKeyNames, indexes); + super(schema, primaryKeyNames, indexes, clientProvidedKeys); } /** @@ -64,27 +76,72 @@ export class InMemoryTabularStorage< // No setup needed for in-memory storage } + /** + * Generates a key value for auto-generated keys + * @param columnName - Name of the column to generate a key for + * @param strategy - The generation strategy to use + * @returns The generated key value + */ + protected generateKeyValue(columnName: string, strategy: KeyGenerationStrategy): string | number { + if (strategy === "autoincrement") { + return ++this.autoIncrementCounter; + } else { + return uuid4(); + } + } + /** * Stores a key-value pair in the repository - * @param value - The combined object to store - * @returns The stored entity + * @param value - The combined object to store (may be missing auto-generated keys) + * @returns The stored entity with all keys filled in * @emits 'put' event with the stored entity when successful */ - async put(value: Entity): Promise { - const { key } = this.separateKeyValueFromCombined(value); + async put(value: InsertType): Promise { + let entityToStore = value as unknown as Entity; + + // Handle auto-generated keys + if (this.hasAutoGeneratedKey() && this.autoGeneratedKeyName) { + const keyName = this.autoGeneratedKeyName as string; + const clientProvidedValue = (value as any)[keyName]; + const hasClientValue = clientProvidedValue !== undefined && clientProvidedValue !== null; + + let shouldGenerate = false; + if (this.clientProvidedKeys === "never") { + // Always generate, ignore client value + shouldGenerate = true; + } else if (this.clientProvidedKeys === "always") { + // Always use client value, error if missing + if (!hasClientValue) { + throw new Error( + `Auto-generated key "${keyName}" is required when clientProvidedKeys is "always"` + ); + } + shouldGenerate = false; + } else { + // "if-missing" - generate only if client didn't provide + shouldGenerate = !hasClientValue; + } + + if (shouldGenerate) { + const generatedValue = this.generateKeyValue(keyName, this.autoGeneratedKeyStrategy!); + entityToStore = { ...value, [keyName]: generatedValue } as Entity; + } + } + + const { key } = this.separateKeyValueFromCombined(entityToStore); const id = await makeFingerprint(key); - this.values.set(id, value); - this.events.emit("put", value); - return value; + this.values.set(id, entityToStore); + this.events.emit("put", entityToStore); + return entityToStore; } /** * Stores multiple key-value pairs in the repository in a bulk operation - * @param values - Array of combined objects to store - * @returns Array of stored entities + * @param values - Array of combined objects to store (may be missing auto-generated keys) + * @returns Array of stored entities with all keys filled in * @emits 'put' event for each value stored */ - async putBulk(values: Entity[]): Promise { + async putBulk(values: InsertType[]): Promise { return await Promise.all(values.map(async (value) => this.put(value))); } diff --git a/packages/storage/src/tabular/IndexedDbTabularStorage.ts b/packages/storage/src/tabular/IndexedDbTabularStorage.ts index 428791ba..8605c784 100644 --- a/packages/storage/src/tabular/IndexedDbTabularStorage.ts +++ b/packages/storage/src/tabular/IndexedDbTabularStorage.ts @@ -10,6 +10,7 @@ import { FromSchema, makeFingerprint, TypedArraySchemaOptions, + uuid4, } from "@workglow/util"; import { HybridSubscriptionManager } from "../util/HybridSubscriptionManager"; import { @@ -17,10 +18,12 @@ import { ExpectedIndexDefinition, MigrationOptions, } from "../util/IndexedDbTable"; -import { BaseTabularStorage } from "./BaseTabularStorage"; +import { BaseTabularStorage, ClientProvidedKeysOption, KeyGenerationStrategy } from "./BaseTabularStorage"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, isSearchCondition, SearchOperator, SimplifyPrimaryKey, @@ -44,7 +47,12 @@ export class IndexedDbTabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularStorage { + Value = Omit, + InsertType extends InsertEntity> = InsertEntity< + Entity, + AutoGeneratedKeys + >, +> extends BaseTabularStorage { /** Promise that resolves to the IndexedDB database instance */ private db: IDBDatabase | undefined; /** Promise to track ongoing database setup to prevent concurrent setup calls */ @@ -71,6 +79,7 @@ export class IndexedDbTabularStorage< * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. * @param migrationOptions - Options for handling database schema migrations + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( public table: string = "tabular_store", @@ -80,9 +89,10 @@ export class IndexedDbTabularStorage< migrationOptions: MigrationOptions & { readonly useBroadcastChannel?: boolean; readonly backupPollingIntervalMs?: number; - } = {} + } = {}, + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { - super(schema, primaryKeyNames, indexes); + super(schema, primaryKeyNames, indexes, clientProvidedKeys); this.migrationOptions = migrationOptions; this.hybridOptions = { useBroadcastChannel: migrationOptions.useBroadcastChannel ?? true, @@ -149,35 +159,117 @@ export class IndexedDbTabularStorage< const primaryKey = pkColumns.length === 1 ? pkColumns[0] : pkColumns; + // Determine if we should use autoIncrement + // IndexedDB autoIncrement only works with single numeric keys + const useAutoIncrement = + this.hasAutoGeneratedKey() && + this.autoGeneratedKeyStrategy === "autoincrement" && + pkColumns.length === 1; + // Ensure that our table is created/upgraded only if the structure (indexes) has changed. return await ensureIndexedDbTable( this.table, primaryKey, expectedIndexes, - this.migrationOptions + this.migrationOptions, + useAutoIncrement + ); + } + + /** + * Generates a key value for UUID keys + * Integer autoincrement keys are handled by IndexedDB's autoIncrement + * @param columnName - Name of the column to generate a key for + * @param strategy - The generation strategy to use + * @returns The generated key value + */ + protected generateKeyValue(columnName: string, strategy: KeyGenerationStrategy): string | number { + if (strategy === "uuid") { + return uuid4(); + } + // autoincrement is handled by IndexedDB's autoIncrement option + throw new Error( + `IndexedDB autoincrement keys are generated by the database, not client-side. Column: ${columnName}` ); } /** * Stores a row in the repository. - * @param record - The entity to store. + * @param record - The entity to store (may be missing auto-generated keys). * @returns The stored entity * @emits put - Emitted when the value is successfully stored */ - async put(record: Entity): Promise { + async put(record: InsertType): Promise { const db = await this.getDb(); - const { key } = this.separateKeyValueFromCombined(record); + let recordToStore = record as unknown as Entity; + + // Handle auto-generated keys + if (this.hasAutoGeneratedKey() && this.autoGeneratedKeyName) { + const keyName = String(this.autoGeneratedKeyName); + const clientProvidedValue = (record as any)[keyName]; + const hasClientValue = clientProvidedValue !== undefined && clientProvidedValue !== null; + + if (this.autoGeneratedKeyStrategy === "uuid") { + // UUID generation - must be done client-side + let shouldGenerate = false; + if (this.clientProvidedKeys === "never") { + shouldGenerate = true; + } else if (this.clientProvidedKeys === "always") { + if (!hasClientValue) { + throw new Error( + `Auto-generated key "${keyName}" is required when clientProvidedKeys is "always"` + ); + } + shouldGenerate = false; + } else { + // "if-missing" + shouldGenerate = !hasClientValue; + } + + if (shouldGenerate) { + const generatedValue = this.generateKeyValue(keyName, "uuid"); + recordToStore = { ...record, [keyName]: generatedValue } as Entity; + } + } else if (this.autoGeneratedKeyStrategy === "autoincrement") { + // Autoincrement handled by IndexedDB + // If clientProvidedKeys is "always", require the value + if (this.clientProvidedKeys === "always" && !hasClientValue) { + throw new Error( + `Auto-generated key "${keyName}" is required when clientProvidedKeys is "always"` + ); + } + // If clientProvidedKeys is "never", omit the key to let IDB generate + if (this.clientProvidedKeys === "never") { + const { [keyName]: _, ...rest } = record as Record; + recordToStore = rest as Entity; + } + // "if-missing": use client value if provided, omit if not + } + } + // Merge key and value, ensuring all fields are at the root level for indexing return new Promise((resolve, reject) => { const transaction = db.transaction(this.table, "readwrite"); const store = transaction.objectStore(this.table); - const request = store.put(record); + const request = store.put(recordToStore); request.onerror = () => { reject(request.error); }; request.onsuccess = () => { - this.events.emit("put", record); - resolve(record); + // For autoincrement keys, we need to update the record with the generated key + if ( + this.hasAutoGeneratedKey() && + this.autoGeneratedKeyName && + this.autoGeneratedKeyStrategy === "autoincrement" + ) { + const keyName = String(this.autoGeneratedKeyName); + if (recordToStore[keyName as keyof Entity] === undefined) { + // Get the generated key from the request result + recordToStore = { ...recordToStore, [keyName]: request.result } as Entity; + } + } + this.events.emit("put", recordToStore); + resolve(recordToStore); }; transaction.oncomplete = () => { // Notify hybrid manager of local change @@ -188,49 +280,13 @@ export class IndexedDbTabularStorage< /** * Stores multiple rows in the repository in a bulk operation. - * @param records - Array of entities to store. + * @param records - Array of entities to store (may be missing auto-generated keys). * @returns Array of stored entities * @emits put - Emitted for each record successfully stored */ - async putBulk(records: Entity[]): Promise { - const db = await this.getDb(); - return new Promise((resolve, reject) => { - const transaction = db.transaction(this.table, "readwrite"); - const store = transaction.objectStore(this.table); - - let completed = 0; - let hasError = false; - - transaction.onerror = () => { - if (!hasError) { - hasError = true; - reject(transaction.error); - } - }; - - transaction.oncomplete = () => { - if (!hasError) { - // Notify hybrid manager of local change - this.hybridManager?.notifyLocalChange(); - resolve(records); - } - }; - - // Add all records to the transaction - for (const record of records) { - const request = store.put(record); - request.onsuccess = () => { - this.events.emit("put", record); - completed++; - }; - request.onerror = () => { - if (!hasError) { - hasError = true; - reject(request.error); - } - }; - } - }); + async putBulk(records: InsertType[]): Promise { + // Use individual put calls to ensure auto-generated keys are handled correctly + return await Promise.all(records.map((record) => this.put(record))); } protected getPrimaryKeyAsOrderedArray(key: PrimaryKey) { diff --git a/packages/storage/src/tabular/PostgresTabularStorage.ts b/packages/storage/src/tabular/PostgresTabularStorage.ts index 7ecfa0a2..1befd896 100644 --- a/packages/storage/src/tabular/PostgresTabularStorage.ts +++ b/packages/storage/src/tabular/PostgresTabularStorage.ts @@ -14,9 +14,12 @@ import { } from "@workglow/util"; import type { Pool } from "pg"; import { BaseSqlTabularStorage } from "./BaseSqlTabularStorage"; +import { ClientProvidedKeysOption } from "./BaseTabularStorage"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, isSearchCondition, SearchOperator, SimplifyPrimaryKey, @@ -43,7 +46,12 @@ export class PostgresTabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseSqlTabularStorage { + Value = Omit, + InsertType extends InsertEntity> = InsertEntity< + Entity, + AutoGeneratedKeys + >, +> extends BaseSqlTabularStorage { protected db: Pool; /** @@ -55,15 +63,17 @@ export class PostgresTabularStorage< * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( db: Pool, table: string = "tabular_store", schema: Schema, primaryKeyNames: PrimaryKeyNames, - indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [] + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { - super(table, schema, primaryKeyNames, indexes); + super(table, schema, primaryKeyNames, indexes, clientProvidedKeys); this.db = db; } @@ -263,11 +273,27 @@ export class PostgresTabularStorage< /** * Generates the SQL column definitions for primary key fields with constraints + * Handles auto-generated keys using SERIAL for integers and UUID DEFAULT for strings * @returns SQL string containing primary key column definitions */ protected constructPrimaryKeyColumns($delimiter: string = ""): string { const cols = Object.entries(this.primaryKeySchema.properties) .map(([key, typeDef]) => { + // Check if this is an auto-generated key + if (this.isAutoGeneratedKey(key)) { + if (this.autoGeneratedKeyStrategy === "autoincrement") { + // Use SERIAL or BIGSERIAL for auto-increment + const sqlType = this.mapTypeToSQL(typeDef); + const isSmallInt = sqlType.includes("SMALLINT"); + const isBigInt = sqlType.includes("BIGINT"); + const serialType = isBigInt ? "BIGSERIAL" : isSmallInt ? "SMALLSERIAL" : "SERIAL"; + return `${$delimiter}${key}${$delimiter} ${serialType}`; + } else if (this.autoGeneratedKeyStrategy === "uuid") { + // Use UUID with DEFAULT gen_random_uuid() + return `${$delimiter}${key}${$delimiter} UUID DEFAULT gen_random_uuid()`; + } + } + const sqlType = this.mapTypeToSQL(typeDef); let constraints = "NOT NULL"; @@ -323,7 +349,7 @@ export class PostgresTabularStorage< if (value && ArrayBuffer.isView(value) && !(value instanceof DataView)) { // It's a TypedArray const array = Array.from(value as unknown as TypedArray); - return `[${array.join(",")}]` as any; + return `[${array.join(",")}]`; } // If it's already a string (serialized), return as-is if (typeof value === "string") { @@ -468,39 +494,93 @@ export class PostgresTabularStorage< * Stores or updates a row in the database. * Uses UPSERT (INSERT ... ON CONFLICT DO UPDATE) for atomic operations. * - * @param entity - The entity to store + * @param entity - The entity to store (may be missing auto-generated keys) * @returns The entity with any server-generated fields updated * @emits "put" event with the updated entity when successful */ - async put(entity: Entity): Promise { + async put(entity: InsertType): Promise { const db = this.db; - const { key, value } = this.separateKeyValueFromCombined(entity); - const sql = ` - INSERT INTO "${this.table}" ( - ${this.primaryKeyColumnList('"')} ${this.valueColumnList() ? ", " + this.valueColumnList('"') : ""} - ) - VALUES ( - ${[...this.primaryKeyColumns(), ...this.valueColumns()] - .map((_, i) => `$${i + 1}`) - .join(", ")} - ) - ${ - !this.valueColumnList() - ? "" - : ` + + // Determine which columns to include in INSERT + const columnsToInsert: string[] = []; + const paramsToInsert: ValueOptionType[] = []; + let paramIndex = 1; + + // Handle primary key columns + const pkColumns = this.primaryKeyColumns(); + for (const col of pkColumns) { + const colStr = String(col); + + // Check if this is an auto-generated key + if (this.isAutoGeneratedKey(colStr)) { + const clientProvidedValue = (entity as any)[col]; + const hasClientValue = clientProvidedValue !== undefined && clientProvidedValue !== null; + + let shouldUseClientValue = false; + if (this.clientProvidedKeys === "never") { + // Never use client value, let database generate + shouldUseClientValue = false; + } else if (this.clientProvidedKeys === "always") { + if (!hasClientValue) { + throw new Error( + `Auto-generated key "${colStr}" is required when clientProvidedKeys is "always"` + ); + } + shouldUseClientValue = true; + } else { + // "if-missing" - use client value if provided + shouldUseClientValue = hasClientValue; + } + + if (shouldUseClientValue) { + columnsToInsert.push(colStr); + paramsToInsert.push(this.jsToSqlValue(colStr, clientProvidedValue)); + } + // Otherwise skip it - let database generate via SERIAL or DEFAULT + continue; + } + + // Regular primary key column + columnsToInsert.push(colStr); + const value = (entity as any)[col]; + paramsToInsert.push(this.jsToSqlValue(colStr, value)); + } + + // Handle value columns + const valueColumns = this.valueColumns(); + for (const col of valueColumns) { + const colStr = String(col); + columnsToInsert.push(colStr); + const value = (entity as any)[col]; + paramsToInsert.push(this.jsToSqlValue(colStr, value)); + } + + const columnList = columnsToInsert.map((c) => `"${c}"`).join(", "); + const placeholders = columnsToInsert.map((_, i) => `$${i + 1}`).join(", "); + + // Build ON CONFLICT clause if there are value columns + const conflictClause = + valueColumns.length > 0 + ? ` ON CONFLICT (${this.primaryKeyColumnList('"')}) DO UPDATE SET - ${(this.valueColumns() as string[]) - .map((col, i) => `"${col}" = $${i + this.primaryKeyColumns().length + 1}`) + ${(valueColumns as string[]) + .map((col) => { + const colIdx = columnsToInsert.indexOf(String(col)); + return `"${col}" = $${colIdx + 1}`; + }) .join(", ")} ` - } + : ""; + + const sql = ` + INSERT INTO "${this.table}" (${columnList}) + VALUES (${placeholders}) + ${conflictClause} RETURNING * `; - const primaryKeyParams = this.getPrimaryKeyAsOrderedArray(key); - const valueParams = this.getValueAsOrderedArray(value); - const params = [...primaryKeyParams, ...valueParams]; + const params = paramsToInsert; const result = await db.query(sql, params); const updatedEntity = result.rows[0] as Entity; @@ -516,13 +596,19 @@ export class PostgresTabularStorage< /** * Stores multiple rows in the database in a bulk operation. - * Uses batch INSERT with ON CONFLICT for better performance. + * Uses individual put calls to ensure auto-generated keys are handled correctly. * - * @param entities - Array of entities to store + * @param entities - Array of entities to store (may be missing auto-generated keys) * @returns Array of entities with any server-generated fields updated * @emits "put" event for each entity stored */ - async putBulk(entities: Entity[]): Promise { + async putBulk(entities: InsertType[]): Promise { + if (entities.length === 0) return []; + + // Use individual put calls to ensure auto-generated keys are handled correctly + return await Promise.all(entities.map((entity) => this.put(entity))); + + /* Original bulk implementation - keeping for reference but using simpler approach above if (entities.length === 0) return []; const db = this.db; @@ -592,6 +678,7 @@ export class PostgresTabularStorage< } return updatedEntities; + */ } /** diff --git a/packages/storage/src/tabular/README.md b/packages/storage/src/tabular/README.md index 75fdb385..8dd7b898 100644 --- a/packages/storage/src/tabular/README.md +++ b/packages/storage/src/tabular/README.md @@ -157,6 +157,178 @@ await repo.put({ }); ``` +## Auto-Generated Primary Keys + +TabularStorage supports automatic generation of primary keys, allowing the storage backend to generate IDs when entities are inserted without them. This is useful for: + +- Security: Preventing clients from choosing arbitrary IDs +- Simplicity: No need to generate IDs client-side +- Database features: Leveraging native auto-increment and UUID generation + +### Schema Configuration + +Mark a primary key column as auto-generated using the `x-auto-generated: true` annotation: + +```typescript +const UserSchema = { + type: "object", + properties: { + id: { type: "integer", "x-auto-generated": true }, // Auto-increment + name: { type: "string" }, + email: { type: "string" }, + }, + required: ["id", "name", "email"], + additionalProperties: false, +} as const satisfies DataPortSchemaObject; + +const DocumentSchema = { + type: "object", + properties: { + id: { type: "string", "x-auto-generated": true }, // UUID + title: { type: "string" }, + content: { type: "string" }, + }, + required: ["id", "title", "content"], + additionalProperties: false, +} as const satisfies DataPortSchemaObject; +``` + +**Generation Strategy (inferred from column type):** +- `type: "integer"` → Auto-increment (SERIAL, INTEGER PRIMARY KEY, counter) +- `type: "string"` → UUID via `uuid4()` from `@workglow/util` + +**Constraints:** +- Only the **first column** in a compound primary key can be auto-generated +- Only **one column** can be auto-generated per table + +### Basic Usage + +```typescript +import { InMemoryTabularStorage } from "@workglow/storage/tabular"; + +const userStorage = new InMemoryTabularStorage(UserSchema, ["id"] as const); +await userStorage.setupDatabase(); + +// Insert without providing ID - it will be auto-generated +const user = await userStorage.put({ + name: "Alice", + email: "alice@example.com" +}); +console.log(user.id); // 1 (auto-generated) + +// TypeScript enforces: id is optional on insert, required on returned entity +``` + +### Client-Provided Keys Configuration + +Control whether clients can provide values for auto-generated keys: + +```typescript +const storage = new PostgresTabularStorage( + db, + "users", + UserSchema, + ["id"] as const, + [], // indexes + { clientProvidedKeys: "if-missing" } // configuration +); +``` + +**Options:** + +| Setting | Behavior | Use Case | +|---------|----------|----------| +| `"if-missing"` (default) | Use client value if provided, generate otherwise | Flexible - supports both auto-generation and client-specified IDs | +| `"never"` | Always generate, ignore client values | Maximum security - never trust client IDs | +| `"always"` | Require client to provide value | Testing/migration - enforce client-side ID generation | + +**Examples:** + +```typescript +// Default: "if-missing" - flexible +const flexibleStorage = new InMemoryTabularStorage( + UserSchema, + ["id"] as const +); + +// Without ID - auto-generated +await flexibleStorage.put({ name: "Bob", email: "bob@example.com" }); + +// With ID - uses client value +await flexibleStorage.put({ id: 999, name: "Charlie", email: "charlie@example.com" }); + +// Secure mode: "never" - always generate +const secureStorage = new PostgresTabularStorage( + db, + "users", + UserSchema, + ["id"] as const, + [], + { clientProvidedKeys: "never" } +); + +// Even if client provides id, it will be ignored and regenerated +const result = await secureStorage.put({ + id: 999, // Ignored! + name: "Diana", + email: "diana@example.com" +}); +// result.id will be database-generated, NOT 999 + +// Testing mode: "always" - require client ID +const testStorage = new InMemoryTabularStorage( + UserSchema, + ["id"] as const, + [], + { clientProvidedKeys: "always" } +); + +// Must provide ID or throws error +await testStorage.put({ + id: 1, + name: "Eve", + email: "eve@example.com" +}); // OK + +await testStorage.put({ + name: "Frank", + email: "frank@example.com" +}); // Throws Error! +``` + +### Backend-Specific Behavior + +Each storage backend implements auto-generation differently: + +| Backend | Integer (autoincrement) | String (UUID) | +|---------|------------------------|---------------| +| **InMemoryTabularStorage** | Internal counter (1, 2, 3...) | `uuid4()` from `@workglow/util` | +| **SqliteTabularStorage** | `INTEGER PRIMARY KEY AUTOINCREMENT` | `uuid4()` client-side | +| **PostgresTabularStorage** | `SERIAL`/`BIGSERIAL` | `gen_random_uuid()` database-side | +| **SupabaseTabularStorage** | `SERIAL` | `gen_random_uuid()` database-side | +| **IndexedDbTabularStorage** | `autoIncrement: true` | `uuid4()` client-side | +| **FsFolderTabularStorage** | Internal counter | `uuid4()` from `@workglow/util` | + +### Constraints + +1. **Only first column**: Only the first primary key column can be auto-generated +2. **Single auto-gen key**: Only one column per table can be auto-generated +3. **Type inference**: Generation strategy is inferred from column type (integer → autoincrement, string → UUID) + +### Type Safety + +TypeScript enforces correct usage through the type system: + +```typescript +// Auto-generated key is OPTIONAL on insert +const entity = { name: "Alice", email: "alice@example.com" }; +await storage.put(entity); // ✅ OK - id can be omitted + +// Returned entity has ALL fields REQUIRED +const result = await storage.put(entity); +const id: number = result.id; // ✅ OK - id is guaranteed to exist +``` + ## Implementations ### InMemoryTabularStorage diff --git a/packages/storage/src/tabular/SharedInMemoryTabularStorage.ts b/packages/storage/src/tabular/SharedInMemoryTabularStorage.ts index 84217d7d..83e760c1 100644 --- a/packages/storage/src/tabular/SharedInMemoryTabularStorage.ts +++ b/packages/storage/src/tabular/SharedInMemoryTabularStorage.ts @@ -10,10 +10,12 @@ import { FromSchema, TypedArraySchemaOptions, } from "@workglow/util"; -import { BaseTabularStorage } from "./BaseTabularStorage"; +import { BaseTabularStorage, ClientProvidedKeysOption } from "./BaseTabularStorage"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, SimplifyPrimaryKey, TabularSubscribeOptions, } from "./ITabularStorage"; @@ -49,7 +51,12 @@ export class SharedInMemoryTabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseTabularStorage { + Value = Omit, + InsertType extends InsertEntity> = InsertEntity< + Entity, + AutoGeneratedKeys + >, +> extends BaseTabularStorage { private channel: BroadcastChannel | null = null; private channelName: string; private inMemoryRepo: InMemoryTabularStorage; @@ -63,19 +70,22 @@ export class SharedInMemoryTabularStorage< * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( channelName: string = "tabular_store", schema: Schema, primaryKeyNames: PrimaryKeyNames, - indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [] + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { - super(schema, primaryKeyNames, indexes); + super(schema, primaryKeyNames, indexes, clientProvidedKeys); this.channelName = channelName; this.inMemoryRepo = new InMemoryTabularStorage( schema, primaryKeyNames, - indexes + indexes, + clientProvidedKeys ); // Forward events from the in-memory repository @@ -243,7 +253,7 @@ export class SharedInMemoryTabularStorage< * @returns The stored entity * @emits 'put' event with the stored entity when successful */ - async put(value: Entity): Promise { + async put(value: InsertType): Promise { const result = await this.inMemoryRepo.put(value); this.broadcast({ type: "PUT", entity: value }); return result; @@ -255,7 +265,7 @@ export class SharedInMemoryTabularStorage< * @returns Array of stored entities * @emits 'put' event for each value stored */ - async putBulk(values: Entity[]): Promise { + async putBulk(values: InsertType[]): Promise { const result = await this.inMemoryRepo.putBulk(values); this.broadcast({ type: "PUT_BULK", entities: values }); return result; diff --git a/packages/storage/src/tabular/SqliteTabularStorage.ts b/packages/storage/src/tabular/SqliteTabularStorage.ts index 4b53d0b6..c6d9146c 100644 --- a/packages/storage/src/tabular/SqliteTabularStorage.ts +++ b/packages/storage/src/tabular/SqliteTabularStorage.ts @@ -11,11 +11,15 @@ import { FromSchema, JsonSchema, TypedArraySchemaOptions, + uuid4, } from "@workglow/util"; import { BaseSqlTabularStorage } from "./BaseSqlTabularStorage"; +import { ClientProvidedKeysOption, KeyGenerationStrategy } from "./BaseTabularStorage"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, isSearchCondition, SearchOperator, SimplifyPrimaryKey, @@ -47,7 +51,12 @@ export class SqliteTabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseSqlTabularStorage { + Value = Omit, + InsertType extends InsertEntity> = InsertEntity< + Entity, + AutoGeneratedKeys + >, +> extends BaseSqlTabularStorage { /** The SQLite database instance */ private db: Sqlite.Database; @@ -59,15 +68,17 @@ export class SqliteTabularStorage< * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( dbOrPath: string | Sqlite.Database, table: string = "tabular_store", schema: Schema, primaryKeyNames: PrimaryKeyNames, - indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [] + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { - super(table, schema, primaryKeyNames, indexes); + super(table, schema, primaryKeyNames, indexes, clientProvidedKeys); if (typeof dbOrPath === "string") { this.db = new Database(dbOrPath); } else { @@ -75,17 +86,45 @@ export class SqliteTabularStorage< } } + /** + * Override to handle SQLite's INTEGER PRIMARY KEY for auto-increment + */ + protected constructPrimaryKeyColumns($delimiter: string = ""): string { + const cols = Object.entries(this.primaryKeySchema.properties) + .map(([key, typeDef]) => { + // Check if this is an auto-generated key with autoincrement strategy + if (this.isAutoGeneratedKey(key) && this.autoGeneratedKeyStrategy === "autoincrement") { + // SQLite uses INTEGER PRIMARY KEY for auto-increment + return `${$delimiter}${key}${$delimiter} INTEGER PRIMARY KEY AUTOINCREMENT`; + } + const sqlType = this.mapTypeToSQL(typeDef); + return `${$delimiter}${key}${$delimiter} ${sqlType} NOT NULL`; + }) + .join(", "); + return cols; + } + /** * Creates the database table if it doesn't exist with the defined schema. * Must be called before using any other methods. */ public async setupDatabase(): Promise { - const sql = ` - CREATE TABLE IF NOT EXISTS \`${this.table}\` ( - ${this.constructPrimaryKeyColumns()} ${this.constructValueColumns()}, - PRIMARY KEY (${this.primaryKeyColumnList()}) - ) - `; + // For auto-generated INTEGER PRIMARY KEY, we don't use the PRIMARY KEY constraint separately + const hasAutoIncrementKey = + this.hasAutoGeneratedKey() && this.autoGeneratedKeyStrategy === "autoincrement"; + + const sql = hasAutoIncrementKey + ? ` + CREATE TABLE IF NOT EXISTS \`${this.table}\` ( + ${this.constructPrimaryKeyColumns()} ${this.constructValueColumns()} + ) + ` + : ` + CREATE TABLE IF NOT EXISTS \`${this.table}\` ( + ${this.constructPrimaryKeyColumns()} ${this.constructValueColumns()}, + PRIMARY KEY (${this.primaryKeyColumnList()}) + ) + `; this.db.exec(sql); // Get primary key columns to avoid creating redundant indexes @@ -317,30 +356,117 @@ export class SqliteTabularStorage< } } + /** + * Generates a key value for string UUID keys + * Integer keys are auto-generated by SQLite's INTEGER PRIMARY KEY + * @param columnName - Name of the column to generate a key for + * @param strategy - The generation strategy to use + * @returns The generated key value + */ + protected generateKeyValue(columnName: string, strategy: KeyGenerationStrategy): string | number { + if (strategy === "uuid") { + return uuid4(); + } + // autoincrement is handled by SQLite INTEGER PRIMARY KEY + throw new Error( + `SQLite autoincrement keys are generated by the database, not client-side. Column: ${columnName}` + ); + } + /** * Stores a key-value pair in the database - * @param entity - The entity to store + * @param entity - The entity to store (may be missing auto-generated keys) * @returns The entity with any server-generated fields updated * @emits 'put' event when successful */ - async put(entity: Entity): Promise { + async put(entity: InsertType): Promise { const db = this.db; - const { key, value } = this.separateKeyValueFromCombined(entity); + let entityToInsert = entity as unknown as Entity; + + // Handle auto-generated keys + if (this.hasAutoGeneratedKey() && this.autoGeneratedKeyName) { + const keyName = String(this.autoGeneratedKeyName); + const clientProvidedValue = (entity as any)[keyName]; + const hasClientValue = clientProvidedValue !== undefined && clientProvidedValue !== null; + + let shouldUseClientValue = false; + if (this.clientProvidedKeys === "never") { + // Always generate, ignore client value + shouldUseClientValue = false; + } else if (this.clientProvidedKeys === "always") { + // Always use client value, error if missing + if (!hasClientValue) { + throw new Error( + `Auto-generated key "${keyName}" is required when clientProvidedKeys is "always"` + ); + } + shouldUseClientValue = true; + } else { + // "if-missing" - use client value if provided + shouldUseClientValue = hasClientValue; + } + + // For UUID strategy, generate client-side if needed + if (this.autoGeneratedKeyStrategy === "uuid" && !shouldUseClientValue) { + const generatedValue = this.generateKeyValue(keyName, "uuid"); + entityToInsert = { ...entity, [keyName]: generatedValue } as Entity; + } else if (this.autoGeneratedKeyStrategy === "uuid" && shouldUseClientValue) { + // Client provided UUID, use it + entityToInsert = entity as unknown as Entity; + } + // For autoincrement strategy, we handle it differently below + } + + // Determine which columns to include in INSERT + let columnsToInsert: string[] = []; + let paramsToInsert: ValueOptionType[] = []; + + // Handle primary key columns + const pkColumns = this.primaryKeyColumns(); + for (const col of pkColumns) { + const colStr = String(col); + // Skip autoincrement keys that should be generated by database + if ( + this.isAutoGeneratedKey(colStr) && + this.autoGeneratedKeyStrategy === "autoincrement" && + this.clientProvidedKeys !== "always" + ) { + const clientProvidedValue = (entityToInsert as any)[colStr]; + const hasClientValue = + clientProvidedValue !== undefined && clientProvidedValue !== null; + if (this.clientProvidedKeys === "if-missing" && hasClientValue) { + // Client provided value for autoincrement key in "if-missing" mode + columnsToInsert.push(colStr); + paramsToInsert.push(this.jsToSqlValue(colStr, clientProvidedValue)); + } + // Otherwise skip it - let SQLite generate + continue; + } + columnsToInsert.push(colStr); + const value = (entityToInsert as any)[colStr]; + paramsToInsert.push(this.jsToSqlValue(colStr, value)); + } + + // Handle value columns + const valueColumns = this.valueColumns(); + for (const col of valueColumns) { + const colStr = String(col); + columnsToInsert.push(colStr); + const value = (entityToInsert as any)[colStr]; + paramsToInsert.push(this.jsToSqlValue(colStr, value)); + } + + const columnList = columnsToInsert.map((c) => `\`${c}\``).join(", "); + const placeholders = columnsToInsert.map(() => "?").join(", "); + const sql = ` - INSERT OR REPLACE INTO \`${ - this.table - }\` (${this.primaryKeyColumnList()} ${this.valueColumnList() ? ", " + this.valueColumnList() : ""}) - VALUES ( - ${this.primaryKeyColumns().map((i) => "?")} - ${this.valueColumns().length > 0 ? ", " + this.valueColumns().map((i) => "?") : ""} - ) + INSERT OR REPLACE INTO \`${this.table}\` (${columnList}) + VALUES (${placeholders}) RETURNING * `; const stmt = db.prepare(sql); - const primaryKeyParams = this.getPrimaryKeyAsOrderedArray(key); - const valueParams = this.getValueAsOrderedArray(value); - const params = [...primaryKeyParams, ...valueParams]; + const params = paramsToInsert; // CRITICAL: Ensure all params are SQLite-compatible before binding // SQLite only accepts: string, number, bigint, boolean, null, Uint8Array @@ -438,13 +564,18 @@ export class SqliteTabularStorage< /** * Stores multiple key-value pairs in the database in a bulk operation - * @param entities - Array of entities to store + * @param entities - Array of entities to store (may be missing auto-generated keys) * @returns Array of entities with any server-generated fields updated * @emits 'put' event for each entity stored */ - async putBulk(entities: Entity[]): Promise { + async putBulk(entities: InsertType[]): Promise { if (entities.length === 0) return []; + // Use individual put calls to ensure auto-generated keys are handled correctly + // Each put() call will handle auto-generated keys appropriately + return await Promise.all(entities.map((entity) => this.put(entity))); + + /* Original bulk implementation - keeping for reference but using simpler approach above const db = this.db; // For SQLite bulk inserts with RETURNING, we need to do them individually @@ -452,7 +583,7 @@ export class SqliteTabularStorage< const updatedEntities: Entity[] = []; // Use a transaction for better performance - const transaction = db.transaction((entitiesToInsert: Entity[]) => { + const transaction = db.transaction((entitiesToInsert: any[]) => { for (const entity of entitiesToInsert) { const { key, value } = this.separateKeyValueFromCombined(entity); const sql = ` @@ -516,6 +647,7 @@ export class SqliteTabularStorage< } return updatedEntities; + */ } /** diff --git a/packages/storage/src/tabular/SupabaseTabularStorage.ts b/packages/storage/src/tabular/SupabaseTabularStorage.ts index 95bb33d2..f679b095 100644 --- a/packages/storage/src/tabular/SupabaseTabularStorage.ts +++ b/packages/storage/src/tabular/SupabaseTabularStorage.ts @@ -13,9 +13,12 @@ import { TypedArraySchemaOptions, } from "@workglow/util"; import { BaseSqlTabularStorage } from "./BaseSqlTabularStorage"; +import { ClientProvidedKeysOption } from "./BaseTabularStorage"; import { AnyTabularStorage, + AutoGeneratedKeys, DeleteSearchCriteria, + InsertEntity, isSearchCondition, SearchOperator, SimplifyPrimaryKey, @@ -43,7 +46,12 @@ export class SupabaseTabularStorage< // computed types Entity = FromSchema, PrimaryKey = SimplifyPrimaryKey, -> extends BaseSqlTabularStorage { + Value = Omit, + InsertType extends InsertEntity> = InsertEntity< + Entity, + AutoGeneratedKeys + >, +> extends BaseSqlTabularStorage { private client: SupabaseClient; private realtimeChannel: RealtimeChannel | null = null; @@ -56,15 +64,17 @@ export class SupabaseTabularStorage< * @param primaryKeyNames - Array of property names that form the primary key * @param indexes - Array of columns or column arrays to make searchable. Each string or single column creates a single-column index, * while each array creates a compound index with columns in the specified order. + * @param clientProvidedKeys - How to handle client-provided values for auto-generated keys */ constructor( client: SupabaseClient, table: string = "tabular_store", schema: Schema, primaryKeyNames: PrimaryKeyNames, - indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [] + indexes: readonly (keyof Entity | readonly (keyof Entity)[])[] = [], + clientProvidedKeys: ClientProvidedKeysOption = "if-missing" ) { - super(table, schema, primaryKeyNames, indexes); + super(table, schema, primaryKeyNames, indexes, clientProvidedKeys); this.client = client; } @@ -250,11 +260,27 @@ export class SupabaseTabularStorage< /** * Generates the SQL column definitions for primary key fields with constraints + * Handles auto-generated keys using SERIAL for integers and UUID DEFAULT for strings * @returns SQL string containing primary key column definitions */ protected constructPrimaryKeyColumns($delimiter: string = ""): string { const cols = Object.entries(this.primaryKeySchema.properties) .map(([key, typeDef]) => { + // Check if this is an auto-generated key + if (this.isAutoGeneratedKey(key)) { + if (this.autoGeneratedKeyStrategy === "autoincrement") { + // Use SERIAL or BIGSERIAL for auto-increment + const sqlType = this.mapTypeToSQL(typeDef); + const isSmallInt = sqlType.includes("SMALLINT"); + const isBigInt = sqlType.includes("BIGINT"); + const serialType = isBigInt ? "BIGSERIAL" : isSmallInt ? "SMALLSERIAL" : "SERIAL"; + return `${$delimiter}${key}${$delimiter} ${serialType}`; + } else if (this.autoGeneratedKeyStrategy === "uuid") { + // Use UUID with DEFAULT gen_random_uuid() + return `${$delimiter}${key}${$delimiter} UUID DEFAULT gen_random_uuid()`; + } + } + const sqlType = this.mapTypeToSQL(typeDef); let constraints = "NOT NULL"; @@ -355,13 +381,43 @@ export class SupabaseTabularStorage< * Stores or updates a row in the database. * Uses UPSERT (INSERT ... ON CONFLICT DO UPDATE) for atomic operations. * - * @param entity - The entity to store + * @param entity - The entity to store (may be missing auto-generated keys) * @returns The entity with any server-generated fields updated * @emits "put" event with the updated entity when successful */ - async put(entity: Entity): Promise { + async put(entity: InsertType): Promise { + // Handle auto-generated keys + let entityToInsert = { ...entity }; + + if (this.hasAutoGeneratedKey() && this.autoGeneratedKeyName) { + const keyName = String(this.autoGeneratedKeyName); + const clientProvidedValue = (entity as any)[keyName]; + const hasClientValue = clientProvidedValue !== undefined && clientProvidedValue !== null; + + let shouldOmitKey = false; + if (this.clientProvidedKeys === "never") { + // Never use client value, let database generate + shouldOmitKey = true; + } else if (this.clientProvidedKeys === "always") { + if (!hasClientValue) { + throw new Error( + `Auto-generated key "${keyName}" is required when clientProvidedKeys is "always"` + ); + } + shouldOmitKey = false; + } else { + // "if-missing" - omit key if client didn't provide it + shouldOmitKey = !hasClientValue; + } + + if (shouldOmitKey) { + // Omit the auto-generated key so Supabase generates it + delete (entityToInsert as any)[keyName]; + } + } + // Normalize optional fields: convert undefined to null for optional fields - const normalizedEntity = { ...entity } as any; + const normalizedEntity = { ...entityToInsert } as any; const requiredSet = new Set(this.valueSchema.required ?? []); for (const key in this.valueSchema.properties) { if (!(key in normalizedEntity) || normalizedEntity[key] === undefined) { @@ -370,7 +426,7 @@ export class SupabaseTabularStorage< } } } - const { data, error } = await this.client + const { data, error} = await this.client .from(this.table) .upsert(normalizedEntity, { onConflict: this.primaryKeyColumnList() }) .select() @@ -391,45 +447,17 @@ export class SupabaseTabularStorage< /** * Stores multiple rows in the database in a bulk operation. - * Uses batch INSERT with ON CONFLICT for better performance. + * Uses individual put calls to ensure auto-generated keys are handled correctly. * - * @param entities - Array of entities to store + * @param entities - Array of entities to store (may be missing auto-generated keys) * @returns Array of entities with any server-generated fields updated * @emits "put" event for each entity stored */ - async putBulk(entities: Entity[]): Promise { + async putBulk(entities: InsertType[]): Promise { if (entities.length === 0) return []; - // Normalize optional fields: convert undefined to null for optional fields - const requiredSet = new Set(this.valueSchema.required ?? []); - const normalizedEntities = entities.map((entity) => { - const normalized = { ...entity } as any; - for (const key in this.valueSchema.properties) { - if (!(key in normalized) || normalized[key] === undefined) { - if (!requiredSet.has(key)) { - normalized[key] = null; - } - } - } - return normalized; - }); - const { data, error } = await this.client - .from(this.table) - .upsert(normalizedEntities, { onConflict: this.primaryKeyColumnList() }) - .select(); - - if (error) throw error; - const updatedEntities = data as Entity[]; - - // Convert all columns from SQL to JS values and emit events - for (const entity of updatedEntities) { - for (const key in this.schema.properties) { - // @ts-ignore - entity[key] = this.sqlToJsValue(key, entity[key]); - } - this.events.emit("put", entity); - } - return updatedEntities; + // Use individual put calls to ensure auto-generated keys are handled correctly + return await Promise.all(entities.map((entity) => this.put(entity))); } /** diff --git a/packages/storage/src/util/IndexedDbTable.ts b/packages/storage/src/util/IndexedDbTable.ts index 8014277d..81521a95 100644 --- a/packages/storage/src/util/IndexedDbTable.ts +++ b/packages/storage/src/util/IndexedDbTable.ts @@ -335,7 +335,8 @@ async function performDestructiveMigration( tableName: string, primaryKey: string | string[], expectedIndexes: ExpectedIndexDefinition[], - options: MigrationOptions = {} + options: MigrationOptions = {}, + autoIncrement: boolean = false ): Promise { if (!options.allowDestructiveMigration) { throw new Error( @@ -411,7 +412,7 @@ async function performDestructiveMigration( } // Create new object store with new schema - const store = db.createObjectStore(tableName, { keyPath: primaryKey }); + const store = db.createObjectStore(tableName, { keyPath: primaryKey, autoIncrement }); // Create indexes for (const idx of expectedIndexes) { @@ -444,7 +445,8 @@ async function createNewDatabase( tableName: string, primaryKey: string | string[], expectedIndexes: ExpectedIndexDefinition[], - options: MigrationOptions = {} + options: MigrationOptions = {}, + autoIncrement: boolean = false ): Promise { options.onMigrationProgress?.(`Creating new database: ${tableName}`, 0); @@ -468,7 +470,7 @@ async function createNewDatabase( } // Create main object store - const store = db.createObjectStore(tableName, { keyPath: primaryKey }); + const store = db.createObjectStore(tableName, { keyPath: primaryKey, autoIncrement }); // Create indexes for (const idx of expectedIndexes) { @@ -500,7 +502,8 @@ export async function ensureIndexedDbTable( tableName: string, primaryKey: string | string[], expectedIndexes: ExpectedIndexDefinition[] = [], - options: MigrationOptions = {} + options: MigrationOptions = {}, + autoIncrement: boolean = false ): Promise { try { // Try to open existing database at current version (or create if doesn't exist) @@ -523,7 +526,7 @@ export async function ensureIndexedDbTable( `Database ${tableName} does not exist or has version conflict, creating...`, 0 ); - return await createNewDatabase(tableName, primaryKey, expectedIndexes, options); + return await createNewDatabase(tableName, primaryKey, expectedIndexes, options, autoIncrement); } // If database was just created, we need to create the stores @@ -549,7 +552,7 @@ export async function ensureIndexedDbTable( } // Create main object store - const store = db.createObjectStore(tableName, { keyPath: primaryKey }); + const store = db.createObjectStore(tableName, { keyPath: primaryKey, autoIncrement }); // Create indexes for (const idx of expectedIndexes) { @@ -596,7 +599,7 @@ export async function ensureIndexedDbTable( // Object store doesn't exist, create it options.onMigrationProgress?.(`Object store ${tableName} does not exist, creating...`, 0); db.close(); - return await createNewDatabase(tableName, primaryKey, expectedIndexes, options); + return await createNewDatabase(tableName, primaryKey, expectedIndexes, options, autoIncrement); } // Compare schemas to determine what migration is needed @@ -638,7 +641,7 @@ export async function ensureIndexedDbTable( `Schema change requires object store recreation for ${tableName}`, 0 ); - db = await performDestructiveMigration(db, tableName, primaryKey, expectedIndexes, options); + db = await performDestructiveMigration(db, tableName, primaryKey, expectedIndexes, options, autoIncrement); } else { options.onMigrationProgress?.(`Performing incremental migration for ${tableName}`, 0); db = await performIncrementalMigration(db, tableName, diff, options); diff --git a/packages/storage/src/vector/README.md b/packages/storage/src/vector/README.md index c452c372..0a6a129d 100644 --- a/packages/storage/src/vector/README.md +++ b/packages/storage/src/vector/README.md @@ -344,13 +344,13 @@ Quantized vectors reduce storage and can improve performance: - **Cons:** Requires PostgreSQL server and pgvector extension - **Setup:** `CREATE EXTENSION vector;` -## Integration with DocumentRepository +## Integration with DocumentDataset -The chunk vector repository works alongside `DocumentRepository` for hierarchical document storage: +The chunk vector repository works alongside `DocumentDataset` for hierarchical document storage: ```typescript import { - DocumentRepository, + DocumentDataset, InMemoryChunkVectorStorage, InMemoryTabularStorage, } from "@workglow/storage"; @@ -363,14 +363,14 @@ await tabularStorage.setupDatabase(); const vectorStorage = new InMemoryChunkVectorStorage(384); await vectorStorage.setupDatabase(); -// Create document repository with both storages -const docRepo = new DocumentRepository(tabularStorage, vectorStorage); +// Create document dataset with both storages +const docDataset = new DocumentDataset(tabularStorage, vectorStorage); // Store document structure in tabular, chunks in vector -await docRepo.upsert(document); +await docDataset.upsert(document); // Search chunks by vector similarity -const results = await docRepo.search(queryVector, { topK: 5 }); +const results = await docDataset.search(queryVector, { topK: 5 }); ``` ### Chunk Metadata for Hierarchical Documents diff --git a/packages/test/src/test/rag/Document.test.ts b/packages/test/src/test/rag/Document.test.ts index cc6e841a..05fbdac5 100644 --- a/packages/test/src/test/rag/Document.test.ts +++ b/packages/test/src/test/rag/Document.test.ts @@ -29,7 +29,7 @@ describe("Document", () => { ]; test("setChunks and getChunks", () => { - const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + const doc = new Document(createTestDocumentNode(), { title: "Test" }, [], "doc1"); doc.setChunks(createTestChunks()); @@ -40,7 +40,7 @@ describe("Document", () => { }); test("findChunksByNodeId", () => { - const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + const doc = new Document(createTestDocumentNode(), { title: "Test" }, [], "doc1"); doc.setChunks(createTestChunks()); diff --git a/packages/test/src/test/rag/DocumentRepository.test.ts b/packages/test/src/test/rag/DocumentRepository.test.ts index aa61bfca..aa6c37a2 100644 --- a/packages/test/src/test/rag/DocumentRepository.test.ts +++ b/packages/test/src/test/rag/DocumentRepository.test.ts @@ -49,13 +49,14 @@ describe("DocumentDataset", () => { const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test Document" }); + const doc = new Document(root, { title: "Test Document" }); - await dataset.upsert(doc); - const retrieved = await dataset.get(doc_id); + const inserted = await dataset.upsert(doc); + const retrieved = await dataset.get(inserted.doc_id!); expect(retrieved).toBeDefined(); - expect(retrieved?.doc_id).toBe(doc_id); + expect(retrieved?.doc_id).toBeDefined(); + expect(retrieved?.doc_id).toBe(inserted.doc_id); expect(retrieved?.metadata.title).toBe("Test Document"); }); @@ -64,12 +65,12 @@ describe("DocumentDataset", () => { const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); - await dataset.upsert(doc); + const doc = new Document(root, { title: "Test" }); + const inserted = await dataset.upsert(doc); // Get a child node const firstChild = root.children[0]; - const retrieved = await dataset.getNode(doc_id, firstChild.nodeId); + const retrieved = await dataset.getNode(inserted.doc_id!, firstChild.nodeId); expect(retrieved).toBeDefined(); expect(retrieved?.nodeId).toBe(firstChild.nodeId); @@ -85,8 +86,8 @@ Paragraph.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); - await dataset.upsert(doc); + const doc = new Document(root, { title: "Test" }); + const inserted = await dataset.upsert(doc); // Find a deeply nested node const section = root.children.find((c) => c.kind === NodeKind.SECTION); @@ -95,7 +96,7 @@ Paragraph.`; const subsection = (section as any).children.find((c: any) => c.kind === NodeKind.SECTION); expect(subsection).toBeDefined(); - const ancestors = await dataset.getAncestors(doc_id, subsection.nodeId); + const ancestors = await dataset.getAncestors(inserted.doc_id!, subsection.nodeId); // Should include root, section, and subsection expect(ancestors.length).toBeGreaterThanOrEqual(3); @@ -109,13 +110,13 @@ Paragraph.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }); // Add chunks const chunks = [ { chunkId: "chunk_1", - doc_id, + doc_id: doc_id, text: "Test chunk", nodePath: [root.nodeId], depth: 1, @@ -124,10 +125,10 @@ Paragraph.`; doc.setChunks(chunks); - await dataset.upsert(doc); + const inserted = await dataset.upsert(doc); // Retrieve chunks - const retrievedChunks = await dataset.getChunks(doc_id); + const retrievedChunks = await dataset.getChunks(inserted.doc_id!); expect(retrievedChunks).toBeDefined(); expect(retrievedChunks.length).toBe(1); }); @@ -142,16 +143,16 @@ Paragraph.`; const root1 = await StructuralParser.parseMarkdown(id1, markdown1, "Doc 1"); const root2 = await StructuralParser.parseMarkdown(id2, markdown2, "Doc 2"); - const doc1 = new Document(id1, root1, { title: "Doc 1" }); - const doc2 = new Document(id2, root2, { title: "Doc 2" }); + const doc1 = new Document(root1, { title: "Doc 1" }); + const doc2 = new Document(root2, { title: "Doc 2" }); - await dataset.upsert(doc1); - await dataset.upsert(doc2); + const inserted1 = await dataset.upsert(doc1); + const inserted2 = await dataset.upsert(doc2); const list = await dataset.list(); expect(list.length).toBe(2); - expect(list).toContain(id1); - expect(list).toContain(id2); + expect(list).toContain(inserted1.doc_id); + expect(list).toContain(inserted2.doc_id); }); it("should delete documents", async () => { @@ -159,14 +160,14 @@ Paragraph.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); - await dataset.upsert(doc); + const doc = new Document(root, { title: "Test" }); + const inserted = await dataset.upsert(doc); - expect(await dataset.get(doc_id)).toBeDefined(); + expect(await dataset.get(inserted.doc_id!)).toBeDefined(); - await dataset.delete(doc_id); + await dataset.delete(inserted.doc_id!); - expect(await dataset.get(doc_id)).toBeUndefined(); + expect(await dataset.get(inserted.doc_id!)).toBeUndefined(); }); it("should return undefined for non-existent document", async () => { @@ -184,10 +185,10 @@ Paragraph.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); - await dataset.upsert(doc); + const doc = new Document(root, { title: "Test" }); + const inserted = await dataset.upsert(doc); - const result = await dataset.getNode(doc_id, "non-existent-node-id"); + const result = await dataset.getNode(inserted.doc_id!, "non-existent-node-id"); expect(result).toBeUndefined(); }); @@ -201,10 +202,10 @@ Paragraph.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); - await dataset.upsert(doc); + const doc = new Document(root, { title: "Test" }); + const inserted = await dataset.upsert(doc); - const result = await dataset.getAncestors(doc_id, "non-existent-node-id"); + const result = await dataset.getAncestors(inserted.doc_id!, "non-existent-node-id"); expect(result).toEqual([]); }); @@ -238,13 +239,13 @@ Paragraph.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc1 = new Document(doc_id, root, { title: "Original Title" }); - await dataset.upsert(doc1); + const doc1 = new Document(root, { title: "Original Title" }); + const inserted1 = await dataset.upsert(doc1); - const doc2 = new Document(doc_id, root, { title: "Updated Title" }); + const doc2 = new Document(root, { title: "Updated Title" }, [], inserted1.doc_id); await dataset.upsert(doc2); - const retrieved = await dataset.get(doc_id); + const retrieved = await dataset.get(inserted1.doc_id!); expect(retrieved?.metadata.title).toBe("Updated Title"); const list = await dataset.list(); @@ -256,28 +257,29 @@ Paragraph.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }); + const inserted = await dataset.upsert(doc); const chunks = [ { chunkId: "chunk_1", - doc_id, + doc_id: inserted.doc_id!, text: "First chunk", nodePath: [root.nodeId, "child-1"], depth: 2, }, { chunkId: "chunk_2", - doc_id, + doc_id: inserted.doc_id!, text: "Second chunk", nodePath: [root.nodeId, "child-2"], depth: 2, }, ]; - doc.setChunks(chunks); - await dataset.upsert(doc); + inserted.setChunks(chunks); + await dataset.upsert(inserted); - const result = await dataset.findChunksByNodeId(doc_id, root.nodeId); + const result = await dataset.findChunksByNodeId(inserted.doc_id!, root.nodeId); expect(result.length).toBe(2); }); @@ -286,11 +288,11 @@ Paragraph.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }); doc.setChunks([]); - await dataset.upsert(doc); + const inserted = await dataset.upsert(doc); - const result = await dataset.findChunksByNodeId(doc_id, "non-matching-node"); + const result = await dataset.findChunksByNodeId(inserted.doc_id!, "non-matching-node"); expect(result).toEqual([]); }); @@ -372,7 +374,7 @@ describe("Document", () => { const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }, [], doc_id); const chunks = [ { @@ -395,7 +397,7 @@ describe("Document", () => { const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }, [], doc_id); const chunks = [ { @@ -408,11 +410,12 @@ describe("Document", () => { ]; doc.setChunks(chunks); - // Serialize + // Serialize (doc_id is NOT included in JSON) const json = doc.toJSON(); + expect(json).not.toHaveProperty("doc_id"); - // Deserialize - const restored = Document.fromJSON(JSON.stringify(json)); + // Deserialize (doc_id is passed separately) + const restored = Document.fromJSON(JSON.stringify(json), doc_id); expect(restored.doc_id).toBe(doc.doc_id); expect(restored.metadata.title).toBe(doc.metadata.title); @@ -424,7 +427,7 @@ describe("Document", () => { const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }, [], doc_id); const chunks = [ { @@ -463,7 +466,7 @@ describe("Document", () => { const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }, [], doc_id); const chunks = [ { @@ -485,7 +488,7 @@ describe("Document", () => { const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }, [], doc_id); doc.setChunks([]); const result = doc.findChunksByNodeId("any-node"); diff --git a/packages/test/src/test/rag/EndToEnd.test.ts b/packages/test/src/test/rag/EndToEnd.test.ts index bcbaf6b7..f0b1b2eb 100644 --- a/packages/test/src/test/rag/EndToEnd.test.ts +++ b/packages/test/src/test/rag/EndToEnd.test.ts @@ -11,7 +11,7 @@ import { DocumentChunkDataset, DocumentChunkPrimaryKey, DocumentChunkSchema, - DocumentRepository, + DocumentDataset, DocumentStorageKey, DocumentStorageSchema, StructuralParser, @@ -68,7 +68,7 @@ Finds patterns in data.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Test"); - const doc = new Document(doc_id, root, { title: "Test" }); + const doc = new Document(root, { title: "Test" }, [], doc_id); const chunks = [ { @@ -105,7 +105,7 @@ Finds patterns in data.`; await storage.setupDatabase(); const vectorDataset = new DocumentChunkDataset(storage); - const docRepo = new DocumentRepository(tabularStorage, vectorDataset as any); + const docRepo = new DocumentDataset(tabularStorage, storage); // Create document with enriched hierarchy const markdown = `# Guide @@ -121,7 +121,8 @@ Content about topic B.`; const doc_id = uuid4(); const root = await StructuralParser.parseMarkdown(doc_id, markdown, "Guide"); - const doc = new Document(doc_id, root, { title: "Guide" }); + const doc = new Document(root, { title: "Guide" }); + const inserted = await docRepo.upsert(doc); // Enrich (in real workflow this would use DocumentEnricherTask) // For test, manually add enrichment @@ -132,12 +133,12 @@ Content about topic B.`; }, }; - const enrichedDoc = new Document(doc_id, enrichedRoot as any, doc.metadata); + const enrichedDoc = new Document(enrichedRoot as any, doc.metadata, [], inserted.doc_id); await docRepo.upsert(enrichedDoc); // Generate chunks using workflow (without embedding to avoid model requirement) const chunkResult = await hierarchicalChunker({ - doc_id, + doc_id: inserted.doc_id!, documentTree: enrichedRoot, maxTokens: 256, overlap: 25, @@ -150,7 +151,7 @@ Content about topic B.`; await docRepo.upsert(enrichedDoc); // Verify chunks were stored - const retrieved = await docRepo.getChunks(doc_id); + const retrieved = await docRepo.getChunks(inserted.doc_id!); expect(retrieved).toBeDefined(); expect(retrieved.length).toBe(chunkResult.count); }); diff --git a/packages/test/src/test/rag/RagWorkflow.test.ts b/packages/test/src/test/rag/RagWorkflow.test.ts index 6390f81a..1e938d46 100644 --- a/packages/test/src/test/rag/RagWorkflow.test.ts +++ b/packages/test/src/test/rag/RagWorkflow.test.ts @@ -51,7 +51,7 @@ import { DocumentChunkDataset, DocumentChunkPrimaryKey, DocumentChunkSchema, - DocumentRepository, + DocumentDataset, DocumentStorageKey, DocumentStorageSchema, registerDocumentChunkDataset, @@ -73,7 +73,7 @@ describe("RAG Workflow End-to-End", () => { DocumentChunk >; let vectorDataset: DocumentChunkDataset; - let docRepo: DocumentRepository; + let docRepo: DocumentDataset; const vectorRepoName = "rag-test-vector-repo"; const embeddingModel = "onnx:Xenova/all-MiniLM-L6-v2:q8"; const summaryModel = "onnx:Falconsai/text_summarization:fp32"; @@ -105,7 +105,7 @@ describe("RAG Workflow End-to-End", () => { const tabularRepo = new InMemoryTabularStorage(DocumentStorageSchema, DocumentStorageKey); await tabularRepo.setupDatabase(); - docRepo = new DocumentRepository(tabularRepo, vectorDataset as any); + docRepo = new DocumentDataset(tabularRepo, storage); }); afterAll(async () => { @@ -149,7 +149,7 @@ describe("RAG Workflow End-to-End", () => { .textEmbedding({ model: embeddingModel, }) - .vectorStoreUpsert({ + .chunkVectorUpsert({ dataset: vectorRepoName, }); diff --git a/packages/test/src/test/storage-tabular/InMemoryTabularRepository.test.ts b/packages/test/src/test/storage-tabular/InMemoryTabularRepository.test.ts index 0f7f0de3..8a9a5f4d 100644 --- a/packages/test/src/test/storage-tabular/InMemoryTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/InMemoryTabularRepository.test.ts @@ -10,11 +10,16 @@ import { runGenericTabularRepositorySubscriptionTests } from "./genericTabularRe import { AllTypesPrimaryKeyNames, AllTypesSchema, + AutoIncrementPrimaryKeyNames, + AutoIncrementSchema, CompoundPrimaryKeyNames, CompoundSchema, + runAutoGeneratedKeyTests, runGenericTabularRepositoryTests, SearchPrimaryKeyNames, SearchSchema, + UuidPrimaryKeyNames, + UuidSchema, } from "./genericTabularRepositoryTests"; describe("InMemoryTabularStorage", () => { @@ -45,4 +50,17 @@ describe("InMemoryTabularStorage", () => { ), { usesPolling: false } ); + + runAutoGeneratedKeyTests( + async () => + new InMemoryTabularStorage( + AutoIncrementSchema, + AutoIncrementPrimaryKeyNames + ), + async () => + new InMemoryTabularStorage( + UuidSchema, + UuidPrimaryKeyNames + ) + ); }); diff --git a/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts b/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts index 63483bae..41d9da0a 100644 --- a/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/IndexedDbTabularRepository.test.ts @@ -14,11 +14,16 @@ import { runGenericTabularRepositorySubscriptionTests } from "./genericTabularRe import { AllTypesPrimaryKeyNames, AllTypesSchema, + AutoIncrementPrimaryKeyNames, + AutoIncrementSchema, CompoundPrimaryKeyNames, CompoundSchema, + runAutoGeneratedKeyTests, runGenericTabularRepositoryTests, SearchPrimaryKeyNames, SearchSchema, + UuidPrimaryKeyNames, + UuidSchema, } from "./genericTabularRepositoryTests"; describe("IndexedDbTabularStorage", () => { @@ -331,4 +336,19 @@ describe("IndexedDbTabularStorage", () => { }); }); }); + + runAutoGeneratedKeyTests( + async () => + new IndexedDbTabularStorage( + `${dbName}_autoinc_${uuid4().replace(/-/g, "_")}`, + AutoIncrementSchema, + AutoIncrementPrimaryKeyNames + ), + async () => + new IndexedDbTabularStorage( + `${dbName}_uuid_${uuid4().replace(/-/g, "_")}`, + UuidSchema, + UuidPrimaryKeyNames + ) + ); }); diff --git a/packages/test/src/test/storage-tabular/PostgresTabularRepository.test.ts b/packages/test/src/test/storage-tabular/PostgresTabularRepository.test.ts index 4db15b82..8d6fc70f 100644 --- a/packages/test/src/test/storage-tabular/PostgresTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/PostgresTabularRepository.test.ts @@ -12,11 +12,16 @@ import { describe } from "vitest"; import { AllTypesPrimaryKeyNames, AllTypesSchema, + AutoIncrementPrimaryKeyNames, + AutoIncrementSchema, CompoundPrimaryKeyNames, CompoundSchema, + runAutoGeneratedKeyTests, runGenericTabularRepositoryTests, SearchPrimaryKeyNames, SearchSchema, + UuidPrimaryKeyNames, + UuidSchema, } from "./genericTabularRepositoryTests"; const db = new PGlite() as unknown as Pool; @@ -52,4 +57,21 @@ describe("PostgresTabularStorage", () => { return repo; } ); + + runAutoGeneratedKeyTests( + async () => + new PostgresTabularStorage( + db, + `autoinc_test_${uuid4().replace(/-/g, "_")}`, + AutoIncrementSchema, + AutoIncrementPrimaryKeyNames + ), + async () => + new PostgresTabularStorage( + db, + `uuid_test_${uuid4().replace(/-/g, "_")}`, + UuidSchema, + UuidPrimaryKeyNames + ) + ); }); diff --git a/packages/test/src/test/storage-tabular/SqliteTabularRepository.test.ts b/packages/test/src/test/storage-tabular/SqliteTabularRepository.test.ts index 673cf40e..bd19b2f2 100644 --- a/packages/test/src/test/storage-tabular/SqliteTabularRepository.test.ts +++ b/packages/test/src/test/storage-tabular/SqliteTabularRepository.test.ts @@ -10,11 +10,16 @@ import { describe } from "vitest"; import { AllTypesPrimaryKeyNames, AllTypesSchema, + AutoIncrementPrimaryKeyNames, + AutoIncrementSchema, CompoundPrimaryKeyNames, CompoundSchema, + runAutoGeneratedKeyTests, runGenericTabularRepositoryTests, SearchPrimaryKeyNames, SearchSchema, + UuidPrimaryKeyNames, + UuidSchema, } from "./genericTabularRepositoryTests"; describe("SqliteTabularStorage", () => { @@ -48,4 +53,21 @@ describe("SqliteTabularStorage", () => { return repo; } ); + + runAutoGeneratedKeyTests( + async () => + new SqliteTabularStorage( + ":memory:", + `autoinc_test_${uuid4().replace(/-/g, "_")}`, + AutoIncrementSchema, + AutoIncrementPrimaryKeyNames + ), + async () => + new SqliteTabularStorage( + ":memory:", + `uuid_test_${uuid4().replace(/-/g, "_")}`, + UuidSchema, + UuidPrimaryKeyNames + ) + ); }); diff --git a/packages/test/src/test/storage-tabular/genericTabularRepositoryTests.ts b/packages/test/src/test/storage-tabular/genericTabularRepositoryTests.ts index 6c1a3f8d..cea71594 100644 --- a/packages/test/src/test/storage-tabular/genericTabularRepositoryTests.ts +++ b/packages/test/src/test/storage-tabular/genericTabularRepositoryTests.ts @@ -77,6 +77,30 @@ export const AllTypesSchema = { additionalProperties: false, } as const satisfies DataPortSchemaObject; +export const AutoIncrementPrimaryKeyNames = ["id"] as const; +export const AutoIncrementSchema = { + type: "object", + properties: { + id: { type: "integer", "x-auto-generated": true }, + name: { type: "string" }, + email: { type: "string" }, + }, + required: ["id", "name", "email"], + additionalProperties: false, +} as const satisfies DataPortSchemaObject; + +export const UuidPrimaryKeyNames = ["id"] as const; +export const UuidSchema = { + type: "object", + properties: { + id: { type: "string", "x-auto-generated": true }, + title: { type: "string" }, + content: { type: "string" }, + }, + required: ["id", "title", "content"], + additionalProperties: false, +} as const satisfies DataPortSchemaObject; + export function runGenericTabularRepositoryTests( createCompoundPkRepository: () => Promise< ITabularStorage @@ -1159,3 +1183,158 @@ export function runGenericTabularRepositoryTests( }); } } + +/** + * Tests for auto-generated keys functionality + */ +export function runAutoGeneratedKeyTests( + createAutoIncrementRepository: () => Promise< + ITabularStorage + >, + createUuidRepository: () => Promise< + ITabularStorage + > +) { + describe("Auto-Generated Keys", () => { + describe("AutoIncrement Strategy", () => { + let repository: ITabularStorage; + + beforeEach(async () => { + repository = await createAutoIncrementRepository(); + await repository.setupDatabase?.(); + }); + + afterEach(async () => { + await repository.deleteAll(); + repository.destroy(); + }); + + it("should auto-generate integer ID when not provided", async () => { + const entity = { name: "Test User", email: "test@example.com" }; + const result = await repository.put(entity as any); + + expect(result.id).toBeDefined(); + expect(typeof result.id).toBe("number"); + expect(result.name).toBe("Test User"); + expect(result.email).toBe("test@example.com"); + }); + + it("should auto-generate sequential IDs", async () => { + const entity1 = { name: "User 1", email: "user1@example.com" }; + const entity2 = { name: "User 2", email: "user2@example.com" }; + const entity3 = { name: "User 3", email: "user3@example.com" }; + + const result1 = await repository.put(entity1 as any); + const result2 = await repository.put(entity2 as any); + const result3 = await repository.put(entity3 as any); + + expect(result1.id).toBeDefined(); + expect(result2.id).toBeDefined(); + expect(result3.id).toBeDefined(); + + // IDs should be sequential (though we don't enforce specific values) + expect(result2.id).toBeGreaterThan(result1.id); + expect(result3.id).toBeGreaterThan(result2.id); + }); + + it("should handle putBulk with auto-generated IDs", async () => { + const entities = [ + { name: "Bulk 1", email: "bulk1@example.com" }, + { name: "Bulk 2", email: "bulk2@example.com" }, + { name: "Bulk 3", email: "bulk3@example.com" }, + ]; + + const results = await repository.putBulk(entities as any); + + expect(results).toHaveLength(3); + for (const result of results) { + expect(result.id).toBeDefined(); + expect(typeof result.id).toBe("number"); + } + }); + + it("should retrieve entity by auto-generated ID", async () => { + const entity = { name: "Retrievable", email: "retrieve@example.com" }; + const inserted = await repository.put(entity as any); + + const retrieved = await repository.get({ id: inserted.id }); + + expect(retrieved).toBeDefined(); + expect(retrieved!.id).toBe(inserted.id); + expect(retrieved!.name).toBe("Retrievable"); + expect(retrieved!.email).toBe("retrieve@example.com"); + }); + }); + + describe("UUID Strategy", () => { + let repository: ITabularStorage; + + beforeEach(async () => { + repository = await createUuidRepository(); + await repository.setupDatabase?.(); + }); + + afterEach(async () => { + await repository.deleteAll(); + repository.destroy(); + }); + + it("should auto-generate UUID when not provided", async () => { + const entity = { title: "Test Doc", content: "Test content" }; + const result = await repository.put(entity as any); + + expect(result.id).toBeDefined(); + expect(typeof result.id).toBe("string"); + expect(result.id.length).toBeGreaterThan(0); + // UUID v4 format check (loose) + expect(result.id).toMatch(/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i); + expect(result.title).toBe("Test Doc"); + expect(result.content).toBe("Test content"); + }); + + it("should generate unique UUIDs", async () => { + const entity1 = { title: "Doc 1", content: "Content 1" }; + const entity2 = { title: "Doc 2", content: "Content 2" }; + + const result1 = await repository.put(entity1 as any); + const result2 = await repository.put(entity2 as any); + + expect(result1.id).toBeDefined(); + expect(result2.id).toBeDefined(); + expect(result1.id).not.toBe(result2.id); + }); + + it("should handle putBulk with auto-generated UUIDs", async () => { + const entities = [ + { title: "Bulk Doc 1", content: "Bulk content 1" }, + { title: "Bulk Doc 2", content: "Bulk content 2" }, + { title: "Bulk Doc 3", content: "Bulk content 3" }, + ]; + + const results = await repository.putBulk(entities as any); + + expect(results).toHaveLength(3); + const ids = new Set(); + for (const result of results) { + expect(result.id).toBeDefined(); + expect(typeof result.id).toBe("string"); + ids.add(result.id); + } + // All IDs should be unique + expect(ids.size).toBe(3); + }); + + it("should retrieve entity by auto-generated UUID", async () => { + const entity = { title: "Retrievable", content: "Can be found" }; + const inserted = await repository.put(entity as any); + + const retrieved = await repository.get({ id: inserted.id }); + + expect(retrieved).toBeDefined(); + expect(retrieved!.id).toBe(inserted.id); + expect(retrieved!.title).toBe("Retrievable"); + expect(retrieved!.content).toBe("Can be found"); + }); + }); + }); +} diff --git a/packages/test/src/test/util/Document.test.ts b/packages/test/src/test/util/Document.test.ts index c81c3b15..5f3d5c4d 100644 --- a/packages/test/src/test/util/Document.test.ts +++ b/packages/test/src/test/util/Document.test.ts @@ -29,7 +29,7 @@ describe("Document", () => { ]; test("setChunks and getChunks", () => { - const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + const doc = new Document(createTestDocumentNode(), { title: "Test" }, [], "doc1"); doc.setChunks(createTestChunks()); @@ -40,7 +40,7 @@ describe("Document", () => { }); test("findChunksByNodeId", () => { - const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + const doc = new Document(createTestDocumentNode(), { title: "Test" }, [], "doc1"); doc.setChunks(createTestChunks()); diff --git a/packages/util/src/json-schema/JsonSchema.ts b/packages/util/src/json-schema/JsonSchema.ts index 7c3bb522..ffb9f42e 100644 --- a/packages/util/src/json-schema/JsonSchema.ts +++ b/packages/util/src/json-schema/JsonSchema.ts @@ -19,6 +19,7 @@ export type JsonSchemaCustomProps = { "x-ui-group-priority"?: number; "x-ui-group-open"?: boolean; "x-ui"?: unknown; + "x-auto-generated"?: boolean; // marks a primary key column as auto-generated by storage backend }; export type JsonSchema = From 72da8b1cff150fc90587c3bd8931a69e041625af Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Sun, 18 Jan 2026 07:03:14 +0000 Subject: [PATCH 14/14] [chore] Update GitHub Actions Workflow to Use Latest Bun Version --- .github/workflows/test.yml | 4 +++- packages/ai-provider/package.json | 4 ++-- packages/ai/package.json | 4 ++-- packages/ai/src/model/ModelRegistry.ts | 7 +------ packages/ai/src/task/DocumentEnricherTask.ts | 2 +- packages/dataset/package.json | 4 ++-- packages/debug/package.json | 4 ++-- packages/job-queue/package.json | 4 ++-- packages/sqlite/package.json | 8 ++++---- packages/storage/package.json | 4 ++-- packages/task-graph/package.json | 4 ++-- packages/task-graph/src/task/InputResolver.ts | 8 +++++--- packages/tasks/package.json | 4 ++-- packages/test/package.json | 4 ++-- packages/util/package.json | 4 ++-- 15 files changed, 34 insertions(+), 35 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c72cf215..0a1e0e8c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,6 +16,8 @@ jobs: steps: - uses: actions/checkout@v4 - uses: oven-sh/setup-bun@v2 + with: + bun-version: latest - run: bun i - - run: bun run build + - run: bun run rebuild - run: bun test diff --git a/packages/ai-provider/package.json b/packages/ai-provider/package.json index 3db05d11..e13702e7 100644 --- a/packages/ai-provider/package.json +++ b/packages/ai-provider/package.json @@ -17,8 +17,8 @@ }, "exports": { ".": { - "bun": "./src/index.ts", - "types": "./src/types.ts", + "bun": "./dist/index.js", + "types": "./dist/types.d.ts", "import": "./dist/index.js" } }, diff --git a/packages/ai/package.json b/packages/ai/package.json index 218a0f7f..66c27682 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" } }, diff --git a/packages/ai/src/model/ModelRegistry.ts b/packages/ai/src/model/ModelRegistry.ts index f570dcf8..1256ec49 100644 --- a/packages/ai/src/model/ModelRegistry.ts +++ b/packages/ai/src/model/ModelRegistry.ts @@ -52,16 +52,11 @@ async function resolveModelFromRegistry( id: string, format: string, registry: ServiceRegistry -): Promise { +): Promise { const modelRepo = registry.has(MODEL_REPOSITORY) ? registry.get(MODEL_REPOSITORY) : getGlobalModelRepository(); - if (Array.isArray(id)) { - const results = await Promise.all(id.map((i) => modelRepo.findByName(i))); - return results.filter((model): model is NonNullable => model !== undefined); - } - const model = await modelRepo.findByName(id); if (!model) { throw new Error(`Model "${id}" not found in repository`); diff --git a/packages/ai/src/task/DocumentEnricherTask.ts b/packages/ai/src/task/DocumentEnricherTask.ts index c1900b0b..4558f6cf 100644 --- a/packages/ai/src/task/DocumentEnricherTask.ts +++ b/packages/ai/src/task/DocumentEnricherTask.ts @@ -374,7 +374,7 @@ export class DocumentEnricherTask extends Task< } } - const text = node.text.trim(); + const text = node.text?.trim(); if (text && extract) { const nodeEntities = await extract(text); if (nodeEntities?.length) { diff --git a/packages/dataset/package.json b/packages/dataset/package.json index ce883eab..9bd8d372 100644 --- a/packages/dataset/package.json +++ b/packages/dataset/package.json @@ -39,8 +39,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" } }, diff --git a/packages/debug/package.json b/packages/debug/package.json index ad0577e6..f5f48dd4 100644 --- a/packages/debug/package.json +++ b/packages/debug/package.json @@ -19,8 +19,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/browser.ts", - "types": "./src/types.ts", + "bun": "./dist/browser.js", + "types": "./dist/types.d.ts", "import": "./dist/browser.js" } }, diff --git a/packages/job-queue/package.json b/packages/job-queue/package.json index 85a4fe08..be6b080e 100644 --- a/packages/job-queue/package.json +++ b/packages/job-queue/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" } }, diff --git a/packages/sqlite/package.json b/packages/sqlite/package.json index 2518526f..73f530a8 100644 --- a/packages/sqlite/package.json +++ b/packages/sqlite/package.json @@ -38,13 +38,13 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" }, "./bun": { - "types": "./src/bun.ts", - "import": "./src/bun.ts" + "types": "./dist/bun.d.ts", + "import": "./dist/bun.js" }, "./node": { "types": "./dist/node.d.ts", diff --git a/packages/storage/package.json b/packages/storage/package.json index d09c4bd4..237807c4 100644 --- a/packages/storage/package.json +++ b/packages/storage/package.json @@ -45,8 +45,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" } }, diff --git a/packages/task-graph/package.json b/packages/task-graph/package.json index 86955350..61e0f400 100644 --- a/packages/task-graph/package.json +++ b/packages/task-graph/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" } }, diff --git a/packages/task-graph/src/task/InputResolver.ts b/packages/task-graph/src/task/InputResolver.ts index 87439126..59dbfbcc 100644 --- a/packages/task-graph/src/task/InputResolver.ts +++ b/packages/task-graph/src/task/InputResolver.ts @@ -101,10 +101,12 @@ export async function resolveSchemaInputs>( if (typeof value === "string") { resolved[key] = await resolver(value, format, config.registry); } - // Handle arrays of strings - pass the entire array to the resolver - // (resolvers like resolveModelFromRegistry handle arrays even though typed as string) + // Handle arrays of strings - iterate and resolve each element else if (Array.isArray(value) && value.every((item) => typeof item === "string")) { - resolved[key] = await resolver(value as unknown as string, format, config.registry); + const results = await Promise.all( + (value as string[]).map((item) => resolver(item, format, config.registry)) + ); + resolved[key] = results.filter((result) => result !== undefined); } // Skip if not a string or array of strings (already resolved or direct instance) } diff --git a/packages/tasks/package.json b/packages/tasks/package.json index 5062a565..335e494a 100644 --- a/packages/tasks/package.json +++ b/packages/tasks/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" } }, diff --git a/packages/test/package.json b/packages/test/package.json index 985c87e1..8c4077e3 100644 --- a/packages/test/package.json +++ b/packages/test/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" } }, diff --git a/packages/util/package.json b/packages/util/package.json index 4cb1bb5f..bb841966 100644 --- a/packages/util/package.json +++ b/packages/util/package.json @@ -23,8 +23,8 @@ ".": { "react-native": "./dist/browser.js", "browser": "./dist/browser.js", - "bun": "./src/bun.ts", - "types": "./src/types.ts", + "bun": "./dist/bun.js", + "types": "./dist/types.d.ts", "import": "./dist/node.js" } },