Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions scripts/libevm-allowed-packages.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"github.com/ava-labs/libevm/accounts"
"github.com/ava-labs/libevm/accounts/external"
"github.com/ava-labs/libevm/accounts/keystore"
"github.com/ava-labs/libevm/accounts/scwallet"
"github.com/ava-labs/libevm/common"
"github.com/ava-labs/libevm/common/bitutil"
"github.com/ava-labs/libevm/common/compiler"
"github.com/ava-labs/libevm/common/hexutil"
"github.com/ava-labs/libevm/common/lru"
"github.com/ava-labs/libevm/common/math"
"github.com/ava-labs/libevm/common/prque"
"github.com/ava-labs/libevm/consensus/misc/eip4844"
"github.com/ava-labs/libevm/core/asm"
"github.com/ava-labs/libevm/core/bloombits"
"github.com/ava-labs/libevm/core/rawdb"
"github.com/ava-labs/libevm/core/state"
"github.com/ava-labs/libevm/core/types"
"github.com/ava-labs/libevm/core/vm"
"github.com/ava-labs/libevm/crypto"
"github.com/ava-labs/libevm/crypto/blake2b"
"github.com/ava-labs/libevm/crypto/bls12381"
"github.com/ava-labs/libevm/crypto/bn256"
"github.com/ava-labs/libevm/crypto/kzg4844"
"github.com/ava-labs/libevm/eth/tracers/js"
"github.com/ava-labs/libevm/eth/tracers/logger"
"github.com/ava-labs/libevm/eth/tracers/native"
"github.com/ava-labs/libevm/ethdb"
"github.com/ava-labs/libevm/ethdb/leveldb"
"github.com/ava-labs/libevm/ethdb/memorydb"
"github.com/ava-labs/libevm/ethdb/pebble"
"github.com/ava-labs/libevm/event"
"github.com/ava-labs/libevm/libevm"
"github.com/ava-labs/libevm/libevm/legacy"
"github.com/ava-labs/libevm/libevm/options"
"github.com/ava-labs/libevm/libevm/stateconf"
"github.com/ava-labs/libevm/log"
"github.com/ava-labs/libevm/metrics"
"github.com/ava-labs/libevm/rlp"
"github.com/ava-labs/libevm/trie"
"github.com/ava-labs/libevm/trie/testutil"
"github.com/ava-labs/libevm/trie/trienode"
"github.com/ava-labs/libevm/trie/triestate"
"github.com/ava-labs/libevm/trie/utils"
"github.com/ava-labs/libevm/triedb"
"github.com/ava-labs/libevm/triedb/database"
196 changes: 196 additions & 0 deletions tests/imports_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package tests

import (
"bufio"
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"regexp"
"slices"
"strings"
"testing"

"github.com/ava-labs/avalanchego/utils/set"
"github.com/stretchr/testify/require"
)

// TestDoNotImportFromGraft ensures that files outside the graft directory
// do not import packages from within the graft directory.
func TestDoNotImportFromGraft(t *testing.T) {
graftRegex := regexp.MustCompile(`^github\.com/ava-labs/avalanchego/graft(/|$)`)

// Find all graft imports in the codebase (excluding the graft directory itself)
foundImports, err := findImportsMatchingPattern("..", graftRegex, func(path string, _ string, _ *ast.ImportSpec) bool {
return strings.Contains(path, "/graft/")
})
require.NoError(t, err, "Failed to find graft imports")

if len(foundImports) == 0 {
return
}

// After this point, there are imports from the graft directory, and the test will fail.
// The remaining code is just necessary to pretty-print the error message,
// to make it easier to find and fix the disallowed imports.
sortedImports := make([]string, 0, len(foundImports))
for importPath := range foundImports {
sortedImports = append(sortedImports, importPath)
}
slices.Sort(sortedImports)

var errorMsg strings.Builder
errorMsg.WriteString("Files outside the graft directory must not import from the graft directory!\n\n")
for _, importPath := range sortedImports {
files := foundImports[importPath]
fileList := files.List()
slices.Sort(fileList)

errorMsg.WriteString(fmt.Sprintf("- %s\n", importPath))
errorMsg.WriteString(fmt.Sprintf(" Used in %d file(s):\n", len(fileList)))
for _, file := range fileList {
errorMsg.WriteString(fmt.Sprintf(" • %s\n", file))
}
errorMsg.WriteString("\n")
}
require.Fail(t, errorMsg.String())
}

// TestLibevmImportsAreAllowed ensures that all libevm imports in the graft directory
// are explicitly allowed via the libevm-allowed-packages.txt file.
func TestLibevmImportsAreAllowed(t *testing.T) {
allowedPackages, err := loadAllowedPackages("../scripts/libevm-allowed-packages.txt")
require.NoError(t, err, "Failed to load allowed packages")

// Find all libevm imports in source files, excluding underscore and "eth*" named imports
libevmRegex := regexp.MustCompile(`^github\.com/ava-labs/libevm/`)
foundImports, err := findImportsMatchingPattern("../graft", libevmRegex, func(_ string, _ string, imp *ast.ImportSpec) bool {
// Skip underscore and "eth*" named imports
return imp.Name != nil && (imp.Name.Name == "_" || strings.HasPrefix(imp.Name.Name, "eth"))
})
require.NoError(t, err, "Failed to find libevm imports")

var disallowedImports set.Set[string]
for importPath := range foundImports {
if !allowedPackages.Contains(importPath) {
disallowedImports.Add(importPath)
}
}

if len(disallowedImports) == 0 {
return
}

// After this point, there are disallowed imports, and the test will fail.
// The remaining code is just necessary to pretty-print the error message,
// to make it easier to find and fix the disallowed imports.
sortedDisallowed := disallowedImports.List()
slices.Sort(sortedDisallowed)

var errorMsg strings.Builder
errorMsg.WriteString("Files inside the graft directory must not import forbidden libevm packages!\nIf a package is safe to import, add it to ./scripts/libevm-allowed-packages.txt.\n\n")
for _, importPath := range sortedDisallowed {
files := foundImports[importPath]
fileList := files.List()
slices.Sort(fileList)

errorMsg.WriteString(fmt.Sprintf("- %s\n", importPath))
errorMsg.WriteString(fmt.Sprintf(" Used in %d file(s):\n", len(fileList)))
for _, file := range fileList {
errorMsg.WriteString(fmt.Sprintf(" • %s\n", file))
}
errorMsg.WriteString("\n")
}
require.Fail(t, errorMsg.String())
}

// loadAllowedPackages reads the allowed packages from the specified file
func loadAllowedPackages(filename string) (set.Set[string], error) {
file, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("failed to open allowed packages file: %w", err)
}
defer file.Close()

allowed := set.Set[string]{}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}

line = strings.Trim(line, `"`)
allowed.Add(line)
}

if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to read allowed packages file: %w", err)
}

return allowed, nil
}

// importFilter is a function that can filter imports based on the file path,
// import path, and AST import spec (which contains the import name).
// Return true to skip this import.
type importFilter func(filePath string, importPath string, importSpec *ast.ImportSpec) bool

// findImportsMatchingPattern is a generalized function that finds all imports
// matching a given regex pattern in the specified directory.
// The filterFunc can be used to skip certain files or imports (return true to skip).
// Returns a map of import paths to the set of files that contain them.
func findImportsMatchingPattern(
rootDir string,
importRegex *regexp.Regexp,
filterFunc importFilter,
) (map[string]set.Set[string], error) {
imports := make(map[string]set.Set[string])

err := filepath.Walk(rootDir, func(path string, _ os.FileInfo, err error) error {
if err != nil || !strings.HasSuffix(path, ".go") {
return err
}

// Skip generated files, main_test.go, and tempextrastest directory
filename := filepath.Base(path)
if strings.HasPrefix(filename, "gen_") || strings.Contains(path, "core/main_test.go") || strings.Contains(path, "tempextrastest/") {
return nil
}

node, err := parser.ParseFile(token.NewFileSet(), path, nil, parser.ParseComments)
if err != nil {
return fmt.Errorf("failed to parse %s: %w", path, err)
}

for _, imp := range node.Imports {
if imp.Path == nil {
continue
}

importPath := strings.Trim(imp.Path.Value, `"`)
if !importRegex.MatchString(importPath) {
continue
}

if filterFunc != nil && filterFunc(path, importPath, imp) {
continue
}

if _, exists := imports[importPath]; !exists {
imports[importPath] = set.Set[string]{}
}
fileSet := imports[importPath]
fileSet.Add(path)
imports[importPath] = fileSet
}
return nil
})
if err != nil {
return nil, err
}

return imports, nil
}
Loading