Skip to content

Commit

Permalink
Merge pull request #3059 from stan-dev/arm64-tests
Browse files Browse the repository at this point in the history
Fixes for ARM64
  • Loading branch information
WardBrian authored May 6, 2024
2 parents 9202f1f + 7326216 commit b2b2ad8
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 18 deletions.
19 changes: 15 additions & 4 deletions make/compiler_flags
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ endif

## Set OS specific library filename extensions
ifeq ($(OS),Windows_NT)
WINARM64 := $(shell echo | $(CXX) -E -dM - | findstr __aarch64__)
LIBRARY_SUFFIX ?= .dll
STR_SEARCH ?= findstr
endif

ifeq ($(OS),Darwin)
LIBRARY_SUFFIX ?= .dylib
STR_SEARCH ?= grep
endif

ifeq ($(OS),Linux)
LIBRARY_SUFFIX ?= .so
STR_SEARCH ?= grep
endif

## Set default compiler
Expand All @@ -42,6 +44,11 @@ ifeq (default,$(origin CXX))
endif
endif

ARM64_CHECK := $(shell echo | $(CXX) -E -dM - | $(STR_SEARCH) __aarch64__)
ifneq ($(ARM64_CHECK),)
ARM64 = true
endif

# Detect compiler type
# - CXX_TYPE: {gcc, clang, mingw32-gcc, other}
# - CXX_MAJOR: major version of CXX
Expand Down Expand Up @@ -164,7 +171,7 @@ ifeq ($(OS),Windows_NT)

make/ucrt:
pound := \#
UCRT_STRING := $(shell echo '$(pound)include <windows.h>' | $(CXX) -E -dM - | findstr _UCRT)
UCRT_STRING := $(shell echo '$(pound)include <windows.h>' | $(CXX) -E -dM - | $(STR_SEARCH) _UCRT)
ifneq (,$(UCRT_STRING))
IS_UCRT ?= true
else
Expand Down Expand Up @@ -211,6 +218,10 @@ endif
## makes reentrant version lgamma_r available from cmath
CXXFLAGS_OS += -D_REENTRANT

ifeq ($(ARM64), true)
CXXFLAGS_OS += -ffp-contract=off
endif

## silence warnings occuring due to the TBB and Eigen libraries
CXXFLAGS_WARNINGS += -Wno-ignored-attributes

Expand Down Expand Up @@ -275,7 +286,7 @@ endif
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,--disable-new-dtags

# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
ifeq ($(WINARM64),)
ifneq ($(OS), Windows_NT)
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_LIB)"
endif

Expand All @@ -299,7 +310,7 @@ CXXFLAGS_TBB ?= -I $(TBB)/include
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)

# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
ifeq ($(WINARM64),)
ifneq ($(OS), Windows_NT)
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)"
endif
LDLIBS_TBB ?= -ltbb
Expand Down
9 changes: 5 additions & 4 deletions make/libraries
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,11 @@ ifeq (Windows_NT, $(OS))
TBB_CXXFLAGS += -D_UCRT
endif
# TBB does not have assembly code for Windows ARM64, so we need to use GCC builtins
ifneq ($(WINARM64),)
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
endif
ifeq ($(ARM64),true)
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
WINARM64 = true
endif
SH_CHECK := $(shell command -v sh 2>/dev/null)
ifdef SH_CHECK
WINDOWS_HAS_SH ?= true
Expand Down
6 changes: 6 additions & 0 deletions stan/math/prim/fun/inv_sqrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ inline auto inv_sqrt(const Container& x) {
template <typename Container, require_not_var_matrix_t<Container>* = nullptr,
require_container_st<std::is_arithmetic, Container>* = nullptr>
inline auto inv_sqrt(const Container& x) {
// Eigen 3.4.0 has precision issues on ARM64 with vectorised rsqrt
// Resolved in current master branch, below can be removed on next release
#ifdef __aarch64__
return apply_scalar_unary<inv_sqrt_fun, Container>::apply(x);
#else
return apply_vector_unary<Container>::apply(
x, [](const auto& v) { return v.array().rsqrt(); });
#endif
}

} // namespace math
Expand Down
6 changes: 4 additions & 2 deletions test/unit/math/fwd/core/std_numeric_limits_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ TEST(AgradFwdNumericLimits, All_Fvar) {
EXPECT_FALSE(std::numeric_limits<fvar<double> >::traps);
EXPECT_FALSE(std::numeric_limits<fvar<fvar<double> > >::traps);

EXPECT_FALSE(std::numeric_limits<fvar<double> >::tinyness_before);
EXPECT_FALSE(std::numeric_limits<fvar<fvar<double> > >::tinyness_before);
EXPECT_EQ(std::numeric_limits<fvar<double> >::tinyness_before,
std::numeric_limits<double>::tinyness_before);
EXPECT_EQ(std::numeric_limits<fvar<fvar<double> > >::tinyness_before,
std::numeric_limits<double>::tinyness_before);

EXPECT_TRUE(std::numeric_limits<fvar<double> >::round_style);
EXPECT_TRUE(std::numeric_limits<fvar<fvar<double> > >::round_style);
Expand Down
6 changes: 4 additions & 2 deletions test/unit/math/mix/core/std_numeric_limits_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ TEST(AgradMixNumericLimits, All_Fvar) {
EXPECT_FALSE(std::numeric_limits<fvar<var> >::traps);
EXPECT_FALSE(std::numeric_limits<fvar<fvar<var> > >::traps);

EXPECT_FALSE(std::numeric_limits<fvar<var> >::tinyness_before);
EXPECT_FALSE(std::numeric_limits<fvar<fvar<var> > >::tinyness_before);
EXPECT_EQ(std::numeric_limits<fvar<var> >::tinyness_before,
std::numeric_limits<double>::tinyness_before);
EXPECT_EQ(std::numeric_limits<fvar<fvar<var> > >::tinyness_before,
std::numeric_limits<double>::tinyness_before);

EXPECT_TRUE(std::numeric_limits<fvar<var> >::round_style);
EXPECT_TRUE(std::numeric_limits<fvar<fvar<var> > >::round_style);
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/prim/fun/offset_multiplier_transform_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ TEST(prob_transform, offset_multiplier_constrain_matrix) {
EXPECT_FLOAT_EQ(result(i), stan::math::offset_multiplier_constrain(
x(i), offsetd, sigma(i), lp1));
}
EXPECT_EQ(lp0, lp1);
EXPECT_FLOAT_EQ(lp0, lp1);
auto x_free = stan::math::offset_multiplier_free(result, offsetd, sigma);
for (size_t i = 0; i < x.size(); ++i) {
EXPECT_FLOAT_EQ(x.coeff(i), x_free.coeff(i));
Expand All @@ -211,7 +211,7 @@ TEST(prob_transform, offset_multiplier_constrain_matrix) {
EXPECT_FLOAT_EQ(result(i), stan::math::offset_multiplier_constrain(
x(i), offset(i), sigma(i), lp1));
}
EXPECT_EQ(lp0, lp1);
EXPECT_FLOAT_EQ(lp0, lp1);
auto x_free = stan::math::offset_multiplier_free(result, offset, sigma);
for (size_t i = 0; i < x.size(); ++i) {
EXPECT_FLOAT_EQ(x.coeff(i), x_free.coeff(i));
Expand Down
7 changes: 5 additions & 2 deletions test/unit/math/prim/prob/neg_binomial_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,11 @@ TEST(ProbDistributionsNegBinomial, chiSquareGoodnessFitTest3) {

double chi = 0;

for (int j = 0; j < K; j++)
chi += ((bin[j] - expect[j]) * (bin[j] - expect[j]) / expect[j]);
for (int j = 0; j < K; j++) {
if (expect[j] != 0) {
chi += ((bin[j] - expect[j]) * (bin[j] - expect[j]) / expect[j]);
}
}

EXPECT_LT(chi, boost::math::quantile(boost::math::complement(mydist, 1e-6)));
}
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/test_ad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,7 @@ void expect_common_unary_vectorized(const F& f) {
for (double x1 : args)
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
auto int_args = internal::common_int_args();
for (int x1 : args)
for (int x1 : int_args)
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
}

Expand Down Expand Up @@ -2022,7 +2022,7 @@ void expect_common_unary_vectorized(const F& f) {
for (double x1 : args)
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
auto int_args = internal::common_int_args();
for (int x1 : args)
for (int x1 : int_args)
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
for (auto x1 : common_complex())
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
Expand Down

0 comments on commit b2b2ad8

Please sign in to comment.