diff --git a/ast/expression_extractor.go b/ast/expression_extractor.go index e3b52c71e..a02c159ec 100644 --- a/ast/expression_extractor.go +++ b/ast/expression_extractor.go @@ -146,6 +146,7 @@ type ExpressionExtractor struct { ReferenceExtractor ReferenceExtractor MemberExtractor MemberExtractor PathExtractor PathExtractor + IdentifierPrefix string nextIdentifier int } @@ -156,19 +157,18 @@ func (extractor *ExpressionExtractor) Extract(expression Expression) ExpressionE } func (extractor *ExpressionExtractor) FreshIdentifier() string { - defer func() { - extractor.nextIdentifier++ - }() + identifier := extractor.nextIdentifier + extractor.nextIdentifier++ + return extractor.FormatIdentifier(identifier) +} + +func (extractor *ExpressionExtractor) FormatIdentifier(identifier int) string { // TODO: improve // NOTE: to avoid naming clashes with identifiers in the program, // include characters that can't be represented in source: // - \x00 = Null character // - \x1F = Information Separator One - return extractor.FormatIdentifier(extractor.nextIdentifier) -} - -func (extractor *ExpressionExtractor) FormatIdentifier(identifier int) string { - return fmt.Sprintf("\x00exp\x1F%d", identifier) + return fmt.Sprintf("\x00%s\x1F%d", extractor.IdentifierPrefix, identifier) } type ExtractedExpression struct { diff --git a/interpreter/condition_test.go b/interpreter/condition_test.go index 026949e95..33fa80eea 100644 --- a/interpreter/condition_test.go +++ b/interpreter/condition_test.go @@ -27,6 +27,9 @@ import ( "github.com/onflow/cadence/activations" "github.com/onflow/cadence/ast" + "github.com/onflow/cadence/bbq" + "github.com/onflow/cadence/bbq/compiler" + . "github.com/onflow/cadence/bbq/test_utils" "github.com/onflow/cadence/bbq/vm" compilerUtils "github.com/onflow/cadence/bbq/vm/test" "github.com/onflow/cadence/common" @@ -1659,3 +1662,134 @@ func TestInterpretFunctionExpressionPostConditions(t *testing.T) { ) }) } + +func TestInterpretFunctionBeforePostConditionAndInheritedBeforePostCondition(t *testing.T) { + + t.Parallel() + + importLocation := common.NewAddressLocation( + nil, + common.MustBytesToAddress([]byte{0x1}), + "", + ) + + const importCode = ` + struct interface SI { + + fun test(a: Int, b: Int) { + post { + before(a) == 0 + } + } + } + ` + + const testCode = ` + import SI from 0x1 + + struct S: SI { + + fun test(a: Int, b: Int) { + post { + before(b) == 1 + } + } + } + + fun test() { + S().test(a: 0, b: 1) + } + ` + + if *compile { + + programs := CompiledPrograms{} + + _ = ParseCheckAndCompile(t, + importCode, + importLocation, + programs, + ) + + _, err := compilerUtils.CompileAndInvokeWithOptionsAndPrograms(t, + testCode, + "test", + compilerUtils.CompilerAndVMOptions{ + ParseCheckAndCompileOptions: ParseCheckAndCompileOptions{ + ParseAndCheckOptions: &ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + ImportHandler: func(checker *sema.Checker, importedLocation common.Location, _ ast.Range) (sema.Import, error) { + importedProgram, ok := programs[importedLocation] + if !ok { + return nil, fmt.Errorf("cannot find program for location %s", importedLocation) + } + + return sema.ElaborationImport{ + Elaboration: importedProgram.DesugaredElaboration.OriginalElaboration(), + }, nil + }, + }, + }, + CompilerConfig: &compiler.Config{ + ImportHandler: func(location common.Location) *bbq.InstructionProgram { + return programs[location].Program + }, + LocationHandler: func(identifiers []ast.Identifier, location common.Location) ([]sema.ResolvedLocation, error) { + return []sema.ResolvedLocation{ + { + Location: location, + Identifiers: identifiers, + }, + }, nil + }, + }, + }, + }, + programs, + ) + require.NoError(t, err) + + } else { + + importedChecker, err := ParseAndCheckWithOptions(t, + importCode, + ParseAndCheckOptions{ + Location: importLocation, + }, + ) + require.NoError(t, err) + + inter, err := parseCheckAndInterpretWithOptions(t, + testCode, + ParseCheckAndInterpretOptions{ + ParseAndCheckOptions: &ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + ImportHandler: func(checker *sema.Checker, importedLocation common.Location, _ ast.Range) (sema.Import, error) { + return sema.ElaborationImport{ + Elaboration: importedChecker.Elaboration, + }, nil + }, + }, + }, + InterpreterConfig: &interpreter.Config{ + ImportLocationHandler: func(inter *interpreter.Interpreter, location common.Location) interpreter.Import { + + program := interpreter.ProgramFromChecker(importedChecker) + subInterpreter, err := inter.NewSubInterpreter(program, location) + if err != nil { + panic(err) + } + + return interpreter.InterpreterImport{ + Interpreter: subInterpreter, + } + }, + }, + }, + ) + require.NoError(t, err) + + _, err = inter.Invoke("test") + require.NoError(t, err) + } +} diff --git a/sema/before_extractor.go b/sema/before_extractor.go index b99842f5f..aa134c9f1 100644 --- a/sema/before_extractor.go +++ b/sema/before_extractor.go @@ -25,18 +25,23 @@ import ( type BeforeExtractor struct { ExpressionExtractor *ast.ExpressionExtractor - report func(error) memoryGauge common.MemoryGauge + report func(error) } -func NewBeforeExtractor(memoryGauge common.MemoryGauge, report func(error)) *BeforeExtractor { +func NewBeforeExtractor( + memoryGauge common.MemoryGauge, + identifierPrefix string, + report func(error), +) *BeforeExtractor { beforeExtractor := &BeforeExtractor{ - report: report, memoryGauge: memoryGauge, + report: report, } expressionExtractor := &ast.ExpressionExtractor{ InvocationExtractor: beforeExtractor, FunctionExtractor: beforeExtractor, + IdentifierPrefix: identifierPrefix, MemoryGauge: memoryGauge, } beforeExtractor.ExpressionExtractor = expressionExtractor diff --git a/sema/before_extractor_test.go b/sema/before_extractor_test.go index 5d016a58e..92c0b40e7 100644 --- a/sema/before_extractor_test.go +++ b/sema/before_extractor_test.go @@ -42,7 +42,7 @@ func TestBeforeExtractor(t *testing.T) { require.Empty(t, errs) - extractor := NewBeforeExtractor(nil, nil) + extractor := NewBeforeExtractor(nil, "", nil) identifier1 := ast.Identifier{ Identifier: extractor.ExpressionExtractor.FormatIdentifier(0), diff --git a/sema/checker.go b/sema/checker.go index 924c52b63..1d2bec91c 100644 --- a/sema/checker.go +++ b/sema/checker.go @@ -22,6 +22,7 @@ import ( goErrors "errors" "math" "math/big" + "strings" "github.com/rivo/uniseg" @@ -2695,7 +2696,12 @@ func (checker *Checker) maybeAddResourceInvalidation(resource Resource, invalida func (checker *Checker) beforeExtractor() *BeforeExtractor { if checker._beforeExtractor == nil { - checker._beforeExtractor = NewBeforeExtractor(checker.memoryGauge, checker.report) + checker._beforeExtractor = NewBeforeExtractor( + checker.memoryGauge, + // TODO: improve + strings.ReplaceAll(checker.Location.ID(), ".", "\x00"), + checker.report, + ) } return checker._beforeExtractor }