diff --git a/ast/parser.go b/ast/parser.go index 98f9c56bd6..3107117494 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -2630,7 +2630,7 @@ func (p *Parser) regoV1Import(imp *Import) { path := imp.Path.Value.(Ref) if len(path) == 1 || !path[1].Equal(RegoV1CompatibleRef[1]) || len(path) > 2 { - p.errorf(imp.Path.Location, "invalid import, must be `%s`", RegoV1CompatibleRef) + p.errorf(imp.Path.Location, "invalid import `%s`, must be `%s`", path, RegoV1CompatibleRef) return } diff --git a/ast/parser_test.go b/ast/parser_test.go index 6d7b97602e..8895ca3ebd 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -1337,9 +1337,10 @@ func TestFutureAndRegoV1ImportsExtraction(t *testing.T) { } func TestRegoV1Import(t *testing.T) { - assertParseErrorContains(t, "rego", "import rego", "invalid import, must be `rego.v1`") - assertParseErrorContains(t, "rego.foo", "import rego.foo", "invalid import, must be `rego.v1`") - assertParseErrorContains(t, "rego.foo.bar", "import rego.foo.bar", "invalid import, must be `rego.v1`") + assertParseErrorContains(t, "rego", "import rego", "invalid import `rego`, must be `rego.v1`") + assertParseErrorContains(t, "rego.foo", "import rego.foo", "invalid import `rego.foo`, must be `rego.v1`") + assertParseErrorContains(t, "rego.foo.bar", "import rego.foo.bar", "invalid import `rego.foo.bar`, must be `rego.v1`") + assertParseErrorContains(t, "rego.v1.bar", "import rego.v1.bar", "invalid import `rego.v1.bar`, must be `rego.v1`") assertParseErrorContains(t, "rego.v1 + alias", "import rego.v1 as xyz", "`rego` imports cannot be aliased") assertParseImport(t, "import rego.v1", diff --git a/cmd/bench_test.go b/cmd/bench_test.go index f370be9e02..2bb4dd1c37 100644 --- a/cmd/bench_test.go +++ b/cmd/bench_test.go @@ -60,6 +60,44 @@ func TestRunBenchmark(t *testing.T) { } } +func TestRunBenchmarkWithQueryImport(t *testing.T) { + params := testBenchParams() + // We add the rego.v1 import .. + params.imports = newrepeatedStringFlag([]string{"rego.v1"}) + + // .. which provides the 'in' keyword + args := []string{`"a" in ["a", "b", "c"]`} + var buf bytes.Buffer + + rc, err := benchMain(args, params, &buf, &goBenchRunner{}) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if rc != 0 { + t.Fatalf("Unexpected return code %d, expected 0", rc) + } + + // Expect a json serialized benchmark result with histogram fields + var br testing.BenchmarkResult + err = util.UnmarshalJSON(buf.Bytes(), &br) + if err != nil { + t.Fatalf("Unexpected error unmarshalling output: %s", err) + } + + if br.N == 0 || br.T == 0 || br.MemAllocs == 0 || br.MemBytes == 0 { + t.Fatalf("Expected benchmark results to be non-zero, got: %+v", br) + } + + if _, ok := br.Extra["histogram_timer_rego_query_eval_ns_count"]; !ok { + t.Fatalf("Expected benchmark results to contain histogram_timer_rego_query_eval_ns_count, got: %+v", br) + } + + if float64(br.N) != br.Extra["histogram_timer_rego_query_eval_ns_count"] { + t.Fatalf("Expected 'histogram_timer_rego_query_eval_ns_count' to be equal to N") + } +} + func TestRunBenchmarkE2E(t *testing.T) { params := testBenchParams() params.e2e = true diff --git a/cmd/eval_test.go b/cmd/eval_test.go index c5adbdc9c2..028319a7ed 100755 --- a/cmd/eval_test.go +++ b/cmd/eval_test.go @@ -2294,3 +2294,83 @@ p contains 2 if { } } } + +func TestWithQueryImports(t *testing.T) { + tests := []struct { + note string + query string + imports []string + exp string + expErrs []string + }{ + { + note: "no imports, none required", + query: "1 + 2", + exp: "3\n", + }, + { + note: "future keyword used, future.keywords imported", + query: `"b" in ["a", "b", "c"]`, + imports: []string{"future.keywords.in"}, + exp: "true\n", + }, + { + note: "future keyword used, rego.v1 imported", + query: `"b" in ["a", "b", "c"]`, + imports: []string{"rego.v1"}, + exp: "true\n", + }, + { + note: "future keyword used, invalid rego.v2 imported", + query: `"b" in ["a", "b", "c"]`, + imports: []string{"rego.v2"}, + expErrs: []string{ + "1:8: rego_parse_error: invalid import `rego.v2`, must be `rego.v1`", + }, + }, + { + note: "future keyword used, no imports", + query: `"b" in ["a", "b", "c"]`, + expErrs: []string{ + "1:5: rego_unsafe_var_error: var in is unsafe (hint: `import future.keywords.in` to import a future keyword)", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + params := newEvalCommandParams() + _ = params.outputFormat.Set(evalPrettyOutput) + params.imports = newrepeatedStringFlag(tc.imports) + + var buf bytes.Buffer + + defined, err := eval([]string{tc.query}, params, &buf) + + if len(tc.expErrs) == 0 { + if err != nil { + t.Fatalf("Unexpected error: %v, buf: %s", err, buf.String()) + } + + if !defined { + t.Fatal("expected result to be defined") + } + + if buf.String() != tc.exp { + t.Fatalf("expected:\n\n%s\n\ngot:\n\n%s", tc.exp, buf.String()) + } + } else { + if err == nil { + t.Fatal("expected error, got none") + } + + actual := buf.String() + for _, expErr := range tc.expErrs { + if !strings.Contains(actual, expErr) { + t.Fatalf("expected error:\n\n%v\n\ngot\n\n%v", expErr, actual) + } + } + } + }) + } +} diff --git a/rego/rego.go b/rego/rego.go index 17873e380a..25379fed6f 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -1761,14 +1761,15 @@ func (r *Rego) prepare(ctx context.Context, qType queryType, extras []extraStage return err } - futureImports := []*ast.Import{} + queryImports := []*ast.Import{} for _, imp := range imports { - if imp.Path.Value.(ast.Ref).HasPrefix(ast.Ref([]*ast.Term{ast.FutureRootDocument})) { - futureImports = append(futureImports, imp) + path := imp.Path.Value.(ast.Ref) + if path.HasPrefix([]*ast.Term{ast.FutureRootDocument}) || path.HasPrefix([]*ast.Term{ast.RegoRootDocument}) { + queryImports = append(queryImports, imp) } } - r.parsedQuery, err = r.parseQuery(futureImports, r.metrics) + r.parsedQuery, err = r.parseQuery(queryImports, r.metrics) if err != nil { return err } @@ -1921,7 +1922,7 @@ func (r *Rego) parseRawInput(rawInput *interface{}, m metrics.Metrics) (ast.Valu return ast.InterfaceToValue(*rawPtr) } -func (r *Rego) parseQuery(futureImports []*ast.Import, m metrics.Metrics) (ast.Body, error) { +func (r *Rego) parseQuery(queryImports []*ast.Import, m metrics.Metrics) (ast.Body, error) { if r.parsedQuery != nil { return r.parsedQuery, nil } @@ -1929,7 +1930,11 @@ func (r *Rego) parseQuery(futureImports []*ast.Import, m metrics.Metrics) (ast.B m.Timer(metrics.RegoQueryParse).Start() defer m.Timer(metrics.RegoQueryParse).Stop() - popts, err := future.ParserOptionsFromFutureImports(futureImports) + popts, err := future.ParserOptionsFromFutureImports(queryImports) + if err != nil { + return nil, err + } + popts, err = parserOptionsFromRegoVersionImport(queryImports, popts) if err != nil { return nil, err } @@ -1937,6 +1942,17 @@ func (r *Rego) parseQuery(futureImports []*ast.Import, m metrics.Metrics) (ast.B return ast.ParseBodyWithOpts(r.query, popts) } +func parserOptionsFromRegoVersionImport(imports []*ast.Import, popts ast.ParserOptions) (ast.ParserOptions, error) { + for _, imp := range imports { + path := imp.Path.Value.(ast.Ref) + if ast.Compare(path, ast.RegoV1CompatibleRef) == 0 { + popts.RegoVersion = ast.RegoV1 + return popts, nil + } + } + return popts, nil +} + func (r *Rego) compileModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error { // Only compile again if there are new modules.