Skip to content

Commit

Permalink
Codegen: upcasting; for-expr pattern matching fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanjermakov committed Mar 12, 2024
1 parent 4e2ef47 commit 461c08d
Show file tree
Hide file tree
Showing 16 changed files with 161 additions and 69 deletions.
41 changes: 16 additions & 25 deletions src/codegen/js/expr.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
import {
emitLines,
emitVirtualTraits,
extractValue,
indent,
jsError,
jsRelName,
jsString,
jsVariable,
nextVariable
} from '.'
import { emitLines, emitUpcasts, extractValue, indent, jsError, jsRelName, jsString, jsVariable, nextVariable } from '.'
import { Module, Param } from '../../ast'
import { BinaryExpr, Expr, OperandExpr, UnaryExpr } from '../../ast/expr'
import { MatchExpr, Pattern, PatternExpr } from '../../ast/match'
Expand Down Expand Up @@ -51,26 +41,25 @@ export const emitUnaryExpr = (unaryExpr: UnaryExpr, module: Module, ctx: Context
const call = unaryExpr.op
const args = call.args.map(a => {
const { emit, resultVar: res } = emitExpr(a.expr, module, ctx)
const argTraits = a.expr.traits
const traitEmit = argTraits ? emitVirtualTraits(res, argTraits) : ''
return { emit: emitLines([emit, traitEmit]), resultVar: res }
const upcastEmit = a.expr.upcasts ? emitUpcasts(res, a.expr.upcasts) : ''
return { emit: emitLines([emit, upcastEmit]), resultVar: res }
})
const genericTypes = call.generics?.map(g => emitGeneric(g, module, ctx)) ?? []
const jsArgs = [...args, ...genericTypes]
const traitEmit = unaryExpr.traits ? emitVirtualTraits(resultVar, unaryExpr.traits) : ''
const upcastEmit = unaryExpr.upcasts ? emitUpcasts(resultVar, unaryExpr.upcasts) : ''
const variantDef = call.variantDef
if (variantDef) {
const variantName = `${variantDef.typeDef.name.value}.${variantDef.variant.name.value}`
const call = jsVariable(resultVar, `${variantName}(${jsArgs.map(a => a.resultVar).join(', ')})`)
return {
emit: emitLines([...jsArgs.map(a => a.emit), call, traitEmit]),
emit: emitLines([...jsArgs.map(a => a.emit), call, upcastEmit]),
resultVar
}
} else {
const operand = emitOperand(unaryExpr.operand, module, ctx)
const call = jsVariable(resultVar, `${operand.resultVar}(${jsArgs.map(a => a.resultVar).join(', ')})`)
return {
emit: emitLines([operand.emit, ...jsArgs.map(a => a.emit), call, traitEmit]),
emit: emitLines([operand.emit, ...jsArgs.map(a => a.emit), call, upcastEmit]),
resultVar
}
}
Expand Down Expand Up @@ -106,10 +95,8 @@ export const emitBinaryExpr = (binaryExpr: BinaryExpr, module: Module, ctx: Cont
? jsArgs.map(a => a.resultVar)
: [lOp.resultVar, ...jsArgs.map(a => a.resultVar)]
).join(', ')
const upcastRels = binaryExpr.lOperand.traits
const upcastEmit = upcastRels
? [...upcastRels.entries()].map(([name, rel]) => `${lOp.resultVar}.${name} = ${jsRelName(rel)}`)
: ''
const upcasts = binaryExpr.lOperand.upcasts
const upcastEmit = upcasts ? emitUpcasts(lOp.resultVar, upcasts) : ''
const callerEmit = call.impl ? jsRelName(call.impl) : `${lOp.resultVar}.${relTypeName(methodDef.rel)}`
const callEmit = jsVariable(resultVar, `${callerEmit}().${methodName}(${argsEmit})`)
return {
Expand Down Expand Up @@ -170,22 +157,26 @@ export const emitOperand = (operand: Operand, module: Module, ctx: Context): Emi
case 'for-expr': {
const expr = emitExpr(operand.expr, module, ctx)
const iteratorVar = nextVariable(ctx)
const upcastsEmit = operand.expr.upcasts ? emitUpcasts(expr.resultVar, operand.expr.upcasts) : ''
const iterator = {
emit: emitLines([
expr.emit,
upcastsEmit,
// TODO: do not invoke `iter` if it is already Iter
jsVariable(iteratorVar, `${expr.resultVar}.Iterable().iter(${expr.resultVar})`)
]),
resultVar: iteratorVar
}
const iterateeVarOption = nextVariable(ctx)
const iterateeVar = nextVariable(ctx)
const thenBlock = emitLines([
jsVariable(iterateeVar, `${extractValue(iterateeVarOption)}.value`),
emitPattern(operand.pattern, module, ctx, `${iterateeVar}`),
...emitBlockStatements(operand.block, module, ctx)
])
const block = emitLines([
jsVariable(iterateeVar, `${iterator.resultVar}.Iter().next(${iterator.resultVar})`),
`if (${iterateeVar}.$noisVariant === "Some") {`,
jsVariable(iterateeVarOption, `${iterator.resultVar}.Iter().next(${iterator.resultVar})`),
`if (${iterateeVarOption}.$noisVariant === "Some") {`,
indent(thenBlock),
`} else {`,
indent(`break;`),
Expand Down Expand Up @@ -354,10 +345,10 @@ export const emitPattern = (
const name = pattern.expr.value
return jsVariable(name, assignVar, pub)
case 'con-pattern':
const patterns = pattern.expr.fieldPatterns.flatMap(f => {
const patterns = pattern.expr.fieldPatterns.map(f => {
const fieldAssign = f.name.value
if (f.pattern) {
return [emitPattern(f.pattern, module, ctx, `${assignVar}.value.${fieldAssign}`)]
return emitPattern(f.pattern, module, ctx, `${assignVar}.value.${fieldAssign}`)
}
return jsVariable(fieldAssign, `${assignVar}.value.${fieldAssign}`)
})
Expand Down
9 changes: 7 additions & 2 deletions src/codegen/js/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Context } from '../../scope'
import { InstanceRelation, relTypeName } from '../../scope/trait'
import { concatVid, vidFromScope, vidFromString } from '../../scope/util'
import { VirtualIdentifier } from '../../scope/vid'
import { Upcast } from '../../semantic/upcast'
import { VirtualType, virtualTypeToString } from '../../typecheck'
import { groupBy } from '../../util/array'
import { unreachable } from '../../util/todo'
Expand Down Expand Up @@ -126,6 +127,10 @@ export const jsGenericTypeName = (type: VirtualType): string => {
}
}

export const emitVirtualTraits = (resultVar: string, traits: Map<string, InstanceRelation>): string => {
return emitLines([...traits.entries()].map(([name, rel]) => `${resultVar}.${name} = ${jsRelName(rel)}`))
export const emitUpcasts = (resultVar: string, upcasts: Map<string, Upcast>): string => {
const args = [...upcasts.entries()].map(
([k, v]) => `[${[...v.traits.entries()].map(([tk, tv]) => `[${jsString(tk)}, ${jsRelName(tv)}]`).join(', ')}]`
)
args.unshift(resultVar)
return `${resultVar}.upcast(${args.join(', ')})`
}
10 changes: 9 additions & 1 deletion src/codegen/js/statement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,15 @@ export const emitVariant = (v: Variant, typeDef: TypeDef, module: Module, ctx: C
const props = [
`$noisType: ${type}`,
...(typeDef.variants.length > 1 ? [`$noisVariant: ${name}`] : []),
`value: {${fields}}`
`value: {${fields}}`,
`upcast: ${emitUpcastFn(v, typeDef, module, ctx)}`
].join(',\n')
return `function(${fieldNames.join(', ')}) {\n${indent(`return {\n${indent(props)}\n}`)}\n}`
}

export const emitUpcastFn = (v: Variant, typeDef: TypeDef, module: Module, ctx: Context) => {
const params = ['value', 'Self', ...typeDef.generics.map(g => g.name.value)]
const selfG = emitLines(['for (const [trait, impl] of Self) {', indent('value[trait] = impl;'), '}'])
const block = emitLines([selfG])
return `function(${params.join(', ')}) {\n${indent(block)}\n}`
}
1 change: 1 addition & 0 deletions src/scope/trait.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ describe('trait', () => {
expect(formatImplTypes(findSuperRelChains(vidFromString('std::list::List'), ctx))).toEqual([
['std::iter::Iterable<T>'],
['std::iter::Collector<T>'],
['std::io::Display'],
['std::io::Display']
])

Expand Down
9 changes: 8 additions & 1 deletion src/scope/trait.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ export interface InstanceRelation {
* Instance def
*/
instanceDef: TraitDef | ImplDef
/**
* Generics
*/
generics: VirtualGeneric[]
/**
* There are two types of implementations:
* - trait impl: implementing trait for some type
Expand Down Expand Up @@ -85,11 +89,12 @@ const getImplRel = (instance: TraitDef | ImplDef, ctx: Context): InstanceRelatio
}

const getTraitImplRel = (instance: TraitDef, module: Module, ctx: Context): InstanceRelation | undefined => {
const generics = instance.generics.filter(g => g.name.value !== 'Self').map(g => genericToVirtual(g, ctx))
const traitType: VirtualType = {
kind: 'vid-type',
identifier: { names: [...module.identifier.names, instance.name.value] },
// self args are for bounds and should be excluded from virtual types
typeArgs: instance.generics.filter(g => g.name.value !== 'Self').map(g => genericToVirtual(g, ctx))
typeArgs: generics
}
const ref = resolveVid(traitType.identifier, ctx, ['trait-def'])
assert(!!ref, 'traitDef did not find itself by name')
Expand All @@ -101,6 +106,7 @@ const getTraitImplRel = (instance: TraitDef, module: Module, ctx: Context): Inst
implDef: traitRef,
forDef: traitRef,
instanceDef: instance,
generics,
inherent: false
}
}
Expand Down Expand Up @@ -134,6 +140,7 @@ const getImplImplRel = (instance: ImplDef, module: Module, ctx: Context): Instan
implDef: <VirtualIdentifierMatch<TypeDef | TraitDef>>ref,
forDef: forDef,
instanceDef: instance,
generics: instance.generics.map(g => genericToVirtual(g, ctx)),
inherent: !instance.forTrait
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/semantic/expr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import { checkExhaustion } from './exhaust'
import { checkAccessExpr } from './instance'
import { checkPattern } from './match'
import { operatorImplMap } from './op'
import { upcast } from './upcast'

export const checkExpr = (expr: Expr, ctx: Context): void => {
switch (expr.kind) {
Expand Down Expand Up @@ -312,6 +313,10 @@ export const checkForExpr = (forExpr: ForExpr, ctx: Context): void => {
// TODO: break with a value
forExpr.type = unknownType

// TODO: only one needed
upcast(forExpr.expr, forExpr.expr.type!, iter, ctx)
upcast(forExpr.expr, forExpr.expr.type!, iterable, ctx)

module.scopeStack.pop()
}

Expand Down
26 changes: 5 additions & 21 deletions src/semantic/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,7 @@ import {
instanceScope,
unwindScope
} from '../scope'
import {
InstanceRelation,
findSuperRelChains,
relTypeName,
resolveTypeImpl,
traitDefToVirtualType,
typeDefToVirtualType
} from '../scope/trait'
import { InstanceRelation, findSuperRelChains, traitDefToVirtualType, typeDefToVirtualType } from '../scope/trait'
import { idToVid, vidEq, vidToString } from '../scope/util'
import { Definition, MethodDef, NameDef, resolveVid, typeKinds } from '../scope/vid'
import { VirtualType, genericToVirtual, isAssignable, typeEq, typeToVirtual } from '../typecheck'
Expand All @@ -35,6 +28,7 @@ import { notFoundError, semanticError, typeError, unknownTypeError } from './err
import { checkClosureExpr, checkExpr } from './expr'
import { checkPattern } from './match'
import { typeNames } from './type-def'
import { Upcast, upcast } from './upcast'
import { useExprToVids } from './use-expr'

export interface Checked {
Expand All @@ -50,7 +44,7 @@ export interface Static {
}

export interface Virtual {
traits: Map<string, InstanceRelation>
upcasts: Map<string, Upcast>
}

export const prepareModule = (module: Module): void => {
Expand Down Expand Up @@ -308,12 +302,7 @@ const checkFnDef = (fnDef: FnDef, ctx: Context): void => {
if (!isAssignable(rs.type!, returnTypeResolved, ctx)) {
addError(ctx, typeError(rs, rs.type!, returnTypeResolved, ctx))
}
const res = resolveTypeImpl(rs.type!, returnTypeResolved, ctx)
if (res) {
rs.traits ??= new Map()
rs.traits.set(relTypeName(res.trait), res.impl)
module.relImports.push(res.impl)
}
upcast(rs, rs.type!, returnTypeResolved, ctx)
})
} else {
if (!module.compiled && instScope?.rel.instanceDef.kind !== 'trait-def') {
Expand Down Expand Up @@ -734,11 +723,6 @@ export const checkCallArgs = (node: AstNode<any>, args: Operand[], paramTypes: V
addError(ctx, typeError(arg, argType, paramType, ctx))
}

const res = resolveTypeImpl(argType, paramType, ctx)
if (res) {
arg.traits ??= new Map()
arg.traits.set(relTypeName(res.trait), res.impl)
ctx.moduleStack.at(-1)!.relImports.push(res.impl)
}
upcast(arg, argType, paramType, ctx)
}
}
38 changes: 38 additions & 0 deletions src/semantic/upcast.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { Virtual } from '.'
import { Context } from '../scope'
import { InstanceRelation, relTypeName, resolveTypeImpl } from '../scope/trait'
import { VirtualType } from '../typecheck'
import { makeGenericMapOverStructure } from '../typecheck/generic'

export interface Upcast {
traits: Map<string, InstanceRelation>
}

export const upcast = (
virtual: Partial<Virtual>,
type: VirtualType,
returnTypeResolved: VirtualType,
ctx: Context
): void => {
const res = resolveTypeImpl(type, returnTypeResolved, ctx)
if (res) {
const genericMap = makeGenericMapOverStructure(type, res.impl.forType)
virtual.upcasts ??= new Map()
virtual.upcasts.set('Self', { traits: new Map([[relTypeName(res.trait), res.impl]]) })
for (const g of res.impl.generics) {
const gUpcast: Upcast = { traits: new Map() }
const concreteG = genericMap.get(g.name)
if (concreteG) {
for (const b of g.bounds) {
const gRes = resolveTypeImpl(concreteG, b, ctx)
if (gRes) {
gUpcast.traits.set(relTypeName(gRes.trait), gRes.impl)
ctx.moduleStack.at(-1)!.relImports.push(gRes.impl)
}
}
}
virtual.upcasts.set(g.name, gUpcast)
}
ctx.moduleStack.at(-1)!.relImports.push(res.impl)
}
}
10 changes: 9 additions & 1 deletion src/std/bool.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
* @returns {Boolean}
*/
Bool.Bool = function(value) {
return { $noisType: 'std::bool::Bool', value }
return {
$noisType: 'std::bool::Bool',
value,
upcast: function(value, Self) {
for (const [trait, impl] of Self) {
value[trait] = impl;
}
}
}
}

/**
Expand Down
10 changes: 9 additions & 1 deletion src/std/char.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
* @returns {Char}
*/
Char.Char = function(value) {
return { $noisType: 'std::char::Char', value }
return {
$noisType: 'std::char::Char',
value,
upcast: function(value, Self) {
for (const [trait, impl] of Self) {
value[trait] = impl;
}
}
}
}

/**
Expand Down
10 changes: 9 additions & 1 deletion src/std/float.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
* @returns {Float}
*/
Float.Float = function(value) {
return { $noisType: 'std::float::Float', value }
return {
$noisType: 'std::float::Float',
value,
upcast: function(value, Self) {
for (const [trait, impl] of Self) {
value[trait] = impl;
}
}
}
}

/**
Expand Down
10 changes: 9 additions & 1 deletion src/std/int.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
* @returns {Int}
*/
Int.Int = function(value) {
return { $noisType: 'std::int::Int', value }
return {
$noisType: 'std::int::Int',
value,
upcast: function(value, Self) {
for (const [trait, impl] of Self) {
value[trait] = impl;
}
}
}
}

/**
Expand Down
4 changes: 2 additions & 2 deletions src/std/iter/mapIter.no
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ impl <T, U> Iter<U> for MapIter<T, U> {
}

pub trait MapAdapter<T> {
fn map<U>(self, f: |T|: U): MapIter<T, U>
fn map<U>(self, f: |T|: U): Iter<U>
}

impl <T> MapAdapter<T> for Iter<T> {
fn map<U>(self, f: |T|: U): MapIter<T, U> {
fn map<U>(self, f: |T|: U): Iter<U> {
MapIter(iter: self, f: f)
}
}
Expand Down
Loading

0 comments on commit 461c08d

Please sign in to comment.