Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/commons/extensible_metadata.cddl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
MetadataScalar = null / bool / int / float / text
MessageMetadata = {
? provider: text,
? model: text,
? modelType: text,
? runId: text,
? threadId: text,
? systemFingerprint: text,
? serviceTier: text,
* text => MetadataScalar,
}
70 changes: 58 additions & 12 deletions packages/cddl2py/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ interface ResolveTypeOptions {
quoteForwardReferences?: boolean
}

const STRING_RECORD_KEY_TYPES = new Set(['str', 'text', 'tstr'])

export function transform (assignments: Assignment[], options?: TransformOptions): string {
const ctx: Context = {
pydantic: options?.pydantic ?? false,
Expand Down Expand Up @@ -135,6 +137,7 @@ function generateGroup (group: Group, ctx: Context): string {
}

const props = properties as Property[]
const extraItemsType = getExtraItemsType(props, ctx)

if (props.length === 1) {
const prop = props[0]
Expand All @@ -150,7 +153,7 @@ function generateGroup (group: Group, ctx: Context): string {
}

const mixins = props.filter(isUnNamedProperty)
const ownProps = props.filter(p => !isUnNamedProperty(p))
const ownProps = props.filter(p => !isUnNamedProperty(p) && !isExtensibleRecordProperty(p))

const simpleMixinBases: string[] = []
const unionMixinGroups: string[][] = []
Expand Down Expand Up @@ -186,10 +189,10 @@ function generateGroup (group: Group, ctx: Context): string {
}

if (unionMixinGroups.length > 0) {
return comments + generateGroupWithUnionMixins(name, simpleMixinBases, unionMixinGroups, ownProps, ctx)
return comments + generateGroupWithUnionMixins(name, simpleMixinBases, unionMixinGroups, ownProps, extraItemsType, ctx)
}

return comments + generateClass(name, simpleMixinBases, ownProps, ctx)
return comments + generateClass(name, simpleMixinBases, ownProps, ctx, extraItemsType)
}

function generateGroupWithChoices (name: string, properties: (Property | Property[])[], ctx: Context): string {
Expand Down Expand Up @@ -259,6 +262,7 @@ function generateGroupWithUnionMixins (
simpleBases: string[],
unionGroups: string[][],
ownProps: Property[],
extraItemsType: string | undefined,
ctx: Context
): string {
if (ownProps.length === 0 && simpleBases.length === 0) {
Expand All @@ -278,7 +282,7 @@ function generateGroupWithUnionMixins (

if (ownProps.length > 0) {
const baseName = `_${name}Fields`
blocks.push(generateClass(baseName, [], ownProps, ctx))
blocks.push(generateClass(baseName, [], ownProps, ctx, extraItemsType))

for (let i = 0; i < unionTypes.length; i++) {
const variantName = `_${name}Variant${i}`
Expand All @@ -291,7 +295,7 @@ function generateGroupWithUnionMixins (
const variantName = `_${name}Variant${i}`
variantNames.push(variantName)
const bases = [unionTypes[i], ...simpleBases]
blocks.push(generateClass(variantName, bases, [], ctx))
blocks.push(generateClass(variantName, bases, [], ctx, extraItemsType))
}
}
} else {
Expand Down Expand Up @@ -366,7 +370,13 @@ function generateArrayAssignment (arr: CDDLArray, ctx: Context): string {
// Class generation (TypedDict or Pydantic BaseModel)
// ---------------------------------------------------------------------------

function generateClass (name: string, bases: string[], props: Property[], ctx: Context): string {
function generateClass (
name: string,
bases: string[],
props: Property[],
ctx: Context,
extraItemsType?: string
): string {
const lines: string[] = []

let classDecl: string
Expand All @@ -381,17 +391,25 @@ function generateClass (name: string, bases: string[], props: Property[], ctx: C
} else {
ctx.typingExtensionsImports.add('TypedDict')
const typedDictBases = bases.filter((base) => isModelCompatibleBase(base, ctx))
if (typedDictBases.length > 0) {
classDecl = `class ${name}(${typedDictBases.join(', ')}):`
} else {
classDecl = `class ${name}(TypedDict):`
}
const baseList = typedDictBases.length > 0 ? typedDictBases.join(', ') : 'TypedDict'
classDecl = extraItemsType
? `class ${name}(${baseList}, extra_items=${extraItemsType}):`
: `class ${name}(${baseList}):`
}

lines.push(classDecl)

if (ctx.pydantic && extraItemsType) {
ctx.pydanticImports.add('ConfigDict')
ctx.pydanticImports.add('Field')
lines.push(` __pydantic_extra__: dict[str, ${extraItemsType}] = Field(init=False)`)
lines.push(` model_config = ConfigDict(extra='allow')`)
}

if (props.length === 0) {
lines.push(' pass')
if (lines.length === 1) {
lines.push(' pass')
}
return lines.join('\n')
}

Expand Down Expand Up @@ -470,6 +488,34 @@ function generateField (prop: Property, ctx: Context): string | null {
return ` ${propName}: ${typeStr}${commentSuffix}`
}

function isExtensibleRecordProperty (prop: Property): boolean {
return !isUnNamedProperty(prop) &&
prop.Occurrence.m === Infinity &&
!prop.HasCut &&
STRING_RECORD_KEY_TYPES.has(prop.Name)
}

function getExtraItemsType (props: Property[], ctx: Context): string | undefined {
const types = props
.filter(isExtensibleRecordProperty)
.flatMap((prop) => {
const cddlTypes = Array.isArray(prop.Type) ? prop.Type : [prop.Type]
return cddlTypes.map((type) => resolveType(type, ctx))
})

if (types.length === 0) {
return
}

const uniqueTypes = [...new Set(types)]
if (uniqueTypes.length === 1) {
return uniqueTypes[0]
}

ctx.typingImports.add('Union')
return `Union[${uniqueTypes.join(', ')}]`
}

// ---------------------------------------------------------------------------
// Type resolution
// ---------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html

exports[`extensible metadata > should render extensible metadata as a TypedDict with extra_items 1`] = `
"# compiled with https://www.npmjs.com/package/cddl2py

from __future__ import annotations

from typing import Union
from typing_extensions import NotRequired, TypedDict

MetadataScalar = Union[None, bool, int, float, str]

class MessageMetadata(TypedDict, extra_items=MetadataScalar):
provider: NotRequired[str]
model: NotRequired[str]
model_type: NotRequired[str]
run_id: NotRequired[str]
thread_id: NotRequired[str]
system_fingerprint: NotRequired[str]
service_tier: NotRequired[str]
"
`;
71 changes: 71 additions & 0 deletions packages/cddl2py/tests/extensible_metadata.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import url from 'node:url'
import path from 'node:path'
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'

import cli from '../src/cli.js'
import { normalizeSnapshotOutput } from './snapshot.js'

const __dirname = url.fileURLToPath(new URL('.', import.meta.url))
const cddlFile = path.join(__dirname, '..', '..', '..', 'examples', 'commons', 'extensible_metadata.cddl')

vi.mock('../src/constants', () => ({
pkg: {
name: 'cddl2py',
version: '0.1.0',
author: 'Test Author',
description: 'Generate Python types from CDDL'
},
NATIVE_TYPE_MAP: {
any: 'Any',
number: 'Union[int, float]',
int: 'int',
uint: 'int',
nint: 'int',
float: 'float',
float16: 'float',
float32: 'float',
float64: 'float',
bool: 'bool',
bstr: 'bytes',
bytes: 'bytes',
tstr: 'str',
text: 'str',
str: 'str',
nil: 'None',
null: 'None',
}
}))

describe('extensible metadata', () => {
let exitOrig = process.exit
let logOrig = console.log
let errorOrig = console.error

beforeEach(() => {
process.exit = vi.fn() as any
console.log = vi.fn()
console.error = vi.fn()
})

afterEach(() => {
process.exit = exitOrig
console.log = logOrig
console.error = errorOrig
})

it('should render extensible metadata as a TypedDict with extra_items', async () => {
await cli([cddlFile])

expect(process.exit).not.toHaveBeenCalledWith(1)
expect(console.error).not.toHaveBeenCalled()

const output = vi.mocked(console.log).mock.calls.flat().join('\n')

expect(output).toContain('class MessageMetadata(TypedDict, extra_items=MetadataScalar):')
expect(output).toContain('provider: NotRequired[str]')
expect(output).toContain('model_type: NotRequired[str]')
expect(output).toContain('system_fingerprint: NotRequired[str]')
expect(output).not.toContain('text: NotRequired[MetadataScalar]')
expect(normalizeSnapshotOutput(output)).toMatchSnapshot()
})
})
42 changes: 42 additions & 0 deletions packages/cddl2py/tests/transform_edge_cases.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,48 @@ describe('transform edge cases', () => {
expect(output).toContain('maybe_enabled: Optional[bool] = Field(default=False)')
})

it('should emit extensible TypedDict properties as extra_items', () => {
const output = transform([
variable('metadata-scalar', ['null', 'bool', 'int', 'float', 'text']),
group('message-metadata', [
property('provider', 'text', {
Occurrence: { n: 0, m: 1 }
}),
property('model', 'text', {
Occurrence: { n: 0, m: 1 }
}),
property('text', groupRef('metadata-scalar'), {
Occurrence: { n: 0, m: Infinity }
})
])
])

expect(output).toContain('class MessageMetadata(TypedDict, extra_items=MetadataScalar):')
expect(output).toContain('provider: NotRequired[str]')
expect(output).toContain('model: NotRequired[str]')
expect(output).not.toContain('text: NotRequired[MetadataScalar]')
})

it('should emit typed extra fields for pydantic models', () => {
const output = transform([
variable('metadata-scalar', ['null', 'bool', 'int', 'float', 'text']),
group('message-metadata', [
property('provider', 'text', {
Occurrence: { n: 0, m: 1 }
}),
property('text', groupRef('metadata-scalar'), {
Occurrence: { n: 0, m: Infinity }
})
])
], { pydantic: true })

expect(output).toContain('from pydantic import BaseModel, ConfigDict, Field')
expect(output).toContain('class MessageMetadata(BaseModel):')
expect(output).toContain('__pydantic_extra__: dict[str, MetadataScalar] = Field(init=False)')
expect(output).toContain(`model_config = ConfigDict(extra='allow')`)
expect(output).not.toContain('text: MetadataScalar')
})

it('should cover direct type-resolution edge cases', () => {
const output = transform([
variable('direct-range', {
Expand Down
Loading