From 7b453f7745cad73fc81e7884faf473aecda99556 Mon Sep 17 00:00:00 2001 From: Yiming Date: Wed, 14 Feb 2024 21:24:49 +0800 Subject: [PATCH] fix: improve generated typing for polymorphic models (#1002) --- package.json | 5 +- .../src/enhancements/create-enhancement.ts | 56 ++++++++++--------- .../runtime/src/enhancements/default-auth.ts | 6 +- packages/runtime/src/enhancements/delegate.ts | 6 +- packages/runtime/src/enhancements/omit.ts | 8 +-- packages/runtime/src/enhancements/password.ts | 11 ++-- .../src/enhancements/policy/handler.ts | 6 +- .../runtime/src/enhancements/policy/index.ts | 4 +- .../src/enhancements/policy/policy-utils.ts | 4 +- packages/runtime/src/enhancements/proxy.ts | 13 +++-- .../runtime/src/enhancements/query-utils.ts | 4 +- .../src/plugins/enhancer/enhance/index.ts | 39 ++++++++++++- .../src/plugins/prisma/schema-generator.ts | 9 ++- packages/sdk/src/utils.ts | 10 ++++ .../integration/tests/misc/stacktrace.test.ts | 2 +- 15 files changed, 123 insertions(+), 60 deletions(-) diff --git a/package.json b/package.json index 83d9f7072..fd0c2763d 100644 --- a/package.json +++ b/package.json @@ -9,7 +9,10 @@ "test-ci": "ZENSTACK_TEST=1 pnpm -r run test --silent --forceExit", "publish-all": "pnpm --filter \"./packages/**\" -r publish --access public", "publish-preview": "pnpm --filter \"./packages/**\" -r publish --force --registry https://preview.registry.zenstack.dev/", - "unpublish-preview": "pnpm --recursive --shell-mode exec -- npm unpublish -f --registry https://preview.registry.zenstack.dev/ \"\\$PNPM_PACKAGE_NAME\"" + "unpublish-preview": "pnpm --recursive --shell-mode exec -- npm unpublish -f --registry https://preview.registry.zenstack.dev/ \"\\$PNPM_PACKAGE_NAME\"", + "publish-next": "pnpm --filter \"./packages/**\" -r publish --access public --tag next", + "publish-preview-next": "pnpm --filter \"./packages/**\" -r publish --force --registry https://preview.registry.zenstack.dev/ --tag next", + "unpublish-preview-next": "pnpm --recursive --shell-mode exec -- npm unpublish -f --registry https://preview.registry.zenstack.dev/ --tag next \"\\$PNPM_PACKAGE_NAME\"" }, "keywords": [], "author": "", diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index dbca40874..1b9796970 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -32,60 +32,64 @@ export type TransactionIsolationLevel = | 'Snapshot' | 'Serializable'; -/** - * Options for {@link createEnhancement} - */ export type EnhancementOptions = { /** - * Policy definition + * The kinds of enhancements to apply. By default all enhancements are applied. */ - policy: PolicyDef; + kinds?: EnhancementKind[]; /** - * Model metadata + * Whether to log Prisma query */ - modelMeta: ModelMeta; + logPrismaQuery?: boolean; /** - * Zod schemas for validation + * Hook for transforming errors before they are thrown to the caller. */ - zodSchemas?: ZodSchemas; + errorTransformer?: ErrorTransformer; /** - * Whether to log Prisma query + * The `maxWait` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. */ - logPrismaQuery?: boolean; + transactionMaxWait?: number; /** - * The Node module that contains PrismaClient + * The `timeout` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - prismaModule: any; + transactionTimeout?: number; /** - * The kinds of enhancements to apply. By default all enhancements are applied. + * The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. */ - kinds?: EnhancementKind[]; + transactionIsolationLevel?: TransactionIsolationLevel; +}; +/** + * Options for {@link createEnhancement} + * + * @private + */ +export type InternalEnhancementOptions = EnhancementOptions & { /** - * Hook for transforming errors before they are thrown to the caller. + * Policy definition */ - errorTransformer?: ErrorTransformer; + policy: PolicyDef; /** - * The `maxWait` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. + * Model metadata */ - transactionMaxWait?: number; + modelMeta: ModelMeta; /** - * The `timeout` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. + * Zod schemas for validation */ - transactionTimeout?: number; + zodSchemas?: ZodSchemas; /** - * The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. + * The Node module that contains PrismaClient */ - transactionIsolationLevel?: TransactionIsolationLevel; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + prismaModule: any; }; /** @@ -103,13 +107,15 @@ let hasDefaultAuth: boolean | undefined = undefined; * Gets a Prisma client enhanced with all enhancement behaviors, including access * policy, field validation, field omission and password hashing. * + * @private + * * @param prisma The Prisma client to enhance. * @param context Context. * @param options Options. */ export function createEnhancement( prisma: DbClient, - options: EnhancementOptions, + options: InternalEnhancementOptions, context?: EnhancementContext ) { if (!prisma) { diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 9e0a64a4f..cce9af782 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -4,7 +4,7 @@ import deepcopy from 'deepcopy'; import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross'; import { DbClientContract } from '../types'; -import { EnhancementContext, EnhancementOptions } from './create-enhancement'; +import { EnhancementContext, InternalEnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; /** @@ -14,7 +14,7 @@ import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './prox */ export function withDefaultAuth( prisma: DbClient, - options: EnhancementOptions, + options: InternalEnhancementOptions, context?: EnhancementContext ): DbClient { return makeProxy( @@ -31,7 +31,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { constructor( prisma: DbClientContract, model: string, - options: EnhancementOptions, + options: InternalEnhancementOptions, private readonly context?: EnhancementContext ) { super(prisma, model, options); diff --git a/packages/runtime/src/enhancements/delegate.ts b/packages/runtime/src/enhancements/delegate.ts index 0a1e39d8c..7032a965a 100644 --- a/packages/runtime/src/enhancements/delegate.ts +++ b/packages/runtime/src/enhancements/delegate.ts @@ -15,13 +15,13 @@ import { resolveField, } from '../cross'; import type { CrudContract, DbClientContract } from '../types'; -import type { EnhancementOptions } from './create-enhancement'; +import type { InternalEnhancementOptions } from './create-enhancement'; import { Logger } from './logger'; import { DefaultPrismaProxyHandler, makeProxy } from './proxy'; import { QueryUtils } from './query-utils'; import { formatObject, prismaClientValidationError } from './utils'; -export function withDelegate(prisma: DbClient, options: EnhancementOptions): DbClient { +export function withDelegate(prisma: DbClient, options: InternalEnhancementOptions): DbClient { return makeProxy( prisma, options.modelMeta, @@ -34,7 +34,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { private readonly logger: Logger; private readonly queryUtils: QueryUtils; - constructor(prisma: DbClientContract, model: string, options: EnhancementOptions) { + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { super(prisma, model, options); this.logger = new Logger(prisma); this.queryUtils = new QueryUtils(prisma, this.options); diff --git a/packages/runtime/src/enhancements/omit.ts b/packages/runtime/src/enhancements/omit.ts index e51f9cb47..fa834166d 100644 --- a/packages/runtime/src/enhancements/omit.ts +++ b/packages/runtime/src/enhancements/omit.ts @@ -1,9 +1,9 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { enumerate, getModelFields, resolveField, type ModelMeta } from '../cross'; +import { enumerate, getModelFields, resolveField } from '../cross'; import { DbClientContract } from '../types'; -import { EnhancementOptions } from './create-enhancement'; +import { InternalEnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, makeProxy } from './proxy'; /** @@ -11,7 +11,7 @@ import { DefaultPrismaProxyHandler, makeProxy } from './proxy'; * * @private */ -export function withOmit(prisma: DbClient, options: EnhancementOptions): DbClient { +export function withOmit(prisma: DbClient, options: InternalEnhancementOptions): DbClient { return makeProxy( prisma, options.modelMeta, @@ -21,7 +21,7 @@ export function withOmit(prisma: DbClient, options: Enh } class OmitHandler extends DefaultPrismaProxyHandler { - constructor(prisma: DbClientContract, model: string, options: EnhancementOptions) { + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { super(prisma, model, options); } diff --git a/packages/runtime/src/enhancements/password.ts b/packages/runtime/src/enhancements/password.ts index 7fef04dd8..f83939792 100644 --- a/packages/runtime/src/enhancements/password.ts +++ b/packages/runtime/src/enhancements/password.ts @@ -3,9 +3,9 @@ import { hash } from 'bcryptjs'; import { DEFAULT_PASSWORD_SALT_LENGTH } from '../constants'; -import { NestedWriteVisitor, type ModelMeta, type PrismaWriteActionType } from '../cross'; +import { NestedWriteVisitor, type PrismaWriteActionType } from '../cross'; import { DbClientContract } from '../types'; -import { EnhancementOptions } from './create-enhancement'; +import { InternalEnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; /** @@ -13,7 +13,10 @@ import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './prox * * @private */ -export function withPassword(prisma: DbClient, options: EnhancementOptions): DbClient { +export function withPassword( + prisma: DbClient, + options: InternalEnhancementOptions +): DbClient { return makeProxy( prisma, options.modelMeta, @@ -23,7 +26,7 @@ export function withPassword(prisma: DbClient, op } class PasswordHandler extends DefaultPrismaProxyHandler { - constructor(prisma: DbClientContract, model: string, options: EnhancementOptions) { + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { super(prisma, model, options); } diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 6ae173fcd..383ee356f 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -16,8 +16,8 @@ import { type FieldInfo, type ModelMeta, } from '../../cross'; -import { type CrudContract, type DbClientContract, PolicyOperationKind } from '../../types'; -import type { EnhancementContext, EnhancementOptions } from '../create-enhancement'; +import { PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types'; +import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; @@ -49,7 +49,7 @@ export class PolicyProxyHandler implements Pr constructor( private readonly prisma: DbClient, model: string, - private readonly options: EnhancementOptions, + private readonly options: InternalEnhancementOptions, private readonly context?: EnhancementContext ) { this.logger = new Logger(prisma); diff --git a/packages/runtime/src/enhancements/policy/index.ts b/packages/runtime/src/enhancements/policy/index.ts index 4d1e8b89d..e197e18c1 100644 --- a/packages/runtime/src/enhancements/policy/index.ts +++ b/packages/runtime/src/enhancements/policy/index.ts @@ -3,7 +3,7 @@ import { getIdFields } from '../../cross'; import { DbClientContract } from '../../types'; import { hasAllFields } from '../../validation'; -import type { EnhancementContext, EnhancementOptions } from '../create-enhancement'; +import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { makeProxy } from '../proxy'; import { PolicyProxyHandler } from './handler'; @@ -19,7 +19,7 @@ import { PolicyProxyHandler } from './handler'; */ export function withPolicy( prisma: DbClient, - options: EnhancementOptions, + options: InternalEnhancementOptions, context?: EnhancementContext ): DbClient { const { modelMeta, policy } = options; diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 1f4629359..00c6a51b6 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -19,7 +19,7 @@ import { import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; import { AuthUser, CrudContract, DbClientContract, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; -import type { EnhancementContext, EnhancementOptions } from '../create-enhancement'; +import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { QueryUtils } from '../query-utils'; import type { InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; @@ -38,7 +38,7 @@ export class PolicyUtil extends QueryUtils { constructor( private readonly db: DbClientContract, - options: EnhancementOptions, + options: InternalEnhancementOptions, context?: EnhancementContext, private readonly shouldLogQuery = false ) { diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index e0302f7e9..a3141ad0a 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -3,7 +3,7 @@ import { PRISMA_PROXY_ENHANCER } from '../constants'; import type { ModelMeta } from '../cross'; import type { DbClientContract } from '../types'; -import { EnhancementOptions } from './create-enhancement'; +import { InternalEnhancementOptions } from './create-enhancement'; import { createDeferredPromise } from './policy/promise'; /** @@ -67,7 +67,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { constructor( protected readonly prisma: DbClientContract, protected readonly model: string, - protected readonly options: EnhancementOptions + protected readonly options: InternalEnhancementOptions ) {} async findUnique(args: any): Promise { @@ -241,7 +241,7 @@ export function makeProxy( return propVal; } - return createHandlerProxy(makeHandler(target, prop), propVal, errorTransformer); + return createHandlerProxy(makeHandler(target, prop), propVal, prop, errorTransformer); }, }); @@ -252,6 +252,7 @@ export function makeProxy( function createHandlerProxy( handler: T, origTarget: any, + model: string, errorTransformer?: ErrorTransformer ): T { return new Proxy(handler, { @@ -282,7 +283,7 @@ function createHandlerProxy( if (capture.stack && err instanceof Error) { // save the original stack and replace it with a clean one (err as any).internalStack = err.stack; - err.stack = cleanCallStack(capture.stack, propKey.toString(), err.message); + err.stack = cleanCallStack(capture.stack, model, propKey.toString(), err.message); } if (errorTransformer) { @@ -308,9 +309,9 @@ function createHandlerProxy( } // Filter out @zenstackhq/runtime stack (generated by proxy) from stack trace -function cleanCallStack(stack: string, method: string, message: string) { +function cleanCallStack(stack: string, model: string, method: string, message: string) { // message line - let resultStack = `Error calling enhanced Prisma method \`${method}\`: ${message}`; + let resultStack = `Error calling enhanced Prisma method \`${model}.${method}\`: ${message}`; const lines = stack.split('\n'); let foundMarker = false; diff --git a/packages/runtime/src/enhancements/query-utils.ts b/packages/runtime/src/enhancements/query-utils.ts index f92353081..6959b922f 100644 --- a/packages/runtime/src/enhancements/query-utils.ts +++ b/packages/runtime/src/enhancements/query-utils.ts @@ -9,11 +9,11 @@ import { } from '../cross'; import { CrudContract, DbClientContract } from '../types'; import { getVersion } from '../version'; -import { EnhancementOptions } from './create-enhancement'; +import { InternalEnhancementOptions } from './create-enhancement'; import { prismaClientUnknownRequestError, prismaClientValidationError } from './utils'; export class QueryUtils { - constructor(private readonly prisma: DbClientContract, private readonly options: EnhancementOptions) {} + constructor(private readonly prisma: DbClientContract, private readonly options: InternalEnhancementOptions) {} getIdFields(model: string) { return getIdFields(this.options.modelMeta, model, true); diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index c33de08b0..1d42b5912 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -95,7 +95,8 @@ async function processClientTypes(model: Model, prismaClientDir: string) { removeAuxRelationFields(desc, toRemove, traversal); fixDelegateUnionType(desc, delegateModels, toReplaceText, traversal); removeCreateFromDelegateInputTypes(desc, delegateModels, toRemove, traversal); - removeToplevelCreates(desc, delegateModels, toRemove, traversal); + removeDelegateToplevelCreates(desc, delegateModels, toRemove, traversal); + removeDiscriminatorFromConcreteInputTypes(desc, delegateModels, toRemove); }); toRemove.forEach((n) => n.remove()); @@ -134,7 +135,6 @@ function fixDelegateUnionType( delegateModels.forEach(([delegate, concreteModels]) => { if (name === `$${delegate.name}Payload`) { const discriminator = getDiscriminatorField(delegate); - // const discriminator = 'delegateType'; // delegate.fields.find((f) => hasAttribute(f, '@discriminator')); if (discriminator) { toReplaceText.push([ desc, @@ -178,7 +178,40 @@ function removeCreateFromDelegateInputTypes( }); } -function removeToplevelCreates( +function removeDiscriminatorFromConcreteInputTypes( + desc: Node, + delegateModels: [DataModel, DataModel[]][], + toRemove: (PropertySignature | MethodSignature)[] +) { + if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { + return; + } + + const name = desc.getName(); + delegateModels.forEach(([delegate, concretes]) => { + const discriminator = getDiscriminatorField(delegate); + if (!discriminator) { + return; + } + + concretes.forEach((concrete) => { + // remove discriminator field from the create/update input of concrete models + const regex = new RegExp(`\\${concrete.name}(Unchecked)?(Create|Update).*Input`); + if (regex.test(name)) { + desc.forEachDescendant((d, innerTraversal) => { + if (d.isKind(SyntaxKind.PropertySignature)) { + if (d.getName() === discriminator.name) { + toRemove.push(d); + } + innerTraversal.skip(); + } + }); + } + }); + }); +} + +function removeDelegateToplevelCreates( desc: Node, delegateModels: [DataModel, DataModel[]][], toRemove: (PropertySignature | MethodSignature)[], diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 6592d83b5..72d2a02e6 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -297,7 +297,14 @@ export class PrismaSchemaGenerator { const model = decl.isView ? prisma.addView(decl.name) : prisma.addModel(decl.name); for (const field of decl.fields) { if (field.$inheritedFrom) { - if (field.$inheritedFrom.isAbstract || this.mode === 'logical' || isIdField(field)) { + if ( + // abstract inheritance is always kept + field.$inheritedFrom.isAbstract || + // logical schema keeps all inherited fields + this.mode === 'logical' || + // id fields are always kept + isIdField(field) + ) { this.generateModelField(model, field); } } else { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 0bd98e63e..01d5d274d 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -383,6 +383,16 @@ export function isDelegateModel(node: AstNode) { return isDataModel(node) && hasAttribute(node, '@@delegate'); } +export function isDiscriminatorField(field: DataModelField) { + const model = field.$inheritedFrom ?? field.$container; + const delegateAttr = getAttribute(model, '@@delegate'); + if (!delegateAttr) { + return false; + } + const arg = delegateAttr.args[0]?.value; + return isDataModelFieldReference(arg) && arg.target.$refText === field.name; +} + export function getIdFields(dataModel: DataModel) { const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id') diff --git a/tests/integration/tests/misc/stacktrace.test.ts b/tests/integration/tests/misc/stacktrace.test.ts index 08454d529..f652c5514 100644 --- a/tests/integration/tests/misc/stacktrace.test.ts +++ b/tests/integration/tests/misc/stacktrace.test.ts @@ -31,7 +31,7 @@ describe('Stack trace tests', () => { } expect(error?.stack).toContain( - "Error calling enhanced Prisma method `create`: denied by policy: model entities failed 'create' check" + "Error calling enhanced Prisma method `model.create`: denied by policy: model entities failed 'create' check" ); expect(error?.stack).toContain(`misc/stacktrace.test.ts`); expect((error as any).internalStack).toBeTruthy();