Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions packages/core/src/com/caller-context.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
export interface CallerContext {
getStore(): unknown;
run<R>(store: unknown, callback: () => R): R;
}

let activeCallerContext: CallerContext | undefined;
export function setActiveCallerContext(ctx: CallerContext | undefined): void {
activeCallerContext = ctx;
}

export function getCurrentCaller(): unknown {
return activeCallerContext?.getStore();
}

export function runWithCaller<R>(identity: unknown, fn: () => R): R {
if (identity !== undefined && activeCallerContext) {
return activeCallerContext.run(identity, fn);
}
return fn();
}
17 changes: 14 additions & 3 deletions packages/core/src/com/communication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import {
ReconnectFunction,
RemoteValueListener,
} from '../remote-value.js';
import { getCurrentCaller, runWithCaller } from './caller-context.js';

export interface ConfigEnvironmentRecord extends EnvironmentRecord {
registerMessageHandler?: boolean;
Expand Down Expand Up @@ -373,6 +374,7 @@ export class Communication {
data: { api, method, args: serializeApiCallArguments(args) },
callbackId,
origin,
callerIdentity: getCurrentCaller(),
};
this.callWithCallback(envId, message, callbackId, res, rej);
}
Expand Down Expand Up @@ -698,6 +700,7 @@ export class Communication {
type: 'callback',
forwardingChain: message.forwardingChain,
error: new CircularForwardingError(message, this.rootEnvId, env.id),
callerIdentity: message.callerIdentity,
});
return;
} else if (this.DEBUG) {
Expand Down Expand Up @@ -771,6 +774,7 @@ export class Communication {
},
handlerId: listenerHandlerId,
origin,
callerIdentity: getCurrentCaller(),
};
// sometimes the callback will never happen since target environment is already dead
this.sendTo(envId, message);
Expand Down Expand Up @@ -802,6 +806,7 @@ export class Communication {
handlerId: '',
callbackId,
origin,
callerIdentity: getCurrentCaller(),
};
message.handlerId = this.createHandlerRecord(envId, api, method, fn, message);

Expand Down Expand Up @@ -939,7 +944,9 @@ export class Communication {
const dispatcher = this.eventDispatchers.get(namespacedHandlerId)?.dispatcher;
if (dispatcher) {
this.eventDispatchers.delete(namespacedHandlerId);
const data = await this.apiCall(message.origin, message.data.api, message.data.method, [dispatcher]);
const data = await runWithCaller(message.callerIdentity, () =>
this.apiCall(message.origin, message.data.api, message.data.method, [dispatcher]),
);
if (message.callbackId) {
this.sendTo(message.from, {
to: message.from,
Expand All @@ -957,7 +964,9 @@ export class Communication {
try {
const dispatcher = this.getDispatcher(message.from, message);

const data = await this.apiCall(message.origin, message.data.api, message.data.method, [dispatcher]);
const data = await runWithCaller(message.callerIdentity, () =>
this.apiCall(message.origin, message.data.api, message.data.method, [dispatcher]),
);

if (message.callbackId) {
this.sendTo(message.from, {
Expand All @@ -984,7 +993,9 @@ export class Communication {
private async handleCall(message: CallMessage): Promise<void> {
try {
const args = deserializeApiCallArguments(message.data.args);
const data = await this.apiCall(message.origin, message.data.api, message.data.method, args);
const data = await runWithCaller(message.callerIdentity, () =>
this.apiCall(message.origin, message.data.api, message.data.method, args),
);

if (message.callbackId) {
this.sendTo(message.from, {
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/com/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ export * from './hosts/ws-client-host.js';
export * from './initializers/index.js';
export * from './hosts/index.js';
export * from './communication-errors.js';
export * from './caller-context.js';
1 change: 1 addition & 0 deletions packages/core/src/com/message-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export interface BaseMessage {
error?: Error;
origin: string;
forwardingChain?: string[];
callerIdentity?: unknown;
}

export interface CallMessage extends BaseMessage {
Expand Down
9 changes: 8 additions & 1 deletion packages/core/src/communication.feature.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { BaseHost } from './com/hosts/base-host.js';
import { Communication, type ConfigEnvironmentRecord, type CommunicationOptions } from './com/communication.js';
import { setActiveCallerContext, type CallerContext } from './com/caller-context.js';
import { LoggerService } from './com/logger-service.js';
import type { Target } from './com/types.js';
import { Config } from './entities/config.js';
Expand All @@ -22,6 +23,7 @@ export interface IComConfig {
connectedEnvironments?: {
[environmentId: string]: ConfigEnvironmentRecord;
};
callerContext?: CallerContext;
}
export default class COM extends Feature<'COM'> {
id = 'COM' as const;
Expand Down Expand Up @@ -70,6 +72,7 @@ COM.setup(
resolvedContexts,
publicPath,
connectedEnvironments = {},
callerContext,
},
loggerTransports,
[RUN_OPTIONS]: runOptions,
Expand Down Expand Up @@ -102,6 +105,7 @@ COM.setup(
isNode,
comOptions,
);
setActiveCallerContext(callerContext);
// manually register window initialization api service to be used during
// start of managed iframe in packages/core/src/com/initializers/iframe.ts
communication.registerAPI({ id: WindowInitializerService.apiId }, new WindowInitializerService());
Expand All @@ -110,7 +114,10 @@ COM.setup(
{ environment: communication.getEnvironmentId() },
{ severity: loggerSeverity, maxLogMessages, logToConsole },
);
onDispose(() => communication.dispose());
onDispose(() => {
setActiveCallerContext(undefined);
return communication.dispose();
});
return {
loggerService,
communication,
Expand Down
7 changes: 6 additions & 1 deletion packages/runtime-node/src/node-env-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { IDisposable, SetMultiMap } from '@dazl/patterns';
import { fileURLToPath } from 'node:url';
import { parseArgs } from 'node:util';
import { extname } from 'node:path';
import { ConnectionHandlers, WsServerHost } from './ws-node-host.js';
import { ConnectionHandlers, IdentityExtractor, WsServerHost } from './ws-node-host.js';
import { ILaunchHttpServerOptions, launchEngineHttpServer } from './launch-http-server.js';
import { workerThreadInitializer2 } from './worker-thread-initializer2.js';
import { bindMetricsListener, type PerformanceMetrics } from './metrics-utils.js';
Expand Down Expand Up @@ -52,9 +52,11 @@ export class NodeEnvManager implements IDisposable {
runtimeOptions: Map<string, string | boolean | undefined>,
{
connectionHandlers,
identityExtractor,
...serverOptions
}: ILaunchHttpServerOptions & {
connectionHandlers?: ConnectionHandlers;
identityExtractor?: IdentityExtractor;
} = {},
lazy = false,
) {
Expand All @@ -67,6 +69,9 @@ export class NodeEnvManager implements IDisposable {
runtimeOptions.set('enginePort', port.toString());

const clientsHost = new WsServerHost(socketServer);
if (identityExtractor) {
clientsHost.setIdentityExtractor(identityExtractor);
}
const disposeOnConnectionOpen = connectionHandlers?.onConnectionOpen
? clientsHost.registerConnectionHandler(connectionHandlers.onConnectionOpen)
: undefined;
Expand Down
19 changes: 19 additions & 0 deletions packages/runtime-node/src/ws-node-host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ export interface ConnectionHandlers {
onConnectionReconnect?: IConnectionHandler;
}

/**
* Extracts caller identity from Socket.IO handshake.
* Returned value must be serializable (JSON-safe) since it is added to messages.
*/
export type IdentityExtractor<T = unknown> = (handshake: io.Socket['handshake']) => T;

export class WsHost extends BaseHost {
constructor(private socket: io.Socket) {
super();
Expand All @@ -32,6 +38,8 @@ export class WsServerHost extends BaseHost implements IDisposable {
private reconnectionHandlers = new Set<IConnectionHandler>();
private socketToEnvId = new Map<string, { socket: io.Socket; clientID: string }>();
private clientIdToSocket = new Map<string, io.Socket>();
private identityStore = new Map<string, unknown>();
private identityExtractor?: IdentityExtractor;
private disposables = new SafeDisposable(WsServerHost.name);
dispose = this.disposables.dispose;
isDisposed = this.disposables.isDisposed;
Expand All @@ -46,6 +54,10 @@ export class WsServerHost extends BaseHost implements IDisposable {
this.disposables.add('clear reconnection handlers', () => this.reconnectionHandlers.clear());
}

public setIdentityExtractor(extractor: IdentityExtractor) {
this.identityExtractor = extractor;
}

public registerConnectionHandler(handler: IConnectionHandler) {
this.connectionHandlers.add(handler);
return () => {
Expand Down Expand Up @@ -98,6 +110,11 @@ export class WsServerHost extends BaseHost implements IDisposable {
const existingSocket = this.clientIdToSocket.get(clientId);

this.clientIdToSocket.set(clientId, socket);
if (this.identityExtractor) {
this.callWithErrorHandling(() => {
this.identityStore.set(clientId, this.identityExtractor!(socket.handshake));
});
}

const connectionEvent: IConnectionEvent = {
clientId,
Expand Down Expand Up @@ -128,6 +145,7 @@ export class WsServerHost extends BaseHost implements IDisposable {
// modify message to be able to forward it
message.from = fromId;
message.origin = originId;
message.callerIdentity = this.identityStore.get(clientId); // 'callerIdentity' must be always set on server side to avoid spoofing from client

this.emitMessageHandlers(message);
};
Expand All @@ -149,6 +167,7 @@ export class WsServerHost extends BaseHost implements IDisposable {
}
if (this.clientIdToSocket.get(clientId) === socket) {
this.clientIdToSocket.delete(clientId);
this.identityStore.delete(clientId);
for (const handler of this.disconnectionHandlers) {
this.callWithErrorHandling(() => handler(connectionEvent));
}
Expand Down
73 changes: 73 additions & 0 deletions packages/runtime-node/test-kit/entrypoints/c.node.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import { bindMetricsListener, bindRpcListener, ParentPortHost } from '@dazl/engine-runtime-node';
import { workerData } from 'node:worker_threads';
import { AsyncLocalStorage } from 'node:async_hooks';
import { COM, FeatureClass, RuntimeEngine, TopLevelConfig } from '@dazl/engine-core';
import CallerIdentityFeature from '../feature/caller-identity.feature.js';
import { cEnv } from '../feature/envs.js';
import '../feature/caller-identity.c.env.js';

const options = workerData?.runtimeOptions as Map<string, string> | undefined;
const verbose = options?.get('verbose') ?? false;
const env = cEnv;

if (verbose) {
console.log(`[${env.env}: Started with options: `, options);
}

let activateValue: unknown;
export function getActivateValue() {
return activateValue;
}

export function runEnv({
Feature = CallerIdentityFeature,
topLevelConfig = [],
}: { Feature?: FeatureClass; topLevelConfig?: TopLevelConfig } = {}) {
return new RuntimeEngine(
env,
[
...(workerData
? [
COM.configure({
config: {
host: new ParentPortHost(env.env),
id: env.env,
callerContext: new AsyncLocalStorage(),
},
}),
]
: []),
...topLevelConfig,
],
new Map(options?.entries() ?? []),
).run(Feature);
}

if (workerData) {
const unbindMetricsListener = bindMetricsListener();
let running: ReturnType<typeof runEnv>;
const unbindActivateListener = bindRpcListener('activate', (value: unknown) => {
activateValue = value;
unbindActivateListener();
running = runEnv();
});
const unbindTerminationListener = bindRpcListener('terminate', async () => {
if (verbose) {
console.log(`[${env.env}]: Termination Requested. Waiting for engine.`);
}
unbindTerminationListener();
unbindMetricsListener();
try {
const engine = await running;
if (verbose) {
console.log(`[${env.env}]: Terminating`);
}
return engine.shutdown();
} catch (e) {
console.error('[${env.name}]: Error while shutting down', e);
return;
}
});
} else {
console.log('running engine in test mode');
}
11 changes: 11 additions & 0 deletions packages/runtime-node/test-kit/feature/caller-identity.c.env.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { getCurrentCaller } from '@dazl/engine-core';
import { cEnv } from './envs.js';
import CallerIdentityFeature from './caller-identity.feature.js';

CallerIdentityFeature.setup(cEnv, () => {
return {
identityService: {
whoAmI: () => getCurrentCaller(),
},
};
});
11 changes: 11 additions & 0 deletions packages/runtime-node/test-kit/feature/caller-identity.feature.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { COM, Feature, Service } from '@dazl/engine-core';
import { cEnv } from './envs.js';
import { IdentityService } from './types.js';

export default class CallerIdentityFeature extends Feature<'caller-identity'> {
id = 'caller-identity' as const;
dependencies = [COM];
api = {
identityService: Service.withType<IdentityService>().defineEntity(cEnv).allowRemoteAccess(),
};
}
1 change: 1 addition & 0 deletions packages/runtime-node/test-kit/feature/envs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ import { Environment } from '@dazl/engine-core';

export const aEnv = new Environment('a', 'node', 'single');
export const bEnv = new Environment('b', 'node', 'single');
export const cEnv = new Environment('c', 'node', 'single');
4 changes: 4 additions & 0 deletions packages/runtime-node/test-kit/feature/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ export type EchoService = {
echoChained: () => Promise<string>;
getActivateValue: () => unknown;
};

export type IdentityService = {
whoAmI: () => unknown;
};
Loading