diff --git a/packages/core/src/com/caller-context.ts b/packages/core/src/com/caller-context.ts new file mode 100644 index 000000000..757e860fd --- /dev/null +++ b/packages/core/src/com/caller-context.ts @@ -0,0 +1,20 @@ +export interface CallerContext { + getStore(): unknown; + run(store: unknown, callback: () => R): R; +} + +let activeCallerContext: CallerContext | undefined; +export function setActiveCallerContext(ctx: CallerContext | undefined): void { + activeCallerContext = ctx; +} + +export function getCurrentCaller(): unknown { + return activeCallerContext?.getStore(); +} + +export function runWithCaller(identity: unknown, fn: () => R): R { + if (identity !== undefined && activeCallerContext) { + return activeCallerContext.run(identity, fn); + } + return fn(); +} diff --git a/packages/core/src/com/communication.ts b/packages/core/src/com/communication.ts index 1d741a046..62a945f06 100644 --- a/packages/core/src/com/communication.ts +++ b/packages/core/src/com/communication.ts @@ -62,6 +62,7 @@ import { ReconnectFunction, RemoteValueListener, } from '../remote-value.js'; +import { getCurrentCaller, runWithCaller } from './caller-context.js'; export interface ConfigEnvironmentRecord extends EnvironmentRecord { registerMessageHandler?: boolean; @@ -373,6 +374,7 @@ export class Communication { data: { api, method, args: serializeApiCallArguments(args) }, callbackId, origin, + callerIdentity: getCurrentCaller(), }; this.callWithCallback(envId, message, callbackId, res, rej); } @@ -698,6 +700,7 @@ export class Communication { type: 'callback', forwardingChain: message.forwardingChain, error: new CircularForwardingError(message, this.rootEnvId, env.id), + callerIdentity: message.callerIdentity, }); return; } else if (this.DEBUG) { @@ -771,6 +774,7 @@ export class Communication { }, handlerId: listenerHandlerId, origin, + callerIdentity: getCurrentCaller(), }; // sometimes the callback will never happen since target environment is already dead this.sendTo(envId, message); @@ -802,6 +806,7 @@ export class Communication { handlerId: '', callbackId, origin, + callerIdentity: getCurrentCaller(), }; message.handlerId = this.createHandlerRecord(envId, api, method, fn, message); @@ -939,7 +944,9 @@ export class Communication { const dispatcher = this.eventDispatchers.get(namespacedHandlerId)?.dispatcher; if (dispatcher) { this.eventDispatchers.delete(namespacedHandlerId); - const data = await this.apiCall(message.origin, message.data.api, message.data.method, [dispatcher]); + const data = await runWithCaller(message.callerIdentity, () => + this.apiCall(message.origin, message.data.api, message.data.method, [dispatcher]), + ); if (message.callbackId) { this.sendTo(message.from, { to: message.from, @@ -957,7 +964,9 @@ export class Communication { try { const dispatcher = this.getDispatcher(message.from, message); - const data = await this.apiCall(message.origin, message.data.api, message.data.method, [dispatcher]); + const data = await runWithCaller(message.callerIdentity, () => + this.apiCall(message.origin, message.data.api, message.data.method, [dispatcher]), + ); if (message.callbackId) { this.sendTo(message.from, { @@ -984,7 +993,9 @@ export class Communication { private async handleCall(message: CallMessage): Promise { try { const args = deserializeApiCallArguments(message.data.args); - const data = await this.apiCall(message.origin, message.data.api, message.data.method, args); + const data = await runWithCaller(message.callerIdentity, () => + this.apiCall(message.origin, message.data.api, message.data.method, args), + ); if (message.callbackId) { this.sendTo(message.from, { diff --git a/packages/core/src/com/index.ts b/packages/core/src/com/index.ts index 387a4db9a..d9eba9b65 100644 --- a/packages/core/src/com/index.ts +++ b/packages/core/src/com/index.ts @@ -11,3 +11,4 @@ export * from './hosts/ws-client-host.js'; export * from './initializers/index.js'; export * from './hosts/index.js'; export * from './communication-errors.js'; +export * from './caller-context.js'; diff --git a/packages/core/src/com/message-types.ts b/packages/core/src/com/message-types.ts index 03e4c7c20..590e4614b 100644 --- a/packages/core/src/com/message-types.ts +++ b/packages/core/src/com/message-types.ts @@ -13,6 +13,7 @@ export interface BaseMessage { error?: Error; origin: string; forwardingChain?: string[]; + callerIdentity?: unknown; } export interface CallMessage extends BaseMessage { diff --git a/packages/core/src/communication.feature.ts b/packages/core/src/communication.feature.ts index f30e86c74..9f335a222 100644 --- a/packages/core/src/communication.feature.ts +++ b/packages/core/src/communication.feature.ts @@ -1,5 +1,6 @@ import { BaseHost } from './com/hosts/base-host.js'; import { Communication, type ConfigEnvironmentRecord, type CommunicationOptions } from './com/communication.js'; +import { setActiveCallerContext, type CallerContext } from './com/caller-context.js'; import { LoggerService } from './com/logger-service.js'; import type { Target } from './com/types.js'; import { Config } from './entities/config.js'; @@ -22,6 +23,7 @@ export interface IComConfig { connectedEnvironments?: { [environmentId: string]: ConfigEnvironmentRecord; }; + callerContext?: CallerContext; } export default class COM extends Feature<'COM'> { id = 'COM' as const; @@ -70,6 +72,7 @@ COM.setup( resolvedContexts, publicPath, connectedEnvironments = {}, + callerContext, }, loggerTransports, [RUN_OPTIONS]: runOptions, @@ -102,6 +105,7 @@ COM.setup( isNode, comOptions, ); + setActiveCallerContext(callerContext); // manually register window initialization api service to be used during // start of managed iframe in packages/core/src/com/initializers/iframe.ts communication.registerAPI({ id: WindowInitializerService.apiId }, new WindowInitializerService()); @@ -110,7 +114,10 @@ COM.setup( { environment: communication.getEnvironmentId() }, { severity: loggerSeverity, maxLogMessages, logToConsole }, ); - onDispose(() => communication.dispose()); + onDispose(() => { + setActiveCallerContext(undefined); + return communication.dispose(); + }); return { loggerService, communication, diff --git a/packages/runtime-node/src/node-env-manager.ts b/packages/runtime-node/src/node-env-manager.ts index 4ea34bae8..f306001dd 100644 --- a/packages/runtime-node/src/node-env-manager.ts +++ b/packages/runtime-node/src/node-env-manager.ts @@ -10,7 +10,7 @@ import { IDisposable, SetMultiMap } from '@dazl/patterns'; import { fileURLToPath } from 'node:url'; import { parseArgs } from 'node:util'; import { extname } from 'node:path'; -import { ConnectionHandlers, WsServerHost } from './ws-node-host.js'; +import { ConnectionHandlers, IdentityExtractor, WsServerHost } from './ws-node-host.js'; import { ILaunchHttpServerOptions, launchEngineHttpServer } from './launch-http-server.js'; import { workerThreadInitializer2 } from './worker-thread-initializer2.js'; import { bindMetricsListener, type PerformanceMetrics } from './metrics-utils.js'; @@ -52,9 +52,11 @@ export class NodeEnvManager implements IDisposable { runtimeOptions: Map, { connectionHandlers, + identityExtractor, ...serverOptions }: ILaunchHttpServerOptions & { connectionHandlers?: ConnectionHandlers; + identityExtractor?: IdentityExtractor; } = {}, lazy = false, ) { @@ -67,6 +69,9 @@ export class NodeEnvManager implements IDisposable { runtimeOptions.set('enginePort', port.toString()); const clientsHost = new WsServerHost(socketServer); + if (identityExtractor) { + clientsHost.setIdentityExtractor(identityExtractor); + } const disposeOnConnectionOpen = connectionHandlers?.onConnectionOpen ? clientsHost.registerConnectionHandler(connectionHandlers.onConnectionOpen) : undefined; diff --git a/packages/runtime-node/src/ws-node-host.ts b/packages/runtime-node/src/ws-node-host.ts index 33d1196dc..15063c6e9 100644 --- a/packages/runtime-node/src/ws-node-host.ts +++ b/packages/runtime-node/src/ws-node-host.ts @@ -14,6 +14,12 @@ export interface ConnectionHandlers { onConnectionReconnect?: IConnectionHandler; } +/** + * Extracts caller identity from Socket.IO handshake. + * Returned value must be serializable (JSON-safe) since it is added to messages. + */ +export type IdentityExtractor = (handshake: io.Socket['handshake']) => T; + export class WsHost extends BaseHost { constructor(private socket: io.Socket) { super(); @@ -32,6 +38,8 @@ export class WsServerHost extends BaseHost implements IDisposable { private reconnectionHandlers = new Set(); private socketToEnvId = new Map(); private clientIdToSocket = new Map(); + private identityStore = new Map(); + private identityExtractor?: IdentityExtractor; private disposables = new SafeDisposable(WsServerHost.name); dispose = this.disposables.dispose; isDisposed = this.disposables.isDisposed; @@ -46,6 +54,10 @@ export class WsServerHost extends BaseHost implements IDisposable { this.disposables.add('clear reconnection handlers', () => this.reconnectionHandlers.clear()); } + public setIdentityExtractor(extractor: IdentityExtractor) { + this.identityExtractor = extractor; + } + public registerConnectionHandler(handler: IConnectionHandler) { this.connectionHandlers.add(handler); return () => { @@ -98,6 +110,11 @@ export class WsServerHost extends BaseHost implements IDisposable { const existingSocket = this.clientIdToSocket.get(clientId); this.clientIdToSocket.set(clientId, socket); + if (this.identityExtractor) { + this.callWithErrorHandling(() => { + this.identityStore.set(clientId, this.identityExtractor!(socket.handshake)); + }); + } const connectionEvent: IConnectionEvent = { clientId, @@ -128,6 +145,7 @@ export class WsServerHost extends BaseHost implements IDisposable { // modify message to be able to forward it message.from = fromId; message.origin = originId; + message.callerIdentity = this.identityStore.get(clientId); // 'callerIdentity' must be always set on server side to avoid spoofing from client this.emitMessageHandlers(message); }; @@ -149,6 +167,7 @@ export class WsServerHost extends BaseHost implements IDisposable { } if (this.clientIdToSocket.get(clientId) === socket) { this.clientIdToSocket.delete(clientId); + this.identityStore.delete(clientId); for (const handler of this.disconnectionHandlers) { this.callWithErrorHandling(() => handler(connectionEvent)); } diff --git a/packages/runtime-node/test-kit/entrypoints/c.node.ts b/packages/runtime-node/test-kit/entrypoints/c.node.ts new file mode 100644 index 000000000..9ae5030dd --- /dev/null +++ b/packages/runtime-node/test-kit/entrypoints/c.node.ts @@ -0,0 +1,73 @@ +import { bindMetricsListener, bindRpcListener, ParentPortHost } from '@dazl/engine-runtime-node'; +import { workerData } from 'node:worker_threads'; +import { AsyncLocalStorage } from 'node:async_hooks'; +import { COM, FeatureClass, RuntimeEngine, TopLevelConfig } from '@dazl/engine-core'; +import CallerIdentityFeature from '../feature/caller-identity.feature.js'; +import { cEnv } from '../feature/envs.js'; +import '../feature/caller-identity.c.env.js'; + +const options = workerData?.runtimeOptions as Map | undefined; +const verbose = options?.get('verbose') ?? false; +const env = cEnv; + +if (verbose) { + console.log(`[${env.env}: Started with options: `, options); +} + +let activateValue: unknown; +export function getActivateValue() { + return activateValue; +} + +export function runEnv({ + Feature = CallerIdentityFeature, + topLevelConfig = [], +}: { Feature?: FeatureClass; topLevelConfig?: TopLevelConfig } = {}) { + return new RuntimeEngine( + env, + [ + ...(workerData + ? [ + COM.configure({ + config: { + host: new ParentPortHost(env.env), + id: env.env, + callerContext: new AsyncLocalStorage(), + }, + }), + ] + : []), + ...topLevelConfig, + ], + new Map(options?.entries() ?? []), + ).run(Feature); +} + +if (workerData) { + const unbindMetricsListener = bindMetricsListener(); + let running: ReturnType; + const unbindActivateListener = bindRpcListener('activate', (value: unknown) => { + activateValue = value; + unbindActivateListener(); + running = runEnv(); + }); + const unbindTerminationListener = bindRpcListener('terminate', async () => { + if (verbose) { + console.log(`[${env.env}]: Termination Requested. Waiting for engine.`); + } + unbindTerminationListener(); + unbindMetricsListener(); + try { + const engine = await running; + if (verbose) { + console.log(`[${env.env}]: Terminating`); + } + return engine.shutdown(); + } catch (e) { + console.error('[${env.name}]: Error while shutting down', e); + return; + } + }); +} else { + console.log('running engine in test mode'); +} diff --git a/packages/runtime-node/test-kit/feature/caller-identity.c.env.ts b/packages/runtime-node/test-kit/feature/caller-identity.c.env.ts new file mode 100644 index 000000000..4d0faa018 --- /dev/null +++ b/packages/runtime-node/test-kit/feature/caller-identity.c.env.ts @@ -0,0 +1,11 @@ +import { getCurrentCaller } from '@dazl/engine-core'; +import { cEnv } from './envs.js'; +import CallerIdentityFeature from './caller-identity.feature.js'; + +CallerIdentityFeature.setup(cEnv, () => { + return { + identityService: { + whoAmI: () => getCurrentCaller(), + }, + }; +}); diff --git a/packages/runtime-node/test-kit/feature/caller-identity.feature.ts b/packages/runtime-node/test-kit/feature/caller-identity.feature.ts new file mode 100644 index 000000000..118193394 --- /dev/null +++ b/packages/runtime-node/test-kit/feature/caller-identity.feature.ts @@ -0,0 +1,11 @@ +import { COM, Feature, Service } from '@dazl/engine-core'; +import { cEnv } from './envs.js'; +import { IdentityService } from './types.js'; + +export default class CallerIdentityFeature extends Feature<'caller-identity'> { + id = 'caller-identity' as const; + dependencies = [COM]; + api = { + identityService: Service.withType().defineEntity(cEnv).allowRemoteAccess(), + }; +} diff --git a/packages/runtime-node/test-kit/feature/envs.ts b/packages/runtime-node/test-kit/feature/envs.ts index 6a1aa466a..c52823311 100644 --- a/packages/runtime-node/test-kit/feature/envs.ts +++ b/packages/runtime-node/test-kit/feature/envs.ts @@ -2,3 +2,4 @@ import { Environment } from '@dazl/engine-core'; export const aEnv = new Environment('a', 'node', 'single'); export const bEnv = new Environment('b', 'node', 'single'); +export const cEnv = new Environment('c', 'node', 'single'); diff --git a/packages/runtime-node/test-kit/feature/types.ts b/packages/runtime-node/test-kit/feature/types.ts index 52c357b3d..74fc618f4 100644 --- a/packages/runtime-node/test-kit/feature/types.ts +++ b/packages/runtime-node/test-kit/feature/types.ts @@ -3,3 +3,7 @@ export type EchoService = { echoChained: () => Promise; getActivateValue: () => unknown; }; + +export type IdentityService = { + whoAmI: () => unknown; +}; diff --git a/packages/runtime-node/test/caller-identity.spec.ts b/packages/runtime-node/test/caller-identity.spec.ts new file mode 100644 index 000000000..0d1df8d8e --- /dev/null +++ b/packages/runtime-node/test/caller-identity.spec.ts @@ -0,0 +1,50 @@ +import { createDisposables } from '@dazl/create-disposables'; +import { BaseHost, Communication, WsClientHost } from '@dazl/engine-core'; +import { NodeEnvManager, type NodeEnvsFeatureMapping } from '@dazl/engine-runtime-node'; +import { expect } from 'chai'; +import { cEnv } from '../test-kit/feature/envs.js'; +import { IdentityService } from '../test-kit/feature/types.js'; + +describe('Caller identity propagation with autoLaunch', () => { + const disposables = createDisposables(); + const disposeAfterTest = void }>(obj: T) => { + disposables.add(() => obj.dispose()); + return obj; + }; + + afterEach(() => disposables.dispose()); + + /** + * Cross-process: client -> autoLaunched gateway -> worker-thread env. + */ + it('propagates caller identity from client to worker-thread env', async () => { + const featureEnvironmentsMapping: NodeEnvsFeatureMapping = { + featureToEnvironments: { + 'caller-identity': [cEnv.env], + }, + availableEnvironments: { + [cEnv.env]: { env: cEnv.env, endpointType: 'single', envType: 'node' }, + }, + }; + + const meta = { url: import.meta.resolve('../test-kit/entrypoints/') }; + const manager = new NodeEnvManager(meta, featureEnvironmentsMapping); + disposables.add(() => manager.dispose()); + + const { port: gatewayPort } = await manager.autoLaunch(new Map([['feature', 'caller-identity']]), { + identityExtractor: (handshake) => ({ userId: handshake.auth?.userId ?? 'anonymous' }), + }); + + const clientHost = disposeAfterTest( + new WsClientHost(`http://localhost:${gatewayPort}`, { auth: { userId: 'worker-user' } }), + ); + await clientHost.connected; + const clientCom = disposeAfterTest(new Communication(new BaseHost(), 'client-host')); + clientCom.registerEnv(cEnv.env, clientHost); + clientCom.registerMessageHandler(clientHost); + + const api = clientCom.apiProxy({ id: cEnv.env }, { id: 'caller-identity.identityService' }); + + expect(await api.whoAmI()).to.deep.equal({ userId: 'worker-user' }); + }); +}); diff --git a/packages/runtime-node/test/caller-identity.unit.ts b/packages/runtime-node/test/caller-identity.unit.ts new file mode 100644 index 000000000..abef9b53d --- /dev/null +++ b/packages/runtime-node/test/caller-identity.unit.ts @@ -0,0 +1,319 @@ +import { createDisposables } from '@dazl/create-disposables'; +import { BaseHost, Communication, WsClientHost, getCurrentCaller, setActiveCallerContext } from '@dazl/engine-core'; +import { WsServerHost } from '@dazl/engine-runtime-node'; +import { expect } from 'chai'; +import { safeListeningHttpServer } from 'create-listening-server'; +import { AsyncLocalStorage } from 'node:async_hooks'; +import type { Socket } from 'node:net'; +import { deferred } from 'promise-assist'; +import * as io from 'socket.io'; + +interface IIdentityTestApi { + whoAmI: () => unknown; +} + +interface IAsyncIdentityTestApi { + whoAmI: () => Promise; +} + +interface IGatedApi { + /** Resolves once `gate` resolves; returns identity captured at call entry and at completion. */ + callGated: (label: string) => Promise<{ label: string; entry: unknown; exit: unknown }>; +} + +describe('Caller identity propagation', () => { + const COMMUNICATION_ID = 'identity-test'; + + let socketServer: io.Server | undefined; + let serverTopology: Record = {}; + let port: number; + + const disposables = createDisposables(); + const disposeAfterTest = void }>(obj: T) => { + disposables.add(() => obj.dispose()); + return obj; + }; + afterEach(() => disposables.dispose()); + + beforeEach(async () => { + setActiveCallerContext(new AsyncLocalStorage()); + disposables.add(() => setActiveCallerContext(undefined)); + const { httpServer: server, port: servingPort } = await safeListeningHttpServer(3060); + port = servingPort; + socketServer = new io.Server(server, { cors: {} }); + const connections = new Set(); + disposables.add(async () => { + await socketServer?.close(); + socketServer = undefined; + }); + disposables.add(() => (serverTopology = {})); + const onConnection = (connection: Socket): void => { + connections.add(connection); + disposables.add(() => { + connections.delete(connection); + }); + }; + server.on('connection', onConnection); + disposables.add(() => { + for (const connection of connections) { + connection.destroy(); + } + }); + }); + + /** + * Single-hop: client → server. The server extracts identity from the handshake, + * stamps every inbound message, and the API handler reads it via getCurrentCaller(). + */ + it('exposes caller identity inside server-side API handler', async () => { + const nameSpace = socketServer!.of('processing'); + serverTopology['server-host'] = `http://localhost:${port}/processing`; + + const serverHost = disposeAfterTest(new WsServerHost(nameSpace)); + serverHost.setIdentityExtractor((handshake) => ({ userId: handshake.auth?.userId ?? 'anonymous' })); + + const serverCom = new Communication(serverHost, 'server-host', {}, {}, true); + serverCom.registerAPI( + { id: COMMUNICATION_ID }, + { + whoAmI: () => getCurrentCaller(), + }, + ); + + const clientHost = disposeAfterTest( + new WsClientHost(serverTopology['server-host'], { auth: { userId: 'u-42' } }), + ); + await clientHost.connected; + const clientCom = new Communication(clientHost, 'client-host', serverTopology); + + const api = clientCom.apiProxy({ id: 'server-host' }, { id: COMMUNICATION_ID }); + + expect(await api.whoAmI()).to.deep.equal({ userId: 'u-42' }); + }); + + /** + * Concurrent calls from different clients to the same server-side API must each + * see their own identity. AsyncLocalStorage is the mechanism that prevents the + * identities from bleeding across in-flight calls. + */ + it('keeps caller identities isolated across concurrent in-flight calls', async () => { + const nameSpace = socketServer!.of('processing'); + serverTopology['server-host'] = `http://localhost:${port}/processing`; + + const serverHost = disposeAfterTest(new WsServerHost(nameSpace)); + serverHost.setIdentityExtractor((handshake) => ({ userId: handshake.auth?.userId as string })); + + const gateA = deferred(); + const gateB = deferred(); + const gateC = deferred(); + const gates: Record; resolve: () => void }> = { + A: gateA, + B: gateB, + C: gateC, + }; + + const serverCom = new Communication(serverHost, 'server-host', {}, {}, true); + serverCom.registerAPI( + { id: COMMUNICATION_ID }, + { + callGated: async (label) => { + const entry = getCurrentCaller(); + await gates[label]!.promise; + const exit = getCurrentCaller(); + return { label, entry, exit }; + }, + }, + ); + + const makeClient = async (userId: string) => { + const host = disposeAfterTest(new WsClientHost(serverTopology['server-host']!, { auth: { userId } })); + await host.connected; + const com = new Communication(host, `client-${userId}`, serverTopology); + return com.apiProxy({ id: 'server-host' }, { id: COMMUNICATION_ID }); + }; + + const apiA = await makeClient('user-A'); + const apiB = await makeClient('user-B'); + const apiC = await makeClient('user-C'); + + // Fire all three concurrently; they all suspend inside the handler awaiting their gate. + const pA = apiA.callGated('A'); + const pB = apiB.callGated('B'); + const pC = apiC.callGated('C'); + + // Resolve in a deliberately interleaved order so the runtime cannot rely on FIFO. + gateB.resolve(); + gateC.resolve(); + gateA.resolve(); + + const [resA, resB, resC] = await Promise.all([pA, pB, pC]); + + // Identity must be stable from entry through completion of each handler invocation. + expect(resA).to.deep.equal({ + label: 'A', + entry: { userId: 'user-A' }, + exit: { userId: 'user-A' }, + }); + expect(resB).to.deep.equal({ + label: 'B', + entry: { userId: 'user-B' }, + exit: { userId: 'user-B' }, + }); + expect(resC).to.deep.equal({ + label: 'C', + entry: { userId: 'user-C' }, + exit: { userId: 'user-C' }, + }); + }); + + /** + * Chained: client → processing → workspace. + * + * The original client connects via socket; the WsServerHost stamps its identity + * on every inbound message. The processing env handles the call and delegates + * to the workspace env (in-process, sharing a BaseHost bus). The workspace + * handler must observe the *original* client's identity, not undefined and not + * processing's own. + */ + it('propagates caller identity through chained service calls (client → processing → workspace)', async () => { + const nameSpace = socketServer!.of('processing'); + serverTopology['processing-env'] = `http://localhost:${port}/processing`; + + const serverHost = disposeAfterTest(new WsServerHost(nameSpace)); + serverHost.setIdentityExtractor((handshake) => ({ userId: handshake.auth?.userId as string })); + + // Shared in-process bus between processing and workspace. Processing also + // listens on the WsServerHost so it receives messages coming from the + // socket client. + const innerBus = new BaseHost('inner-bus'); + + const processingCom = new Communication(innerBus, 'processing-env', {}, {}, true); + processingCom.registerMessageHandler(serverHost); + + const workspaceCom = new Communication(innerBus, 'workspace-env', {}, {}, true); + + workspaceCom.registerAPI({ id: 'workspace-api' }, { whoAmI: () => getCurrentCaller() }); + + const workspaceProxy = processingCom.apiProxy( + { id: 'workspace-env' }, + { id: 'workspace-api' }, + ); + + processingCom.registerAPI( + { id: 'processing-api' }, + { whoAmI: () => workspaceProxy.whoAmI() }, + ); + + const clientHost = disposeAfterTest( + new WsClientHost(serverTopology['processing-env'], { auth: { userId: 'u-99' } }), + ); + await clientHost.connected; + const clientCom = new Communication(clientHost, 'client-env', serverTopology); + + const api = clientCom.apiProxy({ id: 'processing-env' }, { id: 'processing-api' }); + + expect(await api.whoAmI()).to.deep.equal({ userId: 'u-99' }); + }); + + /** + * A client subscribes to a server-side event API. The subscribe handler on the + * server should observe the original client's identity at subscription time + * (e.g. so the service can scope events to that user). + */ + it('propagates caller identity into a remote listener subscription handler', async () => { + interface ISubscribableApi { + sub: (cb: (data: string) => void) => void; + unsub: (cb: (data: string) => void) => void; + } + + const nameSpace = socketServer!.of('processing'); + serverTopology['server-host'] = `http://localhost:${port}/processing`; + + const serverHost = disposeAfterTest(new WsServerHost(nameSpace)); + serverHost.setIdentityExtractor((handshake) => ({ userId: handshake.auth?.userId as string })); + + const subscriberIdentity = deferred(); + + const serverCom = new Communication(serverHost, 'server-host', {}, {}, true); + serverCom.registerAPI( + { id: COMMUNICATION_ID }, + { + sub: (_cb) => { + subscriberIdentity.resolve(getCurrentCaller()); + }, + unsub: (_cb) => {}, + }, + ); + + const clientHost = disposeAfterTest( + new WsClientHost(serverTopology['server-host'], { auth: { userId: 'sub-user' } }), + ); + await clientHost.connected; + const clientCom = new Communication(clientHost, 'client-host', serverTopology); + + const api = clientCom.apiProxy( + { id: 'server-host' }, + { id: COMMUNICATION_ID }, + { + sub: { listener: true, removeListener: 'unsub' }, + unsub: { removeListener: 'sub' }, + }, + ); + + await api.sub(() => {}); + + expect(await subscriberIdentity.promise).to.deep.equal({ userId: 'sub-user' }); + }); + + /** + * Same idea for unsubscribe: when the last listener is removed and the runtime + * sends an UnListenMessage, the server-side unsub handler should observe the + * caller's identity (e.g. to clean up per-user state). + */ + it('propagates caller identity into a remote listener unsubscription handler', async () => { + interface ISubscribableApi { + sub: (cb: (data: string) => void) => void; + unsub: (cb: (data: string) => void) => void; + } + + const nameSpace = socketServer!.of('processing'); + serverTopology['server-host'] = `http://localhost:${port}/processing`; + + const serverHost = disposeAfterTest(new WsServerHost(nameSpace)); + serverHost.setIdentityExtractor((handshake) => ({ userId: handshake.auth?.userId as string })); + + const unsubscriberIdentity = deferred(); + + const serverCom = new Communication(serverHost, 'server-host', {}, {}, true); + serverCom.registerAPI( + { id: COMMUNICATION_ID }, + { + sub: (_cb) => {}, + unsub: (_cb) => { + unsubscriberIdentity.resolve(getCurrentCaller()); + }, + }, + ); + + const clientHost = disposeAfterTest( + new WsClientHost(serverTopology['server-host'], { auth: { userId: 'unsub-user' } }), + ); + await clientHost.connected; + const clientCom = new Communication(clientHost, 'client-host', serverTopology); + + const api = clientCom.apiProxy( + { id: 'server-host' }, + { id: COMMUNICATION_ID }, + { + sub: { listener: true, removeListener: 'unsub' }, + unsub: { removeListener: 'sub' }, + }, + ); + + const handler = () => {}; + await api.sub(handler); + await api.unsub(handler); + + expect(await unsubscriberIdentity.promise).to.deep.equal({ userId: 'unsub-user' }); + }); +});