Skip to content

Commit

Permalink
Semantic: method check generalization
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanjermakov committed Apr 4, 2024
1 parent ac66c4e commit bda7df7
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 62 deletions.
4 changes: 3 additions & 1 deletion src/e2e.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ const compile = async (files: { [path: string]: string }): Promise<Context> => {
ctx.check = true
pkg.modules.forEach(m => checkModule(m, ctx))

if (ctx.errors.length > 0) return ctx

mkdirSync(pkg.path, { recursive: true })
writeFileSync(join(pkg.path, 'package.json'), JSON.stringify({ name: 'test', type: 'module' }))
await emitPackage(true, pkg, ctx)
Expand Down Expand Up @@ -101,7 +103,7 @@ const compileStd = async (): Promise<void> => {
}

const run = (ctx: Context): SpawnSyncReturns<Buffer> => {
if (ctx.errors.length > 0) throw Error('semantic errors')
if (ctx.errors.length > 0) throw Error(`semantic errors:\n${ctx.errors.map(e => `\t${e.message}`).join('\n')}`)
return spawnSync('node', ['dist/test/mod.js'], { cwd: 'tmp' })
}

Expand Down
9 changes: 7 additions & 2 deletions src/semantic/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,16 @@ export const missingVarInitError = (ctx: Context, varDef: VarDef): SemanticError
return semanticError(39, ctx, varDef, msg)
}

export const noImplFoundError = (ctx: Context, name: Name, methodDef: MethodDef, operand: Operand): SemanticError => {
export const noImplFoundError = (
ctx: Context,
node: AstNode<any>,
methodDef: MethodDef,
operand: Operand
): SemanticError => {
const traitVid = vidToString(methodDef.rel.implDef.vid)
const operandType = virtualTypeToString(operand.type!)
const msg = `no impl of trait \`${traitVid}\` found for type \`${operandType}\``
return semanticError(40, ctx, name, msg)
return semanticError(40, ctx, node, msg)
}

export const unexpectedRefutablePatternError = (ctx: Context, patternExpr: PatternExpr): SemanticError => {
Expand Down
134 changes: 76 additions & 58 deletions src/semantic/instance.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { checkCallArgs, checkType } from '.'
import { checkCallArgs } from '.'
import { UnaryExpr } from '../ast/expr'
import { MethodCallOp } from '../ast/op'
import { CallOp, MethodCallOp } from '../ast/op'
import { Name, Operand } from '../ast/operand'
import { Type } from '../ast/type'
import { Context, addError } from '../scope'
import { getInstanceForType, resolveMethodImpl, resolveTypeImpl } from '../scope/trait'
import { vidFromString, vidToString } from '../scope/util'
import { vidFromScope, vidFromString, vidToString } from '../scope/util'
import { MethodDef, VirtualIdentifier, resolveVid, typeKinds } from '../scope/vid'
import { VirtualFnType, VirtualType, combine, genericToVirtual, typeToVirtual } from '../typecheck'
import {
Expand Down Expand Up @@ -101,64 +102,75 @@ export const checkMethodCall = (expr: UnaryExpr, mCall: MethodCallOp, ctx: Conte
return
}

const methodName = mCall.name.value
const methodVid = { names: [...typeVid.names, methodName] }
const methodVid = { names: [...typeVid.names, mCall.name.value] }
const args = [expr.operand, ...mCall.call.args.map(a => a.expr)]
return checkMethodCall_(args, mCall.call, methodVid, ctx, mCall.typeArgs)
}

export const checkMethodCall_ = (
args: Operand[],
call: CallOp,
methodVid: VirtualIdentifier,
ctx: Context,
typeArgs?: Type[]
): VirtualType | undefined => {
const typeVid = vidFromScope(methodVid)
const ref = resolveVid(methodVid, ctx, ['method-def'])
if (!ref || ref.def.kind !== 'method-def') {
// hint if it is a field call
ctx.silent = true
const fieldType = checkFieldAccess(operand, mCall.name, ctx)
ctx.silent = false
if (fieldType) {
const note = `to access field \`${methodName}\`, surround operand in parentheses`
addError(ctx, notFoundError(ctx, mCall.name, vidToString(methodVid), 'method', [note]))
} else {
addError(ctx, notFoundError(ctx, mCall.name, vidToString(methodVid), 'method'))
}
addError(ctx, notFoundError(ctx, call, vidToString(methodVid), 'method'))
return
}

mCall.call.methodDef = ref.def
call.methodDef = ref.def
const fnType = <VirtualFnType>ref.def.fn.type

// TODO: check required type args (that cannot be inferred via `resolveFnGenerics`)
mCall.typeArgs.forEach(typeArg => checkType(typeArg, ctx))

const operandTypeRef = resolveVid(typeVid, ctx, typeKinds)
if (ref.def.rel.instanceDef.kind === 'trait-def') {
const resolved = resolveMethodImpl(operand.type!, ref.def, ctx)
if (ref.def.rel.instanceDef.kind === 'trait-def' && !ref.def.fn.static && args.length > 0) {
const self = args[0]
const resolved = resolveMethodImpl(self.type!, ref.def, ctx)
if (resolved) {
mCall.call.impl = resolved
call.impl = resolved

ctx.moduleStack.at(-1)!.relImports.push(call.impl)
// TODO: upcast only happen to the direct implType, but not its supertypes
// Use case: std::range has to return Iter<T> instead of RangeIter. When RangeIter is passed into a method
// where MapIter is expected, upcast of RangeIter for MapIter does not upcast it to Iter
upcast(self, self.type!, ref.def.rel.implType, ctx)
} else {
if (operandTypeRef && operandTypeRef.def.kind !== 'trait-def' && operandTypeRef.def.kind !== 'generic') {
addError(ctx, noImplFoundError(ctx, mCall.name, ref.def, operand))
addError(ctx, noImplFoundError(ctx, call, ref.def, self))
}
}
} else {
mCall.call.impl = ref.def.rel
}
if (mCall.call.impl) {
ctx.moduleStack.at(-1)!.relImports.push(mCall.call.impl)
// TODO: upcast only happen to the direct implType, but not its supertypes
// Use case: std::range has to return Iter<T> instead of RangeIter. When RangeIter is passed into a method
// where MapIter is expected, upcast of RangeIter for MapIter does not upcast it to Iter
upcast(operand, operand.type!, ref.def.rel.implType, ctx)
call.impl = ref.def.rel
ctx.moduleStack.at(-1)!.relImports.push(call.impl)
}

let genericMaps = makeMethodGenericMaps(operand, ref.def, mCall, ctx)
const args = ref.def.fn.static ? mCall.call.args.map(a => a.expr) : [operand, ...mCall.call.args.map(a => a.expr)]
let genericMaps = makeMethodGenericMaps(
args.map(a => a.type!),
ref.def,
call,
ctx,
typeArgs
)
args.forEach(a => {
a.type = resolveType(a.type!, genericMaps, ctx)
})
const paramTypes = fnType.paramTypes.map(pt => resolveType(pt, genericMaps, ctx))
checkCallArgs(mCall, args, paramTypes, ctx)
checkCallArgs(call, args, paramTypes, ctx)
// recalculate generic maps since malleable args might've been updated
genericMaps = makeMethodGenericMaps(operand, ref.def, mCall, ctx)
genericMaps = makeMethodGenericMaps(
args.map(a => a.type!),
ref.def,
call,
ctx,
typeArgs
)

const implForType = getInstanceForType(ref.def.rel.instanceDef, ctx)
const implForGenericMap = makeGenericMapOverStructure(operand.type!, implForType)
mCall.call.generics = fnType.generics.map((g, i) => {
const typeArg = mCall.typeArgs.at(i)
const implForGenericMap = selfType ? makeGenericMapOverStructure(selfType, implForType) : new Map()
call.generics = fnType.generics.map((g, i) => {
const typeArg = typeArgs?.at(i)
if (!typeArg) return { generic: g, impls: [] }
const vTypeArg = typeToVirtual(typeArg, ctx)
const t = resolveType(g, [implForGenericMap], ctx)
Expand All @@ -178,36 +190,42 @@ export const checkMethodCall = (expr: UnaryExpr, mCall: MethodCallOp, ctx: Conte
}

export const makeMethodGenericMaps = (
lOperand: Operand,
argTypes: VirtualType[],
methodDef: MethodDef,
call: MethodCallOp,
ctx: Context
call: CallOp,
ctx: Context,
typeArgs?: Type[]
): Map<string, VirtualType>[] => {
const maps = []
const self = !methodDef.fn.static ? argTypes[0] : undefined

if (call.call.impl) {
const operandRel = resolveTypeImpl(lOperand.type!, call.call.impl.forType, ctx)
if (call.impl && self) {
const operandRel = resolveTypeImpl(self, call.impl.forType, ctx)
if (operandRel) {
const operandImplGenericMap = makeGenericMapOverStructure(operandRel.impl.implType, call.call.impl.forType)
const operandImplGenericMap = makeGenericMapOverStructure(operandRel.impl.implType, call.impl.forType)
maps.push(operandImplGenericMap)
}
}

const fnType = <VirtualFnType>methodDef.fn.type
const typeArgs = call.typeArgs.map(tp => typeToVirtual(tp, ctx))
const fnTypeArgGenericMap = makeFnTypeArgGenericMap(fnType, typeArgs)
maps.push(fnTypeArgGenericMap)

// TODO: only for non-static methods
const implForType = getInstanceForType(methodDef.rel.instanceDef, ctx)
const implForGenericMap = makeGenericMapOverStructure(lOperand.type!, implForType)
// if Self type param is explicit, `resolveGenericsOverStructure` treats it as regular generic and interrupts
// further mapping in `fnGenericMap`, thus should be removed
implForGenericMap.delete(selfType.name)
maps.push(implForGenericMap)

const args = [lOperand.type!, ...call.call.args.map(a => a.expr.type!)]
const fnGenericMap = makeFnGenericMap(fnType, args)
if (typeArgs) {
const fnTypeArgGenericMap = makeFnTypeArgGenericMap(
fnType,
typeArgs.map(tp => typeToVirtual(tp, ctx))
)
maps.push(fnTypeArgGenericMap)
}

if (self) {
const implForType = getInstanceForType(methodDef.rel.instanceDef, ctx)
const implForGenericMap = makeGenericMapOverStructure(self, implForType)
// if Self type param is explicit, `resolveGenericsOverStructure` treats it as regular generic and interrupts
// further mapping in `fnGenericMap`, thus should be removed
implForGenericMap.delete(selfType.name)
maps.push(implForGenericMap)
}

const fnGenericMap = makeFnGenericMap(fnType, argTypes)
maps.push(fnGenericMap)

return maps
Expand Down
2 changes: 1 addition & 1 deletion src/semantic/semantic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ type Node<T> {
}
impl <T> Node<T> {
fn child()
fn child(self)
}
fn main() {
Expand Down

0 comments on commit bda7df7

Please sign in to comment.