diff --git a/CMakeLists.txt b/CMakeLists.txt index f2d81d8c6..14f24bb58 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ endif() if (NOT LLVM_BUILD_TYPE) set(LLVM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) endif() -set(LLVM_CONFIG_EXECUTABLE ${CMAKE_SOURCE_DIR}/third_party/llvm-${LLVM_BUILD_TYPE}-install/bin/llvm-config) +set(LLVM_CONFIG_EXECUTABLE ${CMAKE_SOURCE_DIR}/third_party/llvm-Release-install/bin/llvm-config) if(NOT EXISTS ${LLVM_CONFIG_EXECUTABLE}) message(FATAL_ERROR "llvm-config could not be found!") endif() @@ -28,7 +28,7 @@ execute_process( OUTPUT_STRIP_TRAILING_WHITESPACE ) -set(LLVM_CXXFLAGS "${LLVM_CXXFLAGS} -fno-exceptions -fno-rtti") +set(LLVM_CXXFLAGS "${LLVM_CXXFLAGS} -fno-exceptions -fno-rtti -Wno-deprecated-enum-enum-conversion") execute_process( COMMAND ${LLVM_CONFIG_EXECUTABLE} --libs @@ -84,10 +84,10 @@ set(GTEST_CXXFLAGS "-DGTEST_HAS_RTTI=0") set(GTEST_INCLUDEDIR "${LLVM_SRC}/utils/unittest/googletest/include") set(GTEST_LIBS "-lgtest_main -lgtest") -set(LLVM_CXXFLAGS "${LLVM_CXXFLAGS} -std=c++17") +set(LLVM_CXXFLAGS "${LLVM_CXXFLAGS} -std=c++20") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(CMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LANGUAGE_STANDARD "c++17") + set(CMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LANGUAGE_STANDARD "c++20") set(LLVM_CXXFLAGS "${LLVM_CXXFLAGS} -fvisibility-inlines-hidden") set(PASS_LDFLAGS "-Wl,-undefined,dynamic_lookup") endif() @@ -129,11 +129,32 @@ find_library(HIREDIS_LIBRARY third_party/hiredis-install/lib NO_DEFAULT_PATH) -find_library(ALIVE_LIBRARY alive2 PATHS "${ALIVE_BUILDDIR}" NO_DEFAULT_PATH) -if (ALIVE_LIBRARY) - message(STATUS "Alive2: ${ALIVE_LIBRARY}") +find_library(ALIVE_IR ir PATHS "${ALIVE_BUILDDIR}" NO_DEFAULT_PATH) +if (ALIVE_IR) + message(STATUS "Alive2 IR") else() - message(SEND_ERROR "Alive2 not found") + message(SEND_ERROR "Alive2 libir.a not found") +endif() + +find_library(ALIVE_SMT smt PATHS "${ALIVE_BUILDDIR}" NO_DEFAULT_PATH) +if (ALIVE_SMT) + message(STATUS "Alive2 SMT") +else() + message(SEND_ERROR "Alive2 libsmt.a not found") +endif() + +find_library(ALIVE_TOOLS tools PATHS "${ALIVE_BUILDDIR}" NO_DEFAULT_PATH) +if (ALIVE_TOOLS) + message(STATUS "Alive2 TOOLS") +else() + message(SEND_ERROR "Alive2 libtools.a not found") +endif() + +find_library(ALIVE_UTIL util PATHS "${ALIVE_BUILDDIR}" NO_DEFAULT_PATH) +if (ALIVE_UTIL) + message(STATUS "Alive2 UTIL") +else() + message(SEND_ERROR "Alive2 libutil.a not found") endif() set(Z3 "${CMAKE_SOURCE_DIR}/third_party/z3-install/bin/z3") @@ -147,6 +168,8 @@ else() message(SEND_ERROR "Z3 shared lib not found") endif() +set(ALIVE_LIBRARY ${ALIVE_IR} ${ALIVE_SMT} ${ALIVE_TOOLS} ${ALIVE_UTIL} ${Z3_LIBRARY}) + set(SOUPER_CLANG_TOOL_FILES lib/ClangTool/Actions.cpp include/souper/ClangTool/Actions.h @@ -195,12 +218,23 @@ set(SOUPER_INFER_FILES include/souper/Infer/Interpreter.h lib/Infer/Preconditions.cpp include/souper/Infer/Preconditions.h + lib/Infer/SynthUtils.cpp + include/souper/Infer/SynthUtils.h ) add_library(souperInfer STATIC ${SOUPER_INFER_FILES} ) +set(SOUPER_GENERALIZE_FILES + lib/Generalize/Reducer.cpp + include/souper/Generalize/Reducer.h +) + +add_library(souperGeneralize STATIC + ${SOUPER_GENERALIZE_FILES} +) + set(SOUPER_INST_FILES lib/Inst/Inst.cpp include/souper/Inst/Inst.h @@ -239,6 +273,16 @@ add_library(souperTool STATIC ${SOUPER_TOOL_FILES} ) +set(SOUPER_CODEGEN_FILES + lib/Codegen/Codegen.cpp + lib/Codegen/MachineCost.cpp + include/souper/Codegen/Codegen.h +) + +add_library(souperCodegen STATIC + ${SOUPER_CODEGEN_FILES} +) + set(SOUPER_SOURCES ${SOUPER_EXTRACTOR_FILES} ${SOUPER_INST_FILES} @@ -246,12 +290,9 @@ set(SOUPER_SOURCES ${SOUPER_PARSER_FILES} ${SOUPER_SMTLIB2_FILES} ${SOUPER_TOOL_FILES} - ${SOUPER_INFER_FILES}) - -set(SOUPER_CODEGEN_FILES - lib/Codegen/Codegen.cpp - include/souper/Codegen/Codegen.h -) + ${SOUPER_INFER_FILES} + ${SOUPER_GENERALIZE_FILES} + ${SOUPER_CODEGEN_FILES}) add_library(souperPass SHARED ${KLEE_EXPR_FILES} @@ -265,10 +306,6 @@ add_library(souperPassProfileAll SHARED lib/Pass/Pass.cpp ) -add_library(souperCodegen SHARED - ${SOUPER_CODEGEN_FILES} -) - target_compile_definitions(souperPassProfileAll PRIVATE DYNAMIC_PROFILE_ALL=1) add_executable(clang-souper @@ -295,6 +332,14 @@ add_executable(souper-check tools/souper-check.cpp ) +add_executable(generalize + tools/generalize.cpp +) + +add_executable(matcher-gen + tools/matcher-gen.cpp +) + add_executable(souper-interpret tools/souper-interpret.cpp ) @@ -319,6 +364,10 @@ add_executable(parser_tests unittests/Parser/ParserTests.cpp ) +add_executable(codegen_tests + unittests/Codegen/CodegenTests.cpp +) + add_executable(interpreter_tests unittests/Interpreter/InterpreterInfra.cpp unittests/Interpreter/InterpreterTests.cpp) @@ -362,8 +411,8 @@ configure_file( ) foreach(target souper internal-solver-test lexer-test parser-test souper-check count-insts - souper2llvm souper-interpret - souperExtractor souperInfer souperInst souperKVStore souperParser + souper2llvm souper-interpret generalize matcher-gen + souperExtractor souperInfer souperGeneralize souperInst souperKVStore souperParser souperSMTLIB2 souperTool souperPass souperPassProfileAll kleeExpr souperCodegen) set_target_properties(${target} PROPERTIES COMPILE_FLAGS "${LLVM_CXXFLAGS}") @@ -373,7 +422,7 @@ foreach(target souperClangTool clang-souper) set_target_properties(${target} PROPERTIES COMPILE_FLAGS "${CLANG_CXXFLAGS} ${LLVM_CXXFLAGS}") target_include_directories(${target} PRIVATE "${LLVM_INCLUDEDIR}" ${CLANG_INCLUDEDIR}) endforeach() -foreach(target extractor_tests inst_tests parser_tests interpreter_tests bulk_tests) +foreach(target extractor_tests inst_tests parser_tests interpreter_tests bulk_tests codegen_tests) set_target_properties(${target} PROPERTIES COMPILE_FLAGS "${GTEST_CXXFLAGS} ${LLVM_CXXFLAGS}") target_include_directories(${target} PRIVATE "${LLVM_INCLUDEDIR}" "${GTEST_INCLUDEDIR}") endforeach() @@ -381,18 +430,20 @@ endforeach() # static target_link_libraries(kleeExpr ${LLVM_LIBS} ${LLVM_LDFLAGS}) target_link_libraries(souperClangTool souperExtractor souperTool ${CLANG_LIBS} ${LLVM_LIBS} ${LLVM_LDFLAGS}) -target_link_libraries(souperExtractor souperParser souperKVStore souperInfer souperInst kleeExpr) +target_link_libraries(souperExtractor souperParser souperKVStore souperInfer souperInst kleeExpr souperCodegen) target_link_libraries(souperInfer souperExtractor ${LLVM_LIBS} ${LLVM_LDFLAGS} ${Z3_LIBRARY}) +target_link_libraries(souperGeneralize souperInfer ${LLVM_LIBS} ${LLVM_LDFLAGS} ${Z3_LIBRARY}) target_link_libraries(souperInst ${LLVM_LIBS} ${LLVM_LDFLAGS}) target_link_libraries(souperKVStore ${HIREDIS_LIBRARY} ${LLVM_LIBS} ${LLVM_LDFLAGS}) target_link_libraries(souperParser souperInst ${LLVM_LIBS} ${LLVM_LDFLAGS} ${ALIVE_LIBRARY}) target_link_libraries(souperSMTLIB2 ${LLVM_LIBS} ${LLVM_LDFLAGS}) target_link_libraries(souperTool souperExtractor souperSMTLIB2) +target_link_libraries(souperCodegen ${LLVM_LIBS} ${LLVM_LDFLAGS}) # dynamic target_link_libraries(souperCodegen ${PASS_LDFLAGS}) -target_link_libraries(souperPass souperCodegen ${PASS_LDFLAGS} ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) -target_link_libraries(souperPassProfileAll souperCodegen ${PASS_LDFLAGS} ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) +target_link_libraries(souperPass ${PASS_LDFLAGS} ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) +target_link_libraries(souperPassProfileAll ${PASS_LDFLAGS} ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) # executables target_link_libraries(souper souperExtractor souperKVStore souperParser souperSMTLIB2 souperTool kleeExpr ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) @@ -400,6 +451,8 @@ target_link_libraries(internal-solver-test souperSMTLIB2) target_link_libraries(lexer-test souperParser) target_link_libraries(parser-test souperParser) target_link_libraries(souper-check souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) +target_link_libraries(generalize souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser souperInfer souperGeneralize ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) +target_link_libraries(matcher-gen souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) target_link_libraries(souper-interpret souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) target_link_libraries(clang-souper souperClangTool souperExtractor souperKVStore souperParser souperSMTLIB2 souperTool kleeExpr ${CLANG_LIBS} ${LLVM_LIBS} ${LLVM_LDFLAGS} ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) target_link_libraries(count-insts souperParser) @@ -407,6 +460,7 @@ target_link_libraries(souper2llvm souperParser souperCodegen) target_link_libraries(extractor_tests souperExtractor souperParser ${GTEST_LIBS} ${ALIVE_LIBRARY}) target_link_libraries(inst_tests souperInfer souperPass souperInst souperExtractor ${GTEST_LIBS} ${ALIVE_LIBRARY}) target_link_libraries(parser_tests souperParser ${GTEST_LIBS} ${ALIVE_LIBRARY}) +target_link_libraries(codegen_tests souperCodegen souperInst ${GTEST_LIBS} ${ALIVE_LIBRARY}) target_link_libraries(interpreter_tests souperInfer souperInst ${GTEST_LIBS} ${ALIVE_LIBRARY}) target_link_libraries(bulk_tests souperInfer souperInst ${GTEST_LIBS} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) @@ -436,27 +490,12 @@ configure_file( add_custom_target(check COMMAND ${CMAKE_BINARY_DIR}/run_lit - DEPENDS extractor_tests inst_tests parser-test parser_tests profileRuntime souper souper-check souper-interpret souperPass souper2llvm souperPassProfileAll count-insts interpreter_tests bulk_tests + DEPENDS extractor_tests inst_tests parser-test parser_tests profileRuntime souper souper-check souper-interpret souperPass souper2llvm souperPassProfileAll count-insts interpreter_tests bulk_tests codegen_tests USES_TERMINAL) # we want assertions even in release mode! string(REPLACE "-DNDEBUG" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") -find_program(GO_EXECUTABLE NAMES go DOC "go executable") -if(NOT GO_EXECUTABLE STREQUAL "GO_EXECUTABLE-NOTFOUND") - add_executable(souperweb-backend - tools/souperweb-backend.cpp - ) - set_target_properties(souperweb-backend - PROPERTIES COMPILE_FLAGS "${LLVM_CXXFLAGS}") - target_include_directories(souperweb-backend PRIVATE "${LLVM_INCLUDEDIR}") - - target_link_libraries(souperweb-backend souperTool souperExtractor souperPass souperKVStore souperSMTLIB2 souperParser souperInst ${HIREDIS_LIBRARY}) - - add_custom_target(souperweb ALL COMMAND ${GO_EXECUTABLE} build -o ${CMAKE_BINARY_DIR}/souperweb ${CMAKE_SOURCE_DIR}/tools/souperweb.go - COMMENT "Building souperweb") -endif() - add_library(profileRuntime STATIC runtime/souperPassProfile.c) @@ -472,8 +511,10 @@ configure_file(${CMAKE_SOURCE_DIR}/utils/py_souper2llvm.in ${CMAKE_BINARY_DIR}/p configure_file(${CMAKE_SOURCE_DIR}/include/souper/Tool/GetSolver.h.in ${CMAKE_BINARY_DIR}/include/souper/Tool/GetSolver.h @ONLY) if (BUILD_CLANG_TOOL) - configure_file(${CMAKE_SOURCE_DIR}/utils/sclang.in ${CMAKE_BINARY_DIR}/sclang @ONLY) - configure_file(${CMAKE_SOURCE_DIR}/utils/sclang.in ${CMAKE_BINARY_DIR}/sclang++ @ONLY) +configure_file(${CMAKE_SOURCE_DIR}/utils/sclang.in ${CMAKE_BINARY_DIR}/sclang @ONLY) +configure_file(${CMAKE_SOURCE_DIR}/utils/sclang.in ${CMAKE_BINARY_DIR}/sclang++ @ONLY) +configure_file(${CMAKE_SOURCE_DIR}/utils/mclang.in ${CMAKE_BINARY_DIR}/mclang @ONLY) +configure_file(${CMAKE_SOURCE_DIR}/utils/mclang.in ${CMAKE_BINARY_DIR}/mclang++ @ONLY) endif() add_subdirectory(docs) diff --git a/Dockerfile b/Dockerfile index b1c911694..362194091 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,34 +1,24 @@ from ubuntu:20.04 run set -x; \ - echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections \ - && apt-get update -y -qq \ - && apt-get dist-upgrade -y -qq \ - && apt-get autoremove -y -qq \ - && apt-get remove -y -qq clang llvm llvm-runtime \ - && apt-get install libgmp10 \ - && echo 'ca-certificates valgrind libc6-dev libgmp-dev cmake ninja-build make autoconf automake libtool golang-go python subversion re2c git clang' > /usr/src/build-deps \ - && apt-get install -y $(cat /usr/src/build-deps) --no-install-recommends \ - && git clone https://github.com/antirez/redis /usr/src/redis - -run export CC=clang CXX=clang++ \ - && cd /usr/src/redis \ - && git checkout 5.0.3 \ - && make -j10 \ - && make install - -run export GOPATH=/usr/src/go \ - && go get github.com/gomodule/redigo/redis + echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections \ + && apt-get update -y -qq \ + && apt-get dist-upgrade -y -qq \ + && apt-get autoremove -y -qq \ + && apt-get remove -y -qq clang llvm llvm-runtime \ + && apt-get install libgmp10 \ + && echo 'ca-certificates valgrind libc6-dev libgmp-dev cmake ninja-build make autoconf automake libtool python python3 subversion re2c git clang libstdc++-10-dev redis' > /usr/src/build-deps \ + && apt-get install -y $(cat /usr/src/build-deps) --no-install-recommends add build_deps.sh /usr/src/souper/build_deps.sh add clone_and_test.sh /usr/src/souper/clone_and_test.sh run export CC=clang CXX=clang++ \ - && cd /usr/src/souper \ -# && ./build_deps.sh Debug \ -# && rm -r third_party/llvm-Debug-build \ - && ./build_deps.sh Release \ - && rm -r third_party/llvm-Release-build + && cd /usr/src/souper \ +# && ./build_deps.sh Debug \ +# && rm -r third_party/llvm-Debug-build \ + && ./build_deps.sh Release \ + && rm -r third_party/llvm-Release-build add CMakeLists.txt /usr/src/souper/CMakeLists.txt @@ -41,19 +31,16 @@ add tools /usr/src/souper/tools add utils /usr/src/souper/utils add unittests /usr/src/souper/unittests -run export GOPATH=/usr/src/go \ - && export LD_LIBRARY_PATH=/usr/src/souper/third_party/z3-install/lib:$LD_LIBRARY_PATH \ - && mkdir -p /usr/src/souper-build \ - && cd /usr/src/souper-build \ - && CC=/usr/src/souper/third_party/llvm-Release-install/bin/clang CXX=/usr/src/souper/third_party/llvm-Release-install/bin/clang++ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DTEST_SYNTHESIS=ON ../souper \ - && ninja souperweb souperweb-backend \ - && ninja check \ - && cp souperweb souperweb-backend /usr/local/bin \ - && cd .. \ - && rm -rf /usr/src/souper-build \ - && strip /usr/local/bin/* \ - && groupadd -r souper \ - && useradd -m -r -g souper souper \ - && mkdir /data \ - && chown souper:souper /data \ - && rm -rf /usr/local/include /usr/local/lib/*.a /usr/local/lib/*.la +run export LD_LIBRARY_PATH=/usr/src/souper/third_party/z3-install/lib:$LD_LIBRARY_PATH \ + && mkdir -p /usr/src/souper-build \ + && cd /usr/src/souper-build \ + && CC=/usr/src/souper/third_party/llvm-Release-install/bin/clang CXX=/usr/src/souper/third_party/llvm-Release-install/bin/clang++ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DTEST_SYNTHESIS=ON ../souper \ + && ninja \ + && ninja check \ + && cd .. \ + && rm -rf /usr/src/souper-build \ + && groupadd -r souper \ + && useradd -m -r -g souper souper \ + && mkdir /data \ + && chown souper:souper /data \ + && rm -rf /usr/local/include /usr/local/lib/*.a /usr/local/lib/*.la diff --git a/README.md b/README.md index ef7743c95..1c22f6328 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ Souper is a superoptimizer for LLVM IR. It uses an SMT solver to help identify missing peephole optimizations in LLVM's midend optimizers. +The architecture and concepts of Souper are described in [Souper: A synthesizing superoptimizer](https://arxiv.org/pdf/1711.04422.pdf). + # Requirements Souper should work on any reasonably modern Linux or OS X machine. @@ -11,11 +13,6 @@ http://llvm.org/docs/GettingStarted.html#getting-a-modern-host-c-toolchain You will also need CMake to build Souper and its dependencies. -If you have Go installed, you will also need the Redigo Redis client: -``` -$ go get github.com/gomodule/redigo/redis -``` - # Building Souper 1. Download and build dependencies: diff --git a/build_deps.sh b/build_deps.sh index efec652ec..4b4ff9ba3 100755 --- a/build_deps.sh +++ b/build_deps.sh @@ -21,17 +21,18 @@ fi ncpus=$(command nproc 2>/dev/null || command sysctl -n hw.ncpu 2>/dev/null || echo 8) -# hiredis version 0.14.0 -hiredis_commit=685030652cd98c5414ce554ff5b356dfe8437870 +# hiredis latest as of May 7 2021 +hiredis_commit=667dbf536524ba3f28c1d964793db1055c5a64f2 llvm_repo=https://github.com/regehr/llvm-project.git # llvm_commit specifies the git branch or hash to checkout to -llvm_commit=disable-peepholes-v04 +llvm_commit=disable-peepholes-llvm12-v02 klee_repo=https://github.com/rsas/klee klee_branch=pure-bv-qf-llvm-7.0 -alive_commit=v1 +alive_commit=v2 alive_repo=https://github.com/manasij7479/alive2.git z3_repo=https://github.com/Z3Prover/z3.git -z3_commit=z3-4.8.9 +# latest as of May 25 2021 +z3_commit=322531e95cb7da59b4596000ffbc92d792433f17 llvm_build_type=Release if [ -n "$1" ] ; then @@ -75,7 +76,7 @@ mkdir -p $llvm_srcdir mkdir -p $llvm_builddir -cmake_flags="-DCMAKE_INSTALL_PREFIX=$llvm_installdir -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_FORCE_ENABLE_STATS=ON -DCMAKE_BUILD_TYPE=$llvm_build_type -DLLVM_ENABLE_Z3_SOLVER=OFF -DLLVM_ENABLE_PROJECTS=\'llvm;clang;compiler-rt\'" +cmake_flags="-DCMAKE_INSTALL_PREFIX=$llvm_installdir -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_FORCE_ENABLE_STATS=ON -DCMAKE_BUILD_TYPE=$llvm_build_type -DLLVM_ENABLE_Z3_SOLVER=OFF -DLLVM_ENABLE_PROJECTS=\'llvm;clang;openmp;compiler-rt\'" if [ -n "`which ninja`" ] ; then (cd $llvm_builddir && cmake ${llvm_srcdir}/llvm -G Ninja $cmake_flags -DCMAKE_CXX_FLAGS="-DDISABLE_WRONG_OPTIMIZATIONS_DEFAULT_VALUE=true -DDISABLE_PEEPHOLES_DEFAULT_VALUE=false" "$@") @@ -114,5 +115,5 @@ mkdir -p $hiredis_installdir/include/hiredis mkdir -p $hiredis_installdir/lib (cd $hiredis_srcdir && git checkout $hiredis_commit && make libhiredis.a && - cp -r hiredis.h async.h read.h sds.h adapters ${hiredis_installdir}/include/hiredis && + cp -r alloc.h hiredis.h async.h read.h sds.h adapters ${hiredis_installdir}/include/hiredis && cp libhiredis.a ${hiredis_installdir}/lib) diff --git a/build_docker.sh b/build_docker.sh index f0012d7c4..02da70cb6 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -2,7 +2,7 @@ # docker system prune -a -tar cz Dockerfile build_deps.sh clone_and_test.sh CMakeLists.txt docs include lib patches runtime scripts test tools utils unittests | docker build -t souperweb - +tar cz Dockerfile build_deps.sh clone_and_test.sh CMakeLists.txt docs include lib runtime test tools utils unittests | docker build -t souperweb - container=$(/usr/bin/docker run -d souperweb true) docker export $container | docker import - souperweb_squashed docker build -t souperweb_final - < Dockerfile.metadata diff --git a/clone_and_test.sh b/clone_and_test.sh index a2fb96f47..d8bcbaae8 100755 --- a/clone_and_test.sh +++ b/clone_and_test.sh @@ -14,8 +14,8 @@ echo "TRAVIS_EVENT_TYPE set to push"; fi if [ -z ${TRAVIS_BRANCH} ]; then -TRAVIS_BRANCH="master"; -echo "TRAVIS_BRANCH set to master"; +TRAVIS_BRANCH="main"; +echo "TRAVIS_BRANCH set to main"; fi # check if this is a pull request or a push @@ -32,7 +32,6 @@ else ln -s /usr/src/souper/third_party; fi -export GOPATH=/usr/src/go Z3=/usr/bin/z3 SRCDIR="$PWD" diff --git a/docs/Doxyfile.in b/docs/Doxyfile.in index 7ad3b1d6c..7ba7f5095 100644 --- a/docs/Doxyfile.in +++ b/docs/Doxyfile.in @@ -1119,13 +1119,6 @@ VERBATIM_HEADERS = YES ALPHABETICAL_INDEX = YES -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - # In case all classes in a project start with a common prefix, all classes will # be put under the same header in the alphabetical index. The IGNORE_PREFIX tag # can be used to specify a prefix (or a list of prefixes) that should be ignored diff --git a/include/souper/Codegen/Codegen.h b/include/souper/Codegen/Codegen.h index ed8a88d18..b6fe6534b 100644 --- a/include/souper/Codegen/Codegen.h +++ b/include/souper/Codegen/Codegen.h @@ -16,13 +16,17 @@ #define SOUPER_CODEGEN_CODEGEN_H #include "souper/Inst/Inst.h" +#include "souper/Parser/Parser.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "llvm/IR/Verifier.h" #include +#include "llvm/Support/MemoryBuffer.h" + namespace souper { class Codegen { @@ -47,6 +51,19 @@ class Codegen { llvm::Value *getValue(Inst *I); }; +// If there are no errors, the function returns false. If an error is found, +// a message describing the error is written to OS (if non-null) and true is +// returned. +bool genModule(InstContext &IC, Inst *I, llvm::Module &Module); + +struct BackendCost { + std::vector C; +}; + +void getBackendCost(InstContext &IC, Inst *I, BackendCost &BC); + +bool compareCosts(const BackendCost &C1, const BackendCost &C2); + } // namespace souper #endif // SOUPER_CODEGEN_CODEGEN_H diff --git a/include/souper/Extractor/Solver.h b/include/souper/Extractor/Solver.h index a9648cfe1..8299b0a40 100644 --- a/include/souper/Extractor/Solver.h +++ b/include/souper/Extractor/Solver.h @@ -44,6 +44,14 @@ class Solver { InstMapping Mapping, bool &IsValid, std::vector> *Model) = 0; + virtual std::error_code + isSatisfiable(llvm::StringRef Query, bool &Result, + unsigned NumModels, + std::vector *Models, + unsigned Timeout = 0) = 0; + + virtual SMTLIBSolver *getSMTLIBSolver() = 0; + virtual std::string getName() = 0; virtual @@ -90,8 +98,9 @@ class Solver { virtual std::error_code abstractPrecondition(const BlockPCs &BPCs, const std::vector &PCs, - InstMapping &Mapping, InstContext &IC, - bool &FoundWeakest) = 0; + InstMapping &Mapping, InstContext &IC, bool &FoundWeakest, + std::vector> &KBResults, + std::vector> &CRResults) = 0; }; std::unique_ptr createBaseSolver( diff --git a/include/souper/Generalize/Reducer.h b/include/souper/Generalize/Reducer.h new file mode 100644 index 000000000..701151765 --- /dev/null +++ b/include/souper/Generalize/Reducer.h @@ -0,0 +1,69 @@ +#ifndef SOUPER_GENERALIZE_REDUCER_H +#define SOUPER_GENERALIZE_REDUCER_H + +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/KnownBits.h" +#include "souper/Parser/Parser.h" +#include "souper/Infer/EnumerativeSynthesis.h" +#include "souper/Infer/ConstantSynthesis.h" +#include "souper/Infer/Pruning.h" +#include "souper/Infer/SynthUtils.h" + +namespace souper { + +class Reducer { +public: + Reducer(InstContext &IC_, Solver *S_) : IC(IC_), S(S_), varnum(0), numSolverCalls(0) {} + + ParsedReplacement ReduceGreedy(ParsedReplacement Input); + + ParsedReplacement ReducePairsGreedy(ParsedReplacement Input); + + ParsedReplacement ReduceTriplesGreedy(ParsedReplacement Input); + + // Eventually replace the functions in Preconditions{.h/.cpp} with this. + // Does not produce exhaustive result. TODO Have an option to wrap in a cegis loop. + bool inferKBPrecondition(ParsedReplacement &Input, std::vector Targets); + + ParsedReplacement ReduceGreedyKBIFY(ParsedReplacement Input); + + ParsedReplacement ReduceRedundantPhis(ParsedReplacement Input); + + // Assumes Input is valid + ParsedReplacement ReducePCs(ParsedReplacement Input); + + ParsedReplacement WeakenKB(ParsedReplacement Input); + + ParsedReplacement WeakenCR(ParsedReplacement Input); + + ParsedReplacement WeakenDB(ParsedReplacement Input); + + ParsedReplacement WeakenOther(ParsedReplacement Input); + + ParsedReplacement ReducePCsToDF(ParsedReplacement Input); + + ParsedReplacement ReducePoison(ParsedReplacement Input); + + bool VerifyInput(ParsedReplacement &Input); + + bool safeToRemove(Inst *I, ParsedReplacement &Input); + + Inst *Eliminate(ParsedReplacement &Input, Inst *I); + + void ReduceRec(ParsedReplacement Input_, std::vector &Results); + void Stats() { + llvm::outs() << "Solver Calls: " << numSolverCalls << "\n"; + } +private: + InstContext &IC; + Solver *S; + int varnum; + int numSolverCalls; + std::unordered_set DNR; +}; + +} + +#endif + diff --git a/include/souper/Infer/AliveDriver.h b/include/souper/Infer/AliveDriver.h index a2fc40d4f..18d4c9b74 100644 --- a/include/souper/Infer/AliveDriver.h +++ b/include/souper/Infer/AliveDriver.h @@ -30,7 +30,7 @@ class AliveDriver { typedef std::unordered_map Cache; public: AliveDriver(Inst *LHS_, Inst *PreCondition_, InstContext &IC_, - std::vector ExtraInputs = {}); + const std::vector &ExtraInputs = {}, bool WidthIndep = false); std::map synthesizeConstants(souper::Inst *RHS); std::map synthesizeConstantsWithCegis(souper::Inst *RHS, InstContext &IC); @@ -41,6 +41,12 @@ class AliveDriver { delete(p.second); } } + + std::vector> getValidTypings() { + return ValidTypings; + } + + bool WidthIndependentMode; // This probably doesn't need to be public private: Inst *LHS, *PreCondition; @@ -51,7 +57,6 @@ class AliveDriver { std::unordered_map TypeCache; - IR::Type &getType(int n); IR::Type &getOverflowType(int n); @@ -65,6 +70,8 @@ class AliveDriver { std::unordered_map NamesCache; bool IsLHS; + std::vector> ValidTypings; + InstContext &IC; smt::smt_initializer smt_init; }; diff --git a/include/souper/Infer/EnumerativeSynthesis.h b/include/souper/Infer/EnumerativeSynthesis.h index c15b15d14..84f8ad659 100644 --- a/include/souper/Infer/EnumerativeSynthesis.h +++ b/include/souper/Infer/EnumerativeSynthesis.h @@ -30,6 +30,8 @@ namespace souper { class EnumerativeSynthesis { public: + EnumerativeSynthesis(); + // Synthesize an instruction from the specification in LHS std::error_code synthesize(SMTLIBSolver *SMTSolver, const BlockPCs &BPCs, @@ -38,6 +40,10 @@ class EnumerativeSynthesis { bool CheckAllGuesses, InstContext &IC, unsigned Timeout); + std::vector + generateExprs(InstContext &IC, size_t CountLimit, + std::vector Vars, size_t Width); + }; } diff --git a/include/souper/Infer/Interpreter.h b/include/souper/Infer/Interpreter.h index 43ad8e544..d07353b38 100644 --- a/include/souper/Infer/Interpreter.h +++ b/include/souper/Infer/Interpreter.h @@ -115,15 +115,18 @@ EvalValue evaluateAShr(llvm::APInt A, llvm::APInt B); EvalValue evaluateSingleInst(Inst *I, std::vector &Args); public: - ConcreteInterpreter() {} - ConcreteInterpreter(ValueCache &Input) : Cache(Input) {} - ConcreteInterpreter(Inst *I, ValueCache &Input) : Cache(Input) { + ConcreteInterpreter() : Cache() {} + ConcreteInterpreter(ValueCache Input) : Cache(Input) {} + ConcreteInterpreter(Inst *I, ValueCache Input) : Cache(Input) { CacheWritable = true; evaluateInst(I); CacheWritable = false; } void setEvalPhiFirstBranch() {EvalPhiFirstBranch = true;}; EvalValue evaluateInst(Inst *Root); + + void printCache(llvm::raw_ostream &Out); + }; } diff --git a/include/souper/Infer/Preconditions.h b/include/souper/Infer/Preconditions.h index 7aabc735c..aecfff26e 100644 --- a/include/souper/Infer/Preconditions.h +++ b/include/souper/Infer/Preconditions.h @@ -3,13 +3,23 @@ #include "souper/Inst/Inst.h" #include "llvm/Support/KnownBits.h" +#include "llvm/IR/ConstantRange.h" extern unsigned DebugLevel; namespace souper { class SMTLIBSolver; class Solver; + +std::pair>, +std::vector>> +inferAbstractPreconditions(SynthesisContext &SC, Inst *RHS, + Solver *S, bool &FoundWeakest); + std::vector> inferAbstractKBPreconditions(SynthesisContext &SC, Inst *RHS, - SMTLIBSolver *SMTSolver, Solver *S, bool &FoundWeakest); + Solver *S, bool &FoundWeakest); +std::vector> + inferAbstractCRPreconditions(SynthesisContext &SC, Inst *RHS, + Solver *S, bool &FoundWeakest); } #endif // SOUPER_PRECONDITIONS_H diff --git a/include/souper/Infer/SynthUtils.h b/include/souper/Infer/SynthUtils.h new file mode 100644 index 000000000..b31c321a3 --- /dev/null +++ b/include/souper/Infer/SynthUtils.h @@ -0,0 +1,148 @@ +#ifndef SOUPER_SYNTH_UTILS_H +#define SOUPER_SYNTH_UTILS_H + +#include "souper/Inst/Inst.h" +#include "souper/Infer/EnumerativeSynthesis.h" +#include "souper/Infer/ConstantSynthesis.h" +#include "souper/Parser/Parser.h" +#include "souper/Infer/Pruning.h" + +namespace souper { + +// TODO: Lazy construction instead of eager. +// eg: Instead of Builder(I, IC).Add(1)() +// we could do Builder(I).Add(1)(IC) +class Builder { +public: + Builder(Inst *I_, InstContext &IC_) : I(I_), IC(IC_) {} + Builder(InstContext &IC_, Inst *I_) : I(I_), IC(IC_) {} + Builder(InstContext &IC_, llvm::APInt Value) : IC(IC_) { + I = IC.getConst(Value); + } + Builder(Inst *I_, InstContext &IC_, uint64_t Value) : IC(IC_) { + I = IC.getConst(llvm::APInt(I_->Width, Value)); + } + + Inst *operator()() { + assert(I); + return I; + } + +#define BINOP(K) \ + template Builder K(T t) { \ + auto L = I; auto R = i(t, *this); \ + return Builder(IC.getInst(Inst::K, L->Width, {L, R}), IC); \ + } + + BINOP(Add) BINOP(Sub) BINOP(Mul) + BINOP(And) BINOP(Xor) BINOP(Or) + BINOP(Shl) BINOP(LShr) BINOP(UDiv) + BINOP(SDiv) +#undef BINOP + + template Builder Ugt(T t) { \ + auto L = I; auto R = i(t, *this); \ + return Builder(IC.getInst(Inst::Ult, 1, {R, L}), IC); \ + } + +#define BINOPW(K) \ + template Builder K(T t) { \ + auto L = I; auto R = i(t, *this); \ + return Builder(IC.getInst(Inst::K, 1, {L, R}), IC); \ + } + BINOPW(Slt) BINOPW(Ult) BINOPW(Sle) BINOPW(Ule) + BINOPW(Eq) BINOPW(Ne) +#undef BINOPW + +#define UNOP(K) \ + Builder K() { \ + auto L = I; \ + return Builder(IC.getInst(Inst::K, L->Width, {L}), IC); \ + } + UNOP(LogB) UNOP(BitReverse) UNOP(BSwap) UNOP(Cttz) UNOP(Ctlz) + UNOP(BitWidth) UNOP(CtPop) +#undef UNOP + + Builder Flip() { + auto L = I; + auto AllOnes = IC.getConst(llvm::APInt::getAllOnesValue(L->Width)); + return Builder(IC.getInst(Inst::Xor, L->Width, {L, AllOnes}), IC); + } + Builder Negate() { + auto L = I; + auto Zero = IC.getConst(llvm::APInt(L->Width, 0)); + return Builder(IC.getInst(Inst::Sub, L->Width, {Zero, L}), IC); + } + +#define UNOPW(K) \ + Builder K(size_t W) { \ + auto L = I; \ + return Builder(IC.getInst(Inst::K, W, {L}), IC); \ + } + UNOPW(ZExt) UNOPW(SExt) UNOPW(Trunc) +#undef UNOPW + +private: + Inst *I = nullptr; + InstContext &IC; + + Inst *i(Builder A, Inst *I) { + assert(A.I); + return A.I; + } + + template + Inst *i(N Number, Builder B) { + return B.IC.getConst(llvm::APInt(B.I->Width, Number, false)); + } + + template<> + Inst *i(Inst *I, Builder B) { + assert(I); + return I; + } + + template<> + Inst *i(Builder A, Builder B) { + assert(A.I); + return A.I; + } + + template<> + Inst *i(std::string Number, Builder B) { + return B.IC.getConst(llvm::APInt(B.I->Width, Number, 10)); + } + + template<> + Inst *i(llvm::APInt Number, Builder B) { + return B.IC.getConst(Number); + } +}; + +Inst *Replace(Inst *R, InstContext &IC, std::map &M); +ParsedReplacement Replace(ParsedReplacement I, InstContext &IC, + std::map &M); + +Inst *Clone(Inst *R, InstContext &IC); + +InstMapping Clone(InstMapping In, InstContext &IC); + +ParsedReplacement Clone(ParsedReplacement In, InstContext &IC); + +// Also Synthesizes given constants +// Returns clone if verified, nullptrs if not +std::optional Verify(ParsedReplacement Input, InstContext &IC, Solver *S); +// bool IsValid(ParsedReplacement Input, InstContext &IC, Solver *S); + +std::map findOneConstSet(ParsedReplacement Input, const std::set &SymCS, InstContext &IC, Solver *S); + +std::vector> findValidConsts(ParsedReplacement Input, const std::set &Insts, InstContext &IC, Solver *S, size_t MaxCount); + +ValueCache GetCEX(const ParsedReplacement &Input, InstContext &IC, Solver *S); + +std::vector GetMultipleCEX(ParsedReplacement Input, InstContext &IC, Solver *S, size_t MaxCount); + +int profit(const ParsedReplacement &P); + +} +#endif diff --git a/include/souper/Inst/Inst.h b/include/souper/Inst/Inst.h index b854c570d..97326eb15 100644 --- a/include/souper/Inst/Inst.h +++ b/include/souper/Inst/Inst.h @@ -100,6 +100,8 @@ struct Inst : llvm::FoldingSetNode { BSwap, Cttz, Ctlz, + LogB, + BitWidth, BitReverse, FShl, FShr, @@ -125,6 +127,11 @@ struct Inst : llvm::FoldingSetNode { ReservedConst, ReservedInst, + KnownOnesP, + KnownZerosP, + RangeP, + DemandedMask, + None, } Kind; @@ -269,6 +276,7 @@ class InstContext { llvm::APInt DemandedBits, bool Available); std::vector getVariables() const; + std::vector getVariablesFor(Inst *Root) const; }; struct SynthesisContext { @@ -282,10 +290,11 @@ struct SynthesisContext { unsigned Timeout; }; -int cost(Inst *I, bool IgnoreDepsWithExternalUses = false); +int cost(Inst *I, bool IgnoreDepsWithExternalUses = false, std::set Ignore = {}); +int backendCost(Inst *I, bool IgnoreDepsWithExternalUses = false); int countHelper(Inst *I, std::set &Visited); int instCount(Inst *I); -int benefit(Inst *LHS, Inst *RHS); +int benefit(Inst *LHS, Inst *RHS, bool IgnoreDepsWithExternalUses = true); void PrintReplacement(llvm::raw_ostream &Out, const BlockPCs &BPCs, const std::vector &PCs, InstMapping Mapping, @@ -315,7 +324,7 @@ Inst *getInstCopy(Inst *I, InstContext &IC, std::map &InstCache, std::map &BlockCache, std::map *ConstMap, - bool CloneVars); + bool CloneVars, bool CloneBlocks = true); Inst *instJoin(Inst *I, Inst *Reserved, Inst *NewInst, std::map &InstCache, InstContext &IC); diff --git a/include/souper/Tool/CandidateMapUtils.h b/include/souper/Tool/CandidateMapUtils.h index 1699d5942..5c298a616 100644 --- a/include/souper/Tool/CandidateMapUtils.h +++ b/include/souper/Tool/CandidateMapUtils.h @@ -35,6 +35,8 @@ typedef std::vector CandidateMap; void AddToCandidateMap(CandidateMap &M, const CandidateReplacement &CR); +void HarvestAndPrintOpts(InstContext &IC, ExprBuilderContext &EBC, llvm::Module *M, Solver *S); + void AddModuleToCandidateMap(InstContext &IC, ExprBuilderContext &EBC, CandidateMap &CandMap, llvm::Module *M); diff --git a/include/souper/Tool/GetSolver.h.in b/include/souper/Tool/GetSolver.h.in index 8d98ae492..a2f7b9d3e 100644 --- a/include/souper/Tool/GetSolver.h.in +++ b/include/souper/Tool/GetSolver.h.in @@ -62,7 +62,8 @@ static std::unique_ptr GetUnderlyingSolver() { static std::unique_ptr GetSolver(KVStore *&KV) { std::unique_ptr US = GetUnderlyingSolver(); - if (!US) return NULL; + if (!US) + return NULL; std::unique_ptr S = createBaseSolver (std::move(US), SolverTimeout); if (ExternalCache) { KV = new KVStore; diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index 75fa3e52f..de0d5622a 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -321,4 +321,62 @@ llvm::Value *Codegen::getValue(Inst *I) { Inst::getKindName(I->K) + " in Codegen::getValue()"); } +static std::vector +GetInputArgumentTypes(const InstContext &IC, llvm::LLVMContext &Context, Inst *Root) { + const std::vector AllVariables = IC.getVariablesFor(Root); + + std::vector ArgTypes; + ArgTypes.reserve(AllVariables.size()); + for (const Inst *const Var : AllVariables) { + llvm::errs() << "arg with width " << Var->Width << " and number " << Var->Number << "\n"; + ArgTypes.emplace_back(Type::getIntNTy(Context, Var->Width)); + } + + return ArgTypes; +} + +static std::map GetArgsMapping(const InstContext &IC, + Function *F, Inst *Root) { + std::map Args; + + const std::vector AllVariables = IC.getVariablesFor(Root); + for (auto zz : llvm::zip(AllVariables, F->args())) + Args[std::get<0>(zz)] = &(std::get<1>(zz)); + + return Args; +}; + +/// If there are no errors, the function returns false. If an error is found, +/// a message describing the error is written to OS (if non-null) and true is +/// returned. +bool genModule(InstContext &IC, souper::Inst *I, llvm::Module &Module) { + llvm::LLVMContext &Context = Module.getContext(); + const std::vector ArgTypes = GetInputArgumentTypes(IC, Context, I); + const auto FT = llvm::FunctionType::get( + /*Result=*/Codegen::GetInstReturnType(Context, I), + /*Params=*/ArgTypes, /*isVarArg=*/false); + + Function *F = Function::Create(FT, Function::ExternalLinkage, "fun", &Module); + + const std::map Args = GetArgsMapping(IC, F, I); + + BasicBlock *BB = BasicBlock::Create(Context, "entry", F); + + llvm::IRBuilder<> Builder(Context); + Builder.SetInsertPoint(BB); + + Value *RetVal = Codegen(Context, &Module, Builder, /*DT*/ nullptr, + /*ReplacedInst*/ nullptr, Args) + .getValue(I); + + Builder.CreateRet(RetVal); + + // Validate the generated code, checking for consistency. + if (verifyFunction(*F, &llvm::errs())) + return true; + if (verifyModule(Module, &llvm::errs())) + return true; + return false; +} + } // namespace souper diff --git a/lib/Codegen/MachineCost.cpp b/lib/Codegen/MachineCost.cpp new file mode 100644 index 000000000..b394e988f --- /dev/null +++ b/lib/Codegen/MachineCost.cpp @@ -0,0 +1,173 @@ +// Copyright 2014 The Souper Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "souper/Codegen/Codegen.h" +#include "souper/Inst/Inst.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Value.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/SmallVectorMemoryBuffer.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include + +#define DEBUG_TYPE "souper" + +using namespace llvm; + +namespace souper { + +void optimizeModule(llvm::Module &M) { + llvm::LoopAnalysisManager LAM; + llvm::FunctionAnalysisManager FAM; + llvm::CGSCCAnalysisManager CGAM; + llvm::ModuleAnalysisManager MAM; + + llvm::PassBuilder PB; + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + llvm::FunctionPassManager FPM = + PB.buildFunctionSimplificationPipeline(llvm::PassBuilder::OptimizationLevel::O2, + ThinOrFullLTOPhase::None); + llvm::ModulePassManager MPM; + MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); + MPM.run(M, MAM); +} + +long getCodeSize(Module &M, TargetMachine *TM) { + M.setDataLayout(TM->createDataLayout()); + SmallVector DotO; + raw_svector_ostream dest(DotO); + + legacy::PassManager pass; + if (TM->addPassesToEmitFile(pass, dest, nullptr, CGFT_ObjectFile)) { + errs() << "Target machine can't emit a file of this type"; + report_fatal_error("oops"); + } + pass.run(M); + + SmallVectorMemoryBuffer Buf(std::move(DotO)); + auto ObjOrErr = object::ObjectFile::createObjectFile(Buf); + if (!ObjOrErr) + report_fatal_error("createObjectFile() failed"); + object::ObjectFile *OF = ObjOrErr.get().get(); + auto SecList = OF->sections(); + long Size = 0; + for (auto &S : SecList) { + if (S.isText()) + Size += S.getSize(); + } + if (Size > 0) + return Size; + else + report_fatal_error("no text segment found"); +} + +struct TargetInfo { + std::string Trip, CPU; +}; + +std::vector Targets { + { "x86_64", "skylake" }, + { "aarch64", "apple-a12" }, +}; + +bool Init = false; + +void getBackendCost(InstContext &IC, souper::Inst *I, BackendCost &BC) { + // TODO is this better than just forcing all clients of this code to + // do the init themselves? + if (!Init) { + InitializeAllTargetInfos(); + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmParsers(); + InitializeAllAsmPrinters(); + Init = true; + } + + llvm::LLVMContext C; + llvm::Module M("", C); + if (genModule(IC, I, M)) + llvm::report_fatal_error("codegen error in getBackendCost()"); + + optimizeModule(M); + + llvm::errs() << M; + + BackendCost Cost; + for (auto &T : Targets) { + std::string Error; + auto Target = TargetRegistry::lookupTarget(T.Trip, Error); + if (!Target) { + errs() << Error; + report_fatal_error("can't lookup target"); + } + + auto Features = ""; + TargetOptions Opt; + auto RM = Optional(); + auto TM = Target->createTargetMachine(T.Trip, T.CPU, Features, Opt, RM); + + Cost.C.push_back(getCodeSize(M, TM)); + } + + llvm::errs() << "cost vector: "; + for (auto I : Cost.C) { + llvm::errs() << I << " "; + } + llvm::errs() << "\n"; +} + +int threeWayCompare(int A, int B) { + if (A < B) + return -1; + if (A > B) + return 1; + return 0; +} + +// "The value returned indicates whether the element passed as first +// argument is considered to go before the second" +bool compareCosts(const BackendCost &C1, const BackendCost &C2) { + assert(C1.C.size() == C2.C.size()); + + int Count = 0; + for (int i = 0; i < C1.C.size(); ++i) + Count += threeWayCompare(C1.C[i], C2.C[i]); + if (Count < 0) + return true; + if (Count > 0) + return false; + + // break ties using souper cost? + // break final ties how? we want a canonical winner for all cases + // FIXME -- not finished + return false; +} + +} // namespace souper diff --git a/lib/Extractor/Candidates.cpp b/lib/Extractor/Candidates.cpp index 4c98f4de5..867079a64 100644 --- a/lib/Extractor/Candidates.cpp +++ b/lib/Extractor/Candidates.cpp @@ -76,6 +76,10 @@ static llvm::cl::opt PrintSignBitsAtReturn( "print-sign-bits-at-return", llvm::cl::desc("Print sign bits dfa in each value returned from a function (default=false)"), llvm::cl::init(false)); +static llvm::cl::opt NoExternalUses( + "no-external-uses", + llvm::cl::desc("Do not mark external uses. (default=false)"), + llvm::cl::init(false)); static llvm::cl::opt PrintRangeAtReturn( "print-range-at-return", llvm::cl::desc("Print range inforation in each value returned from a function (default=false)"), @@ -294,6 +298,9 @@ Inst *ExprBuilder::buildGEP(Inst *Ptr, gep_type_iterator begin, #endif void ExprBuilder::markExternalUses (Inst *I) { + if (NoExternalUses) { + return; + } std::map UsesCount; std::unordered_set Visited; std::vector Stack; @@ -1022,7 +1029,7 @@ void ExtractExprCandidates(Function &F, const LoopInfo *LI, DemandedBits *DB, In->HarvestFrom = nullptr; EB.markExternalUses(In); BCS->Replacements.emplace_back(&I, InstMapping(In, 0)); - assert(EB.get(&I)->hasOrigin(&I)); + assert(EB.get(&I)->K == Inst::Const || EB.get(&I)->hasOrigin(&I)); } if (!BCS->Replacements.empty()) { std::unordered_set VisitedBlocks; diff --git a/lib/Extractor/ExprBuilder.cpp b/lib/Extractor/ExprBuilder.cpp index facf92826..54016f5eb 100644 --- a/lib/Extractor/ExprBuilder.cpp +++ b/lib/Extractor/ExprBuilder.cpp @@ -344,6 +344,7 @@ Inst *ExprBuilder::getDataflowConditions(Inst *I) { Inst *OneBits = LIC->getInst(Inst::Eq, 1, {VarAndOnes, Ones}); Result = LIC->getInst(Inst::And, 1, {Result, OneBits}); } + if (I->NonZero) { Inst *NonZeroBits = LIC->getInst(Inst::Ne, 1, {I, Zero}); Result = LIC->getInst(Inst::And, 1, {Result, NonZeroBits}); diff --git a/lib/Extractor/KLEEBuilder.cpp b/lib/Extractor/KLEEBuilder.cpp index a94fb2e84..1ea8a86aa 100644 --- a/lib/Extractor/KLEEBuilder.cpp +++ b/lib/Extractor/KLEEBuilder.cpp @@ -340,6 +340,56 @@ class KLEEBuilder : public ExprBuilder { return SubExpr::create(klee::ConstantExpr::create(Width, Width), countOnes(Val)); } + case Inst::LogB: { + ref L = get(Ops[0]); + unsigned Width = L->getWidth(); + ref Val = L; + for (unsigned i=0, j=0; j L = get(Ops[0]); + unsigned Width = L->getWidth(); + return klee::ConstantExpr::create(Width, Width); + } + + case Inst::KnownOnesP: { + auto VarAndOnes = klee::AndExpr::create(get(Ops[0]), get(Ops[1])); + return klee::EqExpr::create(VarAndOnes, get(Ops[1])); + } + case Inst::KnownZerosP: { + auto NotZeros = klee::NotExpr::create(get(Ops[1])); + auto VarNotZero = klee::OrExpr::create(get(Ops[0]), NotZeros); + return klee::EqExpr::create(VarNotZero, NotZeros); + } + case Inst::RangeP: { + auto Var = get(Ops[0]); + auto Lower = get(Ops[1]); + auto Upper = get(Ops[2]); + auto GELower = klee::SgeExpr::create(Var, Lower); + auto LTUpper = klee::SltExpr::create(Var, Upper); + auto Ordinary = klee::AndExpr::create(GELower, LTUpper); + + auto GEUpper = klee::SgeExpr::create(Var, Upper); + auto LTLower = klee::SltExpr::create(Var, Lower); + auto Wrapped = klee::OrExpr::create(GEUpper, LTLower); + + auto Cond = klee::SgtExpr::create(Upper, Lower); + + return klee::SelectExpr::create(Cond, Ordinary, Wrapped); + } + + case Inst::DemandedMask: { + return klee::AndExpr::create(get(Ops[0]), get(Ops[1])); + } + case Inst::FShl: case Inst::FShr: { unsigned IWidth = I->Width; diff --git a/lib/Extractor/Solver.cpp b/lib/Extractor/Solver.cpp index d4d940799..435eb9135 100644 --- a/lib/Extractor/Solver.cpp +++ b/lib/Extractor/Solver.cpp @@ -21,6 +21,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" +#include "souper/Codegen/Codegen.h" #include "souper/Extractor/Solver.h" #include "souper/Infer/AliveDriver.h" #include "souper/Infer/ConstantSynthesis.h" @@ -265,39 +266,14 @@ class BaseSolver : public Solver { std::error_code abstractPrecondition(const BlockPCs &BPCs, const std::vector &PCs, - InstMapping &Mapping, InstContext &IC, - bool &FoundWeakest) override { + InstMapping &Mapping, InstContext &IC, bool &FoundWeakest, + std::vector> &KBResults, + std::vector> &CRResults) override { SynthesisContext SC{IC, SMTSolver.get(), Mapping.LHS, /*LHSUB*/nullptr, PCs, BPCs, /*CheckAllGuesses=*/false, Timeout}; - std::vector> Results = - inferAbstractKBPreconditions(SC, Mapping.RHS, SMTSolver.get(), this, FoundWeakest); - - ReplacementContext RC; - auto LHSStr = RC.printInst(Mapping.LHS, llvm::outs(), true); - llvm::outs() << "infer " << LHSStr << "\n"; - auto RHSStr = RC.printInst(Mapping.RHS, llvm::outs(), true); - llvm::outs() << "result " << RHSStr << "\n"; - for (size_t i = 0; i < Results.size(); ++i) { - for (auto It = Results[i].begin(); It != Results[i].end(); ++It) { - auto &&P = *It; - std::string dummy; - llvm::raw_string_ostream str(dummy); - auto VarStr = RC.printInst(P.first, str, false); - llvm::outs() << VarStr << " -> " << Inst::getKnownBitsString(P.second.Zero, P.second.One); - - auto Next = It; - Next++; - if (Next != Results[i].end()) { - llvm::outs() << " (and) "; - } - } - if (i == Results.size() - 1) { - llvm::outs() << "\n"; - } else { - llvm::outs() << "\n(or)\n"; - } - } + std::tie(KBResults, CRResults) = + inferAbstractPreconditions(SC, Mapping.RHS, this, FoundWeakest); return {}; } @@ -403,10 +379,10 @@ class BaseSolver : public Solver { return std::error_code(); } - std::error_code infer(const BlockPCs &BPCs, - const std::vector &PCs, - Inst *LHS, std::vector &RHSs, - bool AllowMultipleRHSs, InstContext &IC) override { + std::error_code inferHelper(const BlockPCs &BPCs, + const std::vector &PCs, + Inst *LHS, std::vector &RHSs, + bool AllowMultipleRHSs, InstContext &IC) { std::error_code EC; // FIXME -- it's a bit messy to have this custom logic here @@ -457,10 +433,38 @@ class BaseSolver : public Solver { return EC; } - RHSs.clear(); +// RHSs.clear(); + return EC; + } + + std::error_code infer(const BlockPCs &BPCs, + const std::vector &PCs, + Inst *LHS, std::vector &RHSs, + bool AllowMultipleRHSs, InstContext &IC) override { + auto EC = inferHelper(BPCs, PCs, LHS, RHSs, AllowMultipleRHSs, IC); + if (RHSs.size() <= 1) + return EC; + + for (auto &RHS : RHSs) { + BackendCost BC; + getBackendCost(IC, RHS, BC); + // FIXME sort the list + } + return EC; } + SMTLIBSolver *getSMTLIBSolver() override { + return SMTSolver.get(); + } + + std::error_code isSatisfiable(llvm::StringRef Query, bool &Result, + unsigned NumModels, + std::vector *Models, + unsigned Timeout = 0) override { + return SMTSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout); + } + std::error_code isValid(InstContext &IC, const BlockPCs &BPCs, const std::vector &PCs, InstMapping Mapping, bool &IsValid, @@ -512,13 +516,12 @@ class BaseSolver : public Solver { // case LHS to evaluate to UB std::vector Inputs; findVars(LHS, Inputs); - PruningManager Pruner(SC, Inputs, DebugLevel); - Pruner.init(); - ConstantSynthesis CS{&Pruner}; +// PruningManager Pruner(SC, Inputs, DebugLevel); +// Pruner.init(); + ConstantSynthesis CS{nullptr}; std::error_code EC = CS.synthesize(SMTSolver.get(), BPCs, PCs, InstMapping(LHS, RHS), ConstSet, ResultMap, IC, MaxConstantSynthesisTries, Timeout, /*AvoidNops=*/false); - if (EC || ResultMap.empty()) return EC; @@ -677,6 +680,11 @@ class MemCachingSolver : public Solver { return ent->second.first; } } + + SMTLIBSolver *getSMTLIBSolver() override { + return UnderlyingSolver->getSMTLIBSolver(); + } + std::error_code inferConst(const BlockPCs &BPCs, const std::vector &PCs, Inst *LHS, Inst *&RHS, @@ -717,6 +725,13 @@ class MemCachingSolver : public Solver { } } + std::error_code isSatisfiable(llvm::StringRef Query, bool &Result, + unsigned NumModels, + std::vector *Models, + unsigned Timeout = 0) override { + return UnderlyingSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout); + } + std::string getName() override { return UnderlyingSolver->getName() + " + internal cache"; } @@ -745,9 +760,11 @@ class MemCachingSolver : public Solver { std::error_code abstractPrecondition(const BlockPCs &BPCs, const std::vector &PCs, - InstMapping &Mapping, InstContext &IC, - bool &FoundWeakest) override { - return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest); + InstMapping &Mapping, InstContext &IC, bool &FoundWeakest, + std::vector> &KBResults, + std::vector> &CRResults) override { + return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest, + KBResults, CRResults); } std::error_code knownBits(const BlockPCs &BPCs, @@ -840,6 +857,10 @@ class ExternalCachingSolver : public Solver { } } + SMTLIBSolver *getSMTLIBSolver() override { + return UnderlyingSolver->getSMTLIBSolver(); + } + llvm::ConstantRange constantRange(const BlockPCs &BPCs, const std::vector &PCs, Inst *LHS, @@ -847,6 +868,13 @@ class ExternalCachingSolver : public Solver { return UnderlyingSolver->constantRange(BPCs, PCs, LHS, IC); } + std::error_code isSatisfiable(llvm::StringRef Query, bool &Result, + unsigned NumModels, + std::vector *Models, + unsigned Timeout = 0) override { + return UnderlyingSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout); + } + std::error_code isValid(InstContext &IC, const BlockPCs &BPCs, const std::vector &PCs, InstMapping Mapping, bool &IsValid, @@ -885,9 +913,11 @@ class ExternalCachingSolver : public Solver { std::error_code abstractPrecondition(const BlockPCs &BPCs, const std::vector &PCs, - InstMapping &Mapping, InstContext &IC, - bool &FoundWeakest) override { - return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest); + InstMapping &Mapping, InstContext &IC, bool &FoundWeakest, + std::vector> &KBResults, + std::vector> &CRResults) override { + return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest, + KBResults, CRResults); } std::error_code knownBits(const BlockPCs &BPCs, diff --git a/lib/Generalize/Reducer.cpp b/lib/Generalize/Reducer.cpp new file mode 100644 index 000000000..a1ddfc8a8 --- /dev/null +++ b/lib/Generalize/Reducer.cpp @@ -0,0 +1,1037 @@ +#include "llvm/Support/KnownBits.h" +#include "llvm/IR/ConstantRange.h" +#include "souper/Generalize/Reducer.h" +#include "souper/Infer/SynthUtils.h" + +#define _LIBCPP_DISABLE_DEPRECATION_WARNINGS + +namespace souper { + +void collectInsts(Inst *I, std::set &Results) { + std::vector Stack{I}; + while (!Stack.empty()) { + auto Current = Stack.back(); + Stack.pop_back(); + + Results.insert(Current); + + for (auto Child : Current->Ops) { + if (Results.find(Child) == Results.end()) { + Stack.push_back(Child); + } + } + } +} + +void collectInstsToDepth(Inst *I, size_t Depth, std::set &Results) { + std::vector Stack{I}; + std::map DepthMap; + DepthMap[I] = 0; + while (!Stack.empty()) { + auto Current = Stack.back(); + Stack.pop_back(); + + if (DepthMap[Current] > Depth) { + continue; + } + Results.insert(Current); + + for (auto Child : Current->Ops) { + DepthMap[Child] = DepthMap[Current] + 1; + if (Results.find(Child) == Results.end()) { + Stack.push_back(Child); + } + } + } +} + +bool IsReductionCostEffective(Inst *LHS, Inst *RHS) { + return souper::cost(RHS) < souper::cost(LHS); +} + +ParsedReplacement Reducer::ReducePairsGreedy(ParsedReplacement Input) { + size_t Depth = 4, Passes = 5; + bool Changed = false; + while (Passes-- ) { + // Try to remove two instructions at a time + std::set Insts; + + collectInstsToDepth(Input.Mapping.LHS, Depth, Insts); + + for (auto &&I : Insts) { + if (!safeToRemove(I, Input)) { + continue; + } + for (auto &&J : Insts) { + if (I != J) { + if (!safeToRemove(J, Input)) { + continue; + } + + auto Copy = Input; + Eliminate(Input, I); + Eliminate(Input, J); + + if (!IsReductionCostEffective(Input.Mapping.LHS, Input.Mapping.RHS)) { + Input = Copy; + continue; + } + + // Input.print(llvm::errs(), true); + + if (!VerifyInput(Input)) { + Input = Copy; + continue; + } + Changed = true; + + } + } + } + if (!Changed) { + break; + } + } + + return Input; +} + +ParsedReplacement Reducer::ReduceTriplesGreedy(ParsedReplacement Input) { + size_t Depth = 4, Passes = 2; + bool Changed = false; + while (Passes-- ) { + // Try to remove two instructions at a time + std::set Insts; + + collectInstsToDepth(Input.Mapping.LHS, Depth, Insts); + + for (auto &&I : Insts) { + if (!safeToRemove(I, Input)) { + continue; + } + for (auto &&J : Insts) { + if (I != J) { + if (!safeToRemove(J, Input)) { + continue; + } + + for (auto &&K : Insts) { + if (!safeToRemove(K, Input)) { + continue; + } + if (I != K && J != K) { + + auto Copy = Input; + Eliminate(Input, I); + Eliminate(Input, J); + Eliminate(Input, K); + + if (!IsReductionCostEffective(Input.Mapping.LHS, Input.Mapping.RHS)) { + Input = Copy; + continue; + } + + if (!VerifyInput(Input)) { + Input = Copy; + continue; + } + Changed = true; + } + } + } + } + } + if (!Changed) { + break; + } + } + + return Input; +} + +ParsedReplacement Reducer::ReduceGreedy(ParsedReplacement Input) { + std::set Insts; + collectInsts(Input.Mapping.LHS, Insts); + // TODO: topological sort, to reduce number of solver calls + // Try to remove one instruction at a time + int failcount = 0; + std::set Visited; + do { + auto It = Insts.begin(); + auto I = *It; + Insts.erase(It); + if (Visited.find(I) != Visited.end()) { + continue; + } + Visited.insert(I); + if (!safeToRemove(I, Input)) { + continue; + } + auto Copy = Input; + Eliminate(Input, I); + + if (!IsReductionCostEffective(Input.Mapping.LHS, Input.Mapping.RHS)) { + Input = Copy; + failcount++; + if (failcount >= Insts.size()) { + break; + } + continue; + } + + if (!VerifyInput(Input)) { + Input = Copy; + failcount++; + if (failcount >= Insts.size()) { + break; + } + continue; + } + Insts.clear(); + collectInsts(Input.Mapping.LHS, Insts); + } while (!Insts.empty()); + return Input; +} + +// Eventually replace the functions in Preconditions{.h/.cpp} with this. +// Does not produce exhaustive result. TODO Have an option to wrap in a cegis loop. +bool Reducer::inferKBPrecondition(ParsedReplacement &Input, std::vector Targets) { + assert(Targets.size() == 1 && "Multiple targets unimplemented"); + std::map Result; + std::set SymConsts; + std::map InstCache; + std::map BlockCache; + std::map ConstMap; + size_t ConstID = 0; + for (auto V : Targets) { + auto C = IC.createSynthesisConstant(V->Width, ConstID++); + InstCache[V] = C; + SymConsts.insert(C); + } + auto Copy = Input; + InstMapping Rep; + Rep.LHS = getInstCopy(Input.Mapping.LHS, IC, InstCache, BlockCache, &ConstMap, false, false); + Rep.RHS = getInstCopy(Input.Mapping.RHS, IC, InstCache, BlockCache, &ConstMap, false, false); + + ConstantSynthesis CS; + +// llvm::errs() << "Constant synthesis problem: \n"; +// Input.print(llvm::errs(), true); +// llvm::errs() << "....end.... \n"; + + if (auto EC = CS.synthesize(S->getSMTLIBSolver(), Input.BPCs, Input.PCs, + Rep, SymConsts, ConstMap, IC, 30, 60, false)) { + llvm::errs() << "Constant Synthesis internal error : " << EC.message(); + } + + if (ConstMap.empty()) { + if (DebugLevel > 3) { + llvm::errs() << "Constant Synthesis failed, moving on.\n"; + } + Input = Copy; + } else { + InstCache.clear(); + + // TODO: Generalize before allowing multiple targets + auto NewVar = Targets[0]; + auto C = InstCache[NewVar]; + + InstCache[C] = NewVar; + + NewVar->KnownOnes = ConstMap[C]; + NewVar->KnownZeros = ~ConstMap[C]; + + // Give up if can't be weakened 'too much' + const size_t WeakeningThreshold = NewVar->Width/2; + size_t BitsWeakened = 0; + + for (size_t i = 0; i < NewVar->Width; ++i) { + auto SaveZero = NewVar->KnownZeros; + auto SaveOne = NewVar->KnownOnes; + + NewVar->KnownZeros.clearBit(i); + NewVar->KnownOnes.clearBit(i); + + if (!VerifyInput(Input)) { + NewVar->KnownZeros = SaveZero; + NewVar->KnownOnes = SaveOne; + } else { + BitsWeakened++; + } + } + if (BitsWeakened < WeakeningThreshold) { + Input = Copy; // Reset to old state + return false; + } else { + return true; + } + } + return false; +} + +ParsedReplacement Reducer::ReduceGreedyKBIFY(ParsedReplacement Input) { + std::set Insts; + collectInsts(Input.Mapping.LHS, Insts); + // TODO: topological sort, to reduce number of solver calls + // Try to remove one instruction at a time + size_t failcount = 0; + std::set Visited; + do { + + if (souper::cost(Input.Mapping.LHS) - souper::cost(Input.Mapping.LHS) <= 1) { + break; + } + + auto It = Insts.begin(); + auto I = *It; + Insts.erase(It); + if (Visited.find(I) != Visited.end()) { + continue; + } + Visited.insert(I); + if (!safeToRemove(I, Input) || I->Width == 1 /*1 bit KB doesn't make sense*/) { + continue; + } + auto Copy = Input; + auto NewVar = Eliminate(Input, I); + + // Input won't verify because ReduceGreedy has been called before this. + + // Try to replace NewVar with a symbolic constant and do constant synthesis. + Inst *C = IC.createSynthesisConstant(NewVar->Width, 0); + std::map InstCache = {{NewVar, C}}; + std::map BlockCache; + std::map ConstMap; + + InstMapping Rep; + Rep.LHS = getInstCopy(Input.Mapping.LHS, IC, InstCache, BlockCache, &ConstMap, false, false); + Rep.RHS = getInstCopy(Input.Mapping.RHS, IC, InstCache, BlockCache, &ConstMap, false, false); + + ConstantSynthesis CS; + std::set ConstSet{C}; + +// llvm::errs() << "Constant synthesis problem: \n"; +// Input.print(llvm::errs(), true); +// llvm::errs() << "....end.... \n"; + + if (auto EC = CS.synthesize(S->getSMTLIBSolver(), Input.BPCs, Input.PCs, + Rep, ConstSet, ConstMap, IC, 30, 60, false)) { + llvm::errs() << "Constant Synthesis internal error : " << EC.message(); + } + + if (ConstMap.empty()) { + if (DebugLevel > 3) { + llvm::errs() << "Constant Synthesis failed, moving on.\n"; + } + Input = Copy; + failcount++; + if (failcount >= Insts.size()) { + break; + } + } else { + InstCache.clear(); + InstCache[C] = NewVar; + + NewVar->KnownOnes = ConstMap[C]; + NewVar->KnownZeros = ~ConstMap[C]; + + // Give up if can't be weakened 'too much' + const size_t WeakeningThreshold = NewVar->Width/2; + size_t BitsWeakened = 0; + + for (size_t i = 0; i < NewVar->Width; ++i) { + auto SaveZero = NewVar->KnownZeros; + auto SaveOne = NewVar->KnownOnes; + + NewVar->KnownZeros.clearBit(i); + NewVar->KnownOnes.clearBit(i); + + if (!VerifyInput(Input)) { + NewVar->KnownZeros = SaveZero; + NewVar->KnownOnes = SaveOne; + } else { + BitsWeakened++; + } + } + if (BitsWeakened < WeakeningThreshold) { + Input = Copy; + failcount++; + if (failcount >= Insts.size()) { + break; + } + } + } + Insts.clear(); + collectInsts(Input.Mapping.LHS, Insts); + // ^ Can this be skipped? + } while (!Insts.empty()); + return Input; +} + +ParsedReplacement Reducer::ReduceRedundantPhis(ParsedReplacement Input) { + std::set Insts; + std::vector Phis; + + auto Collect = [&] () { + Insts.clear(); + Phis.clear(); + collectInsts(Input.Mapping.LHS, Insts); + collectInsts(Input.Mapping.RHS, Insts); + for (auto &&I : Insts) { + if (I->K == Inst::Phi) { + Phis.push_back(I); + } + } + }; + Collect(); + + size_t NumPhis = Phis.size(); + while (NumPhis --) { + std::map ICache; +// bool Done = false; + for (auto &&I : Phis) { + if (I->Ops.size() == 1) { + ICache[I] = I->Ops[0]; + } else if (I->Ops.size() > 1) { + bool allEq = true; + for (size_t i = 0; i < I->Ops.size(); ++i) { + if (I->Ops[i] != I->Ops[0]) { + allEq = false; + break; + } + } + if (allEq) { + ICache[I] = I->Ops[0]; + } else { +// Done = true; + } + } else { +// Done = true; + } + } + +// if (Done || instCount(Input.Mapping.LHS) - instCount(Input.Mapping.RHS) <= 1) { +// break; +// } + if (souper::cost(Input.Mapping.LHS) <= souper::cost(Input.Mapping.RHS)) { + break; + } + + Input.Mapping.LHS = Replace(Input.Mapping.LHS, IC, ICache); + Input.Mapping.RHS = Replace(Input.Mapping.RHS, IC, ICache); + for (auto &PC : Input.PCs) { + PC.LHS = Replace(PC.LHS, IC, ICache); + PC.RHS = Replace(PC.RHS, IC, ICache); + } + if (NumPhis) { + Collect(); + } + } + return Input; +} +size_t WeakenSingleCR(ParsedReplacement Input, InstContext &IC, Solver *S, + Inst *Target, std::optional Val) { + if (Target->Width <= 8) return 0; // hack + if (!Val.has_value()) { + // Synthesize a value + Inst *C = IC.createVar(Target->Width, "reservedconst_1"); + C->SynthesisConstID = 1; + std::map InstCache = {{Target, C}}; + + auto Copy = Input; + + auto Rep = Replace(Input, IC, InstCache); + + std::set ConstSet{C}; + + std::map ConstMap; + ConstantSynthesis CS; + +// Rep.print(llvm::errs(), true); + + if (auto EC = CS.synthesize(S->getSMTLIBSolver(), Rep.BPCs, + Rep.PCs, Rep.Mapping, ConstSet, ConstMap, IC, 30, 60, false)) { + llvm::errs() << "Constant Synthesis internal error : " << EC.message(); + } + + if (!ConstMap.empty()) { + Val = ConstMap[C]; + } + } + + if (!Val.has_value()) { + return 0; // fail + } + + auto Restore = Target->Range; + + // Binary search to extend upper and lower boundaries + llvm::ConstantRange R(Val.value()); + +// llvm::errs() << "R " << R << " " << Val.value() <<"\n"; + + auto Full = R.getFull(R.getBitWidth()); + + auto L = R.getLower(); + auto U = R.getUpper(); + + size_t inc = 1; + while (inc && U.slt(Full.getUpper())) { + +// llvm::errs() << "L " << L << " " << "U " << U << " inc " << inc <<"\n"; + + auto Backup = Target->Range; + auto Attempt = U + inc; + if (Attempt.sge(Full.getUpper())) { + Attempt = Full.getLower(); + } + Target->Range = llvm::ConstantRange(L, Attempt); + if (Verify(Input, IC, S)) { + U = Attempt; +// llvm::errs() << "U " << Attempt << '\n'; + inc *= 2; + } else { + inc /= 2; + Target->Range = Backup; + } + } + + size_t dec = 1; + while (dec && L.slt(0)) { +// llvm::errs() << "L " << L << " " << "U " << U << " inc " << dec <<"\n"; + auto Backup = Target->Range; + auto Attempt = L - dec; + if (Attempt.sle(Full.getLower())) { + Attempt = Full.getLower(); + } + Target->Range = llvm::ConstantRange(Attempt, U); + if (Verify(Input, IC, S)) { + L = Attempt; +// llvm::errs() << "L " << Attempt << '\n'; + dec *= 2; + } else { + dec /= 2; + Target->Range = Backup; + } + } + +// llvm::errs() << "HERE " << L << " " << U << "\n"; + + if ((U - L).sgt(1 << (Target->Width - 2))) { // Heuristic + return (U - L).getLimitedValue(); + } else { + Target->Range = Restore; + return 0; + }; + +} + +size_t WeakenSingleKB(ParsedReplacement Input, InstContext &IC, Solver *S, + Inst *Target, std::optional Val) { + size_t BitsWeakened = 0; + + if (Target->Width < 8) return 0; // hack + + if (!Val.has_value()) { + // Synthesize a value + Inst *C = IC.createVar(Target->Width, "reservedconst_1"); + C->SynthesisConstID = 1; + std::map InstCache = {{Target, C}}; + + auto Copy = Input; + + auto Rep = Replace(Input, IC, InstCache); + + std::set ConstSet{C}; + + std::map ConstMap; + ConstantSynthesis CS; + +// Rep.print(llvm::errs(), true); + + if (auto EC = CS.synthesize(S->getSMTLIBSolver(), Rep.BPCs, + Rep.PCs, Rep.Mapping, ConstSet, ConstMap, IC, 30, 60, false)) { + llvm::errs() << "Constant Synthesis internal error : " << EC.message(); + } + + if (!ConstMap.empty()) { + Val = ConstMap[C]; + } + } + + if (!Val.has_value()) { + return 0; // No bits weakened + } + + llvm::APInt RestoreZero = Target->KnownZeros; + llvm::APInt RestoreOne = Target->KnownOnes; + + Target->KnownOnes = Val.value(); + Target->KnownZeros = ~Val.value(); + + for (size_t i = 0; i < Target->Width; ++i) { + llvm::APInt OriZ = Target->KnownZeros; + llvm::APInt OriO = Target->KnownOnes; + + if (OriO[i] == 0 && OriZ[i] == 0) { + continue; + } + + if (OriO[i] == 1) Target->KnownOnes.clearBit(i); + if (OriZ[i] == 1) Target->KnownZeros.clearBit(i); + + if (!Verify(Input, IC, S)) { + Target->KnownZeros = OriZ; + Target->KnownOnes = OriO; + } else { + BitsWeakened++; + } + } + + if (BitsWeakened < Target->Width / 2) { + Target->KnownOnes = RestoreOne; + Target->KnownZeros = RestoreZero; + BitsWeakened = 0; + } + + return BitsWeakened; +} + +ParsedReplacement Reducer::ReducePCsToDF(ParsedReplacement Input) { + std::vector FoundVars; + for (auto &&PC : Input.PCs) { + findVars(PC.LHS, FoundVars); + findVars(PC.RHS, FoundVars); + } + std::set Vars; + for (auto &&V : FoundVars) { + if (!V->Name.starts_with("sym") && !V->Name.starts_with("const")) { + Vars.insert(V); + } + } + + auto Backup = Input.PCs; + Input.PCs.clear(); + + bool Succ = false; + + for (auto &&V : Vars) { + auto RangeSize = WeakenSingleCR(Input, IC, S, V, {}); + Succ |= (RangeSize > 0); + } + + if (!Succ) { + for (auto &&V : Vars) { + auto BitsWeakened = WeakenSingleKB(Input, IC, S, V, {}); + Succ |= (BitsWeakened != 0); + } + } + + if (!Succ) { + Input.PCs = Backup; + } + + return Input; +} + +// Assumes Input is valid +ParsedReplacement Reducer::ReducePCs(ParsedReplacement Input) { + + for (size_t i = 0; i < Input.PCs.size(); ++i) { + auto Result = Input; + Result.PCs.clear(); + for (size_t j = 0; j < Input.PCs.size(); ++j) { + if (i != j) { + Result.PCs.push_back(Input.PCs[j]); + } + } + + auto Clone = Verify(Result, IC, S); + if (Clone) { + return ReducePCs(Result); + } + } + + return Input; +} + +// Assumes Input is valid +ParsedReplacement Reducer::WeakenKB(ParsedReplacement Input) { + std::vector Vars; + findVars(Input.Mapping.LHS, Vars); + for (auto &&V : Vars) { + auto OriZero = V->KnownZeros; + auto OriOne = V->KnownOnes; + if (OriZero == 0 && OriOne == 0) { + continue; // this var doesn't have a knownbits condition + } + if (OriZero.getBitWidth() != V->Width || OriOne.getBitWidth() != V->Width) { + continue; // this var doesn't have a well formed knownbits condition + } + + // Try to remove KB + V->KnownOnes = llvm::APInt(V->Width, 0); + V->KnownZeros = llvm::APInt(V->Width, 0); + if (VerifyInput(Input)) { + continue; // Removed KB from this var + } + V->KnownOnes = OriOne; + V->KnownZeros = OriZero; + + // Try resetting bitwise KB + + for (size_t i = 0; i < V->Width; ++i) { + auto Ones = V->KnownOnes; + if (Ones[i]) { + V->KnownOnes.setBitVal(i, false); + if (!VerifyInput(Input)) { + V->KnownOnes = Ones; + } + } + auto Zeros = V->KnownZeros; + if (Zeros[i]) { + V->KnownZeros.setBitVal(i, false); + if (!VerifyInput(Input)) { + V->KnownZeros = Zeros; + } + } + } + } + return Input; +} + +// Assumes Input is valid +ParsedReplacement Reducer::WeakenCR(ParsedReplacement Input) { + std::vector Vars; + findVars(Input.Mapping.LHS, Vars); + + for (auto &&V : Vars) { + auto Ori = V->Range; + if (V->Range.isFullSet()) { + continue; + } + V->Range = llvm::ConstantRange(V->Width, true); + if (!VerifyInput(Input)) { + V->Range = Ori; + } + + auto R = V->Range; + + if (!R.isWrappedSet()) { + auto Full = R.getFull(R.getBitWidth()); + + auto L = R.getLower(); + auto U = R.getUpper(); + + size_t inc = 1; + while (inc && U.slt(Full.getUpper())) { + auto Backup = V->Range; + auto Attempt = U + inc; + + if (Attempt.sge(Full.getUpper())) { + Attempt = Full.getLower(); + } + + V->Range = llvm::ConstantRange(L, Attempt); + if (VerifyInput(Input)) { + U = Attempt; + // llvm::errs() << "U " << Attempt << '\n'; + inc *= 2; + } else { + inc /= 2; + V->Range = Backup; + } + } + + size_t dec = 1; + while (dec && L.slt(0)) { + auto Backup = V->Range; + auto Attempt = L - dec; + if (Attempt.sle(Full.getLower())) { + Attempt = Full.getLower(); + } + V->Range = llvm::ConstantRange(Attempt, U); + if (VerifyInput(Input)) { + L = Attempt; + // llvm::errs() << "L " << Attempt << '\n'; + dec *= 2; + } else { + dec /= 2; + V->Range = Backup; + } + } + } + } + + return Input; +} + +// Assumes Input is valid +ParsedReplacement Reducer::WeakenDB(ParsedReplacement Input) { + auto Ori = Input.Mapping.LHS->DemandedBits; + auto Width = Input.Mapping.LHS->Width; + if (Ori.getBitWidth() != Width || Ori.isAllOnesValue()) { + return Input; + } + // Try replacing with all ones. + Input.Mapping.LHS->DemandedBits.setAllBits(); + if (VerifyInput(Input)) { + return Input; + } + Input.Mapping.LHS->DemandedBits = Ori; + + for (size_t i = 0; i < Width; ++i) { + auto Last = Input.Mapping.LHS->DemandedBits; + if (!Last[i]) { + Input.Mapping.LHS->DemandedBits.setBitVal(i, true); + if (!VerifyInput(Input)) { + Input.Mapping.LHS->DemandedBits = Last; + } + } + } + + return Input; +} + +// Assumes Input is valid +ParsedReplacement Reducer::WeakenOther(ParsedReplacement Input) { + std::vector Vars; + findVars(Input.Mapping.LHS, Vars); + + for (auto &&V : Vars) { +#define WEAKEN(X) \ +if (V->X) { \ + V->X = false; \ + if (!VerifyInput(Input)) {\ + V->X = true;}} + + WEAKEN(NonZero) + WEAKEN(NonNegative) + WEAKEN(PowOfTwo) + WEAKEN(Negative) + +#undef WEAKEN + + while (V->NumSignBits) { + V->NumSignBits--; + if (!VerifyInput(Input)) { + V->NumSignBits++; + break; + } + } + } + + return Input; +} + +bool Reducer::VerifyInput(ParsedReplacement &Input) { + std::vector> Models; + bool Valid; + if (std::error_code EC = S->isValid(IC, Input.BPCs, Input.PCs, Input.Mapping, Valid, &Models)) { + llvm::errs() << EC.message() << '\n'; + } + numSolverCalls++; + return Valid; +} + +bool Reducer::safeToRemove(Inst *I, ParsedReplacement &Input) { + if (I == Input.Mapping.LHS || I->K == Inst::Var || I->K == Inst::Const || + I->K == Inst::UMulWithOverflow || I->K == Inst::UMulO || + I->K == Inst::SMulWithOverflow || I->K == Inst::SMulO || + I->K == Inst::UAddWithOverflow || I->K == Inst::UAddO || + I->K == Inst::SAddWithOverflow || I->K == Inst::SAddO || + I->K == Inst::USubWithOverflow || I->K == Inst::USubO || + I->K == Inst::SSubWithOverflow || I->K == Inst::SSubO) { + return false; + } + return true; +} + +Inst *Reducer::Eliminate(ParsedReplacement &Input, Inst *I) { + // Try to replace I with a new Var. + Inst *NewVar = IC.createVar(I->Width, "newvar" + std::to_string(varnum++)); + + std::map ICache; + ICache[I] = NewVar; + + std::map BCache; + std::map CMap; + + ParsedReplacement NewInst = Input; + + Input.Mapping.LHS = getInstCopy(Input.Mapping.LHS, IC, ICache, + BCache, &CMap, false); + + Input.Mapping.RHS = getInstCopy(Input.Mapping.RHS, IC, ICache, + BCache, &CMap, false); + + for (auto &M : Input.PCs) { + M.LHS = getInstCopy(M.LHS, IC, ICache, BCache, &CMap, false); + M.RHS = getInstCopy(M.RHS, IC, ICache, BCache, &CMap, false); + } + for (auto &BPC : Input.BPCs) { + BPC.PC.LHS = getInstCopy(BPC.PC.LHS, IC, ICache, BCache, &CMap, false); + BPC.PC.RHS = getInstCopy(BPC.PC.RHS, IC, ICache, BCache, &CMap, false); + } + return NewVar; +} + +void Reducer::ReduceRec(ParsedReplacement Input_, std::vector &Results) { + + if (souper::cost(Input_.Mapping.LHS) - souper::cost(Input_.Mapping.LHS) <= 1) { + return; + } + + // Try to remove subsets of instructions recursively, and store all valid results + ReplacementContext RC; + std::string Str; + llvm::raw_string_ostream SStr(Str); + RC.printInst(Input_.Mapping.LHS, SStr, false); + + Inst *Ante = IC.getConst(llvm::APInt(1, true)); + for (auto PC : Input_.PCs ) { + // Inst *Eq = IC.getInst(Inst::Eq, 1, {PC.LHS, PC.RHS}); + // Ante = IC.getInst(Inst::And, 1, {Ante, Eq}); + Ante = Builder(PC.LHS, IC).Eq(PC.RHS).And(Ante)(); + } + + RC.printInst(Ante, SStr, false); + SStr.flush(); + +// llvm::errs() << Str << "\n"; + +// auto Str = Input_.getString(false); + if (DNR.find(Str) != DNR.end()) { + return; + } else { + DNR.insert(Str); + } + + std::set Insts; + collectInsts(Input_.Mapping.LHS, Insts); + collectInsts(Input_.Mapping.RHS, Insts); + +// for (auto &&PC : Input_.PCs) { +// collectInsts(PC.LHS, Insts); +// collectInsts(PC.RHS, Insts); +// } + +// for (auto &&BPC : Input_.BPCs) { +// collectInsts(BPC.PC.LHS, Insts); +// collectInsts(BPC.PC.RHS, Insts); +// } + + if (Insts.size() <= 1) { + return; // Base case + } + + // Remove at least one instruction and call recursively for valid opts + for (auto I : Insts) { + ParsedReplacement Input = Input_; + + if (!safeToRemove(I, Input)) { + continue; + } + + Eliminate(Input, I); + + if (VerifyInput(Input)) { + Results.push_back(Input); + ReduceRec(Input, Results); + } else { + if (DebugLevel >= 2) { + llvm::outs() << "Invalid attempt.\n"; + Input.print(llvm::outs(), true); + } + } + } +} + +Inst *NonPoisonReplacement(Inst *I, InstContext &IC) { + Inst::Kind K = I->K; + switch (I->K) { + case Inst::AddNW: + case Inst::AddNUW: + case Inst::AddNSW: + K = Inst::Add; + break; + case Inst::SubNW: + case Inst::SubNUW: + case Inst::SubNSW: + K = Inst::Sub; + break; + case Inst::MulNW: + case Inst::MulNUW: + case Inst::MulNSW: + K = Inst::Mul; + break; + case Inst::ShlNW: + case Inst::ShlNUW: + case Inst::ShlNSW: + K = Inst::Shl; + break; + case Inst::UDivExact: + K = Inst::UDiv; + break; + case Inst::SDivExact: + K = Inst::SDiv; + break; + default: + llvm_unreachable("Expected instruction with poison flag."); + } + + auto Ret = IC.getInst(K, I->Width, I->Ops); + Ret->Name = I->Name; + Ret->DemandedBits = I->DemandedBits; + return Ret; +} + +void CollectPoisonInsts(Inst *I, std::set &PoisonInsts, + std::set &Visited) { + if (Visited.find(I) != Visited.end()) { + return; + } + Visited.insert(I); + + if (I->K == Inst::AddNSW || I->K == Inst::AddNUW || I->K == Inst::AddNW || + I->K == Inst::SubNSW || I->K == Inst::SubNUW || I->K == Inst::SubNW || + I->K == Inst::MulNSW || I->K == Inst::MulNUW || I->K == Inst::MulNW || + I->K == Inst::ShlNSW || I->K == Inst::ShlNUW || I->K == Inst::ShlNW || + I->K == Inst::UDivExact || I->K == Inst::SDivExact) { + PoisonInsts.insert(I); + } + + for (auto &&Op : I->Ops) { + CollectPoisonInsts(Op, PoisonInsts, Visited); + } +} + +ParsedReplacement Reducer::ReducePoison(ParsedReplacement Input) { + std::set PoisonInsts; + std::set Visited; + CollectPoisonInsts(Input.Mapping.LHS, PoisonInsts, Visited); + CollectPoisonInsts(Input.Mapping.RHS, PoisonInsts, Visited); + for (auto &&PC : Input.PCs) { + CollectPoisonInsts(PC.LHS, PoisonInsts, Visited); + CollectPoisonInsts(PC.RHS, PoisonInsts, Visited); + } + + for (auto I : PoisonInsts) { + auto Rep = NonPoisonReplacement(I, IC); + if (!Rep) { + continue; + } + std::map Cache = {{I, NonPoisonReplacement(I, IC)}}; + + auto Cand = Replace(Input, IC, Cache); + + if (VerifyInput(Cand)) { + Input = Cand; + } + } + + return Input; +} + + +} diff --git a/lib/Infer/AliveDriver.cpp b/lib/Infer/AliveDriver.cpp index f4a360e0f..1e010dd2a 100644 --- a/lib/Infer/AliveDriver.cpp +++ b/lib/Infer/AliveDriver.cpp @@ -37,12 +37,21 @@ namespace { static llvm::cl::opt DisableUndefInput("alive-disable-undef-input", llvm::cl::desc("Assume inputs can not be undef (default = false)"), - llvm::cl::init(false)); + llvm::cl::init(true)); static llvm::cl::opt SkipAliveSolver("alive-skip-solver", llvm::cl::desc("Omit Alive solver calls for performance testing (default = false)"), llvm::cl::init(false)); +static llvm::cl::opt WidthIndepOpt("alive-all-widths", + llvm::cl::desc("Ignore Souper type widths and verify for all widths."), + llvm::cl::init(false)); + +static llvm::cl::opt ShowValidWidths("show-valid-widths", + llvm::cl::desc("Show widths for which the input is valid."), + llvm::cl::init(false)); + + class FunctionBuilder { public: FunctionBuilder(IR::Function &F_) : F(F_) {} @@ -53,6 +62,15 @@ class FunctionBuilder { (std::make_unique(t, std::move(name), *toValue(t, a))); } + template + IR::Value *width(IR::Type &t, std::string name, A a) { + auto fc = std::make_unique(IR::ConstantFn(t, "width", {toValue(t, a)})); + // make_unique fails template type deduction + auto ret = fc.get(); + F.addConstant(std::move(fc)); + return ret; + } + IR::Value *undef(IR::Type &t, std::string name) { auto undef = std::make_unique(t); auto undef_ptr = undef.get(); @@ -179,6 +197,14 @@ class FunctionBuilder { identifiers[x] = ptr; return ptr; } + if (x.find("var_sym") != std::string::npos) { + auto i = std::make_unique(t, std::move(x)); + auto ptr = i.get(); + F.addInput(std::move(i)); + // FIXME: force non poison + identifiers[x] = ptr; + return ptr; + } auto i = std::make_unique(t, std::move(x)); auto ptr = i.get(); F.addInput(std::move(i)); @@ -225,136 +251,73 @@ std::map performCegisFirstQuery(tools::Transform &t, std::map &SouperConsts, smt::expr &TriedExpr) { - IR::State SrcState(t.src, true); - IR::State TgtState(t.tgt, false); - util::sym_exec(SrcState); - util::sym_exec(TgtState); - - auto &&Sv = SrcState.returnVal(); - auto &&Tv = TgtState.returnVal(); - - std::map SynthesisResult; - SynthesisResult.clear(); - - std::set Vars; - std::map SMTConsts; - for (auto &[Var, Val, Pred] : TgtState.getValues()) { - auto &Name = Var->getName(); - if (startsWith("%reservedconst", Name)) { - SMTConsts[Name] = Val.first.value; - } - } - - if (SkipAliveSolver) - return SynthesisResult; - - // TODO: implement synthesis with refinement - smt::Solver::check({{(Sv.first.value == Tv.first.value) && (TriedExpr), - [&](const smt::Result &R) { - - // no more guesses, stop immediately - if (R.isUnsat()) { - if (DebugLevel > 3) - llvm::errs()<<"No more new possible guesses\n"; - return; - } else if (R.isSat()) { - auto &&Model = R.getModel(); - smt::expr TriedAnte(false); - - for (auto &[name, expr] : SMTConsts) { - TriedAnte |= (expr != smt::expr::mkUInt(Model.getInt(expr), expr.bits())); - } - TriedExpr &= TriedAnte; - - for (auto &[name, expr] : SMTConsts) { - auto *I = SouperConsts[name]; - SynthesisResult[I] = llvm::APInt(I->Width, Model.getInt(expr)); - } - } - }}}); - - return SynthesisResult; + llvm::errs() << "Constant synthesis through alive unimplemented."; + return {}; +// IR::State SrcState(t.src, true); +// IR::State TgtState(t.tgt, false); +// util::sym_exec(SrcState); +// util::sym_exec(TgtState); +// +// auto &&Sv = SrcState.returnVal(); +// auto &&Tv = TgtState.returnVal(); + +// std::map SynthesisResult; +// SynthesisResult.clear(); +// +// std::set Vars; +// std::map SMTConsts; +// for (auto &[Var, Val] : TgtState.getValues()) { +// auto &Name = Var->getName(); +// if (startsWith("%reservedconst", Name)) { +// SMTConsts[Name] = Val.first.value; +// } +// } +// +// if (SkipAliveSolver) +// return SynthesisResult; +// +// auto R = smt::check_expr((Sv.first.value == Tv.first.value) && (TriedExpr)); +// // no more guesses, stop immediately +// if (R.isUnsat()) { +// if (DebugLevel > 3) +// llvm::errs()<<"No more new possible guesses\n"; +// return {}; +// } else if (R.isSat()) { +// auto &&Model = R.getModel(); +// smt::expr TriedAnte(false); +// +// for (auto &[name, expr] : SMTConsts) { +// TriedAnte |= (expr != smt::expr::mkUInt(Model.getInt(expr), expr.bits())); +// } +// TriedExpr &= TriedAnte; +// +// for (auto &[name, expr] : SMTConsts) { +// auto *I = SouperConsts[name]; +// SynthesisResult[I] = llvm::APInt(I->Width, Model.getInt(expr)); +// } +// } +// return SynthesisResult; } std::map synthesizeConstantUsingSolver(tools::Transform &t, std::map &SouperConsts) { + // Removed because implementation was bit-rotting. + // The combination of the options requiring this is + // not currently used. + llvm::errs() << "Direct solver based constant synthesis through alive unimplemented."; return {}; - - IR::State SrcState(t.src, true), tgt_state(t.tgt, false); - util::sym_exec(SrcState); - util::sym_exec(tgt_state); - - util::Errors Errs; - - auto &&SrcRet = SrcState.returnVal(); - auto &&TgtRet = tgt_state.returnVal(); - - auto QVars = SrcState.getQuantVars(); - QVars.insert(SrcRet.second.begin(), SrcRet.second.end()); - - auto ErrF = [&](const smt::Result &r, bool print_var, const char *msg) { -// tools::error(Errs, SrcState, tgt_state, r, print_var, nullptr, -// SrcRet.first, TgtRet.first, msg,false); - //FIXME: temporarily disabled, find a way to pass a Type to tools::error - std::cerr << msg << "\n"; - tools::TransformPrintOpts Opts; - Opts.print_fn_header = true; - t.print(std::cerr, Opts); - }; - - std::set Vars; - std::map SMTConsts; - - for (auto &[var, val, Pred] : SrcState.getValues()) { - auto &name = var->getName(); - if (startsWith("%var", name)) { - auto app = val.first.value.isApp(); - assert(app); - Vars.insert(Z3_get_app_arg(smt::ctx(), app, 1)); - } - } - for (auto &[var, val, Pred] : tgt_state.getValues()) { - auto &name = var->getName(); - if (startsWith("%reserved", name)) { - auto app = val.first.value.isApp(); - assert(app); - SMTConsts[name] = (Z3_get_app_arg(smt::ctx(), app, 1)); - } - } - - auto SimpleConstExistsCheck = - smt::expr::mkForAll(Vars, SrcRet.first.value == TgtRet.first.value); - - std::map SynthesisResult; - - if (SkipAliveSolver) - return SynthesisResult; - - smt::Solver::check({{preprocess(t, QVars, SrcRet.second, - std::move(SimpleConstExistsCheck)), - [&] (const smt::Result &R) { - if (R.isUnsat()) { - ErrF(R, true, "Value mismatch"); - } else if (R.isSat()) { - auto &&Model = R.getModel(); - for (auto &[name, expr] : SMTConsts) { - auto *I = SouperConsts[name]; - SynthesisResult[I] = llvm::APInt(I->Width, Model.getInt(expr)); - } - return; - } else { - ErrF(R, true, "Unknown/Invalid Result, investigate."); - } - }}}); - - return SynthesisResult; } souper::AliveDriver::AliveDriver(Inst *LHS_, Inst *PreCondition_, InstContext &IC_, - std::vector ExtraInputs) + const std::vector &ExtraInputs, bool WidthIndep) : LHS(LHS_), PreCondition(PreCondition_), IC(IC_) { + smt::set_query_timeout(std::to_string(10000)); // milliseconds IsLHS = true; + WidthIndependentMode = WidthIndep; + if (WidthIndepOpt) { + WidthIndependentMode = true; + } InstNumbers = 101; //FIXME: Magic number. 101 is chosen arbitrarily. //This should go away once non-input variable names are not discarded @@ -482,6 +445,7 @@ void souper::AliveDriver::copyInputs(souper::AliveDriver::Cache &To, bool souper::AliveDriver::verify (Inst *RHS, Inst *RHSAssumptions) { RExprCache.clear(); + ValidTypings.clear(); IR::Function RHSF; copyInputs(RExprCache, RHSF); if (!translateRoot(RHS, RHSAssumptions, RHSF, RExprCache)) { @@ -503,25 +467,73 @@ bool souper::AliveDriver::verify (Inst *RHS, Inst *RHSAssumptions) { t.tgt = std::move(RHSF); tools::TransformVerify tv(t, /*check_each_var=*/false); + auto types = tv.getTypings(); + + if (!types.hasSingleTyping()) { + unsigned i = 0; + size_t correct = 0; + size_t incorrect = 0; + for (; types; ++types) { + tv.fixupTypes(types); + if (auto errs = tv.verify()) { + if (DebugLevel > 4) { + llvm::errs() << "Invalid typing: \n"; + for (auto &&P : Inputs) { + llvm::errs() << P.first->Name << ' ' << P.second->bits() << "\n"; + } + } + incorrect++; + } else { + std::map Typing; + for (auto &&P : Inputs) { + Typing[P.first] = P.second->bits(); + } + ValidTypings.push_back(Typing); + correct++; + } + } + if (!incorrect) { + return true; + } else if (!correct) { + return false; + } else { + if (ShowValidWidths) { + llvm::outs() << "; Partially Valid.\n"; + std::sort(ValidTypings.begin(), ValidTypings.end(), + [](const auto &A, const auto &B) {return A.begin()->second < B.begin()->second;}); + for (auto &&P : ValidTypings) { + llvm::outs() << "; "; + for (auto &&I : P) { + llvm::outs() << I.first->Name << ' ' << I.second << '\t'; + } + llvm::outs() << '\n'; + } + } + return false; + } + } + if (SkipAliveSolver) return false; if (auto errs = tv.verify()) { - if (DebugLevel >= 1) { + if (DebugLevel >= 2) { std::ostringstream os; os << errs << "\n"; llvm::errs() << os.str(); + llvm::errs() << "RHS rejected by Alive2\n"; } return false; // TODO: Encode errs into ErrorCode } else { - if (DebugLevel > 2) - llvm::errs() << "RHS proved valid.\n"; + if (DebugLevel >= 2) + llvm::errs() << "RHS verified by Alive2\n"; return true; } } bool souper::AliveDriver::translateRoot(const souper::Inst *I, const Inst *PC, IR::Function &F, Cache &ExprCache) { + if (!translateAndCache(I, F, ExprCache)) { return false; } @@ -533,9 +545,7 @@ bool souper::AliveDriver::translateRoot(const souper::Inst *I, const Inst *PC, FunctionBuilder Builder(F); if (PC) { - auto Zero = Builder.val(getType(I->Width), llvm::APInt(I->Width, 0)); - ExprCache[I] = Builder.select(getType(I->Width), "%ifpc", - ExprCache[PC], ExprCache[I], Zero); + Builder.assume(ExprCache[PC]); } Builder.ret(getType(I->Width), ExprCache[I]); F.setType(getType(I->Width)); @@ -583,6 +593,23 @@ bool souper::AliveDriver::translateAndCache(const souper::Inst *I, return true; // Already translated } + if (I->K == Inst::KnownOnesP) { + auto *VarAndOnes = IC.getInst(Inst::And, I->Width, {I->Ops[0], I->Ops[1]}); + auto *Eq = IC.getInst(Inst::Eq, 1, {VarAndOnes, I->Ops[1]}); + auto Ret = translateAndCache(Eq, F, ExprCache); + ExprCache[I] = ExprCache[Eq]; + return Ret; + } + if (I->K == Inst::KnownZerosP) { + auto *FlipZeros = IC.getInst(Inst::Xor, I->Width, {I->Ops[1], + IC.getConst(llvm::APInt::getAllOnesValue(I->Width))}); + auto *VarNotZeros = IC.getInst(Inst::Or, I->Width, {I->Ops[0], FlipZeros}); + auto *Eq = IC.getInst(Inst::Eq, 1, {VarNotZeros, FlipZeros}); + auto Ret = translateAndCache(Eq, F, ExprCache); + ExprCache[I] = ExprCache[Eq]; + return Ret; + } + auto Ops = I->Ops; if (souper::Inst::isOverflowIntrinsicMain(I->K)) { Ops = Ops[0]->Ops; @@ -624,6 +651,7 @@ bool souper::AliveDriver::translateAndCache(const souper::Inst *I, switch (I->K) { case souper::Inst::Var: { ExprCache[I] = Builder.var(t, Name); + // llvm::errs() << "Var: " << Name << "\n"; if (IsLHS) { Inputs.push_back({I, ExprCache[I]}); } @@ -711,6 +739,7 @@ bool souper::AliveDriver::translateAndCache(const souper::Inst *I, BINOPF(MulNUW, Mul, NUW); BINOPF(MulNW, Mul, NSW | IR::BinOp::NUW); BINOP(And, And); + BINOP(DemandedMask, And); BINOP(Or, Or); BINOP(Xor, Xor); BINOP(Shl, Shl); @@ -782,8 +811,16 @@ bool souper::AliveDriver::translateAndCache(const souper::Inst *I, UNARYOP(BSwap, BSwap); UNARYOP(BitReverse, BitReverse); + case souper::Inst::BitWidth: { + ExprCache[I] = Builder.width(t, Name /*is ignored*/, + ExprCache[I->Ops[0]]); + return true; + } + + // TODO: Desugar log2. Alive2 only supports log2 for concrete constants. + default:{ - llvm::outs() << "Unsupported Instruction Kind : " << I->getKindName(I->K) << "\n"; + llvm::errs() << "Unsupported Instruction Kind : " << I->getKindName(I->K) << "\n"; return false; } } @@ -829,9 +866,14 @@ souper::AliveDriver::translateDemandedBits(const souper::Inst* I, IR::Type &souper::AliveDriver::getType(int Width) { std::string n = "i" + std::to_string(Width); if (TypeCache.find(n) == TypeCache.end()) { - TypeCache[n] = new IR::IntType(std::move(n), Width); + if (WidthIndependentMode) { + auto t = new IR::SymbolicType("symty_", (1 << IR::SymbolicType::Int)); + TypeCache[""] = t; + } else { + TypeCache[n] = new IR::IntType(std::move(n), Width); + } } - return *TypeCache[n]; + return WidthIndependentMode ? *TypeCache[""] : *TypeCache[n]; } IR::Type &souper::AliveDriver::getOverflowType(int Width) { diff --git a/lib/Infer/ConstantSynthesis.cpp b/lib/Infer/ConstantSynthesis.cpp index 67046464a..16882fdac 100644 --- a/lib/Infer/ConstantSynthesis.cpp +++ b/lib/Infer/ConstantSynthesis.cpp @@ -43,7 +43,7 @@ Inst *getUBConstraint(Inst::Kind K, unsigned OpNum, Inst *C, case Inst::AShr: // right operand has to be < Width return (OpNum == 0) ? - IC.getConst(llvm::APInt(1, true)) : + IC.getConst(llvm::APInt(1, true)) : IC.getInst(Inst::Ult, 1, { C, IC.getConst(llvm::APInt(C->Width, C->Width)) }); case Inst::UDiv: @@ -81,7 +81,7 @@ Inst *getConstConstraint(Inst::Kind K, unsigned OpNum, Inst *C, case Inst::USubSat: // left operand cannot be 0, right operand cannot be 0 or -1 return (OpNum == 0) ? - IC.getInst(Inst::Ne, 1, { IC.getConst(llvm::APInt(C->Width, 0)), C }) : + IC.getInst(Inst::Ne, 1, { IC.getConst(llvm::APInt(C->Width, 0)), C }) : IC.getInst(Inst::And, 1, { IC.getInst(Inst::Ne, 1, { IC.getConst(llvm::APInt(C->Width, 0)), C }), IC.getInst(Inst::Ne, 1, { IC.getConst(llvm::APInt::getAllOnesValue(C->Width)), C }) @@ -94,6 +94,7 @@ Inst *getConstConstraint(Inst::Kind K, unsigned OpNum, Inst *C, IC.getInst(Inst::Ne, 1, { IC.getConst(llvm::APInt(C->Width, 1)), C }) }); + case Inst::DemandedMask: case Inst::And: case Inst::Or: // neither operand can be 0 or -1 @@ -147,7 +148,7 @@ Inst *getConstConstraint(Inst::Kind K, unsigned OpNum, Inst *C, IC.getInst(Inst::Ult, 1, { IC.getConst(llvm::APInt(C->Width, 2)), C }), IC.getInst(Inst::Ne, 1, { IC.getConst(llvm::APInt::getAllOnesValue(C->Width)), C }) }); - + case Inst::SDiv: case Inst::SRem: case Inst::URem: @@ -193,7 +194,7 @@ Inst *getConstConstraint(Inst::Kind K, unsigned OpNum, Inst *C, IC.getInst(Inst::And, 1, { IC.getInst(Inst::Ult, 1, { C, IC.getConst(llvm::APInt::getAllOnesValue(C->Width) - 1) }), IC.getInst(Inst::Ne, 1, { IC.getConst(llvm::APInt(C->Width, 0)), C }) - }); + }); case Inst::Slt: // we don't want: @@ -231,20 +232,10 @@ Inst *getConstConstraint(Inst::Kind K, unsigned OpNum, Inst *C, IC.getInst(Inst::Ne, 1, { IC.getConst(llvm::APInt::getSignedMinValue(C->Width)), C }) }); - case Inst::Eq: - case Inst::Ne: - case Inst::SAddO: - case Inst::UAddO: - case Inst::SSubO: - case Inst::USubO: - case Inst::SMulO: - case Inst::UMulO: case Inst::Select: // handled elsewhere: 2nd and 3rd arguments can't be same constant - // no constraint - return IC.getConst(llvm::APInt(1, true)); - default: - llvm::report_fatal_error("unmatched: " + (std::string)Inst::getKindName(K)); + // no constraint + return IC.getConst(llvm::APInt(1, true)); } } @@ -256,7 +247,7 @@ void addComplexConstraints(Inst *I, // --x // ~~x // 2 * x / 2 - + // first and second arguments to funnel shift can't both be zero if (I->K == Inst::FShl || I->K == Inst::FShr) { if (ConstSet.find(I->Ops[0]) != ConstSet.end() && @@ -345,6 +336,7 @@ ConstantSynthesis::synthesize(SMTLIBSolver *SMTSolver, auto ConstConstraints = TrueConst; std::set Visited; + visitConstants(Mapping.LHS, Visited, ConstConstraints, ConstSet, IC, AvoidNops); visitConstants(Mapping.RHS, Visited, ConstConstraints, ConstSet, IC, AvoidNops); for (int I = 0; I < MaxTries; ++I) { @@ -409,6 +401,7 @@ ConstantSynthesis::synthesize(SMTLIBSolver *SMTSolver, std::map InstCache; std::map BlockCache; Inst *RHSCopy = getInstCopy(Mapping.RHS, IC, InstCache, BlockCache, &ConstMap, false); + Inst *LHSCopy = getInstCopy(Mapping.LHS, IC, InstCache, BlockCache, &ConstMap, false); std::vector Blocks = getBlocksFromPhis(Mapping.LHS); for (auto Block : Blocks) { @@ -426,7 +419,7 @@ ConstantSynthesis::synthesize(SMTLIBSolver *SMTSolver, std::vector ModelInstsSecondQuery; std::vector ModelValsSecondQuery; - Query = BuildQuery(IC, BPCs, PCs, InstMapping(Mapping.LHS, RHSCopy), + Query = BuildQuery(IC, BPCs, PCs, InstMapping(LHSCopy, RHSCopy), &ModelInstsSecondQuery, 0); if (Query.empty()) diff --git a/lib/Infer/EnumerativeSynthesis.cpp b/lib/Infer/EnumerativeSynthesis.cpp index 66bce8352..a57701030 100644 --- a/lib/Infer/EnumerativeSynthesis.cpp +++ b/lib/Infer/EnumerativeSynthesis.cpp @@ -33,8 +33,8 @@ extern unsigned DebugLevel; using namespace souper; using namespace llvm; -static const std::vector UnaryOperators = { - Inst::CtPop, Inst::BSwap, Inst::BitReverse, Inst::Cttz, Inst::Ctlz, Inst::Freeze +static std::vector UnaryOperators = { + Inst::CtPop, Inst::BSwap, Inst::BitReverse, Inst::Cttz, Inst::Ctlz }; static const std::vector BinaryOperators = { @@ -91,6 +91,12 @@ namespace { static cl::opt IgnoreCost("souper-enumerative-synthesis-ignore-cost", cl::desc("Ignore cost of RHSes -- just generate them. (default=false)"), cl::init(false)); + static cl::opt SynFreeze("souper-synthesize-freeze", + cl::desc("Generate Freeze (default=true)"), + cl::init(true)); + static cl::opt SynLog("souper-synthesize-log", + cl::desc("Generate LogB (default=false)"), + cl::init(false)); static cl::opt MaxLHSCands("souper-max-lhs-cands", cl::desc("Gather at most this many values from a LHS to use as synthesis inputs (default=8)"), cl::init(8)); @@ -100,6 +106,9 @@ namespace { static cl::opt OnlyInferIN("souper-only-infer-iN", cl::desc("Only infer integer constants (default=false)"), cl::init(false)); + static cl::opt TryShrinkConsts("souper-shrink-consts", + cl::desc("Try to shrink constants (defaults=false)"), + cl::init(false)); } // TODO @@ -642,9 +651,9 @@ std::error_code synthesizeWithAlive(SynthesisContext &SC, std::vector &R } assert (RHS); RHSs.emplace_back(RHS); - if (!SC.CheckAllGuesses) { + if (!SC.CheckAllGuesses) return EC; - } + if (DebugLevel > 3) { llvm::outs() << "; result " << RHSs.size() << ":\n"; ReplacementContext RC; @@ -722,7 +731,7 @@ std::error_code synthesizeWithKLEE(SynthesisContext &SC, std::vector &RH if (!ResultConstMap.empty()) { std::map InstCache; std::map BlockCache; - RHS = getInstCopy(I, SC.IC, InstCache, BlockCache, &ResultConstMap, false); + RHS = getInstCopy(I, SC.IC, InstCache, BlockCache, &ResultConstMap, false, false); } else { continue; } @@ -747,22 +756,23 @@ std::error_code synthesizeWithKLEE(SynthesisContext &SC, std::vector &RH RHS = nullptr; } } - - // FIXME shrink constants properly, this is a placeholder where we - // just see if we can replace every constant with zero - if (RHS && !ResultConstMap.empty() && DoubleCheckWithAlive) { - std::map ZeroConstMap; - for (auto it : ResultConstMap) { - auto I = it.first; - ZeroConstMap[I] = llvm::APInt(I->Width, 0); + if (TryShrinkConsts) { + // FIXME shrink constants properly, this is a placeholder where we + // just see if we can replace every constant with zero + // TODO(manasij) : Implement binary search, involve alive only when we find a solution + if (RHS && !ResultConstMap.empty() && DoubleCheckWithAlive) { + std::map ZeroConstMap; + for (auto it : ResultConstMap) { + auto I = it.first; + ZeroConstMap[I] = llvm::APInt(I->Width, 0); + } + std::map InstCache; + std::map BlockCache; + auto newRHS = getInstCopy(I, SC.IC, InstCache, BlockCache, &ZeroConstMap, false, false); + if (isTransformationValid(SC.LHS, newRHS, SC.PCs, SC.BPCs, SC.IC)) + RHS = newRHS; } - std::map InstCache; - std::map BlockCache; - auto newRHS = getInstCopy(I, SC.IC, InstCache, BlockCache, &ZeroConstMap, false); - if (isTransformationValid(SC.LHS, newRHS, SC.PCs, SC.BPCs, SC.IC)) - RHS = newRHS; } - if (RHS) { RHSs.emplace_back(RHS); if (!SC.CheckAllGuesses) @@ -881,3 +891,44 @@ EnumerativeSynthesis::synthesize(SMTLIBSolver *SMTSolver, return EC; } + +EnumerativeSynthesis::EnumerativeSynthesis() { + if (SynFreeze) { + UnaryOperators.push_back(Inst::Freeze); + } + if (SynLog) { + UnaryOperators.push_back(Inst::LogB); + } +} + +std::vector +EnumerativeSynthesis::generateExprs(InstContext &IC, size_t CountLimit, + std::vector Vars, size_t Width) { + MaxNumInstructions = CountLimit; + + std::set Visited; + std::vector PruneFuncs = { [&Visited](Inst *I, std::vector &ReservedInsts) { + return CountPrune(I, ReservedInsts, Visited); + }}; + auto PruneCallback = MkPruneFunc(PruneFuncs); + + std::vector Guesses; + + int TooExpensive = CountLimit + 1; + + for (auto I : Vars) { + if (I->Width == Width) + addGuess(I, Width, IC, TooExpensive, Guesses, TooExpensive); + } + + auto Generate = [&Guesses](Inst *Guess) { + Guesses.push_back(Guess); + return true; + }; + + getGuesses(Vars, Width, TooExpensive, IC, nullptr, + nullptr, TooExpensive, PruneCallback, Generate); + + return Guesses; +} + diff --git a/lib/Infer/Interpreter.cpp b/lib/Infer/Interpreter.cpp index b3e3c3f82..4220da1f6 100644 --- a/lib/Infer/Interpreter.cpp +++ b/lib/Infer/Interpreter.cpp @@ -444,6 +444,27 @@ namespace souper { return Args[0]; } + case Inst::LogB: { + return {llvm::APInt(Inst->Width, ARG0.logBase2())}; + } + + case Inst::BitWidth: { + return {llvm::APInt(Inst->Width, Inst->Width)}; + // Is the result always of this width? + } + + case Inst::KnownOnesP: { + auto A = ARG0; + return A & ARG1 == A; + } + case Inst::KnownZerosP: { + auto Z = ARG1; + return ARG0 | ~Z == Z; + } + case Inst::DemandedMask: { + return ARG0 & ARG1; + } + default: llvm::report_fatal_error("unimplemented instruction kind " + std::string(Inst::getKindName(Inst->K)) + @@ -459,6 +480,10 @@ namespace souper { if (Cache.find(Root) != Cache.end()) return Cache[Root]; + if (Root->K == Inst::BitWidth) { + return {llvm::APInt(Root->Width, Root->Width)}; + } + // TODO SmallVector std::vector EvaluatedArgs; for (auto &&I : Root->Ops) @@ -468,4 +493,11 @@ namespace souper { Cache[Root] = Result; return Result; } + + void ConcreteInterpreter::printCache(llvm::raw_ostream &Out) { + for (auto &&KV : Cache) { + Out << KV.first->Name << " = " << KV.second.getValue() << '\n'; + } + } + } diff --git a/lib/Infer/Preconditions.cpp b/lib/Infer/Preconditions.cpp index 3df2bedc2..ef9151a4d 100644 --- a/lib/Infer/Preconditions.cpp +++ b/lib/Infer/Preconditions.cpp @@ -4,10 +4,175 @@ #include using llvm::APInt; +static llvm::cl::opt FixItNoVar("fixit-no-restrict-vars", + llvm::cl::desc("Do not restrict input variables, only constants." + "(default=false)"), + llvm::cl::init(false)); + +static llvm::cl::opt GenCR("gencr", + llvm::cl::desc("Generate a CR precondition." + "(default=false)"), + llvm::cl::init(false)); + +static llvm::cl::opt GenKB("genkb", + llvm::cl::desc("Generate a KB precondition." + "(default=true)"), +llvm::cl::init(true)); + + namespace souper { + +std::pair>, +std::vector>> +inferAbstractPreconditions(SynthesisContext &SC, Inst *RHS, + Solver *S, bool &FoundWeakest) { + + std::vector> CRResults; + std::vector> KBResults; + if (GenKB) KBResults = inferAbstractKBPreconditions(SC, RHS, S, FoundWeakest); + if (GenCR) CRResults = inferAbstractCRPreconditions(SC, RHS, S, FoundWeakest); + return std::make_pair(KBResults, CRResults); +} + +std::vector> + inferAbstractCRPreconditions(SynthesisContext &SC, Inst *RHS, + Solver *S, bool &FoundWeakest) { + + InstMapping Mapping(SC.LHS, RHS); + bool Valid; + if (DebugLevel >= 3) { + PrintReplacement(llvm::outs(), SC.BPCs, SC.PCs, Mapping); + } + std::vector> Models; + + if (std::error_code EC = S->isValid(SC.IC, SC.BPCs, SC.PCs, Mapping, Valid, &Models)) { + llvm::errs() << EC.message() << '\n'; + } + std::vector PCCopy = SC.PCs; + if (Valid) { + FoundWeakest = true; + if (DebugLevel > 1) { + llvm::errs() << "Already valid.\n"; + } + return {}; + } + + std::vector Vars; + findVars(Mapping.LHS, Vars); + std::set FilteredVars; + + for (auto Var : Vars) { + std::string NamePrefix = Var->Name; + NamePrefix.resize(4); + if (!FixItNoVar || Var->K != Inst::Var || NamePrefix == "fake") { + FilteredVars.insert(Var); + } + } + + // std::map OriginalState; + + // for (auto V : Vars) { + // OriginalState[V] = V->Range; + // } + + std::vector> Results; + Inst *Precondition = SC.IC.getConst(llvm::APInt(1, true)); + + std::vector> ValidInputs; + // unlike for known bits, this is not guaranteed to terminate quickly. + size_t Iterations = 0; + while (true) { + if (Iterations >= 10000) { + break; + // TODO adjust this threshold + } else { + Iterations++; + } + + std::vector ModelInsts; + std::vector ModelVals; + + if (!ValidInputs.empty()) { + auto LastInputSet = ValidInputs.back(); + // for (auto V : Vars) { + // V->KnownOnes = OriginalState[V].OriginalOne; + // V->KnownZeros = OriginalState[V].OriginalZero; + // } + Inst *NewPre = nullptr; + for (size_t i = 0; i < Vars.size(); ++i) { + auto &I = Vars[i]; + auto W = I->Width; +// auto Zero = SC.IC.getConst(llvm::APInt(W, 0)); + + auto VarConstraint = SC.IC.getInst(Inst::Ne, 1, {I, SC.IC.getConst(LastInputSet[I])}); + + if (NewPre) { + NewPre = SC.IC.getInst(Inst::Or, 1, {NewPre, VarConstraint}); + } else { + NewPre = VarConstraint; + } + + } + + // Do not find an input belonging to a derived abstract set. + if (NewPre) { + Precondition = SC.IC.getInst(Inst::And, 1, {Precondition, NewPre}); + } + } + + // Find one input for which the given transformation is valid + Models.clear(); + std::string Query = BuildQuery(SC.IC, SC.BPCs, PCCopy, Mapping, + &ModelInsts, Precondition, true); + + + S->isSatisfiable(Query, FoundWeakest, ModelInsts.size(), + &ModelVals, SC.Timeout); + + std::map CurrentCE; + if (FoundWeakest) { + for (unsigned J = 0; J < ModelInsts.size(); ++J) { + if (FilteredVars.find(ModelInsts[J]) != FilteredVars.end()) { + CurrentCE[ModelInsts[J]] = ModelVals[J]; + } else { + auto Zero = llvm::APInt(ModelInsts[J]->Width, 0); + CurrentCE[ModelInsts[J]] = Zero; + } + + if (DebugLevel >= 3) { + llvm::outs() << "Starting with : " << ModelVals[J] << "\n"; + } + } + ValidInputs.push_back(CurrentCE); + } else { + if (ValidInputs.empty()) { + if (DebugLevel >= 3) { + llvm::outs() << "Transformation is not valid for any input.\n"; + } + return {}; + } else { + FoundWeakest = true; + if (DebugLevel >= 3) { + llvm::outs() << "Exhausted search space.\n"; + } + break; + } + } + + // Widen CurrentCE into the largest possible CR which maintains validity + // How to do this?? + + + + + } + + return {}; +} + std::vector> inferAbstractKBPreconditions(SynthesisContext &SC, Inst *RHS, - SMTLIBSolver *SMTSolver, Solver *S, bool &FoundWeakest) { + Solver *S, bool &FoundWeakest) { InstMapping Mapping(SC.LHS, RHS); bool Valid; if (DebugLevel >= 3) { @@ -20,7 +185,10 @@ std::vector> } std::vector PCCopy = SC.PCs; if (Valid) { - llvm::outs() << "Already valid.\n"; + FoundWeakest = true; + if (DebugLevel > 1) { + llvm::errs() << "Already valid.\n"; + } return {}; } @@ -31,6 +199,15 @@ std::vector> std::vector Vars; findVars(Mapping.LHS, Vars); + std::set FilteredVars; + + for (auto Var : Vars) { + std::string NamePrefix = Var->Name; + NamePrefix.resize(4); + if (!FixItNoVar || Var->K != Inst::Var || NamePrefix == "fake") { + FilteredVars.insert(Var); + } + } std::map OriginalState; @@ -42,7 +219,7 @@ std::vector> std::vector> Results; Inst *Precondition = SC.IC.getConst(llvm::APInt(1, true)); - while (true) { // guaranteed to terminate + while (true) { // guaranteed to terminate in O(Width) if (!Results.empty()) { bool foundNonTop = false;; for (auto R : Results) { @@ -70,49 +247,69 @@ std::vector> V->KnownOnes = OriginalState[V].OriginalOne; V->KnownZeros = OriginalState[V].OriginalZero; } - + Inst *NewPre = nullptr; for (size_t i = 0; i < Vars.size(); ++i) { auto &I = Vars[i]; auto W = I->Width; - auto Zero = SC.IC.getConst(llvm::APInt(W, 0)); +// auto Zero = SC.IC.getConst(llvm::APInt(W, 0)); auto AllOnes = SC.IC.getConst(llvm::APInt::getAllOnesValue(W)); - auto A = SC.IC.getInst(Inst::And, W, {I, SC.IC.getConst(KB[I].One)}); + auto KnownOne = SC.IC.getConst(KB[I].One); + auto KnownZero = SC.IC.getConst(KB[I].Zero); + auto A = SC.IC.getInst(Inst::And, W, {I, KnownOne}); auto B = SC.IC.getInst(Inst::And, W, {SC.IC.getInst(Inst::Xor, W, {I, AllOnes}), - SC.IC.getConst(KB[I].Zero)}); + KnownZero}); - auto New = SC.IC.getInst(Inst::And, 1, - {SC.IC.getInst(Inst::Eq, 1, {A, Zero}), - SC.IC.getInst(Inst::Eq, 1, {B, Zero})}); + auto VarConstraint = SC.IC.getInst(Inst::Or, 1, + {SC.IC.getInst(Inst::Ne, 1, {A, KnownOne}), + SC.IC.getInst(Inst::Ne, 1, {B, KnownZero})}); - // Do not find an input belonging to a derived abstract set. - Precondition = SC.IC.getInst(Inst::And, 1, {Precondition, New}); + if (NewPre) { + NewPre = SC.IC.getInst(Inst::Or, 1, {NewPre, VarConstraint}); + } else { + NewPre = VarConstraint; + } + + } + // Do not find an input belonging to a derived abstract set. + if (NewPre) { + Precondition = SC.IC.getInst(Inst::And, 1, {Precondition, NewPre}); } } + // Find one input for which the given transformation is valid Models.clear(); std::string Query = BuildQuery(SC.IC, SC.BPCs, PCCopy, Mapping, &ModelInsts, Precondition, true); - SMTSolver->isSatisfiable(Query, FoundWeakest, ModelInsts.size(), + S->isSatisfiable(Query, FoundWeakest, ModelInsts.size(), &ModelVals, SC.Timeout); std::map Known; if (FoundWeakest) { for (unsigned J = 0; J < ModelInsts.size(); ++J) { llvm::KnownBits KBCurrent(ModelInsts[J]->Width); - Known[ModelInsts[J]].One = ModelVals[J]; + if (FilteredVars.find(ModelInsts[J]) != FilteredVars.end()) { + Known[ModelInsts[J]].One = ModelVals[J]; + Known[ModelInsts[J]].Zero = ~ModelVals[J]; + } else { + auto Zero = llvm::APInt(ModelInsts[J]->Width, 0); + Known[ModelInsts[J]].One = Zero; + Known[ModelInsts[J]].Zero = Zero; + } + if (DebugLevel >= 3) { llvm::outs() << "Starting with : " << ModelVals[J] << "\n"; } - Known[ModelInsts[J]].Zero = ~ModelVals[J]; } } else { if (Results.empty()) { - llvm::outs() << "Transformation is not valid for any input.\n"; + if (DebugLevel >= 3) { + llvm::outs() << "Transformation is not valid for any input.\n"; + } return {}; } else { FoundWeakest = true; @@ -122,6 +319,7 @@ std::vector> break; } } + for (unsigned J = 0; J < Vars.size(); ++J) { Vars[J]->KnownZeros = Known[Vars[J]].Zero; Vars[J]->KnownOnes = Known[Vars[J]].One; diff --git a/lib/Infer/Pruning.cpp b/lib/Infer/Pruning.cpp index b5e55ff7c..fdfb2dc0a 100644 --- a/lib/Infer/Pruning.cpp +++ b/lib/Infer/Pruning.cpp @@ -37,7 +37,7 @@ namespace { static llvm::cl::opt EnableFB("souper-dataflow-pruning-fb", llvm::cl::desc("Prune with forced-bits analysis (default=true)"), - llvm::cl::init(true)); + llvm::cl::init(false)); static llvm::cl::opt EnableRB("souper-dataflow-pruning-rb", llvm::cl::desc("Prune with required-bits analysis (default=true)"), @@ -45,7 +45,7 @@ namespace { static llvm::cl::opt EnableBB("souper-dataflow-pruning-bb", llvm::cl::desc("Prune with bivalent-bits analysis (default=true)"), - llvm::cl::init(true)); + llvm::cl::init(false)); } namespace souper { @@ -294,7 +294,7 @@ bool PruningManager::isInfeasible(souper::Inst *RHS, if (C.hasValue()) { auto Val = C.getValue(); if (StatsLevel > 2) - llvm::errs() << " LHS value = " << Val << "\n"; + llvm::errs() << " LHS value = " << Val <<" - " < 2) @@ -414,7 +414,12 @@ bool PruningManager::isInfeasible(souper::Inst *RHS, } else { auto RHSV = ConcreteInterpreters[I].evaluateInst(RHS); if (RHSV.hasValue()) { - if (Val != RHSV.getValue()) { + auto RVal = RHSV.getValue(); + if (SC.LHS->DemandedBits != 0) { + Val &= SC.LHS->DemandedBits; + RVal &= SC.LHS->DemandedBits; + } + if (Val != RVal) { if (StatsLevel > 2) { llvm::errs() << " RHS value = " << RHSV.getValue() << "\n"; llvm::errs() << " pruned using concrete interpreter!\n"; @@ -600,7 +605,8 @@ void PruningManager::init() { }; } - ConcreteInterpreter BlankCI; + ValueCache C; + ConcreteInterpreter BlankCI(C); LHSKnownBitsNoSpec = KnownBitsAnalysis().findKnownBits(SC.LHS, BlankCI, false); LHSMustDemandedBits = MustDemandedBitsAnalysis().findMustDemandedBits(SC.LHS); improveMustDemandedBits(LHSMustDemandedBits); @@ -681,28 +687,38 @@ bool PruningManager::isInputValid(ValueCache &Cache) { void PruningManager::improveMustDemandedBits(InputVarInfo &IVI) { for (size_t i = 0; i < InputVals.size(); ++i) { - for (size_t j = 0; j < InputVals.size(); ++j) { - for (auto &Pair : IVI) { - auto Var = Pair.first; - auto &I1 = InputVals[i][Var]; - auto &I2 = InputVals[j][Var]; - if (I1.hasValue() && I1.hasValue() && - I1.getValue() != I2.getValue()) { - auto &MDB = Pair.second; - for (size_t k = 0; k < Var->Width; ++k) { - if (!MDB[k] && I1.getValue()[k] != I2.getValue()[k]) { - auto V1 = ConcreteInterpreters[i].evaluateInst(SC.LHS); - auto V2 = ConcreteInterpreters[j].evaluateInst(SC.LHS); - if (V1.hasValue() && V2.hasValue() && - V1.getValue() != V2.getValue()) { - MDB.setBit(k); - // If two input values of a variable differing in the - // k'th bit can produce differing outputs, the k'th - // is required/must demanded/important. - } - } - } + auto VC = InputVals[i]; + for (auto &Pair : IVI) { + auto Var = Pair.first; + + for (size_t k = 0; k < Var->Width; ++k) { + llvm::APInt Val = VC[Var].getValue(); + auto Val2 = Val; + if (Val[k]) { + Val2.clearBit(k); + } else { + Val2.setBit(k); + } + + ValueCache VC2 = VC; + VC2[Var] = EvalValue(Val2); + auto &MDB = Pair.second; + if (!isInputValid(VC2) || MDB[k]) { + continue; } + + auto V1 = ConcreteInterpreters[i].evaluateInst(SC.LHS); + auto V2 = ConcreteInterpreter(SC.LHS, VC2).evaluateInst(SC.LHS); + + if (V1.hasValue() && V2.hasValue() && + V1.getValue() != V2.getValue()) { + MDB.setBit(k); + // If two input values of a variable differing in the + // k'th bit BUT EQUAL IN ALL OTHER BITS can produce + // differing outputs, the k'th is required/must + // demanded/important. + } + } } } diff --git a/lib/Infer/SynthUtils.cpp b/lib/Infer/SynthUtils.cpp new file mode 100644 index 000000000..bbc0fa446 --- /dev/null +++ b/lib/Infer/SynthUtils.cpp @@ -0,0 +1,348 @@ +#include "souper/Infer/SynthUtils.h" +#include "souper/Infer/Pruning.h" + +namespace souper { + +Inst *Replace(Inst *R, InstContext &IC, std::map &M) { + std::map BlockCache; + std::map ConstMap; + return getInstCopy(R, IC, M, BlockCache, &ConstMap, false); +} + +ParsedReplacement Replace(ParsedReplacement I, InstContext &IC, + std::map &M) { + std::map BlockCache; + std::map ConstMap; + + I.Mapping.LHS = getInstCopy(I.Mapping.LHS, IC, M, BlockCache, &ConstMap, false); + I.Mapping.RHS = getInstCopy(I.Mapping.RHS, IC, M, BlockCache, &ConstMap, false); + + for (auto &PC : I.PCs) { + PC.LHS = getInstCopy(PC.LHS, IC, M, BlockCache, &ConstMap, false, false); + PC.RHS = getInstCopy(PC.RHS, IC, M, BlockCache, &ConstMap, false, false); + } + + return I; +} + +Inst *Clone(Inst *R, InstContext &IC) { + std::map BlockCache; + std::map ConstMap; + std::map InstCache; + return getInstCopy(R, IC, InstCache, BlockCache, &ConstMap, true, false); +} + +InstMapping Clone(InstMapping In, InstContext &IC) { + std::map BlockCache; + std::map ConstMap; + std::map InstCache; + InstMapping Out; + Out.LHS = getInstCopy(In.LHS, IC, InstCache, BlockCache, &ConstMap, true, false); + Out.RHS = getInstCopy(In.RHS, IC, InstCache, BlockCache, &ConstMap, true, false); + return Out; +} + +ParsedReplacement Clone(ParsedReplacement In, InstContext &IC) { + std::map BlockCache; + std::map ConstMap; + std::map InstCache; + std::vector RHSVars; + findVars(In.Mapping.RHS, RHSVars); + In.Mapping.LHS = getInstCopy(In.Mapping.LHS, IC, InstCache, BlockCache, &ConstMap, true, false); + In.Mapping.RHS = getInstCopy(In.Mapping.RHS, IC, InstCache, BlockCache, &ConstMap, true, false); + for (auto &PC : In.PCs) { + PC.LHS = getInstCopy(PC.LHS, IC, InstCache, BlockCache, &ConstMap, false, false); + PC.RHS = getInstCopy(PC.RHS, IC, InstCache, BlockCache, &ConstMap, false, false); + } + + return In; +} + +// bool IsValid(ParsedReplacement Input, InstContext &IC, Solver *S) { +// if (Input.PCs.empty()) { +// SynthesisContext SC{IC, S->getSMTLIBSolver(), Input.Mapping.LHS, nullptr, +// Input.PCs,Input.BPCs, false, 15}; +// std::vector Vars; +// findVars(Input.Mapping.LHS, Vars); + +// PruningManager Pruner(SC, Vars, 0); +// Pruner.init(); + +// if (Pruner.isInfeasible(Input.Mapping.RHS, 0)) { +// return false; +// } +// } + +// bool IsValid; +// if (auto EC = S->isValid(IC, Input.BPCs, Input.PCs, Input.Mapping, IsValid, nullptr)) { +// llvm::errs() << EC.message() << '\n'; +// } +// return IsValid; +// } + +//std::map ConstantSynthesis + +// Also Synthesizes given constants +// Returns clone if verified, nullptrs if not +std::optional Verify(ParsedReplacement Input, InstContext &IC, Solver *S) { + + // if (Input.PCs.empty()) { + // SynthesisContext SC{IC, S->getSMTLIBSolver(), Input.Mapping.LHS, nullptr, + // Input.PCs,Input.BPCs, false, 15}; + // std::vector Vars; + // findVars(Input.Mapping.LHS, Vars); + + // PruningManager Pruner(SC, Vars, 0); + // Pruner.init(); + + // if (Pruner.isInfeasible(Input.Mapping.RHS, 0)) { + // Input.Mapping.LHS = nullptr; + // Input.Mapping.RHS = nullptr; + // return Input; + // } + // } + // Input.print(llvm::errs(), true); + Input = Clone(Input, IC); + std::set ConstSet; + souper::getConstants(Input.Mapping.RHS, ConstSet); + souper::getConstants(Input.Mapping.LHS, ConstSet); + if (!ConstSet.empty()) { + std::map ResultConstMap; + ConstantSynthesis CS; + auto SMTSolver = S->getSMTLIBSolver(); + + auto EC = CS.synthesize(SMTSolver, Input.BPCs, Input.PCs, + Input.Mapping, ConstSet, + ResultConstMap, IC, /*MaxTries=*/30, 10, + /*AvoidNops=*/true); + if (!ResultConstMap.empty()) { + std::map InstCache; + std::map BlockCache; + auto LHSCopy = getInstCopy(Input.Mapping.LHS, IC, InstCache, BlockCache, &ResultConstMap, true); + auto RHS = getInstCopy(Input.Mapping.RHS, IC, InstCache, BlockCache, &ResultConstMap, true); + Input.Mapping = InstMapping(LHSCopy, RHS); + for (auto &PC : Input.PCs) { + PC.LHS = getInstCopy(PC.LHS, IC, InstCache, BlockCache, &ResultConstMap, true); + PC.RHS = getInstCopy(PC.RHS, IC, InstCache, BlockCache, &ResultConstMap, true); + } + return Input; + } else { + if (DebugLevel > 2) { + llvm::errs() << "Constant Synthesis ((no Dataflow Preconditions)) failed. \n"; + } + } + return std::nullopt; + } + std::vector> Models; + bool IsValid; + if (auto EC = S->isValid(IC, Input.BPCs, Input.PCs, Input.Mapping, IsValid, &Models)) { + llvm::errs() << EC.message() << '\n'; + } + if (IsValid) { + return Input; + } else { + static int C = 0; +// llvm::errs() << "C " << C++ << '\n'; + return std::nullopt; + // TODO: Better failure indication? + } +} + +std::map findOneConstSet(ParsedReplacement Input, const std::set &SymCS, InstContext &IC, Solver *S) { + + std::map InstCache; + + std::set SynthCS; + + size_t cid = 0; + for (auto C : SymCS) { + InstCache[C] = IC.createSynthesisConstant(C->Width, cid++); + SynthCS.insert(InstCache[C]); + } + Input = Replace(Input, IC, InstCache); + + std::map ResultConstMap; + ConstantSynthesis CS; + auto EC = CS.synthesize(S->getSMTLIBSolver(), Input.BPCs, Input.PCs, + Input.Mapping, SynthCS, + ResultConstMap, IC, /*MaxTries=*/30, 10, + /*AvoidNops=*/true); + + std::map Result; + if (!ResultConstMap.empty()) { + for (auto C : SymCS) { + Result[C] = ResultConstMap[InstCache[C]]; + } + } + + return Result; + +} + +std::vector> findValidConsts(ParsedReplacement Input, const std::set &Insts, InstContext &IC, Solver *S, size_t MaxCount = 1) { + + // FIXME: Ignores Count + std::vector> Results; + + Inst *T = IC.getConst(llvm::APInt(1, 1)); // true + Inst *F = IC.getConst(llvm::APInt(1, 0)); // false + + while (MaxCount-- ) { + auto &&Result = findOneConstSet(Input, Insts, IC, S); + if (Result.empty()) { + break; + } else { + Results.push_back(Result); + for (auto I : Result) { + Input.PCs.push_back({Builder(I.first, IC).Ne(I.second)(), T}); + } + } + } + + return Results; +} + +// Find a single counterexample +ValueCache GetCEX(const ParsedReplacement &Input, InstContext &IC, Solver *S) { + std::vector Vars; + findVars(Input.Mapping.LHS, Vars); + findVars(Input.Mapping.RHS, Vars); + std::vector> Models; + bool IsValid; + if (auto EC = S->isValid(IC, Input.BPCs, Input.PCs, Input.Mapping, IsValid, &Models)) { + llvm::errs() << EC.message() << '\n'; + } + if (IsValid) { + return ValueCache(); + } else { + ValueCache Result; + for (auto &V : Vars) { + for (auto &M : Models) { + if (M.first == V) { + Result[V] = M.second; + } + } + } + return Result; + } +} + +// ParsedReplacement MakeDummyConstexprs(ParsedReplacement Input, InstContext &IC) { +// std::map InstCache; + +// std::vector Stack{Input.Mapping.RHS}; + +// std::set Visited; + +// size_t DummyConstID = 0; + +// // DFS to find all RHS constants +// while (!Stack.empty()) { +// auto I = Stack.back(); +// Stack.pop_back(); +// Visited.insert(I); + +// if (I->K == Inst::Const) { +// if (InstCache.find(I) == InstCache.end()) { +// InstCache[I] = IC.createVar(I->Width, "dummy" + std::to_string(DummyConstID++)); +// } +// } else { +// for (auto &&Op : I->Ops) { +// if (Visited.find(Op) == Visited.end()) { +// Stack.push_back(Op); +// } +// } +// } +// } +// if (!InstCache.empty()) { +// Input.Mapping.RHS = Replace(Input.Mapping.RHS, IC, InstCache); +// } +// return Input; +// } + +bool hasRHSConsts(ParsedReplacement Input) { + std::vector Stack{Input.Mapping.RHS}; + + std::set Visited; + + // DFS to find all RHS constants + while (!Stack.empty()) { + auto I = Stack.back(); + Stack.pop_back(); + Visited.insert(I); + + if (I->K == Inst::Const) { + return true; + } else { + for (auto &&Op : I->Ops) { + if (Visited.find(Op) == Visited.end()) { + Stack.push_back(Op); + } + } + } + } + return false; +} + +std::vector GetMultipleCEX(ParsedReplacement Input, InstContext &IC, Solver *S, size_t MaxCount = 2) { + // auto Input = MakeDummyConstexprs(Original, IC); + + // Is there a way to get a CEX when there are RHS constants? + + if (hasRHSConsts(Input)) { + return {}; + } + + std::vector Results; + while (MaxCount--) { + auto &&Result = GetCEX(Input, IC, S); + if (Result.empty()) { + return Results; + } + for (auto &&CEX : Result) { + if (!CEX.second.hasValue()) { + return Results; + } + } + Results.push_back(Result); + for (auto &&CEX : Result) { + Input.PCs.push_back({Builder(IC, CEX.first).Ne(CEX.second.getValue())(), IC.getConst(llvm::APInt(1, 1))}); + } + } + return Results; +} + +void tagConstExprs(Inst *I, std::set &Set) { + if (I->K == Inst::Const || (I->K == Inst::Var && I->Name.starts_with("sym"))) { + Set.insert(I); + } else { + for (auto Op : I->Ops) { + tagConstExprs(Op, Set); + } + } + + if (I->Ops.size() > 0) { + bool foundNonConst = false; + for (auto Op : I->Ops) { + if (Set.find(Op) == Set.end()) { + foundNonConst = true; + break; + } + } + if (!foundNonConst) { + Set.insert(I); + } + } +} + +size_t constAwareCost(Inst *I) { + std::set ConstExprs; + tagConstExprs(I, ConstExprs); + return souper::cost(I, false, ConstExprs); +} + +int profit(const ParsedReplacement &P) { + return constAwareCost(P.Mapping.LHS) - constAwareCost(P.Mapping.RHS); +} + +} diff --git a/lib/Inst/Inst.cpp b/lib/Inst/Inst.cpp index 5ee10945d..4d2382f12 100644 --- a/lib/Inst/Inst.cpp +++ b/lib/Inst/Inst.cpp @@ -20,6 +20,7 @@ #include #include +#include using namespace souper; @@ -116,6 +117,9 @@ std::string ReplacementContext::printInst(Inst *I, llvm::raw_ostream &Out, std::string ReplacementContext::printInstImpl(Inst *I, llvm::raw_ostream &Out, bool printNames, Inst *OrigI) { + if (printNames && I->Name.length() != 0 && std::isdigit(I->Name[0])) { + I->Name = "v" + I->Name; + } std::string Str; llvm::raw_string_ostream SS(Str); @@ -168,7 +172,12 @@ std::string ReplacementContext::printInstImpl(Inst *I, llvm::raw_ostream &Out, } } - std::string InstName = std::to_string(InstNames.size() + BlockNames.size()); + std::string InstName; + if (printNames && !I->Name.empty()) { + InstName = I->Name; + } else { + InstName = std::to_string(InstNames.size() + BlockNames.size()); + } assert(InstNames.find(I) == InstNames.end()); assert(NameToBlock.find(InstName) == NameToBlock.end()); setInst(InstName, I); @@ -189,6 +198,7 @@ std::string ReplacementContext::printInstImpl(Inst *I, llvm::raw_ostream &Out, if (I->KnownZeros.getBoolValue() || I->KnownOnes.getBoolValue()) Out << " (knownBits=" << Inst::getKnownBitsString(I->KnownZeros, I->KnownOnes) << ")"; + if (I->NonNegative) Out << " (nonNegative)"; if (I->Negative) @@ -418,6 +428,18 @@ const char *Inst::getKindName(Kind K) { return "cttz"; case Ctlz: return "ctlz"; + case LogB: + return "logb"; + case BitWidth: + return "width"; + case KnownOnesP: + return "knownones"; + case KnownZerosP: + return "knownzeros"; + case RangeP: + return "range"; + case DemandedMask: + return "demandedmask"; case FShl: return "fshl"; case FShr: @@ -512,6 +534,12 @@ Inst::Kind Inst::getKind(std::string Name) { .Case("bitreverse", Inst::BitReverse) .Case("cttz", Inst::Cttz) .Case("ctlz", Inst::Ctlz) + .Case("logb", Inst::LogB) + .Case("width", Inst::BitWidth) + .Case("knownones", Inst::KnownOnesP) + .Case("knownzeros", Inst::KnownZerosP) + .Case("range", Inst::RangeP) + .Case("demandedmask", Inst::DemandedMask) .Case("fshl", Inst::FShl) .Case("fshr", Inst::FShr) .Case("sadd.with.overflow", Inst::SAddWithOverflow) @@ -787,6 +815,21 @@ std::vector InstContext::getVariables() const { return AllVariables; }; +std::vector InstContext::getVariablesFor(Inst *I) const { + std::vector AllVariables; + findVars(I, AllVariables); + + std::sort(AllVariables.begin(), AllVariables.end(), + [](const Inst *LHS, const Inst *RHS) { + if (LHS->Width == RHS->Width) + return LHS->Number < RHS->Number; + else + return LHS->Width < RHS->Width; + }); + + return AllVariables; +}; + bool Inst::isCommutative(Inst::Kind K) { switch (K) { case Add: @@ -951,12 +994,11 @@ static int costHelper(Inst *I, Inst *Root, std::set &Visited, return Cost; } -int souper::cost(Inst *I, bool IgnoreDepsWithExternalUses) { - std::set Visited; +int souper::cost(Inst *I, bool IgnoreDepsWithExternalUses, std::set Ignore) { + std::set Visited = Ignore; return costHelper(I, I, Visited, IgnoreDepsWithExternalUses); } - int souper::countHelper(Inst *I, std::set &Visited) { if (!Visited.insert(I).second) return 0; @@ -981,8 +1023,8 @@ int souper::instCount(Inst *I) { return countHelper(I, Visited); } -int souper::benefit(Inst *LHS, Inst *RHS) { - return cost(LHS, /*IgnoreDepsWithExternalUses=*/true) - cost(RHS); +int souper::benefit(Inst *LHS, Inst *RHS, bool IgnoreDepsWithExternalUses) { + return cost(LHS, IgnoreDepsWithExternalUses) - cost(RHS); } void souper::PrintReplacement(llvm::raw_ostream &Out, @@ -1047,7 +1089,7 @@ std::string souper::GetReplacementLHSString(const BlockPCs &BPCs, Inst *LHS, ReplacementContext &Context, bool printNames) { std::string Str; llvm::raw_string_ostream SS(Str); - PrintReplacementLHS(SS, BPCs, PCs, LHS, Context); + PrintReplacementLHS(SS, BPCs, PCs, LHS, Context, printNames); return SS.str(); } @@ -1125,7 +1167,7 @@ void souper::findInsts(Inst *Root, std::vector &Insts, std::function &Visited, std::set &ConstSet) { - if (I->K == Inst::Var && I->SynthesisConstID != 0) { + if (I->K == Inst::Var && (I->SynthesisConstID != 0 || I->Name.starts_with("reserved"))) { ConstSet.insert(I); } else { if (Visited.insert(I).second) @@ -1172,7 +1214,7 @@ Inst *souper::getInstCopy(Inst *I, InstContext &IC, std::map &InstCache, std::map &BlockCache, std::map *ConstMap, - bool CloneVars) { + bool CloneVars, bool CloneBlocks) { if (InstCache.count(I)) return InstCache.at(I); @@ -1197,21 +1239,26 @@ Inst *souper::getInstCopy(Inst *I, InstContext &IC, } } if (!Copy) { - if (CloneVars && I->SynthesisConstID == 0) + if (CloneVars && I->SynthesisConstID == 0) { Copy = IC.createVar(I->Width, I->Name, I->Range, I->KnownZeros, I->KnownOnes, I->NonZero, I->NonNegative, I->PowOfTwo, I->Negative, I->NumSignBits, I->DemandedBits, I->SynthesisConstID); + } else { Copy = I; } } } else if (I->K == Inst::Phi) { if (!BlockCache.count(I->B)) { - auto BlockCopy = IC.createBlock(I->B->Preds); - BlockCache[I->B] = BlockCopy; - Copy = IC.getPhi(BlockCopy, Ops, I->DemandedBits); + if (CloneBlocks) { + auto BlockCopy = IC.createBlock(I->B->Preds); + BlockCache[I->B] = BlockCopy; + Copy = IC.getPhi(BlockCopy, Ops, I->DemandedBits); + } else { + Copy = IC.getPhi(I->B, Ops, I->DemandedBits); + } } else { Copy = IC.getPhi(BlockCache.at(I->B), Ops, I->DemandedBits); } @@ -1222,6 +1269,7 @@ Inst *souper::getInstCopy(Inst *I, InstContext &IC, } assert(Copy); InstCache[I] = Copy; + Copy->Name = I->Name; return Copy; } diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp index f53e5791f..5be83e303 100644 --- a/lib/Parser/Parser.cpp +++ b/lib/Parser/Parser.cpp @@ -544,6 +544,9 @@ bool Parser::typeCheckInst(Inst::Kind IK, unsigned &Width, case Inst::UAddSat: case Inst::SSubSat: case Inst::USubSat: + case Inst::KnownOnesP: + case Inst::KnownZerosP: + case Inst::DemandedMask: MinOps = MaxOps = 2; break; @@ -590,11 +593,14 @@ bool Parser::typeCheckInst(Inst::Kind IK, unsigned &Width, case Inst::BitReverse: case Inst::Cttz: case Inst::Ctlz: + case Inst::LogB: + case Inst::BitWidth: case Inst::Freeze: MaxOps = MinOps = 1; break; case Inst::FShl: case Inst::FShr: + case Inst::RangeP: MaxOps = MinOps = 3; break; @@ -658,6 +664,9 @@ bool Parser::typeCheckInst(Inst::Kind IK, unsigned &Width, case Inst::Slt: case Inst::Ule: case Inst::Sle: + case Inst::KnownOnesP: + case Inst::KnownZerosP: + case Inst::RangeP: ExpectedWidth = 1; break; diff --git a/lib/Pass/Pass.cpp b/lib/Pass/Pass.cpp index 997f5299e..785f07c63 100644 --- a/lib/Pass/Pass.cpp +++ b/lib/Pass/Pass.cpp @@ -18,8 +18,10 @@ #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/CodeGen/UnreachableBlockElim.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -33,6 +35,7 @@ #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Scalar/ADCE.h" #include "llvm/Transforms/Scalar/DCE.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -197,6 +200,29 @@ struct SouperPass : public ModulePass { .getValue(I); } + bool hasNewPhi(InstMapping Cand) { + auto PhiCheck = [](Inst *I) {return I->K == Inst::Phi;}; + std::vector Insts; + + findInsts(Cand.LHS, Insts, PhiCheck); + + std::set LHSPhis; + for (auto &I : Insts) { + LHSPhis.insert(I); + } + Insts.clear(); + + findInsts(Cand.RHS, Insts, PhiCheck); + + for (auto &&I : Insts) { + if (LHSPhis.find(I) == LHSPhis.end()) { + return true; + } + } + + return false; + } + bool runOnFunction(Function *F) { std::string FunctionName; if (F->hasLocalLinkage()) { @@ -233,6 +259,17 @@ struct SouperPass : public ModulePass { if (!TLI) report_fatal_error("getTLI() failed"); + // Run UnreachableBlockElim and ADCE locally + // TODO: In the long run, switch this tool to the new pass manager. + FunctionPassManager FM; + FunctionAnalysisManager FAM; + FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + FAM.registerPass([&] { return DominatorTreeAnalysis(); }); + FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); + FM.addPass(UnreachableBlockElimPass()); + FM.addPass(ADCEPass()); + FM.run(*F, FAM); + FunctionCandidateSet CS = ExtractCandidatesFromPass(F, LI, DB, LVI, SE, TLI, IC, EBC); if (DebugLevel > 3) @@ -277,7 +314,12 @@ struct SouperPass : public ModulePass { EC == std::errc::value_too_large) { continue; } else { - report_fatal_error("Unable to query solver: " + EC.message() + "\n"); + llvm::errs() << "[FIXME: Crash commented out]\nUnable to query solver: " + EC.message() + "\n"; + continue; + // TODO: This is a temporary workaround to suppress a protocol error which is encountered + // once in SPEC 2017. This workaround does not have a negative effect other than maybe + // missing one potential transformation. + //report_fatal_error("Unable to query solver: " + EC.message() + "\n"); } } if (RHSs.empty()) @@ -285,8 +327,12 @@ struct SouperPass : public ModulePass { Cand.Mapping.RHS = RHSs.front(); + if (hasNewPhi(Cand.Mapping)) { + continue; + } + Instruction *I = Cand.Origin; - assert(Cand.Mapping.LHS->hasOrigin(I)); + assert(Cand.Mapping.LHS->K == Inst::Const || Cand.Mapping.LHS->hasOrigin(I)); IRBuilder<> Builder(I); Value *NewVal = getValue(Cand.Mapping.RHS, I, EBC, DT, diff --git a/lib/Tool/CandidateMapUtils.cpp b/lib/Tool/CandidateMapUtils.cpp index df3b3fa7d..b1787eb9f 100644 --- a/lib/Tool/CandidateMapUtils.cpp +++ b/lib/Tool/CandidateMapUtils.cpp @@ -14,7 +14,7 @@ #include "souper/Tool/CandidateMapUtils.h" #include "souper/Util/DfaUtils.h" - +#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" @@ -24,13 +24,124 @@ #include "llvm/Support/raw_ostream.h" #include "souper/KVStore/KVStore.h" #include "souper/SMTLIB2/Solver.h" - +#include "souper/Infer/SynthUtils.h" +#include "llvm/Transforms/Scalar/ADCE.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" void souper::AddToCandidateMap(CandidateMap &M, const CandidateReplacement &CR) { M.emplace_back(CR); } +// limited implementation to filter out specific buggy harvests +bool isWellTyped(souper::Inst *I) { + if (I->K == souper::Inst::Kind::Select) { + return I->Ops[0]->Width == 1 + && (I->Ops[1]->Width == I->Ops[2]->Width) + && isWellTyped(I->Ops[1]) + && isWellTyped(I->Ops[2]); + } + + if (I->Ops.size() == 2) { + return (I->Ops[0]->Width == I->Ops[1]->Width) + && isWellTyped(I->Ops[0]) + && isWellTyped(I->Ops[1]); + } + return true; +} + +bool isProfitable(souper::ParsedReplacement &R) { + if (R.Mapping.LHS->Width == 1 || R.Mapping.RHS->K == souper::Inst::Select) { + return true; + // TODO: Improved cost model for Select + } else { + return souper::benefit(R.Mapping.LHS, R.Mapping.RHS, false) >= 0; + } +} + +void souper::HarvestAndPrintOpts(InstContext &IC, ExprBuilderContext &EBC, llvm::Module *M, Solver *S) { + legacy::FunctionPassManager P(M); + P.add(createInstructionCombiningPass()); + P.doInitialization(); + for (auto &F : *M) { + std::vector LHSs, RHSs; + FunctionCandidateSet CS1 = ExtractCandidates(&F, IC, EBC); + for (auto &B : CS1.Blocks) { + for (auto &R : B->Replacements) { + LHSs.push_back(R.Mapping.LHS); + } + } + +// F.dump(); + P.run(F); +// F.dump(); + + FunctionCandidateSet CS2 = ExtractCandidates(&F, IC, EBC); + for (auto &B : CS2.Blocks) { + for (auto &R : B->Replacements) { + RHSs.push_back(R.Mapping.LHS); + } + } + + for (auto LHS : LHSs) { + std::vector LHSVars; + findVars(LHS, LHSVars); + std::set LHSVarSet; + for (auto V : LHSVars) { + LHSVarSet.insert(V); + } + + if (!isWellTyped(LHS)) { + continue; + } + + for (auto RHS : RHSs) { + if (!isWellTyped(RHS)) { + continue; + } + + if (LHS != RHS && LHS->Width == RHS->Width) { + if ((RHS->K == Inst::ZExt || RHS->K == Inst::Freeze) + && RHS->Ops[0] == LHS) { + continue; + } + + std::vector RHSVars; + findVars(RHS, RHSVars); + std::set RHSVarSet; + for (auto RV : RHSVars) { + if (LHSVarSet.find(RV) == LHSVarSet.end()) { + continue; + } + RHSVarSet.insert(RV); + } + if (LHSVarSet.size() != RHSVarSet.size()) { + continue; + } + ParsedReplacement Rep; + Rep.Mapping.LHS = LHS; + Rep.Mapping.RHS = RHS; + +// Rep.print(llvm::outs(), true); + + if (Verify(Rep, IC, S)) { + if (isProfitable(Rep)) { + BlockPCs BPCs; + std::vector PCs; + ReplacementContext RC; + souper::PrintReplacementLHS(llvm::outs(), BPCs, PCs, Rep.Mapping.LHS, RC, true); + souper::PrintReplacementRHS(llvm::outs(), Rep.Mapping.RHS, RC, true); + llvm::outs() << "\n"; + llvm::outs().flush(); + } + } + } + } + } + } + P.doFinalization(); +} + void souper::AddModuleToCandidateMap(InstContext &IC, ExprBuilderContext &EBC, CandidateMap &CandMap, llvm::Module *M) { for (auto &F : *M) { @@ -41,6 +152,7 @@ void souper::AddModuleToCandidateMap(InstContext &IC, ExprBuilderContext &EBC, } } } + } namespace souper { diff --git a/test/Codegen/argument-order-XFAIL.opt b/test/Codegen/argument-order.opt similarity index 71% rename from test/Codegen/argument-order-XFAIL.opt rename to test/Codegen/argument-order.opt index c11b5c7ba..8a5b616a6 100644 --- a/test/Codegen/argument-order-XFAIL.opt +++ b/test/Codegen/argument-order.opt @@ -10,10 +10,10 @@ result %4 ; CHECK: ; ModuleID = 'souper.ll' ; CHECK-NEXT: source_filename = "souper.ll" -; CHECK: define i8 @fun(i16 %0, i4 %1) { +; CHECK: define i8 @fun(i4 %0, i16 %1) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %2 = zext i4 %1 to i8 -; CHECK-NEXT: %3 = trunc i16 %0 to i8 +; CHECK-NEXT: %2 = zext i4 %0 to i8 +; CHECK-NEXT: %3 = trunc i16 %1 to i8 ; CHECK-NEXT: %4 = and i8 %2, %3 ; CHECK-NEXT: ret i8 %4 ; CHECK-NEXT: } diff --git a/test/Codegen/inst-bitreverse.opt b/test/Codegen/inst-bitreverse.opt index 71e2387ed..8ab07c0e1 100644 --- a/test/Codegen/inst-bitreverse.opt +++ b/test/Codegen/inst-bitreverse.opt @@ -10,7 +10,8 @@ result %1 ; CHECK-NEXT: ret i8 %1 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare i8 @llvm.bitreverse.i8(i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } + diff --git a/test/Codegen/inst-bswap.opt b/test/Codegen/inst-bswap.opt index 23d6dbc76..4446a6482 100644 --- a/test/Codegen/inst-bswap.opt +++ b/test/Codegen/inst-bswap.opt @@ -10,7 +10,7 @@ result %1 ; CHECK-NEXT: ret i16 %1 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare i16 @llvm.bswap.i16(i16) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-ctlz.opt b/test/Codegen/inst-ctlz.opt index fbef3443b..91009c007 100644 --- a/test/Codegen/inst-ctlz.opt +++ b/test/Codegen/inst-ctlz.opt @@ -10,7 +10,7 @@ result %1 ; CHECK-NEXT: ret i8 %1 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare i8 @llvm.ctlz.i8(i8, i1 immarg) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-ctpop.opt b/test/Codegen/inst-ctpop.opt index 221a0a138..a4b1fabac 100644 --- a/test/Codegen/inst-ctpop.opt +++ b/test/Codegen/inst-ctpop.opt @@ -10,7 +10,7 @@ result %1 ; CHECK-NEXT: ret i8 %1 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare i8 @llvm.ctpop.i8(i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-cttz.opt b/test/Codegen/inst-cttz.opt index e10fd102a..29c02bed7 100644 --- a/test/Codegen/inst-cttz.opt +++ b/test/Codegen/inst-cttz.opt @@ -10,7 +10,7 @@ result %1 ; CHECK-NEXT: ret i8 %1 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare i8 @llvm.cttz.i8(i8, i1 immarg) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-fshl.opt b/test/Codegen/inst-fshl.opt index c68e4247c..7c7932a30 100644 --- a/test/Codegen/inst-fshl.opt +++ b/test/Codegen/inst-fshl.opt @@ -12,7 +12,7 @@ result %3 ; CHECK-NEXT: ret i8 %3 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare i8 @llvm.fshl.i8(i8, i8, i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-fshr.opt b/test/Codegen/inst-fshr.opt index e1b718c62..795d869c0 100644 --- a/test/Codegen/inst-fshr.opt +++ b/test/Codegen/inst-fshr.opt @@ -12,7 +12,7 @@ result %3 ; CHECK-NEXT: ret i8 %3 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare i8 @llvm.fshr.i8(i8, i8, i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-sadd_with_overflow-extract-overflow-XFAIL.opt b/test/Codegen/inst-sadd_with_overflow-extract-overflow.opt similarity index 73% rename from test/Codegen/inst-sadd_with_overflow-extract-overflow-XFAIL.opt rename to test/Codegen/inst-sadd_with_overflow-extract-overflow.opt index 41d0855d8..8a1eb4601 100644 --- a/test/Codegen/inst-sadd_with_overflow-extract-overflow-XFAIL.opt +++ b/test/Codegen/inst-sadd_with_overflow-extract-overflow.opt @@ -14,7 +14,7 @@ result %3 ; CHECK-NEXT: ret i1 %4 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare { i8, i1 } @llvm.sadd.with.overflow.i8(i8, i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-sadd_with_overflow-extract-val-XFAIL.opt b/test/Codegen/inst-sadd_with_overflow-extract-val.opt similarity index 73% rename from test/Codegen/inst-sadd_with_overflow-extract-val-XFAIL.opt rename to test/Codegen/inst-sadd_with_overflow-extract-val.opt index 29459432a..4a71cc90b 100644 --- a/test/Codegen/inst-sadd_with_overflow-extract-val-XFAIL.opt +++ b/test/Codegen/inst-sadd_with_overflow-extract-val.opt @@ -14,7 +14,7 @@ result %3 ; CHECK-NEXT: ret i8 %4 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare { i8, i1 } @llvm.sadd.with.overflow.i8(i8, i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-sadd_with_overflow-XFAIL.opt b/test/Codegen/inst-sadd_with_overflow.opt similarity index 71% rename from test/Codegen/inst-sadd_with_overflow-XFAIL.opt rename to test/Codegen/inst-sadd_with_overflow.opt index 7b594b304..d519597a4 100644 --- a/test/Codegen/inst-sadd_with_overflow-XFAIL.opt +++ b/test/Codegen/inst-sadd_with_overflow.opt @@ -12,7 +12,7 @@ result %2 ; CHECK-NEXT: ret { i8, i1 } %3 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willretur ; CHECK-NEXT: declare { i8, i1 } @llvm.sadd.with.overflow.i8(i8, i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-smul_with_overflow-extract-overflow-XFAIL.opt b/test/Codegen/inst-smul_with_overflow-extract-overflow.opt similarity index 73% rename from test/Codegen/inst-smul_with_overflow-extract-overflow-XFAIL.opt rename to test/Codegen/inst-smul_with_overflow-extract-overflow.opt index c3e63ddc9..cde076cb9 100644 --- a/test/Codegen/inst-smul_with_overflow-extract-overflow-XFAIL.opt +++ b/test/Codegen/inst-smul_with_overflow-extract-overflow.opt @@ -14,7 +14,7 @@ result %3 ; CHECK-NEXT: ret i1 %4 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare { i8, i1 } @llvm.smul.with.overflow.i8(i8, i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-smul_with_overflow-extract-val-XFAIL.opt b/test/Codegen/inst-smul_with_overflow-extract-val.opt similarity index 66% rename from test/Codegen/inst-smul_with_overflow-extract-val-XFAIL.opt rename to test/Codegen/inst-smul_with_overflow-extract-val.opt index 6e64a426d..70929eddb 100644 --- a/test/Codegen/inst-smul_with_overflow-extract-val-XFAIL.opt +++ b/test/Codegen/inst-smul_with_overflow-extract-val.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 0 ; CHECK-NEXT: ret i8 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.smul.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-smul_with_overflow-XFAIL.opt b/test/Codegen/inst-smul_with_overflow.opt similarity index 71% rename from test/Codegen/inst-smul_with_overflow-XFAIL.opt rename to test/Codegen/inst-smul_with_overflow.opt index 8462dc034..5dec61884 100644 --- a/test/Codegen/inst-smul_with_overflow-XFAIL.opt +++ b/test/Codegen/inst-smul_with_overflow.opt @@ -12,7 +12,7 @@ result %2 ; CHECK-NEXT: ret { i8, i1 } %3 ; CHECK-NEXT: } -; CHECK: ; Function Attrs: nounwind readnone speculatable +; CHECK: ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn ; CHECK-NEXT: declare { i8, i1 } @llvm.smul.with.overflow.i8(i8, i8) #0 -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } +; CHECK: attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-ssub_with_overflow-extract-overflow-XFAIL.opt b/test/Codegen/inst-ssub_with_overflow-extract-overflow.opt similarity index 66% rename from test/Codegen/inst-ssub_with_overflow-extract-overflow-XFAIL.opt rename to test/Codegen/inst-ssub_with_overflow-extract-overflow.opt index 2096c817c..ba8234f8d 100644 --- a/test/Codegen/inst-ssub_with_overflow-extract-overflow-XFAIL.opt +++ b/test/Codegen/inst-ssub_with_overflow-extract-overflow.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 1 ; CHECK-NEXT: ret i1 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.ssub.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-ssub_with_overflow-extract-val-XFAIL.opt b/test/Codegen/inst-ssub_with_overflow-extract-val.opt similarity index 66% rename from test/Codegen/inst-ssub_with_overflow-extract-val-XFAIL.opt rename to test/Codegen/inst-ssub_with_overflow-extract-val.opt index b1d9345d7..faaef7f90 100644 --- a/test/Codegen/inst-ssub_with_overflow-extract-val-XFAIL.opt +++ b/test/Codegen/inst-ssub_with_overflow-extract-val.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 0 ; CHECK-NEXT: ret i8 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.ssub.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-ssub_with_overflow-XFAIL.opt b/test/Codegen/inst-ssub_with_overflow.opt similarity index 62% rename from test/Codegen/inst-ssub_with_overflow-XFAIL.opt rename to test/Codegen/inst-ssub_with_overflow.opt index 10b1432e1..dbffff03c 100644 --- a/test/Codegen/inst-ssub_with_overflow-XFAIL.opt +++ b/test/Codegen/inst-ssub_with_overflow.opt @@ -11,8 +11,3 @@ result %2 ; CHECK-NEXT: %3 = call { i8, i1 } @llvm.ssub.with.overflow.i8(i8 %0, i8 %1) ; CHECK-NEXT: ret { i8, i1 } %3 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.ssub.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-uadd_with_overflow-extract-overflow-XFAIL.opt b/test/Codegen/inst-uadd_with_overflow-extract-overflow.opt similarity index 66% rename from test/Codegen/inst-uadd_with_overflow-extract-overflow-XFAIL.opt rename to test/Codegen/inst-uadd_with_overflow-extract-overflow.opt index f7fa211f5..8c353d122 100644 --- a/test/Codegen/inst-uadd_with_overflow-extract-overflow-XFAIL.opt +++ b/test/Codegen/inst-uadd_with_overflow-extract-overflow.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 1 ; CHECK-NEXT: ret i1 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.uadd.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-uadd_with_overflow-extract-val-XFAIL.opt b/test/Codegen/inst-uadd_with_overflow-extract-val.opt similarity index 66% rename from test/Codegen/inst-uadd_with_overflow-extract-val-XFAIL.opt rename to test/Codegen/inst-uadd_with_overflow-extract-val.opt index 3a93ca85a..5f1fec59e 100644 --- a/test/Codegen/inst-uadd_with_overflow-extract-val-XFAIL.opt +++ b/test/Codegen/inst-uadd_with_overflow-extract-val.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 0 ; CHECK-NEXT: ret i8 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.uadd.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-uadd_with_overflow-XFAIL.opt b/test/Codegen/inst-uadd_with_overflow.opt similarity index 62% rename from test/Codegen/inst-uadd_with_overflow-XFAIL.opt rename to test/Codegen/inst-uadd_with_overflow.opt index 5fcf26938..3385f136b 100644 --- a/test/Codegen/inst-uadd_with_overflow-XFAIL.opt +++ b/test/Codegen/inst-uadd_with_overflow.opt @@ -11,8 +11,3 @@ result %2 ; CHECK-NEXT: %3 = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 %0, i8 %1) ; CHECK-NEXT: ret { i8, i1 } %3 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.uadd.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-umul_with_overflow-extract-overflow-XFAIL.opt b/test/Codegen/inst-umul_with_overflow-extract-overflow.opt similarity index 66% rename from test/Codegen/inst-umul_with_overflow-extract-overflow-XFAIL.opt rename to test/Codegen/inst-umul_with_overflow-extract-overflow.opt index 11ef0ded5..5855f89e9 100644 --- a/test/Codegen/inst-umul_with_overflow-extract-overflow-XFAIL.opt +++ b/test/Codegen/inst-umul_with_overflow-extract-overflow.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 1 ; CHECK-NEXT: ret i1 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.umul.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-umul_with_overflow-extract-val-XFAIL.opt b/test/Codegen/inst-umul_with_overflow-extract-val.opt similarity index 66% rename from test/Codegen/inst-umul_with_overflow-extract-val-XFAIL.opt rename to test/Codegen/inst-umul_with_overflow-extract-val.opt index 3a29ca915..c181e5d2b 100644 --- a/test/Codegen/inst-umul_with_overflow-extract-val-XFAIL.opt +++ b/test/Codegen/inst-umul_with_overflow-extract-val.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 0 ; CHECK-NEXT: ret i8 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.umul.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-umul_with_overflow-XFAIL.opt b/test/Codegen/inst-umul_with_overflow.opt similarity index 62% rename from test/Codegen/inst-umul_with_overflow-XFAIL.opt rename to test/Codegen/inst-umul_with_overflow.opt index c2502efd5..c16e22326 100644 --- a/test/Codegen/inst-umul_with_overflow-XFAIL.opt +++ b/test/Codegen/inst-umul_with_overflow.opt @@ -11,8 +11,3 @@ result %2 ; CHECK-NEXT: %3 = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %0, i8 %1) ; CHECK-NEXT: ret { i8, i1 } %3 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.umul.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-usub_with_overflow-extract-overflow-XFAIL.opt b/test/Codegen/inst-usub_with_overflow-extract-overflow.opt similarity index 66% rename from test/Codegen/inst-usub_with_overflow-extract-overflow-XFAIL.opt rename to test/Codegen/inst-usub_with_overflow-extract-overflow.opt index 3ac11e5de..f56c94012 100644 --- a/test/Codegen/inst-usub_with_overflow-extract-overflow-XFAIL.opt +++ b/test/Codegen/inst-usub_with_overflow-extract-overflow.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 1 ; CHECK-NEXT: ret i1 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.usub.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-usub_with_overflow-extract-val-XFAIL.opt b/test/Codegen/inst-usub_with_overflow-extract-val.opt similarity index 66% rename from test/Codegen/inst-usub_with_overflow-extract-val-XFAIL.opt rename to test/Codegen/inst-usub_with_overflow-extract-val.opt index 1de03c099..85e244d41 100644 --- a/test/Codegen/inst-usub_with_overflow-extract-val-XFAIL.opt +++ b/test/Codegen/inst-usub_with_overflow-extract-val.opt @@ -13,8 +13,3 @@ result %3 ; CHECK-NEXT: %4 = extractvalue { i8, i1 } %3, 0 ; CHECK-NEXT: ret i8 %4 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.usub.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/inst-usub_with_overflow-XFAIL.opt b/test/Codegen/inst-usub_with_overflow.opt similarity index 62% rename from test/Codegen/inst-usub_with_overflow-XFAIL.opt rename to test/Codegen/inst-usub_with_overflow.opt index 8da3dafd2..a4a225d28 100644 --- a/test/Codegen/inst-usub_with_overflow-XFAIL.opt +++ b/test/Codegen/inst-usub_with_overflow.opt @@ -11,8 +11,3 @@ result %2 ; CHECK-NEXT: %3 = call { i8, i1 } @llvm.usub.with.overflow.i8(i8 %0, i8 %1) ; CHECK-NEXT: ret { i8, i1 } %3 ; CHECK-NEXT: } - -; CHECK: ; Function Attrs: nounwind readnone speculatable -; CHECK-NEXT: declare { i8, i1 } @llvm.usub.with.overflow.i8(i8, i8) #0 - -; CHECK: attributes #0 = { nounwind readnone speculatable willreturn } diff --git a/test/Codegen/return-const-unused-arg.opt b/test/Codegen/return-const-unused-arg.opt index a32a7fddf..1cb431eae 100644 --- a/test/Codegen/return-const-unused-arg.opt +++ b/test/Codegen/return-const-unused-arg.opt @@ -6,7 +6,7 @@ result 42:i8 ; CHECK: ; ModuleID = 'souper.ll' ; CHECK-NEXT: source_filename = "souper.ll" -; CHECK: define i8 @fun(i8 %0) { +; CHECK: define i8 @fun() { ; CHECK-NEXT: entry: ; CHECK-NEXT: ret i8 42 ; CHECK-NEXT: } diff --git a/test/Codegen/return-named-arg-XFAIL.opt b/test/Codegen/return-named-arg.opt similarity index 100% rename from test/Codegen/return-named-arg-XFAIL.opt rename to test/Codegen/return-named-arg.opt diff --git a/test/Dataflow/precision-multiple.ll b/test/Dataflow/precision-multiple.ll index ad249eba9..f4ad5c725 100644 --- a/test/Dataflow/precision-multiple.ll +++ b/test/Dataflow/precision-multiple.ll @@ -1,7 +1,7 @@ ; RUN: %llvm-as -o %t %s -; RUN: %souper -infer-range -infer-non-zero %t > %t2 || true +; RUN: %souper -infer-range -infer-non-zero -souper-max-constant-synthesis-tries=60 %t > %t2 || true ; RUN: %FileCheck %s < %t2 define i8 @foo(i8 %x1, i64 %_phiinput) { diff --git a/test/Dataflow/precision-multiple.opt b/test/Dataflow/precision-multiple.opt index fdfac5c10..c792ae97d 100644 --- a/test/Dataflow/precision-multiple.opt +++ b/test/Dataflow/precision-multiple.opt @@ -1,5 +1,5 @@ -; RUN: %souper-check -infer-range -infer-non-zero %s | %FileCheck %s +; RUN: %souper-check -infer-range -infer-non-zero -souper-max-constant-synthesis-tries=60 %s | %FileCheck %s ; CHECK: nonZero from souper: true ; CHECK: range from souper: [1,0) diff --git a/test/Dataflow/precision-range-1.ll b/test/Dataflow/precision-range-1.ll index c285fb3b9..748aaab5c 100644 --- a/test/Dataflow/precision-range-1.ll +++ b/test/Dataflow/precision-range-1.ll @@ -1,7 +1,7 @@ ; RUN: %llvm-as -o %t %s -; RUN: %souper -infer-range %t > %t2 || true +; RUN: %souper -infer-range -souper-max-constant-synthesis-tries=60 %t > %t2 || true ; RUN: %FileCheck %s < %t2 define i8 @foo(i8 %x1, i64 %_phiinput) { diff --git a/test/Dataflow/precision-range-1.opt b/test/Dataflow/precision-range-1.opt index 5a97d2525..9f09ebaea 100644 --- a/test/Dataflow/precision-range-1.opt +++ b/test/Dataflow/precision-range-1.opt @@ -1,5 +1,5 @@ -; RUN: %souper-check -infer-range %s | %FileCheck %s +; RUN: %souper-check -infer-range -souper-max-constant-synthesis-tries=60 %s | %FileCheck %s ; CHECK: range from souper: [1,0) diff --git a/test/Generalize/addc.opt b/test/Generalize/addc.opt new file mode 100644 index 000000000..a14289991 --- /dev/null +++ b/test/Generalize/addc.opt @@ -0,0 +1,11 @@ +; RUN: %generalize -basic -no-width -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%x:i8 = var +%foo = add %x, 2 +%bar = sub %foo, %x +infer %bar +result 2:i8 +; CHECK:(x:i8 + C1:i8) - x +; CHECK: => +; CHECK:C1 diff --git a/test/Generalize/addshlsub.opt b/test/Generalize/addshlsub.opt new file mode 100644 index 000000000..583d36604 --- /dev/null +++ b/test/Generalize/addshlsub.opt @@ -0,0 +1,16 @@ +; RUN: %generalize -basic -no-width -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%v0:i32 = var ; v0 +%1:i32 = add 4:i32, %v0 +%2:i32 = shl %1, 2:i32 +%3:i32 = sub %2, 16:i32 +infer %3 +%4:i32 = shl %v0, 2:i32 +result %4 + +; CHECK: C3:i32 == (C1:i32 << C2:i32) +; CHECK: |= +; CHECK: ((v0:i32 + C1) << C2) - C3 +; CHECK: => +; CHECK: v0 << C2 \ No newline at end of file diff --git a/test/Generalize/addsub.opt b/test/Generalize/addsub.opt new file mode 100644 index 000000000..28c1ec4c5 --- /dev/null +++ b/test/Generalize/addsub.opt @@ -0,0 +1,12 @@ +; RUN: %generalize -basic -no-width -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%x:i8 = var +%foo = add %x, 42 +%bar = sub %foo, 31 +infer %bar +%xip = add %x, 11 +result %xip +; CHECK: (x:i8 + C1:i8) - C2:i8 +; CHECK: => +; CHECK: x + (C1 - C2) diff --git a/test/Generalize/ctpop.opt b/test/Generalize/ctpop.opt new file mode 100644 index 000000000..471aa58a2 --- /dev/null +++ b/test/Generalize/ctpop.opt @@ -0,0 +1,16 @@ +; RUN: %generalize -basic -no-width -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%v0:i8 = var ; v0 +%1:i8 = ctpop %v0 +%2:i1 = ult 7:i8, %1 +infer %2 +%3:i1 = eq 255:i8, %v0 +result %3 + +; CHECK: C1:i8 == (width(C1) - 1) +; CHECK: |= +; CHECK: C1 +; CHECK: v0 == 0xFF +; TODO: synthesize independent fn instead of 255 \ No newline at end of file diff --git a/test/Generalize/maskoff.opt b/test/Generalize/maskoff.opt new file mode 100644 index 000000000..1403806a7 --- /dev/null +++ b/test/Generalize/maskoff.opt @@ -0,0 +1,13 @@ +; RUN: %generalize -basic -no-width -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%0:i32 = var +%1:i32 = shl %0, 12:i32 +%2:i32 = lshr %1, 12:i32 +%3:i32 = and 1048575:i32, %0 +cand %2 %3 + + +; CHECK: (v0:i32 << C1:i32) >>l C1 +; CHECK: => +; CHECK: v0 & (sext(1) >>l C1) \ No newline at end of file diff --git a/test/Generalize/multoshl.opt b/test/Generalize/multoshl.opt new file mode 100644 index 000000000..458a65d71 --- /dev/null +++ b/test/Generalize/multoshl.opt @@ -0,0 +1,11 @@ +; RUN: %generalize -basic -no-width -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%x:i8 = var +%y = mul %x, 2 +%z = shl %x, 1 +cand %y %z + +; CHECK: x:i8 * C1:i8 (powerOfTwo) +; CHECK: => +; CHECK: x << logb(C1) \ No newline at end of file diff --git a/test/Generalize/nonneg.opt b/test/Generalize/nonneg.opt new file mode 100644 index 000000000..05a433bcf --- /dev/null +++ b/test/Generalize/nonneg.opt @@ -0,0 +1,13 @@ +; RUN: %generalize -basic -no-width -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%0:i4 = var +%1 = lshr %0, 1 +%2 = udiv %1, 2 +infer %2 +%3 = udiv %0, 4 +result %3 + +; CHECK: (v0:i4 >>l 1) /u C2:i4 (nonNegative) +; CHECK: => +; CHECK: v0 /u (C2 + C2) \ No newline at end of file diff --git a/test/Generalize/shiftmask.opt b/test/Generalize/shiftmask.opt new file mode 100644 index 000000000..ef463bcac --- /dev/null +++ b/test/Generalize/shiftmask.opt @@ -0,0 +1,16 @@ +; RUN: %generalize -basic -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%x:i8 = var +%y = shl %x, 2 +%z = lshr %y, 4 +%t = shl %z, 2 +infer %t +%foo = and 60, %x +result %foo + +; CHECK: C2 == (C1 * 2) +; CHECK: |= +; CHECK: ((x << C1) >>l C2) << C1 +; CHECK: => +; CHECK: x & ((0xFF << C2) >>l C1) \ No newline at end of file diff --git a/test/Generalize/shrink.opt b/test/Generalize/shrink.opt new file mode 100644 index 000000000..fe11d6870 --- /dev/null +++ b/test/Generalize/shrink.opt @@ -0,0 +1,13 @@ +; RUN: %generalize -basic -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%0:i64 = var +%1 = lshr %0, 1 +%2 = udiv %1, 100 +infer %2 +%3 = udiv %0, 200 +result %3 + +; CHECK: (v0 >>l 1) /u C2 (nonNegative) +; CHECK: => +; CHECK: v0 /u (C2 + C2) \ No newline at end of file diff --git a/test/Generalize/smax.opt b/test/Generalize/smax.opt new file mode 100644 index 000000000..cebe6e5ac --- /dev/null +++ b/test/Generalize/smax.opt @@ -0,0 +1,15 @@ +; RUN: %generalize -basic -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%v0:i8 = var ; v0 +%1:i32 = sext %v0 +%2:i1 = eq 42:i32, %1 +infer %2 +%3:i1 = eq 42:i8, %v0 +result %3 + +; CHECK: C1 <=u ((1 << (zext(width(v0)) - 1)) - 1) +; CHECK: |= +; CHECK: C1 == sext(v0) +; CHECK: => +; CHECK: v0 == trunc(C1) \ No newline at end of file diff --git a/test/Generalize/widthchange.opt b/test/Generalize/widthchange.opt new file mode 100644 index 000000000..f99e4ca09 --- /dev/null +++ b/test/Generalize/widthchange.opt @@ -0,0 +1,14 @@ +; RUN: %generalize -basic -souper-debug-level=2 %s > /dev/null 2>%t +; RUN: %FileCheck %s < %t + +%v0:i16 = var +%x:i32 = sext %v0 +%y:i32 = sub %x, 20:i32; hex 0x14 +%yt:i16 = trunc %y +infer %yt +%z = add %v0, -20:i16; hex 0xFFEC +result %z + +; CHECK: trunc((sext(v0) - C1)) +; CHECK: => +; CHECK: v0 + trunc((0 - C1)) diff --git a/test/Infer/pruning/dontprune.opt b/test/Infer/pruning/dontprune.opt new file mode 100644 index 000000000..c2010064f --- /dev/null +++ b/test/Infer/pruning/dontprune.opt @@ -0,0 +1,18 @@ +; REQUIRES: synthesis + +; RUN: %souper-check -try-dataflow-pruning %s > %t +; RUN: %FileCheck %s < %t + +; CHECK: Pruning failed + +%0 = block 2 +%1:i64 = var +%2:i64 = lshr %1, 16:i64 +%3:i64 = var +%4:i64 = lshr %3, 16:i64 +%5:i64 = phi %0, %2, %4 +%6:i16 = trunc %5 +%7:i32 = zext %6 +infer %7 (demandedBits=00000000000000001111111100000000) +%8:i32 = trunc %5 +result %8 diff --git a/test/Infer/syn-double-insts/syn-ctpop.opt b/test/Infer/syn-double-insts/syn-ctpop.opt index 3762bf0cc..eb1cf8dd5 100644 --- a/test/Infer/syn-double-insts/syn-ctpop.opt +++ b/test/Infer/syn-double-insts/syn-ctpop.opt @@ -2,15 +2,12 @@ ; RUN: %souper-check -infer-rhs -souper-enumerative-synthesis-max-instructions=2 %s > %t1 ; RUN: %FileCheck %s < %t1 -; synthesis ctpop - -; Need 2min~ +; enumerator stress test %0:i8 = var %1:i8 = add %0, 1:i8 %2:i8 = add %1, 3:i8 -%3:i8 = add %2, 5:i8 -%4:i8 = ctpop %3 -infer %4 -; CHECK: %5:i8 = add 9:i8, %0 -; CHECK: %6:i8 = ctpop %5 +%3:i8 = ctpop %2 +infer %3 +; CHECK: %4:i8 = add 4:i8, %0 +; CHECK: %5:i8 = ctpop %4 diff --git a/test/LLVM/clang1.c b/test/LLVM/clang1.c new file mode 100644 index 000000000..257b159c0 --- /dev/null +++ b/test/LLVM/clang1.c @@ -0,0 +1,19 @@ +// REQUIRES: solver + +// RUN: %clang -O2 -S -o - %s -emit-llvm | %FileCheck -check-prefix=TEST1 %s +// TEST1: %or = or i32 %b, %a + +// RUN: %clang -O2 -S -o - %s -emit-llvm -mllvm -disable-all-peepholes | %FileCheck -check-prefix=TEST2 %s +// TEST2: %xor = xor i32 %a, %b +// TEST2-NEXT: %or = or i32 %xor, %a + +// RUN: SOUPER_SOLVER=%solver SOUPER_NO_INFER=1 SOUPER_NO_EXTERNAL_CACHE=1 %sclang -O2 -S -o - %s -emit-llvm | %FileCheck -check-prefix=TEST3 %s +// TEST3: %or = or i32 %b, %a + +// RUN: LLVM_DISABLE_PEEPHOLES=1 SOUPER_SOLVER=%solver SOUPER_NO_INFER=1 SOUPER_NO_EXTERNAL_CACHE=1 %sclang -O2 -S -o - %s -emit-llvm | %FileCheck -check-prefix=TEST4 %s +// TEST4: %xor = xor i32 %a, %b +// TEST4-NEXT: %or = or i32 %xor, %a + +unsigned foo(unsigned a, unsigned b) { + return (a ^ b) | a; +} diff --git a/test/LLVM/clang2.c b/test/LLVM/clang2.c new file mode 100644 index 000000000..8525e8546 --- /dev/null +++ b/test/LLVM/clang2.c @@ -0,0 +1,19 @@ +// REQUIRES: solver + +// RUN: %clang -O2 -S -o - %s -emit-llvm | %FileCheck -check-prefix=TEST1 %s +// TEST1: and i64 %x, 1 + +// RUN: %clang -O2 -S -o - %s -emit-llvm -mllvm -disable-all-peepholes | %FileCheck -check-prefix=TEST2 %s +// TEST2: shl i64 %x, 63 +// TEST2-NEXT: lshr i64 %shl, 63 + +// RUN: SOUPER_SOLVER=%solver SOUPER_NO_INFER=1 SOUPER_NO_EXTERNAL_CACHE=1 %sclang -O2 -S -o - %s -emit-llvm | %FileCheck -check-prefix=TEST3 %s +// TEST3: and i64 %x, 1 + +// RUN: LLVM_DISABLE_PEEPHOLES=1 SOUPER_SOLVER=%solver SOUPER_NO_INFER=1 SOUPER_NO_EXTERNAL_CACHE=1 %sclang -O2 -S -o - %s -emit-llvm | %FileCheck -check-prefix=TEST4 %s +// TEST4: shl i64 %x, 63 +// TEST4-NEXT: lshr i64 %shl, 63 + +unsigned long foo(unsigned long x) { + return ((x << 63) >> 63) + 1; +} diff --git a/test/Pass/issue-821.ll b/test/Pass/issue-821.ll new file mode 100644 index 000000000..ae592e158 --- /dev/null +++ b/test/Pass/issue-821.ll @@ -0,0 +1,22 @@ +; RUN: %opt -load %pass -souper -S -o - %s 2>&1 | %FileCheck %s + +; CHECK-NOT: ptrtoint +; CHECK-NOT: trunc +; CHECK-NOT: shl +; CHECK: store + +; ModuleID = 'foo.ll' +source_filename = "sqlite3.c" +target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-apple-macosx10.15.0" + +define internal fastcc void @yy_reduce() unnamed_addr { +sw.bb115: + %sub.ptr.rhs.cast139 = ptrtoint i8* undef to i64 + %sub.ptr.sub140 = sub i64 undef, %sub.ptr.rhs.cast139 + %conv141 = trunc i64 %sub.ptr.sub140 to i32 + %bf.value145 = and i32 %conv141, 2147483647 + %bf.shl146 = shl i32 %bf.value145, 1 + store i32 %bf.shl146, i32* undef, align 8 + ret void +} diff --git a/test/Solver/alive-pc.opt b/test/Solver/alive-pc.opt new file mode 100644 index 000000000..e529d3ecd --- /dev/null +++ b/test/Solver/alive-pc.opt @@ -0,0 +1,13 @@ +; RUN: %souper-check -souper-use-alive %s > %t 2>&1 +; RUN: %FileCheck %s < %t + +%0:i64 = var +%1:i64 = udiv 6148914691236517204:i64, %0 +%2:i64 = var +%3:i1 = ule %1, %2 +pc %3 0:i1 +%4:i64 = udiv %2, 2:i64 +infer %4 +%5:i64 = ashr %2, 1:i64 +result %5 +;CHECK: LGTM diff --git a/test/Solver/div-by-zero1.ll b/test/Solver/div-by-zero1.ll index 33e5237ea..1aa56e8a6 100644 --- a/test/Solver/div-by-zero1.ll +++ b/test/Solver/div-by-zero1.ll @@ -1,7 +1,7 @@ ; RUN: %llvm-as -o %t %s -; RUN: %souper -check -souper-only-infer-iN -souper-double-check %t +; RUN: %souper -check -souper-only-infer-iN -souper-double-check -souper-shrink-consts=true %t define void @fn1() { entry: diff --git a/test/Solver/div-by-zero2.ll b/test/Solver/div-by-zero2.ll index d3227a627..538857194 100644 --- a/test/Solver/div-by-zero2.ll +++ b/test/Solver/div-by-zero2.ll @@ -1,7 +1,7 @@ ; RUN: %llvm-as -o %t %s -; RUN: %souper -check -souper-only-infer-i1 -souper-double-check %t +; RUN: %souper -check -souper-only-infer-i1 -souper-double-check -souper-shrink-consts=true %t define void @fn1() { entry: diff --git a/test/Solver/div-by-zero3.ll b/test/Solver/div-by-zero3.ll index 57f2aacb1..3c59b6573 100644 --- a/test/Solver/div-by-zero3.ll +++ b/test/Solver/div-by-zero3.ll @@ -1,7 +1,7 @@ ; RUN: %llvm-as -o %t %s -; RUN: %souper -check -souper-only-infer-i1 -souper-double-check %t +; RUN: %souper -check -souper-only-infer-i1 -souper-double-check -souper-shrink-consts=true %t define void @fn1() { entry: diff --git a/test/Unit/codegen_tests.ll b/test/Unit/codegen_tests.ll new file mode 100644 index 000000000..26dcc6a4b --- /dev/null +++ b/test/Unit/codegen_tests.ll @@ -0,0 +1 @@ +; RUN: %builddir/codegen_tests diff --git a/test/lit.cfg b/test/lit.cfg index a8878a934..c20888050 100644 --- a/test/lit.cfg +++ b/test/lit.cfg @@ -26,6 +26,7 @@ else: config.substitutions.append(('%pass', config.builddir + '/libsouperPass.so')) config.substitutions.append(('%souper', config.builddir + '/souper')) config.substitutions.append(('%souper-check', config.builddir + '/souper-check')) +config.substitutions.append(('%generalize', config.builddir + '/generalize')) config.substitutions.append(('%souper2llvm', config.builddir + '/souper2llvm')) config.substitutions.append(('%sclang', config.builddir + '/sclang')) config.substitutions.append(('%sclang\+\+', config.builddir + '/sclang++')) diff --git a/tools/generalize.cpp b/tools/generalize.cpp new file mode 100644 index 000000000..a52b66e24 --- /dev/null +++ b/tools/generalize.cpp @@ -0,0 +1,2752 @@ +#define _LIBCPP_DISABLE_DEPRECATION_WARNINGS + +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/KnownBits.h" + +#include "souper/Infer/AliveDriver.h" +#include "souper/Infer/Preconditions.h" +#include "souper/Infer/EnumerativeSynthesis.h" +#include "souper/Infer/ConstantSynthesis.h" +#include "souper/Infer/Pruning.h" +#include "souper/Infer/SynthUtils.h" +#include "souper/Inst/InstGraph.h" +#include "souper/Parser/Parser.h" +#include "souper/Generalize/Reducer.h" +#include "souper/Tool/GetSolver.h" +#include "souper/Util/DfaUtils.h" +#include +#include +#include + +using namespace llvm; +using namespace souper; + +unsigned DebugLevel; + +static cl::opt +DebugFlagParser("souper-debug-level", + cl::desc("Control the verbose level of debug output (default=1). " + "The larger the number is, the more fine-grained debug " + "information will be printed."), + cl::location(DebugLevel), cl::init(1)); + +static cl::opt +InputFilename(cl::Positional, cl::desc(""), + cl::init("-")); + +static llvm::cl::opt ReduceKBIFY("reduce-kbify", + llvm::cl::desc("Try to reduce the number of instructions by introducing known bits constraints." + "(default=false)"), + llvm::cl::init(true)); + +static llvm::cl::opt FindConstantRelations("relational", + llvm::cl::desc("Find constant relations." + "(default=true)"), + llvm::cl::init(true)); + +static llvm::cl::opt SymbolizeNumInsts("symbolize-num-insts", + llvm::cl::desc("Number of instructions to synthesize" + "(default=1)"), + llvm::cl::init(1)); + +static llvm::cl::opt SymbolizeConstSynthesis("symbolize-constant-synthesis", + llvm::cl::desc("Allow concrete constants in the generated code."), + llvm::cl::init(false)); + +static llvm::cl::opt SymbolizeHackersDelight("symbolize-bit-hacks", + llvm::cl::desc("Include bit hacks in the components."), + llvm::cl::init(true)); + +static llvm::cl::opt FixIt("fixit", + llvm::cl::desc("Given an invalid optimization, generate a valid one." + "(default=false)"), + llvm::cl::init(false)); + +static cl::opt NumResults("generalization-num-results", + cl::desc("Number of Generalization Results"), + cl::init(1)); + +static cl::opt JustReduce("just-reduce", + cl::desc("JustReduce"), + cl::init(false)); + +static cl::opt Basic("basic", + cl::desc("Run all fast techniques."), + cl::init(false)); + +static cl::opt OnlyWidth("only-width", + cl::desc("Only infer width checks, no synthesis."), + cl::init(false)); + +static cl::opt NoWidth("no-width", + cl::desc("No width independence checks."), + cl::init(false)); + + +static cl::opt Advanced("advanced", + cl::desc("Just run more advanced stuff. Assume -basic."), + cl::init(false)); + +static cl::opt SymbolicDF("symbolic-df", + cl::desc("Generalize with symbolic dataflow facts."), + cl::init(false)); + +// This can probably be done more efficiently, but likely not the bottleneck anywhere +std::vector> GetCombinations(std::vector Counts) { + if (Counts.size() == 1) { + std::vector> Result; + for (int i = 0; i < Counts[0]; ++i) { + Result.push_back({i}); + } + return Result; + } + + auto Last = Counts.back(); + Counts.pop_back(); + auto Partial = GetCombinations(Counts); + + std::vector> Result; + for (int i = 0; i < Last; ++i) { + for (auto Copy : Partial) { + Copy.push_back(i); + Result.push_back(Copy); + } + } + return Result; +} + +template +bool All(const C &c, F f) { + for (auto &&m : c) { + if (!f(m)) { + return false; + } + } + return true; +} + +size_t InferWidth(Inst::Kind K, const std::vector &Ops) { + switch (K) { + case Inst::KnownOnesP: + case Inst::KnownZerosP: + case Inst::Eq: + case Inst::Ne: + case Inst::Slt: + case Inst::Sle: + case Inst::Ult: + case Inst::Ule: return 1; + case Inst::Select: return Ops[1]->Width; + default: return Ops[0]->Width; + } +} + +struct InfixPrinter { + InfixPrinter(ParsedReplacement P_, bool ShowImplicitWidths = true) + : P(P_), ShowImplicitWidths(ShowImplicitWidths) { + varnum = 0; + std::vector NewPCs; + for (auto &&PC : P.PCs) { + countUses(PC.LHS); + countUses(PC.RHS); + if (!registerSymDFVars(PC.LHS)) { + NewPCs.push_back(PC); + } + countUses(P.Mapping.LHS); + countUses(P.Mapping.RHS); + } + P.PCs = NewPCs; + registerSymDBVar(); + registerWidthConstraints(); + } + + void registerWidthConstraints() { + for (auto &&PC : P.PCs) { + if (PC.LHS->K == Inst::Eq && PC.LHS->Ops[0]->K == Inst::BitWidth) { + // PC.LHS looks like (width %x) == 32 + WidthConstraints[PC.LHS->Ops[0]->Ops[0]] = PC.LHS->Ops[1]->Val.getZExtValue(); + } + } + } + + void registerSymDBVar() { + if (P.Mapping.LHS->K == Inst::DemandedMask) { + Syms[P.Mapping.LHS->Ops[1]] = "@db"; + assert(P.Mapping.RHS->K == Inst::DemandedMask && "Expected RHS to be a demanded mask."); + assert(P.Mapping.LHS->Ops[1] == P.Mapping.RHS->Ops[1] && "Expected same mask."); + P.Mapping.LHS = P.Mapping.LHS->Ops[0]; + P.Mapping.RHS = P.Mapping.RHS->Ops[0]; + } + } + + bool registerSymDFVars(Inst *I) { + if (I->K == Inst::KnownOnesP && I->Ops[0]->K == Inst::Var && + I->Ops[1]->Name.starts_with("symDF_K")) { + Syms[I->Ops[1]] = I->Ops[0]->Name + ".k1"; + // VisitedVars.insert(I->Ops[1]->Name); + return true; + } + if (I->K == Inst::KnownZerosP && I->Ops[0]->K == Inst::Var && + I->Ops[1]->Name.starts_with("symDF_K")) { + Syms[I->Ops[1]] = I->Ops[0]->Name + ".k0"; + // VisitedVars.insert(I->Ops[1]->Name); + return true; + } + return false; + } + + void countUses(Inst *I) { + for (auto &&Op : I->Ops) { + if (Op->K != Inst::Var && Op->K != Inst::Const) { + UseCount[Op]++; + } + countUses(Op); + } + } + + template + void operator()(Stream &S) { + if (!P.PCs.empty()) { + printPCs(S); + S << "\n |= \n"; + } + S << printInst(P.Mapping.LHS, S, true); + if (!P.Mapping.LHS->DemandedBits.isAllOnesValue()) { + S << " (" << "demandedBits=" + << Inst::getDemandedBitsString(P.Mapping.LHS->DemandedBits) + << ")"; + } + S << "\n =>\n"; + + S << printInst(P.Mapping.RHS, S, true) << "\n"; + } + + template + std::string printInst(Inst *I, Stream &S, bool Root = false) { + if (Syms.count(I)) { + return Syms[I]; + } + + std::ostringstream OS; + + if (UseCount[I] > 1) { + std::string Name = "var" + std::to_string(varnum++); + Syms[I] = Name; + OS << "let " << Name << " = "; + } + + // x ^ -1 => ~x + if (I->K == Inst::Xor && I->Ops[1]->K == Inst::Const && + I->Ops[1]->Val.isAllOnesValue()) { + return "~" + printInst(I->Ops[0], S); + } + if (I->K == Inst::Xor && I->Ops[0]->K == Inst::Const && + I->Ops[0]->Val.isAllOnesValue()) { + return "~" + printInst(I->Ops[1], S); + } + + if (I->K == Inst::Const) { + if (I->Val.ule(16)) { + return I->Val.toString(10, false); + } else { + return "0x" + I->Val.toString(16, false); + } + } else if (I->K == Inst::Var) { + auto Name = I->Name; + if (isDigit(Name[0])) { + Name = "x" + Name; + } + if (I->Name.starts_with("symconst_")) { + Name = "C" + I->Name.substr(9); + } + if (VisitedVars.count(I->Name)) { + return Name; + } else { + VisitedVars.insert(I->Name); + Inst::getKnownBitsString(I->KnownZeros, I->KnownOnes); + + std::string Buf; + llvm::raw_string_ostream Out(Buf); + + if (I->KnownZeros.getBoolValue() || I->KnownOnes.getBoolValue()) + Out << " (knownBits=" << Inst::getKnownBitsString(I->KnownZeros, I->KnownOnes) + << ")"; + if (I->NonNegative) + Out << " (nonNegative)"; + if (I->Negative) + Out << " (negative)"; + if (I->NonZero) + Out << " (nonZero)"; + if (I->PowOfTwo) + Out << " (powerOfTwo)"; + if (I->NumSignBits > 1) + Out << " (signBits=" << I->NumSignBits << ")"; + if (!I->Range.isFullSet()) + Out << " (range=[" << I->Range.getLower() + << "," << I->Range.getUpper() << "))"; + + std::string W = ShowImplicitWidths ? ":i" + std::to_string(I->Width) : ""; + + if (WidthConstraints.count(I)) { + W = ":i" + std::to_string(WidthConstraints[I]); + } + + return Name + W + Out.str(); + } + } else { + std::string Op; + switch (I->K) { + case Inst::Add: Op = "+"; break; + case Inst::AddNSW: Op = "+nsw"; break; + case Inst::AddNUW: Op = "+nuw"; break; + case Inst::AddNW: Op = "+nw"; break; + case Inst::Sub: Op = "-"; break; + case Inst::SubNSW: Op = "-nsw"; break; + case Inst::SubNUW: Op = "-nuw"; break; + case Inst::SubNW: Op = "-nw"; break; + case Inst::Mul: Op = "*"; break; + case Inst::MulNSW: Op = "*nsw"; break; + case Inst::MulNUW: Op = "*nuw"; break; + case Inst::MulNW: Op = "*nw"; break; + case Inst::UDiv: Op = "/u"; break; + case Inst::SDiv: Op = "/s"; break; + case Inst::URem: Op = "\%u"; break; + case Inst::SRem: Op = "\%s"; break; + case Inst::And: Op = "&"; break; + case Inst::Or: Op = "|"; break; + case Inst::Xor: Op = "^"; break; + case Inst::Shl: Op = "<<"; break; + case Inst::ShlNSW: Op = "<K); break; + } + + std::string Result; + + std::vector Ops = I->orderedOps(); + + if (Inst::isCommutative(I->K)) { + std::sort(Ops.begin(), Ops.end(), [](Inst *A, Inst *B) { + if (A->K == Inst::Const) { + return false; // c OP expr + } else if (B->K == Inst::Const) { + return true; // expr OP c + } else if (A->K == Inst::Var && B->K != Inst::Var) { + return true; // var OP expr + } else if (A->K != Inst::Var && B->K == Inst::Var) { + return false; // expr OP var + } else if (A->K == Inst::Var && B->K == Inst::Var) { + return A->Name > B->Name; // Tends to put vars before symconsts + } else { + return A->K < B->K; // expr OP expr + } + }); + } + + if (Ops.size() == 2) { + auto Meat = printInst(Ops[0], S) + " " + Op + " " + printInst(Ops[1], S); + Result = Root ? Meat : "(" + Meat + ")"; + } else if (Ops.size() == 1) { + Result = Op + "(" + printInst(Ops[0], S) + ")"; + } + else { + std::string Ret = Root ? "" : "("; + Ret += Op; + Ret += " "; + for (auto &&Op : Ops) { + Ret += printInst(Op, S) + " "; + } + while (Ret.back() == ' ') { + Ret.pop_back(); + } + if (!Root) { + Ret += ")"; + } + Result = Ret; + } + if (UseCount[I] > 1) { + OS << Result << ";\n"; + S << OS.str(); + return Syms[I]; + } else { + return Result; + } + } + } + + template + void printPCs(Stream &S) { + bool first = true; + for (auto &&PC : P.PCs) { + // if (PC.LHS->K == Inst::KnownOnesP || PC.LHS->K == Inst::KnownZerosP) { + // continue; + // } + if (first) { + first = false; + } else { + S << " && \n"; + } + if (PC.RHS->K == Inst::Const && PC.RHS->Val == 0) { + S << "!(" << printInst(PC.LHS, S, true) << ")"; + } else if (PC.RHS->K == Inst::Const && PC.RHS->Val == 1) { + S << printInst(PC.LHS, S, true); + } else { + S << printInst(PC.LHS, S, true) << " == " << printInst(PC.RHS, S); + } + } + } + + ParsedReplacement P; + std::set VisitedVars; + std::map Syms; + size_t varnum; + std::map UseCount; + std::map WidthConstraints; + bool ShowImplicitWidths; +}; + +using ConstMapT = std::vector>; +std::pair +AugmentForSymKBDB(ParsedReplacement Original, InstContext &IC) { + auto Input = Clone(Original, IC); + std::vector> ConstMap; + if (Input.Mapping.LHS->DemandedBits.getBitWidth() == Input.Mapping.LHS->Width && + !Input.Mapping.LHS->DemandedBits.isAllOnesValue()) { + auto DB = Input.Mapping.LHS->DemandedBits; + auto SymDFVar = IC.createVar(DB.getBitWidth(), "symDF_DB"); + // SymDFVar->Name = "symDF_DB"; + + SymDFVar->KnownOnes = llvm::APInt(DB.getBitWidth(), 0); + SymDFVar->KnownZeros = llvm::APInt(DB.getBitWidth(), 0); + // SymDFVar->Val = DB; + + Input.Mapping.LHS->DemandedBits.setAllBits(); + Input.Mapping.RHS->DemandedBits.setAllBits(); + + auto W = Input.Mapping.LHS->Width; + + Input.Mapping.LHS = IC.getInst(Inst::DemandedMask, W, {Input.Mapping.LHS, SymDFVar}); + Input.Mapping.RHS = IC.getInst(Inst::DemandedMask, W, {Input.Mapping.RHS, SymDFVar}); + + ConstMap.push_back({SymDFVar, DB}); + } + + std::vector Inputs; + findVars(Input.Mapping.LHS, Inputs); + + for (auto &&I : Inputs) { + auto Width = I->Width; + if (I->KnownZeros.getBitWidth() == I->Width && + I->KnownOnes.getBitWidth() == I->Width && + !(I->KnownZeros == 0 && I->KnownOnes == 0)) { + if (I->KnownZeros != 0) { + Inst *Zeros = IC.createVar(Width, "symDF_K0"); + + // Inst *AllOnes = IC.getConst(llvm::APInt::getAllOnesValue(Width)); + // Inst *NotZeros = IC.getInst(Inst::Xor, Width, + // {Zeros, AllOnes}); + // Inst *VarNotZero = IC.getInst(Inst::Or, Width, {I, NotZeros}); + // Inst *ZeroBits = IC.getInst(Inst::Eq, 1, {VarNotZero, NotZeros}); + Inst *ZeroBits = IC.getInst(Inst::KnownZerosP, 1, {I, Zeros}); + Input.PCs.push_back({ZeroBits, IC.getConst(llvm::APInt(1, 1))}); + ConstMap.push_back({Zeros, I->KnownZeros}); + I->KnownZeros = llvm::APInt(I->Width, 0); + } + + if (I->KnownOnes != 0) { + Inst *Ones = IC.createVar(Width, "symDF_K1"); + // Inst *VarAndOnes = IC.getInst(Inst::And, Width, {I, Ones}); + // Inst *OneBits = IC.getInst(Inst::Eq, 1, {VarAndOnes, Ones}); + Inst *OneBits = IC.getInst(Inst::KnownOnesP, 1, {I, Ones}); + Input.PCs.push_back({OneBits, IC.getConst(llvm::APInt(1, 1))}); + ConstMap.push_back({Ones, I->KnownOnes}); + I->KnownOnes = llvm::APInt(I->Width, 0); + } + } + } + + return {ConstMap, Input}; +} + +bool typeCheck(Inst *I) { + if (I->Ops.size() == 2) { + if (I->Ops[0]->Width != I->Ops[1]->Width) { + if (DebugLevel > 4) llvm::errs() << "Operands must have the same width\n"; + return false; + } + } + if (I->K == Inst::Select) { + if (I->Ops[0]->Width != 1) { + if (DebugLevel > 4) llvm::errs() << "Select condition must be 1 bit wide\n"; + return false; + } + if (I->Ops[1]->Width != I->Ops[2]->Width) { + if (DebugLevel > 4) llvm::errs() << "Select operands must have the same width\n"; + return false; + } + } + if (Inst::isCmp(I->K)) { + if (I->Width != 1) { + if (DebugLevel > 4) llvm::errs() << "Comparison must be 1 bit wide\n"; + return false; + } + } + if (I->K == Inst::Trunc) { + if (I->Ops[0]->Width <= I->Width) { + if (DebugLevel > 4) llvm::errs() << "Trunc operand must be wider than result\n"; + return false; + } + } + if (I->K == Inst::ZExt || I->K == Inst::SExt) { + if (I->Ops[0]->Width >= I->Width) { + if (DebugLevel > 4) llvm::errs() << "Ext operand must be narrower than result\n"; + return false; + } + } + return true; +} +bool typeCheck(ParsedReplacement &R) { + if (R.Mapping.LHS->Width != R.Mapping.RHS->Width) { + if (DebugLevel > 4) llvm::errs() << "LHS and RHS must have the same width\n"; + return false; + } + + if (!typeCheck(R.Mapping.LHS)) { + return false; + } + if (!typeCheck(R.Mapping.RHS)) { + return false; + } + + for (auto &&PC : R.PCs) { + if (!typeCheck(PC.LHS)) { + return false; + } + if (!typeCheck(PC.RHS)) { + return false; + } + } + return true; +} + +struct ShrinkWrap { + ShrinkWrap(InstContext &IC, Solver *S, ParsedReplacement Input, + size_t TargetWidth = 8) : IC(IC), S(S), Input(Input), + TargetWidth(TargetWidth) {} + InstContext &IC; + Solver *S; + ParsedReplacement Input; + size_t TargetWidth; + + std::map InstCache; + + Inst *ShrinkInst(Inst *I, Inst *Parent, size_t ResultWidth) { + if (InstCache.count(I)) { + return InstCache[I]; + } + if (I->K == Inst::Var) { + if (I->Width == 1) { + return I; + } + auto V = IC.createVar(ResultWidth, I->Name); + InstCache[I] = V; + return V; + } else if (I->K == Inst::Const) { + if (I->Width == 1) { + return I; + } + // Treat 0, 1, and -1 specially + if (I->Val.getLimitedValue() == 0) { + auto C = IC.getConst(APInt(ResultWidth, 0)); + InstCache[I] = C; + return C; + } else if (I->Val.getLimitedValue() == 1) { + auto C = IC.getConst(APInt(ResultWidth, 1)); + InstCache[I] = C; + return C; + } else if (I->Val.isAllOnesValue()) { + auto C = IC.getConst(APInt::getAllOnesValue(ResultWidth)); + InstCache[I] = C; + return C; + } else { + auto C = IC.createSynthesisConstant(ResultWidth, I->Val.getLimitedValue()); + InstCache[I] = C; + return C; + } + } else { + if (I->K == Inst::Trunc) { + size_t Target = 0; + // llvm::errs() << "HERE: " << I->Width << " " << I->Ops[0]->Width << '\n'; + if (I->Ops[0]->Width == I->Width + 1) { + // llvm::errs() << "a\n"; + Target = ResultWidth + 1; + } else if (I->Ops[0]->Width == 2 * I->Width) { + Target = ResultWidth * 2; + // llvm::errs() << "b\n"; + } else if (I->Width == 1 && I->Ops[0]->Width != 1) { + // llvm::errs() << "c\n"; + Target = TargetWidth; + ResultWidth = 1; + } else { + // Maintain ratio + // llvm::errs() << "d\n"; + Target = ResultWidth * I->Ops[0]->Width * 1.0 / I->Width; + } + // llvm::errs() << "HERE: " << ResultWidth << " " << Target << '\n'; + return IC.getInst(Inst::Trunc, ResultWidth, { ShrinkInst(I->Ops[0], I, Target)}); + } + if (I->K == Inst::ZExt || I->K == Inst::SExt) { + size_t Target = 0; + if (I->Ops[0]->Width == I->Width - 1) { + Target = ResultWidth - 1; + } else if (I->Ops[0]->Width == I->Width / 2) { + Target = ResultWidth / 2; + } else if (I->Ops[0]->Width == 1) { + Target = 1; + } else { + // Maintain ratio + Target = ResultWidth * I->Ops[0]->Width * 1.0 / I->Width; + } + return IC.getInst(I->K, ResultWidth, { ShrinkInst(I->Ops[0], I, Target)}); + } + + if (I->K == Inst::Eq || I->K == Inst::Ne || + I->K == Inst::Ult || I->K == Inst::Slt || + I->K == Inst::Ule || I->K == Inst::Sle || + I->K == Inst::KnownOnesP || I->K == Inst::KnownZerosP) { + ResultWidth = TargetWidth; + } + + + // print kind name + // llvm::errs() << "Par " << Inst::getKindName(I->K) << " " << I->Width << " " << ResultWidth << '\n'; + + std::map OpMap; + std::vector OriginalOps = I->Ops; + + std::sort(OriginalOps.begin(), OriginalOps.end(), [&](Inst *A, Inst *B) { + if (InstCache.find(A) != InstCache.end()) { + return true; + } + if (A->K == Inst::Const) { + return false; + } + + if (A->K == Inst::Var && !(B->K == Inst::Var || B->K == Inst::Const)) { + return false; + } + return A->Width > B->Width; + }); + + for (auto Op : OriginalOps) { + OpMap[Op] = ShrinkInst(Op, I, ResultWidth); + if (Op->Width != 1) { + ResultWidth = OpMap[Op]->Width; + } + } + + std::vector Ops; + for (auto Op : I->Ops) { + Ops.push_back(OpMap[Op]); + } + + return IC.getInst(I->K, InferWidth(I->K, Ops), Ops); + } + } + + std::optional operator()() { + + auto [CM, Aug] = AugmentForSymKBDB(Input, IC); + + if (!CM.empty()) { + Input = Aug; + } + + // Abort if inputs are of <= Target width + std::vector Inputs; + // TODO: Is there a better decision here? + findVars(Input.Mapping.LHS, Inputs); + for (auto I : Inputs) { + if (I->Width <= TargetWidth) { + return {}; + } + if (!I->Range.isFullSet()) { + return {}; + } + } + + ParsedReplacement New; + New.Mapping.LHS = ShrinkInst(Input.Mapping.LHS, nullptr, TargetWidth); + New.Mapping.RHS = ShrinkInst(Input.Mapping.RHS, nullptr, TargetWidth); + for (auto PC : Input.PCs) { + New.PCs.push_back({ShrinkInst(PC.LHS, nullptr, TargetWidth), + ShrinkInst(PC.RHS, nullptr, TargetWidth)}); + } + + // New.print(llvm::errs(), true); + if (!typeCheck(New)) { + llvm::errs() << "Type check failed\n"; + return {}; + } + return Verify(New, IC, S); + } +}; + +std::vector findConcreteConsts(const ParsedReplacement &Input) { + std::vector Consts; + auto Pred = [](Inst *I) { + return I->K == Inst::Const && I->Name.find("sym") == std::string::npos; + }; + + findInsts(Input.Mapping.LHS, Consts, Pred); + findInsts(Input.Mapping.RHS, Consts, Pred); + std::set ResultSet; // For deduplication + for (auto &&C : Consts) { + ResultSet.insert(C); + } + std::vector Result; + for (auto &&C : ResultSet) { + Result.push_back(C); + } + return Result; +} + +std::vector FilterExprsByValue(const std::vector &Exprs, + llvm::APInt TargetVal, const std::vector> &CMap) { + std::unordered_map ValueCache; + for (auto &&[I, V] : CMap) { + ValueCache[I] = EvalValue(V); + } + std::vector FilteredExprs; + ConcreteInterpreter CPos(ValueCache); + for (auto &&E : Exprs) { + auto Result = CPos.evaluateInst(E); + if (!Result.hasValue()) { + // Don't want to drop a candidate just because we couldn't evaluate it + FilteredExprs.push_back(E); + } else { + if (Result.getValue() == TargetVal) { + FilteredExprs.push_back(E); + } + } + } + return FilteredExprs; +} + +std::vector FilterRelationsByValue(const std::vector &Relations, + const std::vector> &CMap, + std::vector CEXs) { + std::unordered_map ValueCache; + for (auto &&[I, V] : CMap) { + ValueCache[I] = EvalValue(V); + } + + ConcreteInterpreter CPos(ValueCache); + std::vector CNegs; + for (auto &&CEX : CEXs) { + CNegs.push_back(CEX); + } + + std::vector FilteredRelations; + for (auto &&R : Relations) { + auto Result = CPos.evaluateInst(R); + // Positive example + if (Result.hasValue() && !Result.getValue().isAllOnesValue()) { + continue; + } + + // Negative examples + bool foundUnsound = false; + for (auto &&CNeg : CNegs) { + auto ResultNeg = CNeg.evaluateInst(R); + if (ResultNeg.hasValue() && !ResultNeg.getValue().isNullValue()) { + foundUnsound = true; + break; + } + } + if (foundUnsound) { + continue; + } + FilteredRelations.push_back(R); + } + return FilteredRelations; +} + +std::vector InferConstantLimits( + const std::vector> &CMap, + InstContext &IC, const ParsedReplacement &Input, + std::vector CEXs) { + std::vector Results; + if (!FindConstantRelations) { + return Results; + } + auto ConcreteConsts = findConcreteConsts(Input); + std::sort(ConcreteConsts.begin(), ConcreteConsts.end(), + [](auto A, auto B) { + if (A->Width == B->Width) { + return A->Val.ugt(B->Val); + } else { + return A->Width < B->Width; + } + }); + + std::vector Vars; + findVars(Input.Mapping.LHS, Vars); + + for (auto V : Vars) { + for (auto &&[XI, XC] : CMap) { + if (XI->Width == 1) { + continue; + } + // X < Width, X <= Width + auto Width = Builder(V, IC).BitWidth()(); + + if (XI->Width < V->Width) { + Width = Builder(Width, IC).Trunc(XI->Width)(); + } else if (XI->Width > V->Width) { + Width = Builder(Width, IC).ZExt(XI->Width)(); + } + + Results.push_back(Builder(XI, IC).Ult(Width)()); + Results.push_back(Builder(XI, IC).Ule(Width)()); + + // x ule UMAX + if (V->Width < XI->Width) { + auto UMax = Builder(IC, llvm::APInt(XI->Width, 1)).Shl(Width).Sub(1); + Results.push_back(Builder(XI, IC).Ule(UMax)()); + } + + // X ule SMAX + auto WM1 = Builder(Width, IC).Sub(1); + auto SMax = Builder(IC, llvm::APInt(XI->Width, 1)).Shl(WM1).Sub(1)(); + Results.push_back(Builder(XI, IC).Ule(SMax)()); + + auto gZ = Builder(XI, IC).Ugt(0)(); + Results.push_back(Builder(XI, IC).Ult(Width).And(gZ)()); + Results.push_back(Builder(XI, IC).Ule(Width).And(gZ)()); + + // 2 * X < C, 2 * X >= C + for (auto C : ConcreteConsts) { + if (C->Width != XI->Width) { + continue; + } + auto Sum = Builder(XI, IC).Add(XI)(); + Results.push_back(Builder(Sum, IC).Ult(C->Val)()); + Results.push_back(Builder(Sum, IC).Ugt(C->Val)()); + } + } + } + + + for (auto &&[XI, XC] : CMap) { + for (auto &&[YI, YC] : CMap) { + if (XI == YI) { + continue; + } + if (XI->Width != YI->Width) { + continue; + } + auto Sum = Builder(XI, IC).Add(YI)(); + // // Sum related to width + // auto Width = Builder(Sum, IC).BitWidth(); + // Results.push_back(Builder(Sum, IC).Ult(Width)()); + // Results.push_back(Builder(Sum, IC).Ule(Width)()); + // Results.push_back(Builder(Sum, IC).Eq(Width)()); + + // Sum less than const, Sum greater= than const + for (auto C : ConcreteConsts) { + if (Sum->Width != C->Width) { + continue; + } + Results.push_back(Builder(Sum, IC).Ult(C->Val)()); + Results.push_back(Builder(Sum, IC).Ugt(C->Val)()); + } + } + } + + return FilterRelationsByValue(Results, CMap, CEXs); +} + +// Enforce commutativity to prune search space +bool comm(Inst *A, Inst *B, Inst *C) { + return A > B && B > C; +} +bool comm(Inst *A, Inst *B) { + return A > B; +} + +std::vector BitFuncs(Inst *I, InstContext &IC) { + std::vector Results; + Results.push_back(Builder(I, IC).CtPop()()); + Results.push_back(Builder(I, IC).Ctlz()()); + Results.push_back(Builder(I, IC).Cttz()()); + + auto Copy = Results; + for (auto &&C : Copy) { + Results.push_back(Builder(C, IC).BitWidth().Sub(C)()); + } + + return Results; +} + +// This was originally intended to find relational constraints +// but we also use to fine some ad hoc constraints now. +// TODO: Filter relations by concrete interpretation +#define C2 comm(XI, YI) +#define C3 comm(XI, YI, ZI) + +std::vector InferPotentialRelations( + const std::vector> &CMap, + InstContext &IC, const ParsedReplacement &Input, std::vector CEXs, + bool LatticeChecks = false) { + std::vector Results; + if (!FindConstantRelations) { + return Results; + } + + + // if (DebugLevel) { + // llvm::errs() << "Symconsts for rels: " << CMap.size() << "\n"; + // } + // Triple rels + if (CMap.size() >= 3) { + for (auto &&[XI, XC] : CMap) { + for (auto &&[YI, YC] : CMap) { + for (auto &&[ZI, ZC] : CMap) { + if (XI == YI || XI == ZI || YI == ZI) { + continue; + } + if (XC.getBitWidth() != YC.getBitWidth() || + XC.getBitWidth() != ZC.getBitWidth()) { + continue; + } + + if (C3 && (XC | YC | ZC).isAllOnesValue()) { + Results.push_back(Builder(XI, IC).Or(YI).Or(ZI) + .Eq(llvm::APInt::getAllOnesValue(XI->Width))()); + } + + if (C3 && (XC & YC & ZC) == 0) { + Results.push_back(Builder(XI, IC).And(YI).And(ZI) + .Eq(llvm::APInt(XI->Width, 0))()); + } + + // TODO Make width independent by using bitwidth insts + if (C2 && (XC | YC | ~ZC).isAllOnesValue()) { + Results.push_back(Builder(XI, IC).Or(YI).Or(Builder(ZI, IC).Flip()) + .Eq(llvm::APInt::getAllOnesValue(XI->Width))()); + } + + if (XC << YC == ZC) { + Results.push_back(Builder(XI, IC).Shl(YI).Eq(ZI)()); + } + + if (XC.lshr(YC) == ZC) { + Results.push_back(Builder(XI, IC).LShr(YI).Eq(ZI)()); + } + + // if (C2 && (XC & YC).eq(ZC)) { + // Results.push_back(Builder(XI, IC).And(YI).Eq(ZI)()); + // } + + // if (C2 && (XC | YC).eq(ZC)) { + // Results.push_back(Builder(XI, IC).Or(YI).Eq(ZI)()); + // } + + // if (C2 && (XC ^ YC).eq(ZC)) { + // Results.push_back(Builder(XI, IC).Xor(YI).Eq(ZI)()); + // } + + // if (C2 && (XC != 0 && YC != 0) && (XC + YC).eq(ZC)) { + // Results.push_back(Builder(XI, IC).Add(YI).Eq(ZI)()); + // } + + } + } + } + } + + // Pairwise relations + for (auto &&[XI, XC] : CMap) { + // llvm::errs() << "HERE: " << XC << "\n"; + for (auto &&[YI, YC] : CMap) { + if (XI == YI || XC.getBitWidth() != YC.getBitWidth()) { + continue; + } + + if (~XC == YC) { + Results.push_back(Builder(XI, IC).Flip().Eq(YI)()); + } + + // if (C2 && XC == YC) { + // Results.push_back(Builder(XI, IC).Eq(YI)()); + // } + + // if ((XC & YC) == XC) { + // Results.push_back(Builder(XI, IC).And(YI).Eq(XI)()); + + // } + + // if ((XC & YC) == YC) { + // auto W = XI->Width; + // Results.push_back(IC.getInst(Inst::KnownOnesP, W, {XI, YI})); + // } + + // TODO guard + // Results.back()->Print(); + + // Results.push_back(IC.getInst(Inst::KnownZerosP, W, {XI, YI})); + + // todo knownzerosp + + // if ((XC | YC) == XC) { + // Results.push_back(Builder(XI, IC).Or(YI).Eq(XI)()); + // } + + // if ((XC | YC) == YC) { + // Results.push_back(Builder(XI, IC).Or(YI).Eq(YI)()); + // } + + // Mul C + if (C2 && YC!= 0 && XC.urem(YC) == 0) { + auto Fact = XC.udiv(YC); + if (Fact != 1 && Fact != 0) { + Results.push_back(Builder(YI, IC).Mul(Fact).Eq(XI)()); + } + } + + // Add C + // auto Diff = XC - YC; + // if (Diff != 0) { + // Results.push_back(Builder(XI, IC).Sub(Diff).Eq(YI)()); + // } + + if (C2 && XC != 0 && YC.urem(XC) == 0) { + auto Fact = YC.udiv(XC); + if (Fact != 1 && Fact != 0) { + Results.push_back(Builder(XI, IC).Mul(Fact).Eq(YI)()); + } + } + + auto One = llvm::APInt(XC.getBitWidth(), 1); + + auto GENComps = [&] (Inst *A, llvm::APInt AVal, Inst *B, llvm::APInt BVal) { + if (AVal.sle(BVal)) Results.push_back(Builder(A, IC).Sle(B)()); + if (AVal.ule(BVal)) Results.push_back(Builder(A, IC).Ule(B)()); + if (AVal.slt(BVal)) Results.push_back(Builder(A, IC).Slt(B)()); + if (AVal.ult(BVal)) Results.push_back(Builder(A, IC).Ult(B)()); + }; + + GENComps(XI, XC, YI, YC); + // GENComps(Builder(IC, One).Shl(XI)() , One.shl(XC), YI, YC); + // GENComps(XI, XC, Builder(IC, One).Shl(YI)() , One.shl(YC)); + + + // auto XBits = BitFuncs(XI, IC); + // auto YBits = BitFuncs(YI, IC); + + // for (auto &&XBit : XBits) { + // for (auto &&YBit : YBits) { + // Results.push_back(Builder(XBit, IC).Ule(YBit)()); + // Results.push_back(Builder(XBit, IC).Ult(YBit)()); + // } + // } + + // No example yet where this is useful + // for (auto &&XBit : XBits) { + // for (auto &&YBit : YBits) { + // Results.push_back(Builder(XBit, IC).Ne(YBit)()); + // Results.push_back(Builder(XBit, IC).Eq(YBit)()); + // } + // } + + } + Results.push_back(Builder(XI, IC).Eq(Builder(XI, IC).BitWidth().Sub(1))()); + // Results.push_back(Builder(XI, IC).Eq(Builder(XI, IC).BitWidth().UDiv(2))()); + // Results.push_back(Builder(XI, IC).Eq(Builder(XI, IC).BitWidth())()); + } + + // TODO: Make sure this works. + for (auto &&[XI, XC] : CMap) { + for (auto &&[YI, YC] : CMap) { + if (XI == YI || XC.getBitWidth() == YC.getBitWidth()) { + continue; + } + + // llvm::errs() << "HERE: " << XI->Name << ' ' << YI->Name << ' ' << XC.getLimitedValue() << ' ' << YC.getLimitedValue() << '\n'; + + // llvm::errs() << "HERE: " << XC.getLimitedValue() << ' ' << YC.getLimitedValue() << '\n'; + if (XC.getLimitedValue() == YC.getLimitedValue()) { + if (XI->Width > YI->Width) { + // Builder(YI, IC).ZExt(XI->Width).Eq(XI)()->Print(); + Results.push_back(Builder(YI, IC).ZExt(XI->Width).Eq(XI)()); + } else { + Results.push_back(Builder(XI, IC).ZExt(YI->Width).Eq(YI)()); + } + } + } + } + + // for (auto R : InferConstantLimits(CMap, IC, Input)) { + // Results.push_back(R); + // } + // llvm::errs() << "HERE: " << Results.size() << '\n'; + Results = FilterRelationsByValue(Results, CMap, CEXs); + + if (LatticeChecks) { + // TODO Less brute force + for (auto &&[XI, XC] : CMap) { + for (auto &&[YI, YC] : CMap) { + if (XI == YI || XC.getBitWidth() != YC.getBitWidth()) { + continue; + } + Results.push_back(IC.getInst(Inst::KnownOnesP, 1, {XI, YI})); + Results.push_back(IC.getInst(Inst::KnownZerosP, 1, {XI, YI})); + } + } + } + + return Results; +} + +std::set findConcreteConsts(Inst *I) { + std::vector Results; + std::set Ret; + auto Pred = [](Inst *I) {return I->K == Inst::Const;}; + findInsts(I, Results, Pred); + for (auto R : Results) { + Ret.insert(R); + } + return Ret; +} + +std::optional DFPreconditionsAndVerifyGreedy( + ParsedReplacement Input, InstContext &IC, Solver *S, + std::map SymCS) { + + std::map> Restore; + + size_t BitsWeakened = 0; + + auto Clone = souper::Clone(Input, IC); + + for (auto &&C : SymCS) { + if (C.first->Width < 8) continue; + Restore[C.first] = {C.first->KnownZeros, C.first->KnownOnes}; + C.first->KnownZeros = ~C.second; + C.first->KnownOnes = C.second; + } + + std::optional Ret; + auto SOLVE = [&]() -> bool { + Ret = Verify(Input, IC, S); + if (Ret) { + return true; + } else { + return false; + } + }; + + for (auto &&C : SymCS) { + if (C.first->Width < 8) continue; + for (size_t i = 0; i < C.first->Width; ++i) { + llvm::APInt OriZ = C.first->KnownZeros; + llvm::APInt OriO = C.first->KnownOnes; + + if (OriO[i] == 0 && OriZ[i] == 0) { + continue; + } + + if (OriO[i] == 1) C.first->KnownOnes.clearBit(i); + if (OriZ[i] == 1) C.first->KnownZeros.clearBit(i); + + if (!SOLVE()) { + C.first->KnownZeros = OriZ; + C.first->KnownOnes = OriO; + } else { + BitsWeakened++; + } + } + } + +// llvm::errs() << "HERE " << BitsWeakened << "\n"; + if (BitsWeakened >= 32) { // compute better threshold somehow + return Input; + } else { + for (auto &&P : Restore) { + P.first->KnownZeros = P.second.first; + P.first->KnownOnes = P.second.second; + } + return Ret; + } + +} + +std::optional SimplePreconditionsAndVerifyGreedy( + ParsedReplacement Input, InstContext &IC, + Solver *S, std::map SymCS) { + // Assume Input is not valid + std::map NonBools; + for (auto &&C : SymCS) { + if (C.first->Width != 1) { + NonBools.insert(C); + } + } + std::swap(SymCS, NonBools); + + std::optional Clone = std::nullopt; + + auto SOLVE = [&]() -> bool { + Clone = Verify(Input, IC, S); + if (Clone) { + return true; + } else { + return false; + } + }; + + std::vector Insts; + findVars(Input.Mapping.LHS, Insts); + + std::vector> Inputs; + Inputs.push_back({}); + for (auto &&P : SymCS) { + Inputs.back()[P.first] = P.second; + } + + std::map> CVals; + + for (auto &&I : Inputs) { + for (auto &&P: I) { + CVals[P.first].push_back(P.second); + } + } + +#define DF(Fact, Check) \ +if (All(CVals[C], [](auto Val) { return Check;})) { \ +C->Fact = true; auto s = SOLVE(); C->Fact = false; \ +if(s) return Clone;}; + + for (auto &&P : SymCS) { + auto C = P.first; + DF(PowOfTwo, Val.isPowerOf2()); // Invoke solver only if Val is a power of 2 + DF(NonNegative, Val.uge(0)); + DF(NonZero, Val != 0); + DF(Negative, Val.slt(0)); + } +#undef DF + + return Clone; +} + +size_t BruteForceModelCount(Inst *Pred) { + if (Pred->Width >= 8) { + llvm::errs() << "Too wide for brute force model counting.\n"; + return 0; + } + + std::vector Inputs; + findVars(Pred, Inputs); + + ValueCache Cache; + for (auto I : Inputs) { + Cache[I] = EvalValue(llvm::APInt(I->Width, 0)); + } + + auto Update = [&]() { + for (auto I : Inputs) { + if (Cache[I].getValue() == llvm::APInt(I->Width, -1)) { + continue; + } else { + Cache[I] = EvalValue(Cache[I].getValue() + 1); + return true; + } + } + return false; + }; + + size_t ModelCount = 0; + + do { + ConcreteInterpreter CI(Cache); + + if (CI.evaluateInst(Pred).hasValue() && + CI.evaluateInst(Pred).getValue().getBoolValue()) { + ++ModelCount; + } + } while (Update()); + + return ModelCount; +} + +void SortPredsByModelCount(std::vector &Preds) { + std::unordered_map ModelCounts; + for (auto P : Preds) { + ModelCounts[P] = BruteForceModelCount(P); + } + std::sort(Preds.begin(), Preds.end(), [&](Inst *A, Inst *B) { + return ModelCounts[A] > ModelCounts[B]; + }); +} + +std::optional VerifyWithRels(InstContext &IC, Solver *S, + ParsedReplacement Input, + std::vector &Rels) { + std::vector ValidRels; + + ParsedReplacement FirstValidResult = Input; + + for (auto Rel : Rels) { + Input.PCs.push_back({Rel, IC.getConst(llvm::APInt(1, 1))}); + auto Clone = Verify(Input, IC, S); + + // InfixPrinter IP(Input); + // IP(llvm::errs()); + + // llvm::errs() << "RESULT: " << Clone.has_value() << "\n"; + + if (Clone) { + ValidRels.push_back(Rel); + FirstValidResult = Clone.value(); + if (Rels.size() > 10) { + return Clone; + } + } + Input.PCs.pop_back(); + } + + if (ValidRels.empty()) { + return std::nullopt; + } + + std::vector Inputs; + findVars(Input.Mapping.LHS, Inputs); + + size_t MaxWidth = 0; + for (auto I : Inputs) { + if (I->Width > MaxWidth) { + MaxWidth = I->Width; + } + } + + if (MaxWidth > 8) { + if (DebugLevel > 4) { + llvm::errs() << "Too wide for brute force model counting.\n"; + } + // TODO: Use approximate model counting? + return FirstValidResult; + } + + SortPredsByModelCount(ValidRels); + // TODO: Construct WP + // For now, return the weakest valid result + Input.PCs.push_back({ValidRels[0], IC.getConst(llvm::APInt(1, 1))}); + return Input; +} + + +std::optional +FirstValidCombination(ParsedReplacement Input, + const std::vector &Targets, + const std::vector> &Candidates, + std::map InstCache, + InstContext &IC, Solver *S, + std::map SymCS, + bool GEN, + bool SDF, + bool DFF, + std::vector Rels = {}) { + std::vector Counts; + for (auto &&Cand : Candidates) { + Counts.push_back(Cand.size()); + } + + auto Combinations = GetCombinations(Counts); + + size_t IterLimit = 2000; + size_t CurIter = 0; + + for (auto &&Comb : Combinations) { + if (CurIter >= IterLimit) { + break; + } else { + CurIter++; + } + + static int SymExprCount = 0; + auto InstCacheRHS = InstCache; + + std::vector VarsFound; + + for (int i = 0; i < Targets.size(); ++i) { + InstCacheRHS[Targets[i]] = Candidates[i][Comb[i]]; + findVars(Candidates[i][Comb[i]], VarsFound); + if (Candidates[i][Comb[i]]->K != Inst::Var) { + Candidates[i][Comb[i]]->Name = std::string("constexpr_") + std::to_string(SymExprCount++); + } + } + + std::set SymsInCurrent; + for (auto &&V : VarsFound) { + if (V->Name.starts_with("sym")) { + SymsInCurrent.insert(V); + } + } + + std::map ReverseMap; + + for (auto &&[C, Val] : SymCS) { + if (SymsInCurrent.find(C) == SymsInCurrent.end()) { + ReverseMap[C] = Builder(IC, Val)(); + } + } + + std::optional Clone = std::nullopt; + + auto SOLVE = [&](ParsedReplacement P) -> bool { + // InfixPrinter IP(P); + // IP(llvm::errs()); + // llvm::errs() << "\n"; + + if (GEN) { + Clone = Verify(P, IC, S); + if (Clone) { + return true; + } + } + + if (!Rels.empty()) { + auto Result = VerifyWithRels(IC, S, P, Rels); + if (Result) { + Clone = *Result; + return true; + } + } + + if (SDF) { + Clone = SimplePreconditionsAndVerifyGreedy(P, IC, S, SymCS); + + if (Clone) { + return true; + } + } + + if (DFF) { + Clone = DFPreconditionsAndVerifyGreedy(P, IC, S, SymCS); + if (Clone) { + return true; + } + } + + return false; + }; + + auto Copy = Input; + Copy.Mapping.LHS = Replace(Input.Mapping.LHS, IC, InstCacheRHS); + Copy.Mapping.RHS = Replace(Input.Mapping.RHS, IC, InstCacheRHS); + + // Copy.PCs = Input.PCs; + + // Copy.print(llvm::errs(), true); + // llvm::errs() << "\n"; + + if (SOLVE(Copy)) { + return Clone; + } + + if (!ReverseMap.empty()) { + Copy.Mapping.LHS = Replace(Copy.Mapping.LHS, IC, ReverseMap); + Copy.Mapping.RHS = Replace(Copy.Mapping.RHS, IC, ReverseMap); + if (SOLVE(Copy)) { + return Clone; + } + } + + } + + return std::nullopt; +} + + + +std::vector IOSynthesize(llvm::APInt Target, +const std::vector> &ConstMap, +InstContext &IC, size_t Threshold, bool ConstMode, Inst *ParentConst = nullptr) { + + std::vector Results; + + // Handle width changes + for (const auto &[I, Val] : ConstMap) { + if (I == ParentConst) { + continue; + } + if (Target.getBitWidth() == I->Width || !Threshold ) { + continue; + } + + llvm::APInt NewTarget = Target; + if (Target.getBitWidth() < I->Width) { + NewTarget = Target.sgt(0) ? Target.zext(I->Width) : Target.sext(I->Width); + } else { + NewTarget = Target.trunc(I->Width); + } + for (auto X : IOSynthesize(NewTarget, ConstMap, IC, Threshold - 1, ConstMode, nullptr)) { + // ReplacementContext RC; + // RC.printInst(X, llvm::errs(), true); + if (NewTarget.getBitWidth() < Target.getBitWidth()) { + Results.push_back(Builder(IC, X).ZExt(Target.getBitWidth())()); + Results.push_back(Builder(IC, X).SExt(Target.getBitWidth())()); + } else { + Results.push_back(Builder(IC, X).Trunc(Target.getBitWidth())()); + } + } + } + + for (const auto &[I, Val] : ConstMap) { + if (I == ParentConst) { + continue; + } + if (I->Width != Target.getBitWidth()) { + continue; + } + if (!ConstMode) { + if (Val == Target) { + Results.push_back(I); + } else if (Val == 0 - Target) { + Results.push_back(Builder(IC, I).Negate()()); + } else if (Val == ~Target) { + Results.push_back(Builder(IC, I).Flip()()); + } + + // llvm::errs() << "Trying to synthesize " << Target << " from " << Val << "\n"; + + auto One = llvm::APInt(I->Width, 1); + // llvm::errs() << "1: " << One.shl(Val) << "\n"; + if (One.shl(Val) == Target) { + Results.push_back(Builder(IC, One).Shl(I)()); + } + auto MinusOneVal = llvm::APInt::getAllOnesValue(I->Width); + + auto OneBitOne = llvm::APInt(1, 1); + auto MinusOne = Builder(IC, OneBitOne).SExt(I->Width)(); + + // llvm::errs() << "2: " << MinusOne.shl(Val) << "\n"; + if (MinusOneVal.shl(Val) == Target) { + Results.push_back(Builder(IC, MinusOne).Shl(I)()); + } + // llvm::errs() << "3: " << MinusOne.lshr(Val) << "\n"; + if (MinusOneVal.lshr(Val) == Target) { + Results.push_back(Builder(IC, MinusOne).LShr(I)()); + } + } else { + if (ParentConst) { + Results.push_back(Builder(IC, Target)()); + } + } + } + + if (!Threshold) { + return Results; + } + + // Recursive formulation + + #define for_no_nop(X, x) \ + if (Target != x) for (auto X : \ + IOSynthesize(x, ConstMap, IC, Threshold - 1, ConstMode, ParentConst)) + + for (const auto &[I, Val] : ConstMap) { + if (I->Width != Target.getBitWidth()) { + continue; + } + + if (I == ParentConst) { + continue; + } + ParentConst = I; + + // Binary operators + + // C + X == Target + for_no_nop(X, Target - Val) { + Results.push_back(Builder(I, IC).Add(X)()); + } + + // C - X == Target + for_no_nop(X, Val - Target) { + Results.push_back(Builder(I, IC).Sub(X)()); + } + + // X - C == Target + for_no_nop(X, Target + Val) { + Results.push_back(Builder(X, IC).Sub(I)()); + } + + // C * X == Target + if (Val.isNegative() || Target.isNegative()) { + if (Val != 0 && Target.srem(Val) == 0) { + for_no_nop(X, Target.sdiv(Val)) { + Results.push_back(Builder(X, IC).Mul(I)()); + } + } + } else { + if (Val != 0 && Target.urem(Val) == 0) { + for_no_nop(X, Target.udiv(Val)) { + Results.push_back(Builder(X, IC).Mul(I)()); + } + } + } + + // C / X == Target + if (Val.isNegative() || Target.isNegative()) { + if (Target != 0 && Val.srem(Target) == 0) { + for_no_nop(X, Val.sdiv(Target)) { + Results.push_back(Builder(I, IC).SDiv(X)()); + } + } + } else { + if (Target != 0 && Val.urem(Target) == 0) { + for_no_nop(X, Val.udiv(Target)) { + Results.push_back(Builder(I, IC).UDiv(X)()); + } + } + } + + // X / C == Target + + if (Val.isNegative() || Target.isNegative()) { + if (Val != 0 && Target.srem(Val) == 0) { + for_no_nop(X, Val * Target) { + Results.push_back(Builder(X, IC).SDiv(I)()); + } + } + } else { + if (Val != 0 && Target.urem(Val) == 0) { + for_no_nop(X, Val * Target) { + Results.push_back(Builder(X, IC).UDiv(I)()); + } + } + } + + // Shifts? + + // Unary operators (no recursion required) + if (Target == Val.logBase2()) { + Results.push_back(Builder(I, IC).LogB()()); + } + + if (Target == Val.reverseBits()) { + Results.push_back(Builder(I, IC).BitReverse()()); + } + // TODO Add others + + // bit flip + llvm::APInt D = Val; + D.flipAllBits(); + if (Target == D) { + Results.push_back(Builder(I, IC).Xor(llvm::APInt::getAllOnesValue(I->Width))()); + } + + if (Target == D + 1) { + Results.push_back(Builder(I, IC).Xor(llvm::APInt::getAllOnesValue(I->Width)).Add(1)()); + } + + // neg + D = Val; + D.negate(); + if (Target == D && D != Val) { + Results.push_back(Builder(IC, llvm::APInt::getAllOnesValue(I->Width)).Sub(I)()); + } + + for (const auto &[I2, Val2] : ConstMap) { + if (I == I2 || I->Width != I2->Width || I2 == ParentConst) { + continue; + } + if ((Val & Val2) == Target && !Val.isAllOnesValue() && !Val2.isAllOnesValue()) { + Results.push_back(Builder(I, IC).And(I2)()); + } + if ((Val | Val2) == Target && Val != 0 && Val2 != 0) { + Results.push_back(Builder(I, IC).Or(I2)()); + } + if ((Val ^ Val2) == Target && Val != Target && Val2 != Target) { + Results.push_back(Builder(I, IC).Xor(I2)()); + } + } + } + + return Results; +} + +void CountUses(Inst *I, std::map &Count) { + std::vector Stack{I}; + std::set Visited; + while (!Stack.empty()) { + auto *I = Stack.back(); + Stack.pop_back(); + if (Visited.count(I)) { + continue; + } + Visited.insert(I); + for (auto *U : I->Ops) { + if (U->K == Inst::Var) { + Count[U]++; + } + Stack.push_back(U); + } + } +} + +// // Filter candidates to rule out NOPs as much as possible +// std::vector FilterCand(std::vector Cands, +// const std::vector> &ConstMap) { +// return Cands; +// std::vector Results; +// for (auto &&C : Cands) { +// std::map VarCount = CountUses(C); + +// C->Print(); +// for (auto &[I, Count] : VarCount) { +// llvm::errs() << I->Name << " " << Count << "\t"; +// } +// llvm::errs() << "\n\n"; + + +// bool hasDupe = false; +// for (auto &[_, Count] : VarCount) { +// if (Count > 4) { +// hasDupe = true; +// break; +// } +// } +// if (hasDupe) { +// continue; +// } + +// Results.push_back(C); +// } +// return Results; +// } + +std::vector> +InferSpecialConstExprsAllSym(std::vector RHS, +const std::vector> &ConstMap, + InstContext &IC, int depth = 3) { + std::vector> Results; + for (auto R : RHS) { + auto Cands = IOSynthesize(R->Val, ConstMap, IC, depth, false); + std::set Temp; + for (auto C : Cands) { + Temp.insert(C); + } + std::vector DedupedCands; + for (auto C : Temp) { + DedupedCands.push_back(C); + } + Results.push_back(DedupedCands); + std::sort(Results.back().begin(), Results.back().end(), + [](Inst *A, Inst *B) { return instCount(A) < instCount(B);}); + } + return Results; +} + +std::pair +AugmentForSymDB(ParsedReplacement Original, InstContext &IC) { + auto Input = Clone(Original, IC); + std::vector> ConstMap; + if (Input.Mapping.LHS->DemandedBits.getBitWidth() == Input.Mapping.LHS->Width && + !Input.Mapping.LHS->DemandedBits.isAllOnesValue()) { + auto DB = Input.Mapping.LHS->DemandedBits; + auto SymDFVar = IC.createVar(DB.getBitWidth(), "symDF_DB"); + // SymDFVar->Name = "symDF_DB"; + + SymDFVar->KnownOnes = llvm::APInt(DB.getBitWidth(), 0); + SymDFVar->KnownZeros = llvm::APInt(DB.getBitWidth(), 0); + // SymDFVar->Val = DB; + + Input.Mapping.LHS->DemandedBits.setAllBits(); + Input.Mapping.RHS->DemandedBits.setAllBits(); + + auto W = Input.Mapping.LHS->Width; + + Input.Mapping.LHS = IC.getInst(Inst::DemandedMask, W, {Input.Mapping.LHS, SymDFVar}); + Input.Mapping.RHS = IC.getInst(Inst::DemandedMask, W, {Input.Mapping.RHS, SymDFVar}); + + ConstMap.push_back({SymDFVar, DB}); + } + return {ConstMap, Input}; +} + +std::pair +AugmentForSymKB(ParsedReplacement Original, InstContext &IC) { + auto Input = Clone(Original, IC); + ConstMapT ConstMap; + std::vector Inputs; + findVars(Input.Mapping.LHS, Inputs); + + for (auto &&I : Inputs) { + auto Width = I->Width; + if (I->KnownZeros.getBitWidth() == I->Width && + I->KnownOnes.getBitWidth() == I->Width && + !(I->KnownZeros == 0 && I->KnownOnes == 0)) { + if (I->KnownZeros != 0) { + Inst *Zeros = IC.createVar(Width, "symDF_K0"); + + // Inst *AllOnes = IC.getConst(llvm::APInt::getAllOnesValue(Width)); + // Inst *NotZeros = IC.getInst(Inst::Xor, Width, + // {Zeros, AllOnes}); + // Inst *VarNotZero = IC.getInst(Inst::Or, Width, {I, NotZeros}); + // Inst *ZeroBits = IC.getInst(Inst::Eq, 1, {VarNotZero, NotZeros}); + Inst *ZeroBits = IC.getInst(Inst::KnownZerosP, 1, {I, Zeros}); + Input.PCs.push_back({ZeroBits, IC.getConst(llvm::APInt(1, 1))}); + ConstMap.push_back({Zeros, I->KnownZeros}); + I->KnownZeros = llvm::APInt(I->Width, 0); + } + + if (I->KnownOnes != 0) { + Inst *Ones = IC.createVar(Width, "symDF_K1"); + // Inst *VarAndOnes = IC.getInst(Inst::And, Width, {I, Ones}); + // Inst *OneBits = IC.getInst(Inst::Eq, 1, {VarAndOnes, Ones}); + Inst *OneBits = IC.getInst(Inst::KnownOnesP, 1, {I, Ones}); + Input.PCs.push_back({OneBits, IC.getConst(llvm::APInt(1, 1))}); + ConstMap.push_back({Ones, I->KnownOnes}); + I->KnownOnes = llvm::APInt(I->Width, 0); + } + } + } + return {ConstMap, Input}; +} + +// // Harvest synthesis sketch from LHS +// std::function)> GetSketch(Inst *LHS) { +// std::vector Inputs; +// findVars(LHS, Inputs); + + +// } + +std::vector> +InferSpecialConstExprsWithConcretes(std::vector RHS, +const std::vector> &ConstMap, + InstContext &IC, int depth = 3) { + std::vector> Results; + for (auto R : RHS) { + auto Cands = IOSynthesize(R->Val, ConstMap, IC, depth, true); + std::vector Filtered; + for (auto Cand : Cands) { + if (Cand->K != Inst::Const) { + Filtered.push_back(Cand); + } + } + Results.push_back(Filtered); + } + return Results; +} + +std::vector> Enumerate(std::vector RHSConsts, + std::set AtomicComps, InstContext &IC, + const std::vector> &ConstMap, + size_t NumInsts = 1) { + std::vector> Candidates; + + std::vector Components; + for (auto &&C : AtomicComps) { + Components.push_back(C); + // Components.push_back(Builder(C, IC).BSwap()()); + Components.push_back(Builder(C, IC).LogB()()); + Components.push_back(Builder(C, IC).Sub(1)()); + Components.push_back(Builder(C, IC).Xor(-1)()); + if (SymbolizeHackersDelight) { + Components.push_back(Builder(IC, llvm::APInt::getAllOnesValue(C->Width)).Shl(C)()); + Components.push_back(Builder(IC, llvm::APInt(C->Width, 1)).Shl(C)()); + Components.push_back(Builder(IC, C).BitWidth().Sub(1)()); + Components.push_back(Builder(IC, C).BitWidth().Sub(C)()); + // TODO: Add a few more, we can afford to run generalization longer + } + } + + for (auto &&Target : RHSConsts) { + std::vector CandsForTarget; + EnumerativeSynthesis ES; + auto Guesses = ES.generateExprs(IC, NumInsts, Components, + Target->Width); + for (auto &&Guess : Guesses) { + std::set ConstSet; + souper::getConstants(Guess, ConstSet); + if (!ConstSet.empty()) { + if (SymbolizeConstSynthesis) { + CandsForTarget.push_back(Guess); + } + } else { + CandsForTarget.push_back(Guess); + } + } + // Filter by value + Candidates.push_back(FilterExprsByValue(CandsForTarget, Target->Val, ConstMap)); + } + return Candidates; +} + +void findDangerousConstants(Inst *I, std::set &Results) { + std::set Visited; + std::vector Stack{I}; + while (!Stack.empty()) { + auto Cur = Stack.back(); + Stack.pop_back(); + Visited.insert(Cur); + + // if (Cur->K == Inst::Const && Cur->Val == 0) { + // // Don't try to 'generalize' zero! + // Results.insert(Cur); + // } + + if (Visited.find(Cur) == Visited.end()) { + continue; + } + for (auto Child : Cur->Ops) { + if (Cur->K == Inst::ExtractValue) { + if (Child->K == Inst::Const) { + // Constant operands of ExtractValue instructions + Results.insert(Child); + } + } + Stack.push_back(Child); + } + } +} + +// TODO: memoize +bool hasMultiArgumentPhi(Inst *I) { + if (I->K == Inst::Phi) { + return I->Ops.size() > 1; + } + for (auto Op : I->Ops) { + if (hasMultiArgumentPhi(Op)) { + return true; + } + } + return false; +} + +ParsedReplacement ReduceBasic(InstContext &IC, + Solver *S, ParsedReplacement Input) { + static Reducer R(IC, S); + Input = R.ReducePCs(Input); + Input = R.ReduceRedundantPhis(Input); + Input = R.ReduceGreedy(Input); + Input = R.ReducePairsGreedy(Input); + Input = R.ReduceTriplesGreedy(Input); + Input = R.WeakenKB(Input); + Input = R.WeakenCR(Input); + Input = R.WeakenDB(Input); + Input = R.WeakenOther(Input); + if (ReduceKBIFY) { + Input = R.ReduceGreedyKBIFY(Input); + } + Input = R.ReducePCs(Input); + Input = R.ReducePCsToDF(Input); + Input = R.ReducePoison(Input); + return Input; +} + +ParsedReplacement DeAugment(InstContext &IC, + Solver *S, ParsedReplacement Augmented) { + auto Result = ReduceBasic(IC, S, Augmented); + Inst *SymDBVar = nullptr; + if (Result.Mapping.LHS->K == Inst::DemandedMask) { + SymDBVar = Result.Mapping.LHS->Ops[1]; + } + + if (!SymDBVar) { + return Result; + } + + std::map LHSCount, RHSCount; + CountUses(Result.Mapping.LHS, LHSCount); + for (auto M : Result.PCs) { + CountUses(M.LHS, LHSCount); + CountUses(M.RHS, LHSCount); + } + CountUses(Result.Mapping.RHS, RHSCount); + + + + + if (LHSCount[SymDBVar] == 1 && RHSCount[SymDBVar] == 1) { + // We can remove the SymDBVar + Result.Mapping.LHS = Result.Mapping.LHS->Ops[0]; + Result.Mapping.RHS = Result.Mapping.RHS->Ops[0]; + return Result; + } else { + return Result; + } +} + +// Assuming the input has leaves pruned and preconditions weakened +std::optional SuccessiveSymbolize(InstContext &IC, + Solver *S, ParsedReplacement Input, bool &Changed, + std::vector> ConstMap = {}) { + + // Print first successful result and exit, no result sorting. + // Prelude + bool Nested = !ConstMap.empty(); + auto Original = Input; + + auto Fresh = Input; + size_t ticks = std::clock(); + auto Refresh = [&] (auto Msg) { + // Input = Clone(Fresh, IC); + Input = Fresh; + if (DebugLevel > 2) { + auto now = std::clock(); + llvm::errs() << "POST " << Msg << " - " << (now - ticks)*1000/CLOCKS_PER_SEC << " ms\n"; + ticks = now; + } + Changed = true; + }; + + auto LHSConsts = findConcreteConsts(Input.Mapping.LHS); + + auto RHSConsts = findConcreteConsts(Input.Mapping.RHS); + + std::set ConstsBlackList; + findDangerousConstants(Input.Mapping.LHS, ConstsBlackList); + findDangerousConstants(Input.Mapping.RHS, ConstsBlackList); + + for (auto &&C : ConstsBlackList) { + LHSConsts.erase(C); + RHSConsts.erase(C); + } + + ParsedReplacement Result = Input; + + std::map SymConstMap; + + std::map InstCache; + + std::map SymCS; + + static int i = 1; + for (auto I : LHSConsts) { + auto Name = "symconst_" + std::to_string(i++); + SymConstMap[I] = IC.createVar(I->Width, Name); + + // llvm::errs() << "HERE : " << Name << '\t' << SymConstMap[I]->Name << "\n"; + + InstCache[I] = SymConstMap[I]; + SymCS[SymConstMap[I]] = I->Val; + } + for (auto I : RHSConsts) { + if (SymConstMap.find(I) != SymConstMap.end()) { + continue; + } + auto Name = "symconst_" + std::to_string(i++); + SymConstMap[I] = IC.createVar(I->Width, Name); + InstCache[I] = SymConstMap[I]; +// SymCS[SymConstMap[I]] = I->Val; + } + + std::vector RHSFresh; // RHSConsts - LHSConsts + + for (auto C : RHSConsts) { + if (LHSConsts.find(C) == LHSConsts.end()) { + RHSFresh.push_back(C); + } + } + + Refresh("Prelude"); + // Step 1 : Just direct symbolize for common consts, no constraints + + std::map CommonConsts; + for (auto C : LHSConsts) { + CommonConsts[C] = SymConstMap[C]; + + // llvm::errs() << "Common Const: " << C->Val << "\t" << SymConstMap[C]->Name << "\n"; + + } + if (!CommonConsts.empty()) { + Result = Replace(Result, IC, CommonConsts); + auto Clone = Verify(Result, IC, S); + if (Clone) { + return Clone; + } + + Clone = SimplePreconditionsAndVerifyGreedy(Result, IC, S, SymCS); + if (Clone) { + return Clone; + } + +// Clone = DFPreconditionsAndVerifyGreedy(Result, IC, S, SymCS); +// if (Clone.Mapping.LHS && Clone.Mapping.RHS) { +// return Clone; +// } + + } + + Refresh("Direct Symbolize for common consts"); + + // Step 1.5 : Direct symbolize, simple rel constraints on LHS + + for (auto &&C : LHSConsts) { + ConstMap.push_back({SymConstMap[C], C->Val}); + } + auto CounterExamples = GetMultipleCEX(Result, IC, S, 3); + if (Nested) { + CounterExamples = {}; + // FIXME : Figure out how to get CEX for symbolic dataflow + } + auto Relations = InferPotentialRelations(ConstMap, IC, Input, CounterExamples, Nested); + + std::map JustLHSSymConstMap; + + for (auto &&C : LHSConsts) { + JustLHSSymConstMap[C] = SymConstMap[C]; + } + + auto Copy = Replace(Input, IC, JustLHSSymConstMap); + // for (auto &&R : Relations) { + // Copy.PCs.push_back({R, IC.getConst(llvm::APInt(1, 1))}); + // // Copy.print(llvm::errs(), true); + // auto Clone = Verify(Copy, IC, S); + // if (Clone.Mapping.LHS && Clone.Mapping.RHS) { + // return Clone; + // } + // Copy.PCs.pop_back(); + // } + + // llvm::errs() << "Relations : " << Relations.size() << "\n"; + + if (auto RelV = VerifyWithRels(IC, S, Copy, Relations)) { + return RelV; + } + + Refresh("Direct + simple rel constraints"); + + // Step 2 : Symbolize LHS Consts with KB, CR, SimpleDF constrains + if (RHSFresh.empty()) { + auto Copy = Replace(Input, IC, JustLHSSymConstMap); + + auto Clone = SimplePreconditionsAndVerifyGreedy(Copy, IC, S, SymCS); + if (Clone) { + return Clone; + } + + Refresh("LHS Constraints"); + + Clone = DFPreconditionsAndVerifyGreedy(Copy, IC, S, SymCS); + if (Clone) { + return Clone; + } + } + + Refresh("All LHS Constraints"); + + auto ConstantLimits = InferConstantLimits(ConstMap, IC, Input, CounterExamples); + + // Step 3 : Special RHS constant exprs, no constants + + if (!RHSFresh.empty()) { + + std::vector> UnitaryCandidates = + InferSpecialConstExprsAllSym(RHSFresh, ConstMap, IC, /*depth*/0); + + // llvm::errs() << "Unitary candidates: " << UnitaryCandidates[0].size() << "\n"; + + if (!UnitaryCandidates.empty()) { + // if (Nested && DebugLevel > 4) { + // llvm::errs() << "Rels " << Relations.size() << "\n"; + // llvm::errs() << "Unitary candidates: " << UnitaryCandidates[0].size() << "\n"; + // llvm::errs() << "FOO: " << UnitaryCandidates[0][0]->Name << "\n"; + // } + + auto Clone = FirstValidCombination(Input, RHSFresh, UnitaryCandidates, + InstCache, IC, S, SymCS, true, false, false, Relations); + if (Clone) { + return Clone; + } + Refresh("Unitary cands, rel constraints"); + } + + std::vector> SimpleCandidates = + InferSpecialConstExprsAllSym(RHSFresh, ConstMap, IC, /*depth=*/ 2); + + if (!SimpleCandidates.empty()) { + if (DebugLevel > 4) { + llvm::errs() << "InferSpecialConstExprsAllSym candidates: " << SimpleCandidates[0].size() << " x " << ConstantLimits.size() << "\n"; + } + auto Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidates, + InstCache, IC, S, SymCS, + true, false, false); + if (Clone) { + return Clone; + } + } + Refresh("Special expressions, no constants"); + + // Step 4 : Enumerated expressions + + std::set Components; + for (auto C : ConstMap) { + Components.insert(C.first); + } + + auto EnumeratedCandidates = Enumerate(RHSFresh, Components, IC, ConstMap); + // if (DebugLevel > 4) { + // llvm::errs() << "RHSFresh: " << RHSFresh.size() << "\n"; + // llvm::errs() << "Components: " << Components.size() << "\n"; + // llvm::errs() << "EnumeratedCandidates: " << EnumeratedCandidates[0].size() << "\n"; + // } + + if (!EnumeratedCandidates.empty()) { + auto Clone = FirstValidCombination(Input, RHSFresh, EnumeratedCandidates, + InstCache, IC, S, SymCS, true, false, false); + if (Clone) { + return Clone; + } + Refresh("Enumerated cands, no constraints"); + } + + + // Step 4.75 : Enumerate 2 instructions when single RHS Constant. + std::vector> EnumeratedCandidatesTwoInsts; + if (RHSFresh.size() == 1) { + EnumeratedCandidatesTwoInsts = Enumerate(RHSFresh, Components, IC, ConstMap, 2); + + // llvm::errs() << "Guesses: " << EnumeratedCandidatesTwoInsts[0].size() << "\n"; + + auto Clone = FirstValidCombination(Input, RHSFresh, EnumeratedCandidatesTwoInsts, + InstCache, IC, S, SymCS, true, false, false); + if (Clone) { + return Clone; + } + } + Refresh("Enumerated 2 insts for single RHS const cases"); + + if (!EnumeratedCandidates.empty()) { + auto Clone = FirstValidCombination(Input, RHSFresh, EnumeratedCandidates, + InstCache, IC, S, SymCS, false, true, true); + if (Clone) { + return Clone; + } + + // Enumerated Expressions with some relational constraints + if (ConstMap.size() == 2) { + // llvm::errs() << "Relations: " << Relations.size() << "\n"; + // llvm::errs() << "Guesses: " << EnumeratedCandidates[0].size() << "\n"; + + auto Clone = FirstValidCombination(Input, RHSFresh, EnumeratedCandidates, + InstCache, IC, S, SymCS, true, false, false, Relations); + if (Clone) { + return Clone; + } + } + Refresh("Relational constraints for enumerated cands."); + + } + Refresh("Enumerated exprs with constraints"); + + if (RHSFresh.size() == 1 && !Nested) { + // Enumerated Expressions with some relational constraints + if (ConstMap.size() == 2) { + // llvm::errs() << "Enum2 : " << EnumeratedCandidatesTwoInsts.back().size() + // << "\tRels: " << Relations.size() << "\n"; + auto Clone = FirstValidCombination(Input, RHSFresh, EnumeratedCandidatesTwoInsts, + InstCache, IC, S, SymCS, true, false, false, Relations); + if (Clone) { + return Clone; + } + } + } + Refresh("Enumerated 2 insts exprs with relations"); + + // Step 4.8 : Special RHS constant exprs, with constants + + std::vector> SimpleCandidatesWithConsts = + InferSpecialConstExprsWithConcretes(RHSFresh, ConstMap, IC, /*depth=*/ 2); + + if (!SimpleCandidatesWithConsts.empty() && !Nested) { + auto Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidatesWithConsts, + InstCache, IC, S, SymCS, + true, false, false); + if (Clone) { + return Clone; + } + } + + Refresh("Special expressions, with constants"); + + // Enumerated exprs with constraints + + if (!EnumeratedCandidates.empty() && !Nested) { + auto Clone = FirstValidCombination(Input, RHSFresh, EnumeratedCandidates, + InstCache, IC, S, SymCS, true, true, true, Relations); + if (Clone) { + return Clone; + } + Refresh("Enumerated exprs with constraints and relations"); + } + + // Step 5 : Simple exprs with constraints + + if (!SimpleCandidates.empty()) { + auto Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidates, + InstCache, IC, S, SymCS, false, true, true); + if (Clone) { + return Clone; + } + Refresh("Simple cands with constraints"); + + Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidates, + InstCache, IC, S, SymCS, true, false, false, Relations); + if (Clone) { + return Clone; + } + Refresh("Simple cands with constraints and relations"); + } + + // Step 5.5 : Simple exprs with constraints + + if (!SimpleCandidatesWithConsts.empty() && !Nested) { + auto Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidatesWithConsts, + InstCache, IC, S, SymCS, false, true, true); + if (Clone) { + return Clone; + } + Refresh("Simple cands+consts with constraints"); + + Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidatesWithConsts, + InstCache, IC, S, SymCS, true, true, true, Relations); + if (Clone) { + return Clone; + } + + Refresh("Simple cands+consts with constraints and relations"); + } + + // { + // if (!RHSFresh.empty()) { + // std::vector> SimpleCandidatesMoreInsts = + // InferSpecialConstExprsAllSym(RHSFresh, ConstMap, IC, /*depth =*/ 5); + + // if (!SimpleCandidates.empty()) { + // auto Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidatesMoreInsts, + // InstCache, IC, S, SymCS, + // true, false, false); + // if (Clone.Mapping.LHS && Clone.Mapping.RHS) { + // return Clone; + // } + // } + + // Refresh("Special expressions, no constants"); + // } + // } + + if (!EnumeratedCandidates.empty()) { + auto Clone = FirstValidCombination(Input, RHSFresh, EnumeratedCandidates, + InstCache, IC, S, SymCS, true, true, false, ConstantLimits); + if (Clone) { + return Clone; + } + Refresh("Enumerated expressions+consts and constant limits"); + } + + if (!SimpleCandidates.empty()) { + auto Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidates, + InstCache, IC, S, SymCS, true, false, false, ConstantLimits); + if (Clone) { + return Clone; + } + } + if (!SimpleCandidatesWithConsts.empty()) { + auto Clone = FirstValidCombination(Input, RHSFresh, SimpleCandidatesWithConsts, + InstCache, IC, S, SymCS, true, false, false, ConstantLimits); + if (Clone) { + return Clone; + } + Refresh("Simple expressions+consts and constant limits"); + } + + } + + { + auto Copy = Replace(Input, IC, JustLHSSymConstMap); + if (auto VRel = VerifyWithRels(IC, S, Copy, ConstantLimits)) { + return VRel.value(); + } + Refresh("Constant limit constraints on LHS"); + } + + Refresh("END"); + Changed = false; + return Input; +} + +Inst *CloneInst(InstContext &IC, Inst *I, std::map &Vars) { + if (I->K == Inst::Var) { + return Vars[I]; + } else if (I->K == Inst::Const) { + // llvm_unreachable("Const"); + auto Goal = Vars.begin()->second->Width; // TODO Infer. + auto NewVal = I->Val.isSignBitSet() ? I->Val.sextOrTrunc(Goal) : I->Val.zextOrTrunc(Goal); + return IC.getConst(NewVal); + } else { + std::vector Ops; + for (auto Op : I->Ops) { + Ops.push_back(CloneInst(IC, Op, Vars)); + } + return IC.getInst(I->K, InferWidth(I->K, Ops), Ops); + } +} + +InstMapping GetEqWidthConstraint(Inst *I, size_t Width, InstContext &IC) { + return {Builder(I, IC).BitWidth().Eq(Width)(), IC.getConst(llvm::APInt(1, 1))}; +} + +InstMapping GetLessThanWidthConstraint(Inst *I, size_t Width, InstContext &IC) { + // Don't need to check for >0. + return {Builder(I, IC).BitWidth().Ule(Width)(), IC.getConst(llvm::APInt(1, 1))}; +} + +InstMapping GetWidthRangeConstraint(Inst *I, size_t Min, size_t Max, InstContext &IC) { + auto Right = Builder(I, IC).BitWidth().Ule(Max); + auto Left = Builder(IC, llvm::APInt(I->Width, Min)).BitWidth().Ule(Builder(I, IC).BitWidth()); + return {Left.And(Right)(), IC.getConst(llvm::APInt(1, 1))}; +} + +// TODO: More as needed. + +Inst *CombinePCs(const std::vector &PCs, InstContext &IC) { + Inst *Ante = IC.getConst(llvm::APInt(1, true)); + for (auto PC : PCs ) { + Inst *Eq = IC.getInst(Inst::Eq, 1, {PC.LHS, PC.RHS}); + Ante = IC.getInst(Inst::And, 1, {Ante, Eq}); + } + return Ante; +} + +bool IsStaticallyWidthIndependent(ParsedReplacement Input) { + + if (Input.Mapping.LHS->Width == 1) { + return false; + } + + std::vector Consts; + auto Pred = [](Inst *I) {return I->K == Inst::Const;}; + findInsts(Input.Mapping.LHS, Consts, Pred); + findInsts(Input.Mapping.RHS, Consts, Pred); + for (auto M : Input.PCs) { + findInsts(M.LHS, Consts, Pred); + findInsts(M.RHS, Consts, Pred); + } + + std::vector WidthChanges; + auto WPred = [](Inst *I) {return I->K == Inst::Trunc || I->K == Inst::SExt + || I->K == Inst::ZExt;}; + + findInsts(Input.Mapping.LHS, WidthChanges, WPred); + findInsts(Input.Mapping.RHS, WidthChanges, WPred); + for (auto M : Input.PCs) { + findInsts(M.LHS, WidthChanges, WPred); + findInsts(M.RHS, WidthChanges, WPred); + } + + if (!WidthChanges.empty()) { + return false; + } + + // False if non zero or non -1 const + for (auto &&C : Consts) { + if (C->K == Inst::Const && (C->Val != 0 || !C->Val.isAllOnesValue())) { + return false; + } + } + + // TODO Set up constant synthesis problem to see if subexpressions + // simplify to non zero consts + + return true; +} + +std::pair +InstantiateWidthChecks(InstContext &IC, + Solver *S, ParsedReplacement Input) { + + if (IsStaticallyWidthIndependent(Input)) { + return {Input, true}; + } + + if (!NoWidth && !hasMultiArgumentPhi(Input.Mapping.LHS)) { + // Instantiate Alive driver with Symbolic width. + AliveDriver Alive(Input.Mapping.LHS, + Input.PCs.empty() ? nullptr : CombinePCs(Input.PCs, IC), + IC, {}, true); + + // Find set of valid widths. + if (Alive.verify(Input.Mapping.RHS)) { + if (DebugLevel > 4) { + llvm::errs() << "WIDTH: Generalized opt is valid for all widths.\n"; + } + // Completely width independent. No width checks needed. + return {Input, true}; + } + + auto &&ValidTypings = Alive.getValidTypings(); + + if (ValidTypings.empty()) { + // Something went wrong, generalized opt is not valid at any width. + if (DebugLevel > 4) { + llvm::errs() << "WIDTH: Generalized opt is not valid for any width.\n"; + } + Input.Mapping.LHS = nullptr; + Input.Mapping.RHS = nullptr; + return {Input, false}; + } + + // Abstract width to a range or relational precondition + // TODO: Abstraction + + std::vector Inputs; + findVars(Input.Mapping.LHS, Inputs); + if (Inputs.size() == 1 && ValidTypings.size() > 1) { + auto I = Inputs[0]; + auto Width = I->Width; + + std::vector Widths; + for (auto &&V : ValidTypings) { + Widths.push_back(V[I]); + } + + size_t MaxWidth = *std::max_element(Widths.begin(), Widths.end()); + size_t MinWidth = *std::min_element(Widths.begin(), Widths.end()); + + if (ValidTypings.size() == (MaxWidth - MinWidth + 1)) { + Input.PCs.push_back(GetWidthRangeConstraint(I, MinWidth, MaxWidth, IC)); + return {Input, false}; + } + } + + } + + // If abstraction fails, insert checks for existing widths. + std::vector Inputs; + findVars(Input.Mapping.LHS, Inputs); + for (auto &&I : Inputs) { + Input.PCs.push_back(GetEqWidthConstraint(I, I->Width, IC)); + } + return {Input, false}; +} + +std::optional GeneralizeShrinked( + ParsedReplacement Input, InstContext &IC, Solver *S) { + + if (hasMultiArgumentPhi(Input.Mapping.LHS)) { + return std::nullopt; + } + + ShrinkWrap Shrink(IC, S, Input, 8); + auto Smol = Shrink(); + + if (Smol) { + if (DebugLevel > 2) { + llvm::errs() << "Shrinked: \n"; + InfixPrinter P(Smol.value()); + P(llvm::errs()); + Smol->print(llvm::errs(), true); + llvm::errs() << "\n"; + if (DebugLevel > 4) { + Smol.value().print(llvm::errs(), true); + } + } + Input = Smol.value(); + + // Input.print(llvm::errs(), true); + + } else { + if (DebugLevel > 2) { + llvm::errs() << "Shrinking failed\n"; + } + return std::nullopt; + } + + bool Changed = false; + + auto Gen = SuccessiveSymbolize(IC, S, Smol.value(), Changed); + + if (!Changed || !Gen) { + if (DebugLevel > 2) { + llvm::errs() << "Shrinking failed\n"; + } + return std::nullopt; // Generalization failed. + } + + auto [GenWidth, WidthChanged] = InstantiateWidthChecks(IC, S, Gen.value()); + + if (!WidthChanged) { + return std::nullopt; // Width independence check failed. + } + return Gen; +} + +template +Stream &operator<<(Stream &S, InfixPrinter IP) { + IP(S); + return S; +} + +void PrintInputAndResult(ParsedReplacement Input, ParsedReplacement Result) { + ReplacementContext RC; + Result.printLHS(llvm::outs(), RC, true); + Result.printRHS(llvm::outs(), RC, true); + llvm::outs() << "\n"; + + if (DebugLevel > 1) { + llvm::errs() << "IR Input: \n"; + ReplacementContext RC; + Input.printLHS(llvm::errs(), RC, true); + Input.printRHS(llvm::errs(), RC, true); + llvm::errs() << "\n"; + llvm::errs() << "\n\tInput (profit=" << profit(Input) << "):\n\n" + << InfixPrinter(Input) + << "\n\tGeneralized (profit=" << profit(Result) << "):\n\n" + << InfixPrinter(Result, NoWidth) << "\n"; + // Result.print(llvm::errs(), true); + } + llvm::outs().flush(); +} + +int main(int argc, char **argv) { + cl::ParseCommandLineOptions(argc, argv); + KVStore *KV = 0; + + std::unique_ptr S = 0; + S = GetSolver(KV); + + auto MB = MemoryBuffer::getFileOrSTDIN(InputFilename); + if (!MB) { + llvm::errs() << MB.getError().message() << '\n'; + return 1; + } + + InstContext IC; + std::string ErrStr; + + auto &&Data = (*MB)->getMemBufferRef(); + auto Inputs = ParseReplacements(IC, Data.getBufferIdentifier(), + Data.getBuffer(), ErrStr); + + if (!ErrStr.empty()) { + llvm::errs() << ErrStr << '\n'; + return 1; + } + + // TODO: Write default action which chooses what to do based on input structure + + for (auto &&Input: Inputs) { + if (Input.Mapping.LHS == Input.Mapping.RHS) { + if (DebugLevel > 2) llvm::errs() << "Input == Output\n"; + continue; + } else if (profit(Input) < 0) { + if (DebugLevel > 2) llvm::errs() << "Not an optimization\n"; + continue; + } + if (Basic) { + ParsedReplacement Result = ReduceBasic(IC, S.get(), Input); + if (!JustReduce) { + + bool Changed = false; + size_t MaxTries = 1; // Increase this if we ever run with 10/100x timeout. + bool FirstTime = true; + do { + if (!OnlyWidth) { + if (Changed) { + Result = ReduceBasic(IC, S.get(), Result); + } + + std::optional Opt; + if (!NoWidth) { + Opt = GeneralizeShrinked(Result, IC, S.get()); + } + + if (!Opt) { + Opt = SuccessiveSymbolize(IC, S.get(), Result, Changed); + } else { + Changed = true; + } + + if (Opt) { + Result = *Opt; + } + + if (SymbolicDF) { + // Refresh("PUSH SYMDF_KB_DB"); + auto [CM, Aug] = AugmentForSymKBDB(Input, IC); + // auto [CM2, Aug2] = AugmentForSymKB(Aug1, IC); + if (!CM.empty()) { + bool SymDFChanged = false; + + // auto Clone = Verify(Aug, IC, S); + // if (Clone) { + // // Symbolic db+kb can be unconstrained + // // Is this actually possible in practice? + // return Clone; + // } + + // Aug.print(llvm::errs(), true); + + // llvm::errs() << "\n\n"; + + auto Generalized = SuccessiveSymbolize(IC, S.get(), Aug, SymDFChanged, CM); + if (Generalized) { + Result = DeAugment(IC, S.get(), Generalized.value()); + Changed = true; + } + } + // Refresh("POP SYMDF_KB_DB"); + } + + } + bool Indep = false; + if (!NoWidth) { + std::tie(Result, Indep) = InstantiateWidthChecks(IC, S.get(), Result); + } +// Result.print(llvm::errs(), true); + if (!Result.Mapping.LHS && !NoWidth) { + Result = Input; + if (MaxTries == 1) MaxTries++; + NoWidth = true; + continue; // Retry with no width checks + } + + if (!Indep && Result.Mapping.LHS && !NoWidth) { + if (MaxTries == 1) MaxTries++; + NoWidth = true; + PrintInputAndResult(Input, Result); + Result = Input; + continue; // Retry with no width checks + } + // Result = DeAugment(IC, S.get(), Result); + + if ((Changed || FirstTime) && Result.Mapping.LHS && Result.Mapping.RHS) { + PrintInputAndResult(Input, Result); + } + if (FirstTime) FirstTime = false; + } while (--MaxTries && Changed); + } else { + if (Result.Mapping.LHS && Result.Mapping.RHS) { + PrintInputAndResult(Input, Result); + } + } + } + } + return 0; +} diff --git a/tools/matcher-gen.cpp b/tools/matcher-gen.cpp new file mode 100644 index 000000000..e13f9d0e6 --- /dev/null +++ b/tools/matcher-gen.cpp @@ -0,0 +1,1272 @@ +#include "llvm/Support/MemoryBuffer.h" + +#include "souper/Infer/Preconditions.h" +#include "souper/Infer/EnumerativeSynthesis.h" +#include "souper/Infer/SynthUtils.h" +#include "souper/Parser/Parser.h" +#include "souper/Tool/GetSolver.h" + +#include + +using namespace llvm; +using namespace souper; + +unsigned DebugLevel; + +static cl::opt +DebugFlagParser("souper-debug-level", + cl::desc("Control the verbose level of debug output (default=1). " + "The larger the number is, the more fine-grained debug " + "information will be printed."), + cl::location(DebugLevel), cl::init(1)); + +static cl::opt +InputFilename(cl::Positional, cl::desc(""), + cl::init("-")); + +static llvm::cl::opt IgnorePCs("ignore-pcs", + llvm::cl::desc("Ignore inputs which have souper path conditions." + "(default=false)"), + llvm::cl::init(false)); + +static llvm::cl::opt IgnoreDF("ignore-df", + llvm::cl::desc("Ignore inputs with dataflow constraints." + "(default=false)"), + llvm::cl::init(false)); + +static llvm::cl::opt NoDispatch("no-dispatch", + llvm::cl::desc("Do not generate code to dispatch on root instruction kind." + "(default=false)"), + llvm::cl::init(false)); + +static llvm::cl::opt OnlyExplicitWidths("explicit-width-checks", + llvm::cl::desc("Only generate width checks when explicitly specified." + "(default=false)"), + llvm::cl::init(false)); + +static llvm::cl::opt Sort("sortf", + llvm::cl::desc("Sort matchers according to listfile" + "(default=false)"), + llvm::cl::init(false)); + +static llvm::cl::opt ListFile("listfile", + llvm::cl::desc("List of optimization indexes to include.\n" + "(default=empty-string)"), + llvm::cl::init("")); + + +static const std::map MatchOps = { + {Inst::Add, "m_c_Add("}, {Inst::Sub, "m_Sub("}, + {Inst::Mul, "m_c_Mul("}, + + {Inst::Shl, "m_Shl("}, {Inst::LShr, "m_LShr("}, + {Inst::AShr, "m_AShr("}, + + {Inst::AddNSW, "m_NSWAdd("}, {Inst::SubNSW, "m_NSWSub("}, // add _c_ too? + {Inst::MulNSW, "m_NSWMul("}, {Inst::ShlNSW, "m_NSWShl("}, + {Inst::AddNUW, "m_NUWAdd("}, {Inst::SubNUW, "m_NUWSub("}, + {Inst::MulNUW, "m_NUWMul("}, {Inst::ShlNUW, "m_NUWShl("}, + {Inst::AddNW, "m_NWAdd("}, {Inst::SubNW, "m_NWSub("}, + {Inst::MulNW, "m_NWMul("}, {Inst::ShlNW, "m_NWShl("}, + + {Inst::SDiv, "m_SDiv("}, {Inst::UDiv, "m_UDiv("}, + {Inst::SRem, "m_SRem("}, {Inst::URem, "m_URem("}, + + {Inst::AShrExact, "m_AShrExact("}, {Inst::LShrExact, "m_LShrExact("}, + {Inst::UDivExact, "m_UDivExact("}, {Inst::SDivExact, "m_SDivExact("}, + + {Inst::And, "m_c_And("}, {Inst::Or, "m_c_Or("}, + {Inst::Xor, "m_c_Xor("}, + + {Inst::Eq, "m_c_ICmp("}, + {Inst::Ne, "m_c_ICmp("}, + {Inst::Ule, "m_ICmp("}, + {Inst::Ult, "m_ICmp("}, + {Inst::Sle, "m_ICmp("}, + {Inst::Slt, "m_ICmp("}, + + {Inst::SExt, "m_SExt("}, + {Inst::ZExt, "m_ZExt("}, + {Inst::Trunc, "m_Trunc("}, + {Inst::Select, "m_Select("}, + {Inst::Phi, "m_Phi("}, +}; + +static const std::map CreateOps = { + {Inst::Shl, "CreateShl("}, {Inst::AShr, "CreateAShr("}, {Inst::LShr, "CreateLShr("}, + {Inst::Add, "CreateAdd("}, {Inst::Mul, "CreateMul("}, {Inst::Sub, "CreateSub("}, + {Inst::SDiv, "CreateSDiv("}, {Inst::UDiv, "CreateUDiv("}, {Inst::SRem, "CreateSRem("}, + {Inst::URem, "CreateURem("}, + {Inst::Or, "CreateOr("}, {Inst::And, "CreateAnd("}, {Inst::Xor, "CreateXor("}, + {Inst::AShrExact, "CreateAShrExact("},// {Inst::LShrExact, "CreateExactLShr("}, + // {Inst::UDivExact, "CreateExactUDiv("}, {Inst::SDivExact, "CreateExactSDiv("}, + + // FakeOps + {Inst::LogB, "CreateLogB("}, + + {Inst::Eq, "CreateCmp(ICmpInst::ICMP_EQ, "}, + {Inst::Ne, "CreateCmp(ICmpInst::ICMP_NE, "}, + {Inst::Ule, "CreateCmp(ICmpInst::ICMP_ULE, "}, + {Inst::Ult, "CreateCmp(ICmpInst::ICMP_ULT, "}, + {Inst::Sle, "CreateCmp(ICmpInst::ICMP_SLE, "}, + {Inst::Slt, "CreateCmp(ICmpInst::ICMP_SLT, "}, + + {Inst::Trunc, "CreateTrunc("}, + {Inst::SExt, "CreateSExt("}, + {Inst::ZExt, "CreateZExt("}, + + {Inst::Select, "CreateSelect("}, + + {Inst::FShl, "CreateFShl("}, + {Inst::FShr, "CreateFShr("}, + {Inst::BSwap, "CreateBSwap("}, + + {Inst::Const, "dummy"}, +}; + +static const std::map PredNames = { + {Inst::Eq, "ICmpInst::ICMP_EQ"}, + {Inst::Ne, "ICmpInst::ICMP_NE"}, + {Inst::Ule, "ICmpInst::ICMP_ULE"}, + {Inst::Ult, "ICmpInst::ICMP_ULT"}, + {Inst::Sle, "ICmpInst::ICMP_SLE"}, + {Inst::Slt, "ICmpInst::ICMP_SLT"}, +}; + +struct Constraint { + virtual std::string print() = 0; +}; + +struct VarEq : public Constraint { + VarEq(std::string LHS_, std::string RHS_) : LHS(LHS_), RHS(RHS_) {} + std::string LHS; + std::string RHS; + std::string print() override { + return LHS + " == " + RHS; + } +}; + +struct PredEq : public Constraint { + PredEq(std::string P_, std::string K_) : P(P_), K(K_) {} + std::string P; + std::string K; + std::string print() override { + return P + " == " + K; + } +}; + +struct WidthEq : public Constraint { + WidthEq(std::string Name_, size_t W_) : Name(Name_) , W(W_){} + std::string Name; + size_t W; + std::string print() override { + return "util::check_width(" + Name + ',' + std::to_string(W) + ")"; + } +}; + +struct DomCheck : public Constraint { + DomCheck(std::string Name_) : Name(Name_) {} + std::string Name; + + std::string print() override { + return "util::dc(DT, I, " + Name + ")"; + } +}; + +struct VC : public Constraint { + VC(std::string Cons_, std::string Name_) : Cons(Cons_), Name(Name_) {} + std::string print() override { + return "util::" + Cons + "(" + Name + ")"; + } + std::string Name; + std::string Cons; +}; + +struct PC : public Constraint { + PC(std::string LHS, std::string RHS) : L(LHS), R(RHS) {} + std::string print() override { + return "(" + L + " == " + R + ")"; + } + std::string L, R; +}; + +struct DB : public Constraint { + DB(std::string Val_, size_t W_) : Val(Val_), W(W_) {} + std::string print() override { + return "util::vdb(DB, I, \"" + Val + "\", " + std::to_string(W) + ")"; + } + std::string Val; + size_t W; +}; + +struct SymDB : public Constraint { + SymDB(std::string Name_) : Name(Name_) {} + std::string print() override { + return "util::symdb(DB, I, " + Name + ", B)"; + } + std::string Name; +}; + +struct K0 : public Constraint { + K0(std::string Name_, std::string Val_, size_t W_) : Name(Name_), Val(Val_), W(W_) {} + std::string print() override { + return "util::k0(" + Name + ", \"" + Val + "\", " + std::to_string(W) + ")"; + } + std::string Name, Val; + size_t W; +}; + +struct K1 : public Constraint { + K1(std::string Name_, std::string Val_, size_t W_) : Name(Name_), Val(Val_), W(W_) {} + std::string print() override { + return "util::k1(" + Name + ", \"" + Val + "\", " + std::to_string(W) + ")"; + } + std::string Name, Val; + size_t W; +}; + +struct SymK0Bind : public Constraint { + SymK0Bind(std::string Name_, std::string Bind_) : Name(Name_), Bind(Bind_) {} + std::string print() override { + return "util::symk0bind(" + Name + ", " + Bind + ", B)"; + } + std::string Name, Bind; +}; + +struct SymK1Bind : public Constraint { + SymK1Bind(std::string Name_, std::string Bind_) : Name(Name_), Bind(Bind_) {} + std::string print() override { + return "util::symk1bind(" + Name + ", " + Bind + ", B)"; + } + std::string Name, Bind; +}; + +struct SymK0Test : public Constraint { + SymK0Test(std::string Name_, std::string Name2_) : Name(Name_), Name2(Name2_) {} + std::string print() override { + return "util::symk0test(" + Name + ", " + Name2 + ")"; + } + std::string Name, Name2; +}; + +struct SymK1Test : public Constraint { + SymK1Test(std::string Name_, std::string Name2_) : Name(Name_), Name2(Name2_) {} + std::string print() override { + return "util::symk1test(" + Name + ", " + Name2 + ")"; + } + std::string Name, Name2; +}; + + +struct CR : public Constraint { + CR(std::string Name_, std::string L_, std::string H_) : Name(Name_), L(L_), H(H_) {} + std::string print() override { + return "util::cr(" + Name + ", \"" + L + "\", \"" + H + "\")"; + } + std::string Name, L, H; +}; + +struct SymbolTable : public std::map { + std::deque Constraints; + + std::map Preds; + std::vector Vars; + std::set Consts, ConstRefs; + + std::set Used; + + void RegisterPred(Inst *I) { + if (PredNames.find(I->K) == PredNames.end()) { + return; // not a predicate + } + if (Preds.find(I) != Preds.end()) { + return; // already registered + } + auto Name = "P" + std::to_string(Preds.size()); + Preds[I] = Name; + Constraints.push_back(new PredEq(Name, PredNames.at(I->K))); + } + + bool exists(Inst *I) { + if (find(I) == end()) { + return false; + } + return !at(I).empty(); + } + + template + void PrintPreds(Stream &Out) { + if (Preds.empty()) { + return; + } + Out << "ICmpInst::Predicate "; + bool first = true; + for (auto &&P : Preds) { + if (first) { + first = false; + } else { + Out << ", "; + } + Out << P.second; + } + Out << ";\n"; + } + + // Try to translate Souper expressions to APInt operations. + std::pair Translate(souper::Inst *I) { + std::vector> Children; + + if (I->K == Inst::BitWidth) { + if (at(I->Ops[0]).empty()) { + return {"", false}; + } + auto Sym = at(I->Ops[0]); + return {"util::W(" + Sym + ")", true}; + } + + for (auto Op : I ->Ops) { + Children.push_back(Translate(Op)); + if (!Children.back().second) { + return {"", false}; + } + } + + auto MET = [&](auto Str) { + return Children[0].first + "." + Str + "(" + Children[1].first + ")"; + }; + + auto WC = [&](auto Str) { + return Children[0].first + "." + Str + "(" + std::to_string(I->Width) + ")"; + }; + + auto OP = [&](auto Str) { + return "(" + Children[0].first + " " + Str + " " + Children[1].first + ")"; + }; + + auto FUN = [&](auto Str) { + return std::string(Str) + "(" + Children[0].first + ", " + Children[1].first + ")"; + }; + + switch (I->K) { + case Inst::Var : + if (exists(I)) { + return {"util::V(" + at(I) + ")", true}; + } else { + return {"", false}; + } + case Inst::Const : + if (I->Width <= 64) { + return {I->Val.toString(10, false), true}; + } else { + return {"util::V(" + std::to_string(I->Width) + + ", \"" + I->Val.toString(10, false) + "\")", true}; + } + + case Inst::AddNW : + case Inst::AddNUW : + case Inst::AddNSW : + case Inst::Add : return {OP("+"), true}; + + case Inst::SubNW : + case Inst::SubNUW : + case Inst::SubNSW : + case Inst::Sub : return {OP("-"), true}; + + case Inst::MulNW : + case Inst::MulNUW : + case Inst::MulNSW : + case Inst::Mul : return {OP("*"), true}; + + case Inst::Shl : { + if (isdigit(Children[0].first[0])) { + return {FUN("shl"), true}; + } else { + return {MET("shl"), true}; + } + } + case Inst::LShr : return {MET("lshr"), true}; + case Inst::AShr : return {MET("ashr"), true}; + + case Inst::And : return {OP("&"), true}; + case Inst::Or : return {OP("|"), true}; + case Inst::Xor : return {OP("^"), true}; + + case Inst::URem : return {MET("urem"), true}; + case Inst::SRem : return {MET("srem"), true}; + case Inst::UDiv : return {MET("udiv"), true}; + case Inst::SDiv : return {MET("sdiv"), true}; + + case Inst::Slt : { + if (isdigit(Children[0].first[0])) { + return {OP("<"), true}; + } else { + return {MET("slt"), true}; + } + } + case Inst::Sle : { + if (isdigit(Children[0].first[0])) { + return {OP("<="), true}; + } else { + return {MET("sle"), true}; + } + } + case Inst::Ult : { + if (isdigit(Children[0].first[0])) { + return {OP("<"), true}; + } else { + return {MET("ult"), true}; + } + } + case Inst::Ule : { + if (isdigit(Children[0].first[0])) { + return {OP("<="), true}; + } else { + return {MET("ule"), true}; + } + } + case Inst::Eq : { + if (isdigit(Children[0].first[0])) { + return {OP("=="), true}; + } else { + return {MET("eq"), true}; + } + } + case Inst::Ne : { + if (isdigit(Children[0].first[0])) { + return {OP("!="), true}; + } else { + return {MET("ne"), true}; + } + } + case Inst::ZExt : return {WC("zext"), true}; + case Inst::SExt : return {WC("sext"), true}; + case Inst::Trunc : return {WC("trunc"), true}; + + default: { + llvm::errs() << "Unimplemented op in PC: " << Inst::getKindName(I->K) << "\n"; + return {"", false}; + } + } + + } + + Constraint *ConvertPCToWidthConstraint(InstMapping PC) { + if (PC.LHS->K != Inst::Eq) + return nullptr; + if (PC.LHS->Ops[0]->K == Inst::BitWidth) { + return new WidthEq(this->at(PC.LHS->Ops[0]->Ops[0]), + PC.LHS->Ops[1]->Val.getLimitedValue()); + } + if (PC.LHS->Ops.size() > 1 && PC.LHS->Ops[1]->K == Inst::BitWidth) { + return new WidthEq(this->at(PC.LHS->Ops[1]->Ops[0]), + PC.LHS->Ops[0]->Val.getLimitedValue()); + } + return nullptr; + } + + bool GenPCConstraints(std::vector PCs) { + for (auto M : PCs) { + if (M.LHS->K == Inst::KnownZerosP) { + if (M.LHS->Ops[0]->K == Inst::Var && M.LHS->Ops[1]->Name.starts_with("symDF_K")) { + auto C = new SymK0Bind(this->at(M.LHS->Ops[0]), + this->at(M.LHS->Ops[1])); + Constraints.push_front(C); + // Binds have side effects, have to go in front. + } else { + auto C = new SymK0Test(this->at(M.LHS->Ops[0]), + this->at(M.LHS->Ops[1])); + Constraints.push_back(C); + } + } else if (M.LHS->K == Inst::KnownOnesP) { + if (M.LHS->Ops[0]->K == Inst::Var && M.LHS->Ops[1]->Name.starts_with("symDF_K")) { + auto C = new SymK1Bind(this->at(M.LHS->Ops[0]), + this->at(M.LHS->Ops[1])); + Constraints.push_front(C); + // Binds have side effects, have to go in front. + } else { + auto C = new SymK1Test(this->at(M.LHS->Ops[0]), + this->at(M.LHS->Ops[1])); + Constraints.push_back(C); + } + } else if (auto WC = ConvertPCToWidthConstraint(M)) { + Constraints.push_back(WC); + } else { + auto L = Translate(M.LHS); + auto R = Translate(M.RHS); + if (!L.second || !R.second) { + return false; + } + Constraints.push_back(new PC(L.first, R.first)); + } + } + return true; + } + + void GenDomConstraints(Inst *RHS) { + static std::set Visited; + Visited.insert(RHS); + for (auto Op : RHS->Ops) { + if (Op->K == Inst::Const) { + continue; + // TODO: Find other cases + } + auto It = find(Op); + if (It != end()) { + if (Visited.find(Op) == Visited.end()) { + Constraints.push_back(new DomCheck(It->second)); + GenDomConstraints(Op); + } + } + } + } + + void GenDFConstraints(Inst *LHS) { + if (LHS->DemandedBits.getBitWidth() + == LHS->Width && !LHS->DemandedBits.isAllOnesValue()) { + Constraints.push_back(new DB(LHS->DemandedBits.toString(2, false), LHS->Width)); + } + + std::vector Vars; + findVars(LHS, Vars); + + std::set VarSet; + for (auto &&V : Vars) { + VarSet.insert(V); + } + + for (auto &&V : VarSet) { + auto Name = this->at(V); + if (V->KnownOnes.getBitWidth() == V->Width && + V->KnownOnes != 0) { + Constraints.push_back(new K1(Name, V->KnownOnes.toString(2, false), V->Width)); + } + if (V->KnownZeros.getBitWidth() == V->Width && + V->KnownZeros != 0) { + Constraints.push_back(new K0(Name, V->KnownZeros.toString(2, false), V->Width)); + } + + if (!V->Range.isFullSet()) { + Constraints.push_back(new CR(Name, V->Range.getLower().toString(10, false), V->Range.getUpper().toString(10, false))); + } + } + } + + void GenVarPropConstraints(Inst *LHS, bool WidthIndependent) { + std::vector Vars; + findVars(LHS, Vars); + + for (auto V : Vars) { + auto Name = this->at(V); + + if (!WidthIndependent || V->Width == 1) { + Constraints.push_back(new WidthEq(Name, V->Width)); + } + + if (V->PowOfTwo) { + Constraints.push_back(new VC("pow2", Name)); + } + if (V->NonZero) { + Constraints.push_back(new VC("nz", Name)); + } + if (V->NonNegative) { + Constraints.push_back(new VC("nn", Name)); + } + if (V->Negative) { + Constraints.push_back(new VC("neg", Name)); + } + + } + } + + template + void PrintConstraintsPre(Stream &Out) { + if (Constraints.empty()) { + return; + } + Out << "if ("; + bool first = true; + for (auto &&C : Constraints) { + if (first) { + first = false; + } else { + Out << " && "; + } + Out << C->print(); + } + Out << ") {\n"; + } + template + void PrintConstraintsPost(Stream &Out) { + if (Constraints.empty()) { + return; + } + Out << "}\n"; + } + + // Consts = consts found in LHS + // ConstRefs = consts found in RHS + template + void PrintConstDecls(Stream &Out) { + size_t varnum = 0; + + auto Print = [&](SymbolTable &Syms, Inst *C){ + auto Name = "C" + std::to_string(varnum++); + if (C->Width <= 64) { + Out << " auto " << Name << " = C(" + << C->Val.getBitWidth() <<", " + << C->Val << ", B);\n"; + } else { + Out << " auto " << Name << " = C(" + << "APInt(" << C->Val.getBitWidth() << ", " + << "\"" << C->Val.toString(10, false) << "\", 10), B);\n"; + } + Syms[C] = Name; + }; + + for (auto C : ConstRefs) { + // if (Consts.find(C) == Consts.end()) { + Print(*this, C); + // } + } + } +}; + +template +bool GenLHSMatcher(Inst *I, Stream &Out, SymbolTable &Syms, bool IsRoot = false) { + if (!IsRoot) { + if (I->K != souper::Inst::Var && Syms.Used.find(I) != Syms.Used.end()) { + Out << "&" << Syms[I] << " <<= "; + } + } + + static std::set MatchedVals; + if (IsRoot && I->K == Inst::Var) { + if (MatchedVals.find(I) == MatchedVals.end()) { + MatchedVals.insert(I); + Out << "m_Value(" << Syms[I] << ")"; + return true; + } else { + Out << "m_Deferred(" << Syms[I] << ")"; + } + return true; + } + + auto It = MatchOps.find(I->K); + if (It == MatchOps.end()) { + llvm::errs() << "\nUnimplemented matcher:" << Inst::getKindName(I->K) << "\n"; + return false; + } + + auto Op = It->second; + + Out << Op; + + if (!OnlyExplicitWidths) { + if (I->K == Inst::SExt || I->K == Inst::ZExt || I->K == Inst::Trunc) { + Out << I->Width << ", "; + } + } + + if (PredNames.find(I->K) != PredNames.end()) { + Out << Syms.Preds[I] << ", "; + } + + bool first = true; + for (auto Child : I->Ops) { + if (first) { + first = false; + } else { + Out << ", "; + } + + if (Child->K == Inst::Const) { + if (Child->K != souper::Inst::Var && Syms.Used.find(Child) != Syms.Used.end()) { + Out << "&" << Syms[Child] << " <<= "; + } + auto Str = Child->Val.toString(10, false); + if (OnlyExplicitWidths) { + Out << "m_ExtInt(\"" << Str << "\", " << Child->Width << ")"; + } else { + Out << "m_SpecificInt( " << Child->Width << ", \"" << Str << "\")"; + } + } else if (Child->K == Inst::Var) { + if (Child->Name.starts_with("symconst")) { + Out << "m_Constant(&" << Syms[Child] << ")"; + } else if (Child->Name.starts_with("constexpr")) { + llvm::errs() << "FOUND A CONSTEXPR\n"; + return false; + } else { + if (MatchedVals.find(Child) == MatchedVals.end()) { + MatchedVals.insert(Child); + Out << "m_Value(" << Syms[Child] << ")"; + } else { + Out << "m_Deferred(" << Syms[Child] << ")"; + } + } + // Syms[Child].pop_back(); + } else { + if (!GenLHSMatcher(Child, Out, Syms)) { + return false; + } + } + + } + Out << ")"; + return true; +} + +Inst *getSibling(Inst *Child, Inst *Parent) { + if (!Child || !Parent) { + return nullptr; + } + + for (auto Op : Parent->Ops) { + if (Op != Child && Op->Width == Child->Width) { + return Op; + } + } + return nullptr; +} + +template +bool GenRHSCreator(Inst *I, Stream &Out, SymbolTable &Syms, Inst *Parent = nullptr) { + auto It = CreateOps.find(I->K); + if (It == CreateOps.end()) { + llvm::errs() << "\nUnimplemented creator:" << Inst::getKindName(I->K) << "\n"; + return false; + } + auto Op = It->second; + + Out << "B->" << Op; + bool first = true; + for (auto Child : I->Ops) { + if (first) { + first = false; + } else { + Out << ", "; + } + if (Syms.find(Child) != Syms.end()) { + Out << Syms[Child]; + if (Child->K == Inst::Const && Syms[Child].starts_with("C")) { + auto Sib = getSibling(Child, I); + std::string S; + if (Syms.find(Sib) != Syms.end()) { + if (Syms[Sib].starts_with("x")) { + S = Syms[Sib]; + } + } + Out << "(" << S << ")"; // Ad-hoc type inference + } + } else { + if (!GenRHSCreator(Child, Out, Syms, I)) { + return false; + } + } + + } + if (I->K == Inst::Trunc || I->K == Inst::SExt || I->K == Inst::ZExt) { + auto Cousin = getSibling(I, Parent); + std::string S; + if (Syms.find(Cousin) != Syms.end()) { + if (Syms[Cousin].starts_with("x")) { + S = Syms[Cousin]; + } + } + if (!S.empty()) { + Out << ", T(" << S << ")"; // Ad-hoc type inference + } else { + Out << ", T(" << I->Width << ", B)"; + } + } + Out << ")"; + + return true; +} + +template +bool InitSymbolTable(ParsedReplacement Input, Stream &Out, SymbolTable &Syms) { + auto Root = Input.Mapping.LHS; + auto RHS = Input.Mapping.RHS; + std::set LHSInsts; + std::set Visited; + + std::vector Stack{Root}; + for (auto M : Input.PCs) { + Stack.push_back(M.LHS); + Stack.push_back(M.RHS); + } + + int varnum = 0; + while (!Stack.empty()) { + auto I = Stack.back(); + Stack.pop_back(); + LHSInsts.insert(I); + Visited.insert(I); + if (I->K == Inst::Var) { + if (Syms.find(I) == Syms.end()) { + Syms[I] = ("x" + std::to_string(varnum++)); + // llvm::errs() << "Var1: " << I->Name << " -> " << Syms[I] << "\n"; + } + } + if (I->K == Inst::Const) { + Syms.Consts.insert(I); + } + for (int i = 0; i < I->Ops.size(); ++i) { + if (Visited.find(I->Ops[i]) == Visited.end()) { + Stack.push_back(I->Ops[i]); + } + } + } + + Visited.clear(); + Stack = {Root}; + while (!Stack.empty()) { + auto I = Stack.back(); + Stack.pop_back(); + Visited.insert(I); + Syms.RegisterPred(I); + for (int i = 0; i < I->Ops.size(); ++i) { + if (Visited.find(I->Ops[i]) == Visited.end()) { + Stack.push_back(I->Ops[i]); + } + } + } + + Visited.clear(); + Stack.push_back(RHS); + + while (!Stack.empty()) { + auto I = Stack.back(); + Stack.pop_back(); + Visited.insert(I); + if (I->K == Inst::Const) { + Syms.ConstRefs.insert(I); + } + + if (LHSInsts.find(I) != LHSInsts.end()) { + if (Syms.Used.insert(I).second && Syms.find(I) == Syms.end()) { + Syms[I] = ("x" + std::to_string(varnum++)); + // llvm::errs() << "Var0: " << I->Name << " -> " << Syms[I] << "\n"; + } + } + for (auto Child : I->Ops) { + if (Visited.find(Child) == Visited.end()) { + Stack.push_back(Child); + } + } + + } + + if (!Syms.empty()) { + Out << "llvm::Value "; + bool first = true; + for (auto &&S : Syms) { + if (first) { + first = false; + } else { + Out << ", "; + } + Out << "*" << S.second; + } + Out << ";\n"; + } + +// varnum = 0; +// for (auto &&P : Paths) { +// if (P.first == Root || P.first->K == Inst::Var +// || LHSRefs.find(P.first) == LHSRefs.end()) { +// continue; +// } +//// std::string Name = "I"; +//// for (auto idx : P.second) { +//// auto NewName = "y" + std::to_string(varnum++); +//// Out << "auto " << NewName << " = cast(" << Name; +//// Out << ")->getOperand(" << idx << ");\n"; +//// std::swap(Name, NewName); +//// } +//// Syms[P.first].push_back(Name); +// +// auto Name = "y" + std::to_string(varnum++); +// Out << "auto " << Name << " = util::node(I, "; +// printPath(Out, P.second); +// Out << ");\n"; +// Syms[P.first].push_back(Name); +// } + // Syms[Root].push_back("I"); + Syms.PrintPreds(Out); + return true; +} + +template +bool GenMatcher(ParsedReplacement Input, Stream &Out, size_t OptID, bool WidthIndependent) { + SymbolTable Syms; + Out << "{\n"; + + int prof = profit(Input); + size_t LHSSize = souper::instCount(Input.Mapping.LHS); + if (prof <= 0 || LHSSize > 15) { + llvm::errs() << "Skipping replacement profit < 0 or LHS size > 15\n"; + return false; + } + + if (!InitSymbolTable(Input, Out, Syms)) { + return false; + } +// Out << " llvm::errs() << \"NOW \" << " << OptID << "<< \"\\n\";\n"; + + auto F = "util::filter(F, " + std::to_string(OptID) + ") && "; + Out << "if (" << F << "match(I, "; + + SymbolTable SymsCopy = Syms; + if (Input.Mapping.LHS->K == Inst::DemandedMask) { + if (!GenLHSMatcher(Input.Mapping.LHS->Ops[0], Out, SymsCopy, /*IsRoot = */true)) { + return false; + } + } else { + if (!GenLHSMatcher(Input.Mapping.LHS, Out, SymsCopy, /*IsRoot = */true)) { + return false; + } + } + Out << ")) {\n"; + +// Input.print(llvm::errs(), true); + Inst *DemandedMask = nullptr; + if (Input.Mapping.LHS->K == Inst::DemandedMask) { + DemandedMask = Input.Mapping.LHS->Ops[1]; + Syms.Constraints.push_back(new SymDB(Syms[Input.Mapping.LHS->Ops[1]])); + } + Syms.GenVarPropConstraints(Input.Mapping.LHS, WidthIndependent); + Syms.GenDomConstraints(Input.Mapping.RHS); + Syms.GenDFConstraints(Input.Mapping.LHS); + if (!Syms.GenPCConstraints(Input.PCs)) { + llvm::errs() << "Failed to generate PC constraints.\n"; + return false; + } + Syms.PrintConstraintsPre(Out); + + Syms.PrintConstDecls(Out); + + Out << " auto ret"; + + if (Syms.find(Input.Mapping.RHS) != Syms.end()) { + Out << " = " << Syms[Input.Mapping.RHS]; + if (Syms[Input.Mapping.RHS].starts_with("C")) { + Out << "(I)"; + } + Out << ";"; + } else if (Input.Mapping.RHS->K == Inst::DemandedMask && Syms.find(Input.Mapping.RHS->Ops[0]) != Syms.end()) { + assert(DemandedMask == Input.Mapping.RHS->Ops[1] && "DemandedMask mismatch"); + Out << " = " << Syms[Input.Mapping.RHS->Ops[0]] << ";"; + } else if (Input.Mapping.RHS->K == Inst::Const) { + Out << " APInt Result(" + << Input.Mapping.RHS->Width <<", " + << Input.Mapping.RHS->Val << ");\n"; + Out << " = ConstantInt::get(TheContext, Result);"; + } else { + Out << " = "; + if (Input.Mapping.RHS->K == Inst::DemandedMask) { + assert(DemandedMask == Input.Mapping.RHS->Ops[1] && "DemandedMask mismatch"); + if (!GenRHSCreator(Input.Mapping.RHS->Ops[0], Out, Syms)) { + return false; + } + } else { + if (!GenRHSCreator(Input.Mapping.RHS, Out, Syms)) { + return false; + } + } + Out << ";"; + } + Out << "\nif (util::check_width(ret, I)) {\n"; + Out << " St.hit(" << OptID << ", " << prof << ");\n"; + Out << " return ret;\n"; + Out << "\n}\n}\n}"; + + Syms.PrintConstraintsPost(Out); + + return true; +} + +std::string getLLVMInstKindName(Inst::Kind K) { + StringRef str = MatchOps.find(K)->second; + str.consume_front("m_"); + str.consume_back("("); +// str.consume_front("NSW"); +// str.consume_front("NUW"); +// str.consume_front("NW"); + return str.str(); +} + +bool PCHasVar(const ParsedReplacement &Input) { + std::vector Stack; + for (auto &&PC : Input.PCs) { + // if (PC.LHS->K == Inst::KnownOnesP || PC.LHS->K == Inst::KnownZerosP) + // continue; + + // if (PC.LHS->K == Inst::Eq && (PC.LHS->Ops[0]->K == Inst::BitWidth || + // PC.LHS->Ops[1]->K == Inst::BitWidth )) { + // continue; + // } + Stack.push_back(PC.LHS); + Stack.push_back(PC.RHS); + } + + std::set Visited; + std::vector Vars; + + while (!Stack.empty()) { + Inst *I = Stack.back(); + Stack.pop_back(); + + if (Visited.count(I)) { + continue; + } + Visited.insert(I); + + if (I->K == Inst::Var) { + Vars.push_back(I); + } + + for (auto &&Op : I->Ops) { + if (I->K == Inst::KnownOnesP || I->K == Inst::KnownZerosP || I->K == Inst::BitWidth) + continue; + Stack.push_back(Op); + } + } + + + for (auto &&V : Vars) { + // llvm::errs() << V->Name << "\n"; + if (!V->Name.starts_with("sym")) { + return true; + } + } + + return false; +} + + +int main(int argc, char **argv) { + cl::ParseCommandLineOptions(argc, argv); + KVStore *KV = 0; + + std::unique_ptr S = 0; + S = GetSolver(KV); + + std::unordered_set optnumbers; + std::vector Ordered; + if (ListFile != "") { + std::ifstream in(ListFile); + size_t num; + while (in >> num) { + optnumbers.insert(num); + if (Sort) { + Ordered.push_back(num); + } + } + } + + auto MB = MemoryBuffer::getFileOrSTDIN(InputFilename); + if (!MB) { + llvm::errs() << MB.getError().message() << '\n'; + return 1; + } + + InstContext IC; + std::string ErrStr; + + auto &&Data = (*MB)->getMemBufferRef(); + auto Inputs = ParseReplacements(IC, Data.getBufferIdentifier(), + Data.getBuffer(), ErrStr); + + std::set Kinds; + std::sort(Inputs.begin(), Inputs.end(), + [&Kinds](const ParsedReplacement& A, const ParsedReplacement &B) { + Kinds.insert(A.Mapping.LHS->K); + Kinds.insert(B.Mapping.LHS->K); + +// if (A.Mapping.LHS->K < B.Mapping.LHS->K) { +// return true; +// } else if (A.Mapping.LHS->K == B.Mapping.LHS->K) { +// return profitability(A) > profitability(B); +// } else { +// return false; +// } + return A.Mapping.LHS->K < B.Mapping.LHS->K; + }); + + if (!ErrStr.empty()) { + llvm::errs() << ErrStr << '\n'; + return 1; + } + + size_t optnumber = 0; + + Inst::Kind Last = Inst::Kind::None; + + bool first = true; + bool outputs = false; + + std::map Results; + + for (auto &&Input: Inputs) { + auto SKIP = [&] (auto Msg) { + Input.print(llvm::errs(), true); + llvm::errs() << Msg << "\n\n\n"; + llvm::errs().flush(); + }; + + if (PCHasVar(Input)) { + SKIP("SKIP PC has var."); + continue; + } + + if (Input.Mapping.LHS == Input.Mapping.RHS) { + SKIP("SKIP LHS = RHS."); + continue; + } + if (IgnoreDF) { + if (Input.Mapping.LHS->DemandedBits.getBitWidth() + == Input.Mapping.LHS->Width && !Input.Mapping.LHS->DemandedBits.isAllOnesValue()) { + + continue; + } + std::vector Vars; + findVars(Input.Mapping.LHS, Vars); + findVars(Input.Mapping.RHS, Vars); + bool found = false; + for (auto V : Vars) { + if (V->KnownOnes.getBitWidth() == V->Width && V->KnownOnes != 0) { + found = true; + break; + } + + if (V->KnownZeros.getBitWidth() == V->Width && V->KnownZeros != 0) { + found = true; + break; + } +// if (!V->Range.isFullSet() || !V->Range.isEmptySet()) { +// continue; +// } + } + if (found) { + SKIP("SKIP Unsupported DF."); + continue; + } + } + + if (Input.Mapping.LHS->K != Last && !NoDispatch) { + if (!first) { + llvm::outs() << "}\n"; + } + first = false; + llvm::outs() << "if ("; + switch (Input.Mapping.LHS->K) { + case Inst::AddNW: + case Inst::AddNUW: + case Inst::AddNSW: + case Inst::Add: llvm::outs() + << "I->getOpcode() == Instruction::Add"; break; + + case Inst::SubNW: + case Inst::SubNUW: + case Inst::SubNSW: + case Inst::Sub: llvm::outs() + << "I->getOpcode() == Instruction::Sub"; break; + + case Inst::MulNW: + case Inst::MulNUW: + case Inst::MulNSW: + case Inst::Mul: llvm::outs() + << "I->getOpcode() == Instruction::Mul"; break; + + case Inst::ShlNW: + case Inst::ShlNUW: + case Inst::ShlNSW: + case Inst::Shl: llvm::outs() + << "I->getOpcode() == Instruction::Shl"; break; + + case Inst::And: llvm::outs() + << "I->getOpcode() == Instruction::And"; break; + case Inst::Or: llvm::outs() + << "I->getOpcode() == Instruction::Or"; break; + case Inst::Xor: llvm::outs() + << "I->getOpcode() == Instruction::Xor"; break; + case Inst::SRem: llvm::outs() + << "I->getOpcode() == Instruction::SRem"; break; + case Inst::URem: llvm::outs() + << "I->getOpcode() == Instruction::URem"; break; + case Inst::SDiv: llvm::outs() + << "I->getOpcode() == Instruction::SDiv"; break; + case Inst::UDiv: llvm::outs() + << "I->getOpcode() == Instruction::UDiv"; break; + case Inst::ZExt: llvm::outs() + << "I->getOpcode() == Instruction::ZExt"; break; + case Inst::SExt: llvm::outs() + << "I->getOpcode() == Instruction::SExt"; break; + case Inst::Trunc: llvm::outs() + << "I->getOpcode() == Instruction::Trunc"; break; + case Inst::Select: llvm::outs() + << "I->getOpcode() == Instruction::Select"; break; + case Inst::Phi: llvm::outs() + << "isa(I)"; break; + case Inst::Eq: + case Inst::Ne: + case Inst::Ult: + case Inst::Slt: + case Inst::Ule: + case Inst::Sle: llvm::outs() + << "I->getOpcode() == Instruction::ICmp"; break; + + default: llvm::outs() << "true"; + } + llvm::outs() << ") {\n"; + outputs = true; + } + Last = Input.Mapping.LHS->K; + + std::string Str; + llvm::raw_string_ostream Out(Str); + + if (GenMatcher(Input, Out, optnumber, OnlyExplicitWidths)) { + auto current = optnumber++; + if (!optnumbers.empty() + && optnumbers.find(current) == optnumbers.end()) { + Out.flush(); + Str.clear(); + llvm::errs() << "Opt " << current << " skipped on demand.\n"; + SKIP("SKIP Filtered."); + continue; + } + ReplacementContext RC; + std::string IRComment = + "/* Opt : " + + std::to_string(current) + "\n" + + Input.getLHSString(RC, true) + + Input.getRHSString(RC, true) + "*/\n"; + + if (NoDispatch && Sort && !Ordered.empty()) { + Results[current] = IRComment + Str + "\n"; + } else { + llvm::outs() << IRComment << Str << "\n"; + llvm::outs().flush(); + outputs= true; + } + + } else { + SKIP("SKIP Failed to generate matcher."); + } + } + if (outputs) { + llvm::outs() << "}\n"; + } + + if (NoDispatch && Sort && !Ordered.empty()) { + for (auto N : Ordered) { + llvm::outs() << Results[N]; + } + } + +// llvm::outs() << "end:\n"; + + return 0; +} diff --git a/tools/pass-generator/CMakeLists.txt b/tools/pass-generator/CMakeLists.txt new file mode 100644 index 000000000..8c78bd769 --- /dev/null +++ b/tools/pass-generator/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.7) +project(souper-combine) + +cmake_policy(SET CMP0074 NEW) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") + +find_package(LLVM 14.0 REQUIRED CONFIG) +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(HandleLLVMOptions) +include(AddLLVM) + +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") + +add_definitions(${LLVM_DEFINITIONS}) +include_directories(${LLVM_INCLUDE_DIRS}) + +add_subdirectory(src) diff --git a/tools/pass-generator/scripts/bisection.py b/tools/pass-generator/scripts/bisection.py new file mode 100644 index 000000000..8111d498f --- /dev/null +++ b/tools/pass-generator/scripts/bisection.py @@ -0,0 +1,46 @@ +import sys +import re +import os + +def split_on_empty_lines(s): + blank_line_regex = r"(?:\r?\n){2,}" + return re.split(blank_line_regex, s.strip()) + +def split_list(a_list): + half = len(a_list)//2 + return a_list[:half], a_list[half:] + +# Return true if the command crashes +def call(opts, cmd): + if len(opts) == 0: + return False + open("/tmp/scratch.inc", 'w+').write('\n'.join(opts)) + print("HERE: ", cmd + " /tmp/scratch.inc") + return os.system(cmd + " /tmp/scratch.inc > /dev/null") != 0 + +# Find a small range of optimizations which causes cmd to crash +def bsearch(opts, cmd): + if len(opts) == 1: + return opts + elif len(opts) == 0: + return () + else: + left, right = split_list(opts) + if call(left, cmd) == True: + return bsearch(left, cmd) + elif call(right, cmd) == True: + return bsearch(right, cmd) + else: + print("Rangefinder failed, find combinations") + return opts + +def main(): + # First argument is the autogenerated matcher file + # Second argument is a file with the command to run whatever you want to run with the matchers as input + all = split_on_empty_lines(open(sys.argv[1]).read()) + cmd = sys.argv[2] + print('\n'.join(bsearch(all, cmd))) + return 0 + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/pass-generator/src/CMakeLists.txt b/tools/pass-generator/src/CMakeLists.txt new file mode 100644 index 000000000..c2c67d52d --- /dev/null +++ b/tools/pass-generator/src/CMakeLists.txt @@ -0,0 +1 @@ +add_llvm_library(SouperCombine MODULE template.cpp) diff --git a/tools/pass-generator/src/template.cpp b/tools/pass-generator/src/template.cpp new file mode 100644 index 000000000..8ad66ac56 --- /dev/null +++ b/tools/pass-generator/src/template.cpp @@ -0,0 +1,855 @@ +#include "llvm/Pass.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/NoFolder.h" +#include "llvm/InitializePasses.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "" +#include "llvm/Transforms/Utils/InstructionWorklist.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/CommandLine.h" + +#include +#include + +using namespace llvm; +using namespace llvm::PatternMatch; + +static llvm::cl::opt ListFile("listfile", + llvm::cl::desc("List of optimization indexes to include.\n" + "(default=empty-string)"), + llvm::cl::init("")); + +static llvm::cl::opt Low("low", + llvm::cl::desc("Low"), + llvm::cl::init(-1)); + +static llvm::cl::opt High("high", + llvm::cl::desc("High"), + llvm::cl::init(-1)); + + +// TODO Match trees +// TODO Make the commutative operations work +// This is critical because commutative operations +// sometimes have the arguments inverted for canonicalization + +namespace { + +// Custom Creators + +class IRBuilder : public llvm::IRBuilder { +public: + IRBuilder(llvm::LLVMContext &C) : llvm::IRBuilder(C) {} + + llvm::Value *CreateLogB(llvm::Value *V) { + if (ConstantInt *Con = llvm::dyn_cast(V)) { + auto Result = Con->getValue().logBase2(); + return ConstantInt::get(Con->getType(), Result); + } else { + llvm_unreachable("Panic, has to be guarded in advance!"); + } + } + + + + // TODO Verify that these work, the mangling argument is weird + llvm::Value *CreateFShl(llvm::Value *A, llvm::Value *B, llvm::Value *C) { + return CreateIntrinsic(Intrinsic::fshl, {A->getType()}, {A, B, C}); + } + llvm::Value *CreateFShr(llvm::Value *A, llvm::Value *B, llvm::Value *C) { + return CreateIntrinsic(Intrinsic::fshr, {A->getType()}, {A, B, C}); + } + llvm::Value *CreateBSwap(llvm::Value *A) { + return CreateIntrinsic(Intrinsic::bswap, {A->getType()}, {A}); + } + + llvm::Value *CreateAShrExact(llvm::Value *A, llvm::Value *B) { + return llvm::BinaryOperator::CreateExact(Instruction::AShr, A, B); + } + llvm::Value *CreateLShrExact(llvm::Value *A, llvm::Value *B) { + return llvm::BinaryOperator::CreateExact(Instruction::LShr, A, B); + } + llvm::Value *CreateUDivExact(llvm::Value *A, llvm::Value *B) { + return llvm::BinaryOperator::CreateExact(Instruction::UDiv, A, B); + } + llvm::Value *CreateSDivExact(llvm::Value *A, llvm::Value *B) { + return llvm::BinaryOperator::CreateExact(Instruction::SDiv, A, B); + } +}; + + +// Custom Matchers + +static constexpr auto NWFlag = OverflowingBinaryOperator::NoSignedWrap + | OverflowingBinaryOperator::NoUnsignedWrap; +#define NWT(OP) OverflowingBinaryOp_match +#define NWM(OP) \ +template NWT(OP) \ +m_NW##OP(const LHS &L, const RHS &R) { \ + return NWT(OP)(L, R); \ +} + +NWM(Add) +NWM(Sub) +NWM(Mul) +NWM(Shl) + +#undef NWM +#undef NWT + +template +struct phi_match { + phi_match(Args... args) : Matchers{args...} {}; + std::tuple Matchers; + + // cpp weirdness for accessing specific tuple element in runtime + template + void runtime_get(Func func, Tuple& tup, size_t idx) { + if (N == idx) { + std::invoke(func, std::get(tup)); + return; + } + + if constexpr (N + 1 < std::tuple_size_v) { + return runtime_get(func, tup, idx); + } + } + + bool match_nth(size_t n, Value *V) { + bool Ret = false; + auto F = [&](auto M) {Ret = M.match(V);}; + runtime_get(F, Matchers, n); + return Ret; + } + + bool check(const Value *V) { + if (auto Phi = dyn_cast(V)) { + for (size_t i =0; i < Phi->getNumOperands(); ++i) { + if (!match_nth(i, Phi->getOperand(i))) { + return false; + } + } + return true; + } + return false; + } + + template bool match(I *V) { + return check(V); + } +}; + +template +phi_match m_Phi(Args... args) { + return phi_match(args...); +} + +template +inline Exact_match> +m_AShrExact(const LHS &L, const RHS &R) { + return Exact_match>( + BinaryOp_match(L, R)); +} + +template +inline Exact_match> +m_LShrExact(const LHS &L, const RHS &R) { + return Exact_match>( + BinaryOp_match(L, R)); +} + +template +inline Exact_match> +m_UDivExact(const LHS &L, const RHS &R) { + return Exact_match>( + BinaryOp_match(L, R)); +} + +template +inline Exact_match> +m_SDivExact(const LHS &L, const RHS &R) { + return Exact_match>( + BinaryOp_match(L, R)); +} + +struct bind_apint { + APInt &VR; + + bind_apint(APInt &V) : VR(V) {} + + template bool match(ITy *V) { + if (const auto *CV = dyn_cast(V)) { + VR = CV->getValue(); + return true; + } else { + return false; + } + return false; + } +}; + +struct width_specific_intval : public specific_intval { + size_t Width; + width_specific_intval(llvm::APInt V, size_t W) : specific_intval(V), Width(W) {} + + template bool match(ITy *V) { + if (V->getType()->getScalarSizeInBits() != Width) { + return false; + } + return specific_intval::match(V); + } +}; + +inline width_specific_intval m_SpecificInt(size_t W, uint64_t V) { + return width_specific_intval(APInt(64, V), W); +} + +inline width_specific_intval m_SpecificInt(size_t W, std::string S) { + return width_specific_intval(APInt(W, S, 10), W); +} + +struct specific_ext_intval { + llvm::APInt Val; + + specific_ext_intval(std::string S, size_t W) : Val(llvm::APInt(W, S, 10)) {} + + template bool match(ITy *V) { + const auto *CI = dyn_cast(V); + if (!CI && V->getType()->isVectorTy()) + if (const auto *C = dyn_cast(V)) + CI = dyn_cast_or_null(C->getSplatValue(true)); + + if (!CI) + return false; + + auto TargetVal = CI->getValue(); + auto TargetWidth = TargetVal.getBitWidth(); + + return llvm::APInt::isSameValue(TargetVal, Val.zextOrTrunc(TargetWidth)) || + llvm::APInt::isSameValue(TargetVal, Val.sextOrTrunc(TargetWidth)); + + } +}; + +inline specific_ext_intval m_ExtInt(std::string S, size_t W) { + return specific_ext_intval(S, W); +} + +struct constant_matcher { + llvm::Value** Captured; + constant_matcher(llvm::Value** C) : Captured(C) {} + template bool match(OpTy *V) { + if (isa(V)) { + *Captured = dyn_cast(V); + return *Captured != nullptr; + } + return false; + } +}; + +inline constant_matcher m_Constant(Value **V) { + return constant_matcher(V); +} + +// Tested, matches APInts +inline bind_apint m_APInt(APInt &V) { return bind_apint(V); } + +// TODO: Match (arbitrarily) constrained APInts + + +template struct CastClass_match_width { + size_t Width; + Op_t Op; + + CastClass_match_width(size_t W, const Op_t &OpMatch) : Width(W), Op(OpMatch) {} + + template bool match(OpTy *V) { + if (V->getType()->getScalarSizeInBits() != Width) { + return false; + } + if (auto *O = dyn_cast(V)) + return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + return false; + } +}; + +template +struct Capture { + Value **Captured; + Matcher M; + + template + explicit Capture(Value **V, CArgs ...C) : Captured(V), M(C...) {} + + template bool match(OpTy *V) { + if (M.match(V)) { + *Captured = dyn_cast(V); + if (!*Captured) { + llvm::errs() << "Capture failed.\n"; + return false; + } + return true; + } else { + *Captured = nullptr; + return false; + } + } +}; + +template +Capture Cap(Value **V, Matcher &&M) { + return Capture(V, M); +} + +// Equivalent to the Cap function +template +Capture operator<<=(Value **V, Matcher &&M) { + return Capture(V, M); +} + +template +inline CastClass_match_width m_ZExt(size_t W, const OpTy &Op) { + return CastClass_match_width(W, Op); +} + +template +inline CastClass_match_width m_SExt(size_t W, const OpTy &Op) { + return CastClass_match_width(W, Op); +} + +template +inline CastClass_match_width m_Trunc(size_t W, const OpTy &Op) { + return CastClass_match_width(W, Op); +} + +namespace util { + bool dc(llvm::DominatorTree *DT, llvm::Instruction *I, llvm::Value *V) { + if (auto Def = dyn_cast(V)) { + if (I->getParent() == Def->getParent()) { + return true; + } + return DT->dominates(Def, I->getParent()); + } + return true; + } + + bool check_width(llvm::Value *V, size_t W) { + if (V && V->getType() && V->getType()->isIntegerTy()) { + return V->getType()->getScalarSizeInBits() == W; + } else { + return false; + } + } + + bool check_width(llvm::Value *V, Instruction *I) { + if (V && V->getType() && V->getType()->isIntegerTy()) { + return V->getType()->getScalarSizeInBits() == I->getType()->getScalarSizeInBits(); + } else { + return false; + } + } + + template + bool check_related(Out Result, FT F, Args... args) { + return Result == F(args...); + } + + bool pow2(llvm::Value *V) { + if (ConstantInt *Con = llvm::dyn_cast(V)) { + if (Con->getValue().isPowerOf2()) { + return true; + } + } + return false; + } + + bool KnownBitImplies(llvm::APInt Big, llvm::APInt Small) { + + if (Big.getBitWidth() != Small.getBitWidth()) { + return false; + } + +// auto P = [](llvm::APInt A, auto S) { +// llvm::SmallVector Foo; +// A.toString(Foo, 2, false); +// llvm::errs() << "\n" << Foo << " <--" << S << "\n"; +// }; +// +// auto Val = (~Big | Small); +// +// P(Big, "BIG"); +// P(Small, "SMALL"); +// P(~Big, "FLIP"); +// +// P(Small, "OR"); +// P(~Big | Small, "RES"); + + return (~Big | Small).isAllOnes(); + } + + bool k0(llvm::Value *V, std::string Val, size_t ExpectedWidth) { + if (!V || !V->getType() || !V->getType()->isIntegerTy() ) { + return false; + } + auto W = V->getType()->getIntegerBitWidth(); + + if (W != ExpectedWidth) { + return false; + } + + llvm::APInt Value(W, Val, 2); + if (ConstantInt *Con = llvm::dyn_cast(V)) { + auto X = Con->getUniqueInteger(); + return KnownBitImplies(Value, ~X); + } + auto Analyzed = llvm::KnownBits(W); + if (Instruction *I = llvm::dyn_cast(V)) { + DataLayout DL(I->getParent()->getParent()->getParent()); + computeKnownBits(V, Analyzed, DL, 4); + + // llvm::SmallVector Result; + // Analyzed.Zero.toString(Result, 2, false); + + auto b = KnownBitImplies(Value, Analyzed.Zero); +// llvm::errs() << "HERE: " << Result << ' ' << Val +// << ' ' << b << "\n\n"; + return b; + } + return false; + } + + bool k1(llvm::Value *V, std::string Val, size_t ExpectedWidth) { + if (!V || !V->getType() || !V->getType()->isIntegerTy()) { + return false; + } + auto W = V->getType()->getIntegerBitWidth(); + + if (ExpectedWidth != W) { + return false; + } + + llvm::APInt Value(W, Val, 2); + if (ConstantInt *Con = llvm::dyn_cast(V)) { + auto X = Con->getUniqueInteger(); + return KnownBitImplies(Value, X); + } + auto Analyzed = llvm::KnownBits(W); + if (Instruction *I = llvm::dyn_cast(V)) { + DataLayout DL(I->getParent()->getParent()->getParent()); + computeKnownBits(V, Analyzed, DL, 4); + return KnownBitImplies(Value, Analyzed.One); + } + return false; + } + + bool cr(llvm::Value *V, std::string L, std::string H) { + if (!V || !V->getType() || !V->getType()->isIntegerTy()) { + return false; + } + auto W = V->getType()->getIntegerBitWidth(); + llvm::ConstantRange R(llvm::APInt(W, L, 10), llvm::APInt(W, H, 10)); + if (ConstantInt *Con = llvm::dyn_cast(V)) { + return R.contains(Con->getUniqueInteger()); + } + auto CR = computeConstantRange(V, true); + return R.contains(CR); + } + + bool vdb(llvm::DemandedBits *DB, llvm::Instruction *I, std::string DBUnderApprox, size_t ExpectedWidth) { + + if (I->getType()->getIntegerBitWidth() != ExpectedWidth) { + return false; + } + + llvm::APInt V = llvm::APInt(I->getType()->getIntegerBitWidth(), DBUnderApprox, 2); + auto ComputedDB = DB->getDemandedBits(I); + +// llvm::errs() << DBUnderApprox << ' ' << llvm::toString(ComputedDB, 2, false) << ' ' +// << (V | ~ComputedDB).isAllOnes() << "\n"; + + // 0 in DBUnderApprox implies 0 in ComputedDB + return (V | ~ComputedDB).isAllOnes(); + } + + bool symk0bind(llvm::Value *V, llvm::Value *&Bind, IRBuilder *B) { + if (!V || !V->getType() || !V->getType()->isIntegerTy() ) { + return false; + } + + auto W = V->getType()->getIntegerBitWidth(); + + auto Analyzed = llvm::KnownBits(W); + if (Instruction *I = llvm::dyn_cast(V)) { + DataLayout DL(I->getParent()->getParent()->getParent()); + computeKnownBits(V, Analyzed, DL, 4); + if (Analyzed.Zero == 0) { + return false; + } + Bind = B->getInt(Analyzed.Zero); + return true; + } + + return false; + } + + bool symk1bind(llvm::Value *V, llvm::Value *&Bind, IRBuilder *B) { + if (!V || !V->getType() || !V->getType()->isIntegerTy() ) { + return false; + } + + auto W = V->getType()->getIntegerBitWidth(); + + auto Analyzed = llvm::KnownBits(W); + if (Instruction *I = llvm::dyn_cast(V)) { + DataLayout DL(I->getParent()->getParent()->getParent()); + computeKnownBits(V, Analyzed, DL, 4); + if (Analyzed.One == 0) { + return false; + } + Bind = B->getInt(Analyzed.One); + return true; + } + + return false; + } + + bool symk0test(llvm::Value *Bound, llvm::Value *OtherSymConst) { + llvm::Constant *BoundC = llvm::dyn_cast(Bound); + llvm::Constant *OtherC = llvm::dyn_cast(OtherSymConst); + + if (!BoundC || !OtherC) { + return false; + } + + // Width sanity check + if (BoundC->getType()->getIntegerBitWidth() != OtherC->getType()->getIntegerBitWidth()) { + return false; + } + + return KnownBitImplies(OtherC->getUniqueInteger(), ~BoundC->getUniqueInteger()); + } + + bool symk1test(llvm::Value *Bound, llvm::Value *OtherSymConst) { + llvm::Constant *BoundC = llvm::dyn_cast(Bound); + llvm::Constant *OtherC = llvm::dyn_cast(OtherSymConst); + + if (!BoundC || !OtherC) { + return false; + } + + // Width sanity check + if (BoundC->getType()->getIntegerBitWidth() != OtherC->getType()->getIntegerBitWidth()) { + return false; + } + + // llvm::errs() << "SymK1Test: " << llvm::toString(OtherC->getUniqueInteger(), 2, false) << ' ' + // << llvm::toString(BoundC->getUniqueInteger(), 2, false) << "\n"; + + // llvm::errs() << "Result: " << KnownBitImplies(OtherC->getUniqueInteger(), BoundC->getUniqueInteger()) << "\n"; + + return KnownBitImplies(OtherC->getUniqueInteger(), BoundC->getUniqueInteger()); + } + + bool symdb(llvm::DemandedBits *DB, llvm::Instruction *I, llvm::Value *&V, IRBuilder *B) { + auto ComputedDB = DB->getDemandedBits(I); + // Are there other non trivial failure modes? + if (ComputedDB == 0) { + return false; + } + V = B->getInt(ComputedDB); + llvm::errs() << "SymDB: " << llvm::toString(ComputedDB, 2, false) << "\n"; + return true; + } + + bool nz(llvm::Value *V) { + if (ConstantInt *Con = llvm::dyn_cast(V)) { + return !Con->getValue().isZero(); + } +// llvm::errs() << "NZ called on NC.\n"; + return false; + } + + bool nn(llvm::Value *V) { + if (ConstantInt *Con = llvm::dyn_cast(V)) { + return Con->getValue().isNonNegative(); + } +// llvm::errs() << "NN called on NC.\n"; + return false; + } + + bool neg(llvm::Value *V) { + if (ConstantInt *Con = llvm::dyn_cast(V)) { + return Con->getValue().isNegative(); + } +// llvm::errs() << "Neg called on NC.\n"; + return false; + } + + bool filter(const std::set &F, size_t id) { + if (Low != -1 && High != -1) { + llvm::errs() << Low << " " << id << " " << High << " " << (Low <= id && id < High) << "\n"; + return Low <= id && id < High; + } + if (F.empty()) return true; + return F.find(id) != F.end(); +// return true; + } + + struct Stats { + void hit(size_t opt, int cost) { + Hits[opt]++; + Cost[opt] = cost; + } +// void dcmiss(size_t opt) { +// DCMiss[opt]++; +// } + std::map Hits; + std::map Cost; +// std::map DCMiss; + void print() { + std::vector> Copy(Hits.size(), std::make_pair(0, 0)); + std::copy(Hits.begin(), Hits.end(), Copy.begin()); + std::sort(Copy.begin(), Copy.end(), + [](auto &A, auto &B) {return A.second > B.second;}); + llvm::errs() << "Hits begin:\n"; + size_t sum = 0; + for (auto &&P : Copy) { + sum += P.second; + int64_t cost; + if (Cost.find(P.first) == Cost.end()) { + cost = 1; + } else { + cost = Cost[P.first]; + } + llvm::errs() << "OptID " << P.first << " matched " << P.second << " time(s). Cost " << int(P.second) * cost << "\n"; + } + llvm::errs() << "Hits end. Total = " << sum << ".\n"; + } + }; + bool nc(llvm::Value *a, llvm::Value *b) { + if (llvm::isa(a) || llvm::isa(b)) return false; + return true; + } + + llvm::APInt V(llvm::Value *V) { + return llvm::dyn_cast(V)->getValue(); + } + llvm::APInt V(size_t Width, size_t Val) { + return llvm::APInt(Width, Val); + } + llvm::APInt V(size_t Width, std::string Val) { + return llvm::APInt(Width, Val, 2); + } + llvm::APInt V(llvm::Value *Ctx, std::string Val) { + return llvm::APInt(Ctx->getType()->getIntegerBitWidth(), Val, 2); + } + + llvm::APInt W(llvm::Value *Ctx) { + return llvm::APInt(Ctx->getType()->getIntegerBitWidth(), Ctx->getType()->getIntegerBitWidth()); + } + llvm::APInt W(llvm::Value *Ctx, size_t WidthOfWidth) { + return llvm::APInt(WidthOfWidth, Ctx->getType()->getIntegerBitWidth()); + } +} + +bool operator < (int x, const llvm::APInt &B) { + return llvm::APInt(B.getBitWidth(), x).ult(B); +} + +llvm::APInt shl(int A, llvm::APInt B) { + return llvm::APInt(B.getBitWidth(), A).shl(B); +} + +struct SouperCombine : public FunctionPass { + static char ID; + SouperCombine() : FunctionPass(ID) { + if (ListFile != "") { + std::ifstream in(ListFile); + size_t num; + while (in >> num) { + F.insert(num); + } + } + } + ~SouperCombine() { + St.print(); + } + + virtual void getAnalysisUsage(AnalysisUsage &Info) const override { +// Info.addRequired(); + Info.addRequired(); + Info.addRequired(); + Info.addRequired(); +// Info.addRequired(); +// Info.addRequired(); + } + + + bool runOnFunction(Function &F) override { + llvm::errs() << "SouperCombine: " << F.getName() << "\n"; + AssumptionCache AC(F); + + DT = new DominatorTree(F); + DB = new DemandedBits(F, AC, *DT); +//// LVI = +// auto DL = new DataLayout(F.getParent()); +// auto TLI = new TargetLibraryInfo(); +// new LazyValueInfo + + W.reserve(F.getInstructionCount()); + for (auto &BB : F) { + for (auto &&I : BB) { + if (I.getNumOperands() && + !isa(&I) && + !isa(&I) && + !isa(&I) && + !isa(&I) && + I.getType()->isIntegerTy()) { + W.push(&I); + } + } + } + IRBuilder Builder(F.getContext()); + // llvm::errs() << "Before:\n" << F; + auto r = run(Builder); + // llvm::errs() << "After:\n" << F; +// delete DB; +// delete DT; + return r; + } + + bool processInst(Instruction *I, IRBuilder &Builder) { + Builder.SetInsertPoint(I); +// llvm::errs() << "HERE0\n"; + if (auto V = getReplacement(I, &Builder)) { +// llvm::errs() << "HERE1\n"; + replace(I, V, Builder); +// llvm::errs() << "BAR\n"; + return true; + } +// llvm::errs() << "HERE2\n"; + return false; + } + void replace(Instruction *I, Value *V, IRBuilder &Builder) { + W.pushUsersToWorkList(*I); + I->replaceAllUsesWith(V); + } + bool run(IRBuilder &Builder) { + bool Changed = false; + while (auto I = W.removeOne()) { +// llvm::errs() << "FOO\n"; +// I->print(llvm::errs()); + Changed = processInst(I, Builder) || Changed; + } + return Changed; + } + + Value *getReplacement(llvm::Instruction *I, IRBuilder *B) { +// if (!I->hasOneUse()) { +// return nullptr; +// } +// llvm::errs() << "\nHERE REPL\n"; + // Interestingly commenting out ^this block + // slightly improves results. + // Implying this situation can be improved further + + // Autogenerated transforms + #include "gen.cpp.inc" + + return nullptr; + } + + struct SymConst { + SymConst(size_t Width, size_t Value, IRBuilder *B) : Width(Width), Value(Value), B(B) {} + size_t Width; + size_t Value; // TODO: APInt + IRBuilder *B; + + llvm::Value *operator()() { + return B->getIntN(Width, Value); + } + + llvm::Value *operator()(llvm::Value *Ctx) { + return B->getIntN(Ctx->getType()->getIntegerBitWidth(), Value); + } + }; + + SymConst C(size_t Width, size_t Value, IRBuilder *B) { + return SymConst(Width, Value, B); + } + + // Value *C(size_t Width, size_t Value, IRBuilder *B) { + // return B->getIntN(Width, Value); + // } + + + Value *C(llvm::APInt Value, IRBuilder *B) { + return ConstantInt::get(B->getIntNTy(Value.getBitWidth()), Value); +// return B->getIntN(Value.getBitWidth(), Value.getLimitedValue()); + } + + Type *T(size_t W, IRBuilder *B) { + return B->getIntNTy(W); + } + + Type *T(llvm::Value *V) { + return V->getType(); + } + + InstructionWorklist W; + util::Stats St; + DominatorTree *DT; + DemandedBits *DB; + LazyValueInfo *LVI; + std::set F; +}; +} + +char SouperCombine::ID = 0; +namespace llvm { +void initializeSouperCombinePass(llvm::PassRegistry &); +} + +INITIALIZE_PASS_BEGIN(SouperCombine, "souper-combine", "Souper super-optimizer pass", + false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(SouperCombine, "souper-combine", "Souper super-optimizer pass", false, + false) + +static struct Register { + Register() { + initializeSouperCombinePass(*llvm::PassRegistry::getPassRegistry()); + } +} X; + +static void registerSouperCombine( + const llvm::PassManagerBuilder &Builder, llvm::legacy::PassManagerBase &PM) { + PM.add(new SouperCombine); +} + +static llvm::RegisterStandardPasses +RegisterSouperOptimizer(llvm::PassManagerBuilder::EP_Peephole, + registerSouperCombine); diff --git a/tools/souper-check.cpp b/tools/souper-check.cpp index f45adffec..f28523d95 100644 --- a/tools/souper-check.cpp +++ b/tools/souper-check.cpp @@ -88,6 +88,74 @@ static cl::opt CheckAllGuesses("souper-check-all-guesses", cl::desc("Continue even after a valid RHS is found. (default=false)"), cl::init(false)); +static cl::opt Hash("hash", + cl::desc("Hash a trasnformation. (default=false)"), + cl::init(false)); + +static cl::opt FilterRedundant("filter-redundant", + cl::desc("Filter redundant transformations based on static hashing (default=false)"), + cl::init(false)); + + +size_t HashInt(size_t x) { + x = (x ^ (x >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); + x = (x ^ (x >> 27)) * UINT64_C(0x94d049bb133111eb); + x = x ^ (x >> 31); + return x; +} + +size_t HashInst(Inst *I, std::map &M, std::set &SeenVars) { + if (M.find(I) != M.end()) { + return M[I]; + } + + size_t Result = 0; + +// if (I->Name != "") { +// Result ^= std::hash()(I->Name); +// } + + Result ^= HashInt(I->K); + + if (I->K == Inst::Var) { + SeenVars.insert(I); + Result ^= HashInt(SeenVars.size()); + // TODO: DF attributes + M[I] = Result; + } + + if (I->K == Inst::Const) { + Result ^= HashInt(I->Val.getLimitedValue()); + } + + for (size_t i = 0; i < I->Ops.size(); ++i) { + size_t Weight = Inst::isCommutative(I->K) ? 0 : HashInt(i); + + Result ^= (Weight + HashInst(I->Ops[i], M, SeenVars)); + } + + M[I] = Result; + return Result; +} + +size_t HashRep(ParsedReplacement Rep) { + std::map M; + std::set SeenVars; + auto Result = HashInst(Rep.Mapping.LHS, M, SeenVars); + Result ^= 7* HashInst(Rep.Mapping.RHS, M, SeenVars); + // Just ^ produces weird conflicts for very different trees + + // Is this needed? + Result ^= HashInt(Rep.Mapping.LHS->Width); + + for (auto PC : Rep.PCs) { + Result ^= HashInst(PC.LHS, M, SeenVars); + Result ^= HashInst(PC.RHS, M, SeenVars); + } + + return Result; +} + int SolveInst(const MemoryBufferRef &MB, Solver *S) { InstContext IC; std::string ErrStr; @@ -120,7 +188,43 @@ int SolveInst(const MemoryBufferRef &MB, Solver *S) { unsigned Index = 0; int Ret = 0; int Success = 0, Fail = 0, Error = 0; + + std::unordered_set Hashes; + for (auto Rep : Reps) { + if (Hash) { + llvm::outs() << HashRep(Rep) << '\n'; + continue; + } + + if (FilterRedundant) { + auto Hash = HashRep(Rep); + if (Hashes.find(Hash) == Hashes.end()) { + Hashes.insert(Hash); + ReplacementContext RC; + Rep.printLHS(llvm::outs(), RC, true); + Rep.printRHS(llvm::outs(), RC, true); + } else { + llvm::outs() << "; Skipping redundant transformation.\n"; + std::string S; + llvm::raw_string_ostream Str(S); + ReplacementContext RC; + Rep.printLHS(Str, RC, true); + Rep.printRHS(Str, RC, true); + Str.flush(); + llvm::outs() << ';'; + for (size_t i = 0; i < S.length(); ++i) { + auto c = S[i]; + if ((c == '\n' || c == '\r') && i != S.length() - 1) { + llvm::outs() << c << ';'; + } else { + llvm::outs() << c; + } + } + } + continue; + } + if (isInferDFA()) { if (InferNeg) { bool Negative; @@ -254,11 +358,11 @@ int SolveInst(const MemoryBufferRef &MB, Solver *S) { } if (CheckAllGuesses) { - for (unsigned RI = 0 ; RI < RHSs.size() ; RI++) { - llvm::outs()<<"; result " << (RI + 1) <<":\n"; + for (unsigned RI = 0 ; RI < RHSs.size(); RI++) { + llvm::outs() << "; result " << (RI + 1) << ":\n"; ReplacementContext RC; PrintReplacementRHS(llvm::outs(), RHSs[RI], RC); - llvm::outs()<<"\n"; + llvm::outs() << "\n"; } } else { if (PrintRepl) { @@ -289,6 +393,7 @@ int SolveInst(const MemoryBufferRef &MB, Solver *S) { std::set ConstSet; souper::getConstants(Rep.Mapping.RHS, ConstSet); + souper::getConstants(Rep.Mapping.LHS, ConstSet); if (ConstSet.empty()) { llvm::outs() << "; No reservedconst found in RHS\n"; } else { @@ -301,13 +406,17 @@ int SolveInst(const MemoryBufferRef &MB, Solver *S) { } if (!ResultConstMap.empty()) { - ReplacementContext Context; - llvm::outs() << "; RHS inferred successfully\n"; - PrintReplacementRHS(llvm::outs(), Rep.Mapping.RHS, Context); + std::map InstCache; + std::map BlockCache; + Rep.Mapping.LHS = + getInstCopy(Rep.Mapping.LHS, IC, InstCache, BlockCache, &ResultConstMap, false, false); + Rep.Mapping.RHS = + getInstCopy(Rep.Mapping.RHS, IC, InstCache, BlockCache, &ResultConstMap, false, false); + Rep.print(llvm::outs(), true); ++Success; } else { ++Fail; - llvm::outs() << "; Failed to infer RHS\n"; + llvm::outs() << "; Failed to synthesize constant(s)\n"; } } } else if (TryDataflowPruning) { @@ -325,11 +434,37 @@ int SolveInst(const MemoryBufferRef &MB, Solver *S) { } } else if (InferAP) { bool FoundWeakest = false; - S->abstractPrecondition(Rep.BPCs, Rep.PCs, Rep.Mapping, IC, FoundWeakest); + std::vector> KBResults; + std::vector> CRResults; + S->abstractPrecondition(Rep.BPCs, Rep.PCs, Rep.Mapping, IC, FoundWeakest, KBResults, CRResults); if (!FoundWeakest) { llvm::outs() << "Failed to find WP.\n"; } - + ReplacementContext RC; + auto LHSStr = RC.printInst(Rep.Mapping.LHS, llvm::outs(), true); + llvm::outs() << "infer " << LHSStr << "\n"; + auto RHSStr = RC.printInst(Rep.Mapping.RHS, llvm::outs(), true); + llvm::outs() << "result " << RHSStr << "\n"; + for (size_t i = 0; i < KBResults.size(); ++i) { + for (auto It = KBResults[i].begin(); It != KBResults[i].end(); ++It) { + auto &&P = *It; + std::string dummy; + llvm::raw_string_ostream str(dummy); + auto VarStr = RC.printInst(P.first, str, false); + llvm::outs() << VarStr << " -> " << Inst::getKnownBitsString(P.second.Zero, P.second.One); + + auto Next = It; + Next++; + if (Next != KBResults[i].end()) { + llvm::outs() << " (and) "; + } + } + if (i == KBResults.size() - 1) { + llvm::outs() << "\n"; + } else { + llvm::outs() << "\n(or)\n"; + } + } } else { bool Valid; std::vector> Models; diff --git a/tools/souper.cpp b/tools/souper.cpp index bf783cf9e..059b1e85a 100644 --- a/tools/souper.cpp +++ b/tools/souper.cpp @@ -53,6 +53,10 @@ OutputFilename("o", cl::desc("Override output filename"), static cl::opt StaticProfile("souper-static-profile", cl::init(false), cl::desc("Static profiling of Souper optimizations (default=false)")); +static cl::opt HarvestOpts("harvest-opts", cl::init(false), + cl::desc("Harvest optimizations performed by InstCombine (default=false)")); + + static cl::opt Check("check", cl::desc("Check input for expected results"), cl::init(false)); @@ -106,6 +110,11 @@ int main(int argc, char **argv) { ExprBuilderContext EBC; CandidateMap CandMap; + if (HarvestOpts) { + HarvestAndPrintOpts(IC, EBC, M.get(), S.get()); + return 0; + } + AddModuleToCandidateMap(IC, EBC, CandMap, M.get()); if (Check) { diff --git a/tools/souper2llvm.cpp b/tools/souper2llvm.cpp index acac0017e..9e2ba3e31 100644 --- a/tools/souper2llvm.cpp +++ b/tools/souper2llvm.cpp @@ -41,29 +41,6 @@ static cl::opt OutputFilename( "o", cl::desc(""), cl::init("-")); -static std::vector -GetInputArgumentTypes(const InstContext &IC, llvm::LLVMContext &Context) { - const std::vector AllVariables = IC.getVariables(); - - std::vector ArgTypes; - ArgTypes.reserve(AllVariables.size()); - for (const Inst *const Var : AllVariables) - ArgTypes.emplace_back(Type::getIntNTy(Context, Var->Width)); - - return ArgTypes; -} - -static std::map GetArgsMapping(const InstContext &IC, - Function *F) { - std::map Args; - - const std::vector AllVariables = IC.getVariables(); - for (auto zz : llvm::zip(AllVariables, F->args())) - Args[std::get<0>(zz)] = &(std::get<1>(zz)); - - return Args; -}; - int Work(const MemoryBufferRef &MB) { InstContext IC; ReplacementContext RC; @@ -79,31 +56,7 @@ int Work(const MemoryBufferRef &MB) { llvm::LLVMContext Context; llvm::Module Module("souper.ll", Context); - - const std::vector ArgTypes = GetInputArgumentTypes(IC, Context); - const auto FT = llvm::FunctionType::get( - /*Result=*/Codegen::GetInstReturnType(Context, RepRHS.Mapping.RHS), - /*Params=*/ArgTypes, /*isVarArg=*/false); - - Function *F = Function::Create(FT, Function::ExternalLinkage, "fun", &Module); - - const std::map Args = GetArgsMapping(IC, F); - - BasicBlock *BB = BasicBlock::Create(Context, "entry", F); - - llvm::IRBuilder<> Builder(Context); - Builder.SetInsertPoint(BB); - - Value *RetVal = Codegen(Context, &Module, Builder, /*DT*/ nullptr, - /*ReplacedInst*/ nullptr, Args) - .getValue(RepRHS.Mapping.RHS); - - Builder.CreateRet(RetVal); - - // Validate the generated code, checking for consistency. - if (verifyFunction(*F, &llvm::errs())) - return 1; - if (verifyModule(Module, &llvm::errs())) + if (genModule(IC, RepRHS.Mapping.RHS, Module)) return 1; std::error_code EC; diff --git a/tools/souperweb-backend.cpp b/tools/souperweb-backend.cpp deleted file mode 100644 index 72c409f5b..000000000 --- a/tools/souperweb-backend.cpp +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2014 The Souper Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "llvm/AsmParser/Parser.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/raw_ostream.h" -#include "souper/Extractor/Solver.h" -#include "souper/SMTLIB2/Solver.h" -#include "souper/Parser/Parser.h" -#include "souper/Tool/CandidateMapUtils.h" -#include "souper/Tool/GetSolver.h" - -using namespace llvm; -using namespace souper; - -unsigned DebugLevel; - -extern "C" int boolector_main(int argc, char **argv); - -LLVMContext Context; - -void SolveIR(std::unique_ptr MB, Solver *S) { - SMDiagnostic Err; - if (std::unique_ptr M = - parseAssembly(MB->getMemBufferRef(), Err, Context)) { - InstContext IC; - ExprBuilderContext EBC; - CandidateMap CandMap; - - AddModuleToCandidateMap(IC, EBC, CandMap, M.get()); - - SolveCandidateMap(llvm::outs(), CandMap, S, IC, 0); - } else { - Err.print(0, llvm::errs(), false); - } -} - -void SolveInst(std::unique_ptr MB, Solver *S) { - InstContext IC; - std::string ErrStr; - - ParsedReplacement Rep = - ParseReplacement(IC, "", MB->getBuffer(), ErrStr); - if (!ErrStr.empty()) { - llvm::errs() << ErrStr << '\n'; - return; - } - - bool Valid; - std::vector> Models; - if (std::error_code EC = S->isValid(IC, Rep.BPCs, Rep.PCs, - Rep.Mapping, Valid, &Models)) { - llvm::errs() << EC.message() << '\n'; - return; - } - - if (Valid) { - llvm::outs() << "LGTM\n"; - } else { - llvm::outs() << "Invalid"; - if (!Models.empty()) { - llvm::outs() << ", e.g.\n\n"; - std::sort(Models.begin(), Models.end(), - [](const std::pair &A, - const std::pair &B) { - return A.first->Name < B.first->Name; - }); - for (const auto &M : Models) { - llvm::outs() << '%' << M.first->Name << " = " << M.second << '\n'; - } - } - } -} - -static llvm::cl::opt Action("action", llvm::cl::init("")); - -int main(int argc, char **argv) { - cl::ParseCommandLineOptions(argc, argv); - KVStore *KV; - std::unique_ptr S = GetSolver(KV); - - auto MB = MemoryBuffer::getSTDIN(); - if (MB) { - if (Action == "ir") { - SolveIR(std::move(*MB), S.get()); - } else { - SolveInst(std::move(*MB), S.get()); - } - } else { - llvm::errs() << MB.getError().message() << '\n'; - } -} diff --git a/tools/souperweb.go b/tools/souperweb.go deleted file mode 100644 index dc39b2064..000000000 --- a/tools/souperweb.go +++ /dev/null @@ -1,682 +0,0 @@ -// Copyright 2014 The Souper Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "bytes" - "crypto/sha1" - "encoding/json" - "html" - "io" - "math/rand" - "net" - "net/http" - "os" - "os/exec" - "regexp" - "sort" - "strings" - "syscall" - "text/template" - "time" - "fmt" - - "github.com/gomodule/redigo/redis" -) - -type solverQuery struct { - IsIR bool - Req string -} - -type solverReq struct { - solverQuery - dest io.Writer - result chan<- solverResp -} - -type solverResp struct { - result string - err error -} - -type solveHTMLResp struct { - HTML string - Key string -} - -type friendlyError struct { - msg string - err error -} - -func (e friendlyError) Error() string { - s := e.msg - if e.err != nil { - s += " (" + e.err.Error() + ")" - } - return s -} - -type rateReq struct { - ip string - ok chan<- bool -} - -type context struct { - homepage *template.Template - solverch chan solverReq - ratech chan rateReq - pool redis.Pool -} - -func (ctx *context) init() { - ctx.solverch = make(chan solverReq) - ctx.ratech = make(chan rateReq) - ctx.pool = redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - return redis.Dial("tcp", "127.0.0.1:6379") - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } - - go ctx.solverWorker() - go ctx.solverWorker() - go ctx.rateWorker() - - ctx.homepage = ctx.buildHomePage() -} - -func randFromQuery(query string) *rand.Rand { - hash := sha1.Sum([]byte(query)) - var hash8 int64 - for i, h := range hash[0:8] { - hash8 |= int64(h) << (8 * uint(i)) - } - source := rand.NewSource(hash8) - return rand.New(source) -} - -// avoid English words and transcription errors by excluding 01AEIOU -const keyChars = "23456789BCDFGHJKLMNPQRSTVWXYZ" - -func keyFromRand(r *rand.Rand) string { - var key string - for i := 0; i != 6; i++ { - key += string(keyChars[r.Int31n(int32(len(keyChars)))]) - } - return key -} - -var keyRE = regexp.MustCompile("^[" + keyChars + "]{6}$") - -func isKey(k string) bool { - return keyRE.MatchString(k) -} - -func (ctx *context) lookupByKey(conn redis.Conn, key string) (solverQuery, error) { - _, err := conn.Do("SELECT", 0) - if err != nil { - return solverQuery{}, err - } - - jsonquery, err := redis.Bytes(conn.Do("GET", key)) - if err != nil { - return solverQuery{}, friendlyError{"Invalid key", err} - } - - var query solverQuery - err = json.Unmarshal(jsonquery, &query) - return query, err -} - -func (ctx *context) getKey(conn redis.Conn, query solverQuery) (string, error) { - json, err := json.Marshal(query) - if err != nil { - return "", err - } - - rand := randFromQuery(query.Req) - - _, err = conn.Do("SELECT", 0) - if err != nil { - return "", err - } - - for { - _, err := conn.Do("WATCH", "key:"+string(json)) - if err != nil { - return "", err - } - - key, err := redis.String(conn.Do("GET", "key:"+string(json))) - if err != nil && err != redis.ErrNil { - return "", err - } - if key != "" { - _, err := conn.Do("UNWATCH") - if err != nil { - return "", err - } - return key, nil - } - - _, err = conn.Do("WATCH", key) - if err != nil { - return "", err - } - - key = keyFromRand(rand) - len, err := redis.Int(conn.Do("STRLEN", key)) - if len != 0 { - _, err := conn.Do("UNWATCH") - if err != nil { - return "", err - } - continue - } - - _, err = conn.Do("MULTI") - if err != nil { - return "", err - } - - _, err = conn.Do("SET", key, json) - if err != nil { - return "", err - } - - _, err = conn.Do("SET", "key:"+string(json), key) - if err != nil { - return "", err - } - - exec, err := conn.Do("EXEC") - if err != nil { - return "", err - } - if exec == nil { - continue - } - - return key, nil - } -} - -func (ctx *context) solverWorker() { - conn := ctx.pool.Get() - defer conn.Close() - - for r := range ctx.solverch { - var arg string - if r.IsIR { - arg = "-action=ir" - } else { - arg = "-action=inst" - } - cmd := exec.Command(os.Args[0]+"-backend", arg, os.Args[1]) - - cmd.Stdin = strings.NewReader(r.Req) - - var outb, errb bytes.Buffer - cmd.Stdout = &outb - cmd.Stderr = &errb - - var sys syscall.SysProcAttr - sys.Setpgid = true - cmd.SysProcAttr = &sys - - err := cmd.Start() - if err != nil { - os.Stdout.Write([]byte("Error invoking solver: " + err.Error() + "\n")) - r.result <- solverResp{"", friendlyError{"Error invoking solver", err}} - continue - } - - timeout := false - timer := time.AfterFunc(10*time.Second, func() { - syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) - timeout = true - }) - - err = cmd.Wait() - timer.Stop() - if timeout { - os.Stdout.Write([]byte("Solver timeout\n")) - r.result <- solverResp{"", friendlyError{"Solver timeout", nil}} - } else if err != nil { - r.result <- solverResp{"", friendlyError{"Error invoking solver", err}} - } else if errb.Len() != 0 { - r.result <- solverResp{"", friendlyError{errb.String(), nil}} - } else { - r.result <- solverResp{outb.String(), nil} - } - } -} - -func (ctx *context) rateWorker() { - visits := make(map[string][3]int64) - - daily := make(chan bool) - go func() { - for { - time.Sleep(86400 * time.Second) - daily <- true - } - }() - - for { - select { - case r := <-ctx.ratech: - now := time.Now().Unix() - vs := visits[r.ip] - ok := vs[2] <= now-10 - vs[0], vs[1], vs[2] = now, vs[0], vs[1] - visits[r.ip] = vs - r.ok <- ok - - case <-daily: - visits = make(map[string][3]int64) - } - } -} - -func (ctx *context) buildHomePage() *template.Template { - conn := ctx.pool.Get() - defer conn.Close() - - examples := map[string]solverQuery{ - "addnsw (IR)": {true, `define i32 @foo(i32 %x) { -entry: - %add = add nsw i32 %x, 1 - %cmp = icmp sgt i32 %add, %x - %conv = zext i1 %cmp to i32 - ret i32 %conv -} -`}, - "addnsw (inst)": {false, `%x:i32 = var -%add = addnsw %x, 1 -%cmp = slt %x, %add -cand %cmp 1 -`}, - "instcombine1": {false, `%a:i1 = var -%b:i32 = var -%ax:i32 = zext %a -%c = add %ax, %b - -%b1 = add %b, 1 -%c2 = select %a, %b1, %b - -cand %c %c2 -`}, - "simple": {false, `%a:i32 = var -%aa = add %a, %a -%2a = mul %a, 2 -cand %aa %2a -`}, - "simple-pc": {false, `%x:i32 = var -%2lx = slt 2, %x -pc %2lx 1 -%1lx = slt 1, %x -cand %1lx 1 -`}, - "simple-pc-invalid": {false, `%x:i32 = var -%2lx = slt 2, %x -pc %2lx 1 -%3lx = slt 3, %x -cand %3lx 1 -`}, - } - - exampleNames := make([]string, len(examples)) - i := 0 - for name, _ := range examples { - exampleNames[i] = name - i++ - } - sort.Strings(exampleNames) - - var b bytes.Buffer - b.WriteString(` - - - - - - -

souperweb

-Enter Souper instructions or LLVM IR into the box below and hit Submit.
-`) - - b.WriteString(` -
-Souper inst -LLVM IR
- - - - - -
-
-
-Examples:
-`) - for _, n := range exampleNames { - key, err := ctx.getKey(conn, examples[n]) - if err != nil { - panic(err.Error()) - } - b.WriteString(``) - b.WriteString(n) - b.WriteString(`
`) - } - b.WriteString(` -
- -
- -
{{.resp.HTML}}
- - -`) - return template.Must(template.New("homepage").Parse(b.String())) -} - -func (ctx *context) rootHandler(w http.ResponseWriter, r *http.Request) { - if len(r.URL.Path) > 1 { - key := r.URL.Path[1:] - isjson := false - if strings.HasSuffix(key, "/json") { - key = key[:len(key)-5] - isjson = true - } - - if !isKey(key) { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte("Not Found\n")) - return - } - - conn := ctx.pool.Get() - defer conn.Close() - - query, resp := ctx.buildLookupResp(conn, key, r) - if isjson { - json, _ := json.Marshal(map[string]interface{}{"query": query, "resp": resp}) - w.Write(json) - } else { - ctx.writeHomePage(w, query, resp) - } - } else { - ctx.writeHomePage(w, solverQuery{}, solveHTMLResp{}) - } -} - -func (ctx *context) getErrorResp(err error) (resp solveHTMLResp) { - var w bytes.Buffer - w.Write([]byte(`
`))
-	if ferr, ok := err.(friendlyError); ok {
-		if ferr.err != nil {
-			os.Stdout.Write([]byte(ferr.err.Error() + "\n"))
-		}
-		w.Write([]byte(html.EscapeString(ferr.msg)))
-	} else {
-		os.Stdout.Write([]byte(err.Error() + "\n"))
-		w.Write([]byte("Internal error"))
-	}
-	w.Write([]byte(`
`)) - resp.HTML = w.String() - return resp -} - -func (ctx *context) solve(conn redis.Conn, w io.Writer, query solverQuery, r *http.Request) error { - _, err := conn.Do("SELECT", 1) - if err != nil { - return err - } - - json, err := json.Marshal(query) - if err != nil { - return err - } - - result, err := redis.String(conn.Do("GET", json)) - if err != nil && err != redis.ErrNil { - return err - } - if err != redis.ErrNil { - w.Write([]byte("
"))
-		w.Write([]byte(html.EscapeString(result)))
-		w.Write([]byte("
")) - return nil - } - - os.Stdout.WriteString(time.Now().String() + " " + r.RemoteAddr + "\n") - - ch := make(chan bool) - - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return err - } - - ctx.ratech <- rateReq{host, ch} - - if <-ch { - respch := make(chan solverResp) - ctx.solverch <- solverReq{query, w, respch} - resp := <-respch - - if resp.err != nil { - return resp.err - } - - _, err := conn.Do("SET", json, resp.result) - if err != nil { - return err - } - - w.Write([]byte("
"))
-		w.Write([]byte(html.EscapeString(resp.result)))
-		w.Write([]byte("
")) - - return nil - } else { - return friendlyError{"Rate limit exceeded", nil} - } - - return nil -} - -func (ctx *context) buildLookupResp(conn redis.Conn, key string, r *http.Request) (solverQuery, solveHTMLResp) { - query, err := ctx.lookupByKey(conn, key) - if err != nil { - return solverQuery{}, ctx.getErrorResp(err) - } - - var html bytes.Buffer - err = ctx.solve(conn, &html, query, r) - if err != nil { - return query, ctx.getErrorResp(err) - } - - var resp solveHTMLResp - resp.HTML = html.String() - resp.Key = key - return query, resp -} - -func (ctx *context) buildSolveResp(conn redis.Conn, query solverQuery, r *http.Request) solveHTMLResp { - var html bytes.Buffer - err := ctx.solve(conn, &html, query, r) - if err != nil { - return ctx.getErrorResp(err) - } - - key, err := ctx.getKey(conn, query) - if err != nil { - return ctx.getErrorResp(err) - } - - var resp solveHTMLResp - resp.HTML = html.String() - resp.Key = key - return resp -} - -func (ctx *context) solveHandler(w http.ResponseWriter, r *http.Request) { - conn := ctx.pool.Get() - defer conn.Close() - - r.ParseForm() - - typeForm := r.Form["type"] - if len(typeForm) == 0 { - return - } - - dataForm := r.Form["data"] - if len(dataForm) == 0 { - return - } - - data := strings.Replace(dataForm[0], "\r\n", "\n", -1) - query := solverQuery{typeForm[0] == "ir", data} - - resp := ctx.buildSolveResp(conn, query, r) - - if strings.HasSuffix(r.URL.Path, "/json") { - json, _ := json.Marshal(resp) - w.Write(json) - } else if resp.Key != "" { - http.Redirect(w, r, "/"+resp.Key, http.StatusFound) - } else { - ctx.writeHomePage(w, query, resp) - } -} - -func (ctx *context) writeHomePage(w io.Writer, query solverQuery, resp solveHTMLResp) { - ctx.homepage.Execute(w, map[string]interface{}{"query": query, "resp": resp}) -} - -func main() { - var ctx context - ctx.init() - - http.HandleFunc("/", ctx.rootHandler) - http.HandleFunc("/solve", ctx.solveHandler) - http.HandleFunc("/solve/json", ctx.solveHandler) - - fmt.Println("Listening on port :8080") - http.ListenAndServe(":8080", nil) -} diff --git a/unittests/Codegen/CodegenTests.cpp b/unittests/Codegen/CodegenTests.cpp new file mode 100644 index 000000000..6039512b5 --- /dev/null +++ b/unittests/Codegen/CodegenTests.cpp @@ -0,0 +1,55 @@ +// Copyright 2014 The Souper Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llvm/Support/raw_ostream.h" +#include "souper/Codegen/Codegen.h" +#include "gtest/gtest.h" + +unsigned DebugLevel; + +using namespace souper; + +const BackendCost C5a { .C = { 1, 2, 3, 4, 5 }}; +const BackendCost C5b { .C = { 0, 0, 0, 0, 0 }}; +const BackendCost C5c { .C = { 1, 2, 3, 4, 4 }}; +const BackendCost C5d { .C = { 1, 2, 3, 4, 6 }}; + +TEST(CodegenTest, Compare) { + const struct { + BackendCost L, R; + } Tests[] = { + { C5b, C5a }, + { C5b, C5c }, + { C5b, C5d }, + }; + + for (const auto &T : Tests) { + EXPECT_EQ(compareCosts(T.L, T.R), true); + EXPECT_EQ(compareCosts(T.R, T.L), false); + } +} + +TEST(CodegenTest, Sort) { + const struct { + std::vector Costs; + BackendCost Best; + } Tests[] = { + { { C5a, C5b, C5c, C5d }, C5b }, + }; + + for (const auto &T : Tests) { + //std::sort(T.Costs.begin(), T.Costs.end(), compareCosts); + //EXPECT_EQ(T.WantError, ErrStr); + } +} diff --git a/utils/cache_infer.in b/utils/cache_infer.in index 452b53b27..a424c07d7 100755 --- a/utils/cache_infer.in +++ b/utils/cache_infer.in @@ -33,8 +33,10 @@ my $RAM_LIMIT = 4 * 1024 * 1024 * 1024; sub usage() { print <<"END"; Options: - -n number of CPUs to use (default=$NPROCS) - -tag add this tag to cache entries, and skip entries with it + -n number of CPUs to use (default=$NPROCS) + -tag add this tag to cache entries, and skip entries with it + -separate-files put each souper invocation's output into its own output file + -souper-debug-level pass this integer debug level to Souper -verbose END exit -1; @@ -42,18 +44,26 @@ END my $tag = "x"; my $VERBOSE = 0; +my $SAVE_TEMPS; +my $SOUPER_DEBUG = -1; GetOptions( "n=i" => \$NPROCS, "tag=s" => \$tag, "verbose" => \$VERBOSE, + "separate-files" => \$SAVE_TEMPS, + "souper-debug-level=i" => \$SOUPER_DEBUG, ) or usage(); my $OPTS = ""; +$OPTS .= "-souper-external-cache "; $OPTS .= "-souper-double-check "; $OPTS .= "-souper-dataflow-pruning "; $OPTS .= "-souper-enumerative-synthesis-max-instructions=1 "; +if ($SOUPER_DEBUG != -1) { + $OPTS .= "-souper-debug-level=${SOUPER_DEBUG} "; +} my $check = "@CMAKE_BINARY_DIR@/souper-check -solver-timeout=15"; @@ -70,11 +80,22 @@ sub infer($) { print STDERR "$cmd\n"; $fh->flush(); open(my $foo, '>>', 'fail.txt'); - open INF, "$cmd < $tmpfn |" or print $foo "$k\n";; + my $INF; + my $OFN = "tmp_$$.log"; + if ($SAVE_TEMPS) { + system "$cmd < $tmpfn > $OFN 2>&1"; + open my $OF, ">>$OFN" or die; + print $OF "\n\n$cmd\n\n"; + print $OF "$k\n\n"; + close $OF; + open $INF, "<$OFN" or die; + } else { + open $INF, "$cmd < $tmpfn |" or print $foo "$k\n"; + } my $ok = 0; my $failed = 0; my $output = ""; - while (my $line = ) { + while (my $line = <$INF>) { if ($line =~ /Failed/) { $failed = 1; next; @@ -85,7 +106,7 @@ sub infer($) { } $output .= $line; } - close INF; + close $INF; close $fh; unlink $tmpfn; # exit 1 unless $ok || $failed; @@ -93,7 +114,10 @@ sub infer($) { $red->ping || die "no server?"; $red->hset($k, "cache-infer-tag" => $tag); exit 1 unless $ok; + + ## FIXME -- we should have souper-check do this $red->hset($k, "result" => $output); + exit 0; } @@ -127,7 +151,7 @@ sub reset_status($) { sub status() { print "."; $status_cnt++; - my $pct = int(100.0*$status_cnt/$status_total); + my $pct = int(100.0 * $status_cnt/$status_total); if ($pct > $status_opct) { $status_opct = $pct; print "$pct %\n"; diff --git a/utils/cache_to_dir.py b/utils/cache_to_dir.py new file mode 100644 index 000000000..70cd30d8a --- /dev/null +++ b/utils/cache_to_dir.py @@ -0,0 +1,20 @@ +import redis +import sys +r = redis.Redis() +n = 0 +dir = sys.argv[1] +fails = 0 +for key in r.keys(): + try: + val = r.hgetall(key)[b'result'] + if val != b"": + s = key.decode('utf-8') + val.decode('utf-8') + f = open(dir + "/" + str(n) + '.opt', "w") + n = n + 1 + f.write(s) + f.close() + except KeyError: + fails = fails + 1 + +print("Number of failures = ", fails) + diff --git a/utils/magic.sh b/utils/magic.sh new file mode 100755 index 000000000..9166d671c --- /dev/null +++ b/utils/magic.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Last argument is assumed to be a file with multiple inputs +# separated by empty lines + +mkdir -p /tmp/scratch/ +rm -f /tmp/scratch/* + +infile=${@: -1} # Last argument +cmd=${*%${!#}} # All but the last argument + +csplit --quiet --prefix=/tmp/scratch/opt --suffix-format=%02d.txt $infile '/^cand/ +1' '{*}' + +for i in `ls -v /tmp/scratch/*`; do echo "echo \";$i \"&&" $cmd $i "&& echo";done > /tmp/cmdfile.txt +parallel --will-cite -k < /tmp/cmdfile.txt diff --git a/utils/mclang.in b/utils/mclang.in new file mode 100755 index 000000000..1a410ea3f --- /dev/null +++ b/utils/mclang.in @@ -0,0 +1,297 @@ +#!/usr/bin/env perl + +# Copyright 2014 The Souper Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +use warnings; +use strict; +use File::Temp; + +sub compiling() { + foreach my $arg (@ARGV) { + return 1 + if ($arg =~ /\.c$|\.cpp$|\.CC$|\.c\+\+$|\.cc$|\.cxx$|\.C$|\.c\+$/); + } + return 0; +} + +sub linkp() { + foreach my $arg (@ARGV) { + return 0 if ($arg eq "-S" || $arg eq "-c" || $arg eq "-shared"); + } + return 1; +} + +if ($0 =~ /clang$/) { + unshift @ARGV, "@LLVM_BINDIR@/clang"; +} elsif ($0 =~ /clang\+\+$/) { + unshift @ARGV, "@LLVM_BINDIR@/clang++"; +} else { + die "Didn't expect sclang to be invoked as '$0'"; +} + +foreach my $arg (@ARGV) { + if ($arg eq "-help" || $arg eq "--help") { + print <new(UNLINK => 1, DIR => $bitcodedir)->filename; + push @ARGV, "-c", "-emit-llvm"; + push @ARGV, "-o", "${tmp}.bc"; + my $ofn = "${tmp}.cmd"; + open OUTF, ">$ofn" or die; + foreach my $a (@ARGV) { + print OUTF "$a "; + } + print OUTF "\n"; + close OUTF; + open STDOUT, '>/dev/null'; + open STDERR, '>/dev/null'; + exec @ARGV; + die "bailing, exec failed"; +} + +my $souper = 1; +# this environment variable is a comma-separated list of source files that +# souper should avoid processing, for example because they trigger known bugs +my $skip_files = getenv("SOUPER_SKIP_FILES"); +if (defined $skip_files) { + my %skips; + foreach my $f (split(',', $skip_files)) { + $skips{$f} = 1; + } + foreach my $a (@ARGV) { + $souper = 0 if ($skips{$a}); + } +} + +$souper = 0 if getenv("SOUPER_NO_SOUPER"); +$souper = 0 unless compiling(); + +if ($souper) { + push @ARGV, ( + "-Xclang", "-load", + "-mllvm", "-solver-timeout=15", + ); +} + +if ($souper) { + push @ARGV, ("-Xclang", getenv("SOUPER_MATCHER_LIB")); + # push @ARGV, ("-Xclang --souper-combine")); +} + +# if (getenv("SOUPER_DYNAMIC_PROFILE_ALL")) { +# if ($souper) { +# push @ARGV, ("-Xclang", "@SOUPER_PASS_PROFILE_ALL@"); +# } +# } else { +# if ($souper) { +# push @ARGV, ("-Xclang", "@SOUPER_PASS@"); +# } +# } + +# if (getenv("SOUPER_DEBUG") && $souper) { +# push @ARGV, ("-mllvm", "-souper-debug-level=".getenv("SOUPER_DEBUG")); +# } + +# if (getenv("SOUPER_EXPLOIT_BLOCKPCS") && $souper) { +# push @ARGV, ("-mllvm", "-souper-exploit-blockpcs"); +# } + +# if (!getenv("SOUPER_NO_EXTERNAL_CACHE") && $souper) { +# push @ARGV, ("-mllvm", "-souper-external-cache"); +# } + +# if (getenv("SOUPER_NO_INFER") && $souper) { +# push @ARGV, ("-mllvm", "-souper-no-infer"); +# } + +# if (getenv("SOUPER_FIRST_OPT") && $souper) { +# push @ARGV, ("-mllvm", "-souper-first-opt=".$ENV{"SOUPER_FIRST_OPT"}); +# } + +# if (getenv("SOUPER_LAST_OPT") && $souper) { +# push @ARGV, ("-mllvm", "-souper-last-opt=".$ENV{"SOUPER_LAST_OPT"}); +# } + +# if (getenv("SOUPER_STATIC_PROFILE") && $souper) { +# push @ARGV, ("-mllvm", "-souper-static-profile"); +# } + +# if (getenv("SOUPER_DYNAMIC_PROFILE") || +# getenv("SOUPER_DYNAMIC_PROFILE_ALL") && $souper) { +# push @ARGV, ("-g", "-mllvm", "-souper-dynamic-profile"); +# } + +# if (getenv("SOUPER_USE_ALIVE") && $souper) { +# push @ARGV, ("-mllvm", "-souper-use-alive"); +# } + +# if (getenv("SOUPER_DOUBLE_CHECK") && $souper) { +# push @ARGV, ("-mllvm", "-souper-double-check"); +# } + +# if (getenv("SOUPER_INFER_INST") && $souper) { +# push @ARGV, ("-mllvm", "-souper-infer-inst"); +# } + +# if (getenv("SOUPER_ENUMERATIVE_SYNTHESIS_MAX_INSTS") && $souper) { +# push @ARGV, ("-mllvm", "-souper-enumerative-synthesis-max-instructions=".$ENV{"SOUPER_ENUMERATIVE_SYNTHESIS_MAX_INSTS"}); +# } + +# if (getenv("SOUPER_DATAFLOW_PRUNING") && $souper) { +# push @ARGV, ("-mllvm", "-souper-dataflow-pruning"); +# } + +# if (getenv("SOUPER_REDIS_PORT") && $souper) { +# push @ARGV, ("-mllvm", "-souper-redis-port=".$ENV{"SOUPER_REDIS_PORT"}); +# } + +# if (getenv("SOUPER_STATS") && $souper) { +# push @ARGV, ("-mllvm", "-stats"); +# } + +# if (getenv("SOUPER_TIME_REPORT") && $souper) { +# push @ARGV, ("-ftime-report"); +# } + +# if (getenv("SOUPER_VERIFY") && $souper) { +# push @ARGV, ("-mllvm", "-souper-verify"); +# } + +# if (getenv("SOUPER_NO_HARVEST_DATAFLOW_FACTS") && $souper) { +# push @ARGV, ("-mllvm", "-souper-harvest-dataflow-facts=false"); +# } + +# if (getenv("SOUPER_HARVEST_USES") && $souper) { +# push @ARGV, ("-mllvm", "-souper-harvest-uses"); +# } + +if (getenv("SOUPER_DISABLE_LLVM_PEEPHOLES") && $souper) { + push @ARGV, ("-mllvm", "-disable-peepholes"); +} + +# if ((getenv("SOUPER_DYNAMIC_PROFILE") || +# getenv("SOUPER_DYNAMIC_PROFILE_ALL")) && linkp() && $souper) { +# push @ARGV, ("@PROFILE_LIBRARY@", "@HIREDIS_LIBRARY@"); +# } + +if (getenv("SOUPER_DEBUG") && $souper) { + foreach my $arg (@ARGV) { + print STDERR "$arg "; + } + print STDERR "\n"; +} + +foreach my $e (keys %ENV) { + next unless $e =~ /^SOUPER_/; + die "unexpected Souper-related environment variable '${e}'" unless $whitelist{$e}; +} + +exec @ARGV; diff --git a/utils/parallel.sh b/utils/parallel.sh new file mode 100755 index 000000000..864ead744 --- /dev/null +++ b/utils/parallel.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Last argument is assumed to be a dir with multiple inputs + +indir=${@: -1} # Last argument +cmd=${*%${!#}} # All but the last argument + +mkdir -p "${indir}r" +mkdir -p "${indir}d" + +rm "${indir}r"/* +cp "${indir}"/* "${indir}r" + +mkdir -p "${indir}t" +rm "${indir}t"/* + +mkdir -p "${indir}d" +rm "${indir}d"/* + +for i in `ls -v $indir/*`; do echo "timeout 300" $cmd $i " > " ${indir}t/`basename $i` " 2> " ${indir}d/`basename $i` " && cp " ${indir}t/`basename $i` ${indir}r/ ;done > /tmp/cmdfile.txt + +# for i in `ls -v $indir/*`; do echo "timeout 300" $cmd $i " > " ${indir}r/`basename $i`;done > /tmp/cmdfile.txt + +parallel --will-cite < /tmp/cmdfile.txt