Skip to content

Commit

Permalink
Codegen: proper static dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanjermakov committed Mar 10, 2024
1 parent 01aa35b commit 80f1ba4
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 123 deletions.
4 changes: 2 additions & 2 deletions src/ast/op.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ParseNode, filterNonAstNodes } from '../parser'
import { InstanceRelation } from '../scope/trait'
import { MethodDef, VariantDef } from '../scope/vid'
import { ConcreteGeneric, VirtualGeneric, VirtualType } from '../typecheck'
import { ConcreteGeneric } from '../typecheck'
import { Arg, AstNode, AstNodeKind, buildArg } from './index'

export type PostfixOp = CallOp | UnwrapOp | BindOp
Expand Down Expand Up @@ -111,9 +111,9 @@ export const buildBinaryOp = (node: ParseNode): BinaryOp => {
export interface CallOp extends AstNode<'call-op'> {
args: Arg[]
methodDef?: MethodDef
impls?: InstanceRelation[]
variantDef?: VariantDef
generics?: ConcreteGeneric[]
impl?: InstanceRelation
}

export const buildCallOp = (node: ParseNode): CallOp => {
Expand Down
7 changes: 0 additions & 7 deletions src/ast/operand.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { LexerToken } from '../lexer/lexer'
import { ParseNode, ParseTree, filterNonAstNodes } from '../parser'
import { nameLikeTokens } from '../parser/fns'
import { InstanceRelation } from '../scope/trait'
import { VirtualIdentifierMatch } from '../scope/vid'
import { Typed } from '../semantic'
import { Expr, buildExpr } from './expr'
Expand Down Expand Up @@ -161,7 +160,6 @@ export const buildClosureExpr = (node: ParseNode): ClosureExpr => {

export interface ListExpr extends AstNode<'list-expr'>, Partial<Typed> {
exprs: Expr[]
impls?: InstanceRelation[]
}

export const buildListExpr = (node: ParseNode): ListExpr => {
Expand All @@ -172,7 +170,6 @@ export const buildListExpr = (node: ParseNode): ListExpr => {

export interface StringLiteral extends AstNode<'string-literal'>, Partial<Typed> {
value: string
impls?: InstanceRelation[]
}

export const buildStringLiteral = (node: ParseNode): StringLiteral => {
Expand All @@ -181,7 +178,6 @@ export const buildStringLiteral = (node: ParseNode): StringLiteral => {

export interface CharLiteral extends AstNode<'char-literal'>, Partial<Typed> {
value: string
impls?: InstanceRelation[]
}

export const buildCharLiteral = (node: ParseNode): CharLiteral => {
Expand All @@ -190,7 +186,6 @@ export const buildCharLiteral = (node: ParseNode): CharLiteral => {

export interface IntLiteral extends AstNode<'int-literal'>, Partial<Typed> {
value: string
impls?: InstanceRelation[]
}

export const buildIntLiteral = (node: ParseNode): IntLiteral => {
Expand All @@ -199,7 +194,6 @@ export const buildIntLiteral = (node: ParseNode): IntLiteral => {

export interface FloatLiteral extends AstNode<'float-literal'>, Partial<Typed> {
value: string
impls?: InstanceRelation[]
}

export const buildFloatLiteral = (node: ParseNode): FloatLiteral => {
Expand All @@ -208,7 +202,6 @@ export const buildFloatLiteral = (node: ParseNode): FloatLiteral => {

export interface BoolLiteral extends AstNode<'bool-literal'>, Partial<Typed> {
value: string
impls?: InstanceRelation[]
}

export const buildBoolLiteral = (node: ParseNode): BoolLiteral => {
Expand Down
119 changes: 58 additions & 61 deletions src/codegen/js/expr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,19 @@ export const emitUnaryExpr = (unaryExpr: UnaryExpr, module: Module, ctx: Context
const args = unaryExpr.op.args.map(a => emitExpr(a.expr, module, ctx))
const genericTypes = unaryExpr.op.generics?.map(g => emitGeneric(g, module, ctx)) ?? []
const jsArgs = [...args, ...genericTypes]
const impls: string[] = []
if (unaryExpr.op.impls !== undefined) {
for (const impl of unaryExpr.op.impls) {
impls.push(`${resultVar}.${relTypeName(impl)} = ${jsRelName(impl)};`)
}
}
const variantDef = unaryExpr.op.variantDef
if (variantDef) {
const variantName = `${variantDef.typeDef.name.value}.${variantDef.variant.name.value}`
const call = jsVariable(resultVar, `${variantName}(${jsArgs.map(a => a.resultVar).join(', ')})`)
return {
emit: emitLines([...jsArgs.map(a => a.emit), call, ...impls]),
emit: emitLines([...jsArgs.map(a => a.emit), call]),
resultVar
}
} else {
const operand = emitOperand(unaryExpr.operand, module, ctx)
const call = jsVariable(resultVar, `${operand.resultVar}(${jsArgs.map(a => a.resultVar).join(', ')})`)
return {
emit: emitLines([operand.emit, ...jsArgs.map(a => a.emit), call, ...impls]),
emit: emitLines([operand.emit, ...jsArgs.map(a => a.emit), call]),
resultVar
}
}
Expand All @@ -73,53 +67,68 @@ export const emitUnaryExpr = (unaryExpr: UnaryExpr, module: Module, ctx: Context
export const emitBinaryExpr = (binaryExpr: BinaryExpr, module: Module, ctx: Context): EmitExpr => {
const lOp = emitOperand(binaryExpr.lOperand, module, ctx)
const resultVar = nextVariable(ctx)
if (binaryExpr.binaryOp.kind === 'access-op') {
if (binaryExpr.rOperand.kind === 'identifier') {
const accessor = binaryExpr.rOperand.names.at(-1)!.value
const rOp = emitOperand(binaryExpr.rOperand, module, ctx)
switch (binaryExpr.binaryOp.kind) {
case 'access-op': {
if (binaryExpr.rOperand.kind === 'identifier') {
const accessor = binaryExpr.rOperand.names.at(-1)!.value
return {
emit: emitLines([lOp.emit, jsVariable(resultVar, `${lOp.resultVar}.value.${accessor}`)]),
resultVar
}
}
if (binaryExpr.rOperand.kind === 'unary-expr' && binaryExpr.rOperand.op.kind === 'call-op') {
const call = binaryExpr.rOperand.op
const methodDef = call.methodDef!
const methodName = methodDef.fn.name.value
const args = call.args.map(a => emitExpr(a.expr, module, ctx))
const genericTypes = call.generics?.map(g => emitGeneric(g, module, ctx)) ?? []
const jsArgs = [...args, ...genericTypes]
const argsEmit = (
methodDef.fn.static
? jsArgs.map(a => a.resultVar)
: [lOp.resultVar, ...jsArgs.map(a => a.resultVar)]
).join(', ')
return {
emit: emitLines([
lOp.emit,
emitLines(jsArgs.map(a => a.emit)),
jsVariable(
resultVar,
`${
call.impl ? jsRelName(call.impl) : jsError('dynamic dispatch')
}().${methodName}(${argsEmit})`
)
]),
resultVar
}
}
return {
emit: emitLines([lOp.emit, jsVariable(resultVar, `${lOp.resultVar}.${accessor}`)]),
emit: jsError('unwrap/bind ops'),
resultVar
}
}
if (binaryExpr.rOperand.kind === 'unary-expr' && binaryExpr.rOperand.op.kind === 'call-op') {
const callOp = binaryExpr.rOperand.op
const methodDef = callOp.methodDef!
const traitName = relTypeName(methodDef.rel)
const methodName = methodDef.fn.name.value
const args = callOp.args.map(a => emitExpr(a.expr, module, ctx))
const genericTypes = callOp.generics?.map(g => emitGeneric(g, module, ctx)) ?? []
const jsArgs = [...args, ...genericTypes]
const argsEmit = (
methodDef.fn.static ? jsArgs.map(a => a.resultVar) : [lOp.resultVar, ...jsArgs.map(a => a.resultVar)]
).join(', ')
case 'assign-op': {
// TODO: assign all js fields, including $noisType and impls
return {
emit: emitLines([
lOp.emit,
emitLines(jsArgs.map(a => a.emit)),
jsVariable(resultVar, `${lOp.resultVar}.${traitName}().${methodName}(${argsEmit})`)
rOp.emit,
`${extractValue(lOp.resultVar)} = ${extractValue(rOp.resultVar)}`
]),
resultVar
}
}
return {
emit: jsError('unwrap/bind ops'),
resultVar
}
}
const rOp = emitOperand(binaryExpr.rOperand, module, ctx)
if (binaryExpr.binaryOp.kind === 'assign-op') {
return {
emit: emitLines([lOp.emit, rOp.emit, `${extractValue(lOp.resultVar)} = ${extractValue(rOp.resultVar)}`]),
resultVar
default: {
const trait = operatorImplMap.get(binaryExpr.binaryOp.kind)!.names.at(-2)!
const method = operatorImplMap.get(binaryExpr.binaryOp.kind)!.names.at(-1)!
const assign = jsVariable(resultVar, `${trait}().${method}(${lOp.resultVar}, ${rOp.resultVar})`)
return {
emit: emitLines([lOp.emit, rOp.emit, assign]),
resultVar
}
}
}
const trait = operatorImplMap.get(binaryExpr.binaryOp.kind)!.names.at(-2)!
const method = operatorImplMap.get(binaryExpr.binaryOp.kind)!.names.at(-1)!
const assign = jsVariable(resultVar, `${lOp.resultVar}.${trait}().${method}(${lOp.resultVar}, ${rOp.resultVar})`)
return {
emit: emitLines([lOp.emit, rOp.emit, assign]),
resultVar
}
}

export const emitOperand = (operand: Operand, module: Module, ctx: Context): EmitExpr => {
Expand Down Expand Up @@ -181,17 +190,10 @@ export const emitOperand = (operand: Operand, module: Module, ctx: Context): Emi
return emitExpr(operand, module, ctx)
case 'list-expr':
const items = operand.exprs.map(e => emitExpr(e, module, ctx))
const impls: string[] = []
if (operand.impls !== undefined) {
for (const impl of operand.impls) {
impls.push(`${resultVar}.${relTypeName(impl)} = ${jsRelName(impl)};`)
}
}
return {
emit: emitLines([
...items.map(i => i.emit),
jsVariable(resultVar, `List.List([${items.map(i => i.resultVar).join(', ')}])`),
...impls
jsVariable(resultVar, `List.List([${items.map(i => i.resultVar).join(', ')}])`)
]),
resultVar
}
Expand All @@ -213,11 +215,11 @@ export const emitOperand = (operand: Operand, module: Module, ctx: Context): Emi
} else {
const arg = nextVariable(ctx)
const args = nextVariable(ctx)
const relName = relTypeName(operand.ref.def.rel)
const relName = jsRelName(operand.ref.def.rel)
const fnName = operand.ref.def.fn.name.value
return {
emit: '',
resultVar: `(function(${arg}, ...${args}) { ${arg}.${relName}().${fnName}(${arg}, ${args}); })`
resultVar: `(function(${arg}, ...${args}) { return ${relName}().${fnName}(${arg}, ${args}); })`
}
}
}
Expand Down Expand Up @@ -247,13 +249,7 @@ export const emitLiteral = (operand: Operand, module: Module, ctx: Context, resu
default:
return unreachable()
}
const impls: string[] = []
if (operand.impls !== undefined) {
for (const impl of operand.impls) {
impls.push(`${resultVar}.${relTypeName(impl)} = ${jsRelName(impl)};`)
}
}
return { emit: emitLines([jsVariable(resultVar, constructorEmit), ...impls]), resultVar }
return { emit: emitLines([jsVariable(resultVar, constructorEmit)]), resultVar }
}

export const emitMatchExpr = (matchExpr: MatchExpr, module: Module, ctx: Context, resultVar: string): EmitExpr => {
Expand Down Expand Up @@ -344,9 +340,9 @@ export const emitPattern = (
const patterns = pattern.expr.fieldPatterns.flatMap(f => {
const fieldAssign = f.name.value
if (f.pattern) {
return [emitPattern(f.pattern, module, ctx, `${assignVar}.${fieldAssign}`)]
return [emitPattern(f.pattern, module, ctx, `${assignVar}.value.${fieldAssign}`)]
}
return jsVariable(fieldAssign, `${assignVar}.${fieldAssign}`)
return jsVariable(fieldAssign, `${assignVar}.value.${fieldAssign}`)
})
return emitLines([...patterns])
case 'list-expr':
Expand All @@ -366,6 +362,7 @@ export const emitPattern = (

export const emitGeneric = (generic: ConcreteGeneric, module: Module, ctx: Context): EmitExpr => {
const resultVar = nextVariable(ctx)
// TODO: only insert bounded types
const impls: string[] = []
for (const impl of generic.impls) {
impls.push(`${resultVar}.${relTypeName(impl)} = ${jsRelName(impl)};`)
Expand Down
64 changes: 34 additions & 30 deletions src/scope/trait.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import { AstNode, Module } from '../ast'
import { Module } from '../ast'
import { ImplDef, TraitDef } from '../ast/statement'
import { TypeDef } from '../ast/type-def'
import { notFoundError, semanticError } from '../semantic/error'
import { VidType, VirtualType, genericToVirtual, isAssignable, typeToVirtual, virtualTypeToString } from '../typecheck'
import { notFoundError } from '../semantic/error'
import { VidType, VirtualGeneric, VirtualType, genericToVirtual, isAssignable, typeToVirtual } from '../typecheck'
import { makeGenericMapOverStructure, resolveType } from '../typecheck/generic'
import { groupByHaving } from '../util/array'
import { assert } from '../util/todo'
import { Context, addError, addWarning, defKey } from './index'
import { Context, addError, defKey } from './index'
import { concatVid, idToVid, vidEq, vidFromString, vidToString } from './util'
import { VirtualIdentifier, VirtualIdentifierMatch, resolveVid, typeKinds } from './vid'
import { MethodDef, VirtualIdentifier, VirtualIdentifierMatch, resolveVid, typeKinds } from './vid'

/**
* Description of type/trait/impl relations
Expand Down Expand Up @@ -207,28 +206,33 @@ export const relTypeName = (rel: InstanceRelation): string => {
}
}

export const resolveImplsForType = (type: VirtualType, node: AstNode<any>, ctx: Context): InstanceRelation[] => {
const inherentImpl = ctx.impls.find(i => i.inherent && isAssignable(type, i.implType, ctx))
const traitImpls = groupByHaving(
ctx.impls.filter(i => i.instanceDef.kind === 'impl-def' && !i.inherent),
relTypeName,
// TODO: include self-bounded impls
i => isAssignable(type, i.forType, ctx)
)
const rels = []
if (inherentImpl) {
rels.push(inherentImpl)
}
for (const [, is] of traitImpls) {
if (is.length === 0) continue
// TODO: detect at implDef level
const vid = vidToString(is[0].implDef.vid)
if (is.length !== 1) {
const msg = `clashing impls of trait \`${vid}\` for type \`${virtualTypeToString(type)}\``
addWarning(ctx, semanticError(ctx, node, msg))
}
rels.push(is[0])
}
ctx.moduleStack.at(-1)!.relImports.push(...rels)
return rels
export const resolveGenericImpls = (generic: VirtualGeneric, ctx: Context): InstanceRelation[] => {
return generic.bounds.flatMap(b => {
const candidates = ctx.impls
.filter(i => i.instanceDef.kind === 'impl-def' && isAssignable(b, i.implType, ctx))
.toSorted((a, b) => relOrdering(b) - relOrdering(a))
return candidates.length > 0 ? [candidates.at(0)!] : []
})
}

export const resolveMethodImpl = (type: VirtualType, method: MethodDef, ctx: Context): InstanceRelation | undefined => {
const candidates = ctx.impls
.filter(
i =>
i.instanceDef.kind === 'impl-def' &&
isAssignable(type, i.forType, ctx) &&
isAssignable(i.implType, method.rel.implType, ctx) &&
(!i.inherent ||
i.instanceDef.block.statements.find(
s => s.kind === 'fn-def' && s.name.value === method.fn.name.value
))
)
.toSorted((a, b) => relOrdering(b) - relOrdering(a))
return candidates.at(0)
}

export const relOrdering = (rel: InstanceRelation): number => {
let score = 0
if (rel.instanceDef.kind === 'impl-def') score += 8
return score
}
Loading

0 comments on commit 80f1ba4

Please sign in to comment.