diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0423987 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,82 @@ +name: CI + +# Trigger the workflow on push or pull request, but only for the master branch +on: + pull_request: + push: + branches: [master] + +jobs: + build: + name: Building on ${{ matrix.os }} with ghc-${{ matrix.ghc }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + include: + - os: ubuntu-latest + cabal: latest + ghc: "8.10.7" + - os: macos-latest + cabal: latest + ghc: "8.10.7" + steps: + - uses: actions/checkout@v2 + - uses: haskell/actions/setup@v1 + name: Setup Haskell + with: + ghc-version: ${{ matrix.ghc }} + cabal-version: ${{ matrix.cabal }} + - uses: actions/cache@v3 + name: Cache ~/.cabal/store + with: + path: ~/.cabal/store + key: ${{ runner.os }}-${{ matrix.ghc }}-cabal + + - name: Install system dependencies (Linux) + if: matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-latest' + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + cmake ninja-build g++ libboost-all-dev + - name: Install system dependencies (MacOS) + if: matrix.os == 'macos-latest' + run: | + brew install cmake ninja boost + + - uses: actions/cache@v3 + name: Cache /opt/symengine + id: cache-symengine + with: + path: /opt/symengine + key: ${{ runner.os }}-symengine + + - name: Build C++ code + if: steps.cache-symengine.outputs.cache-hit != 'true' + run: | + cd $GITHUB_WORKSPACE + git clone --depth=1 https://github.com/symengine/symengine + cd symengine + cmake -B build -G Ninja \ + -DCMAKE_INSTALL_PREFIX=/opt/symengine \ + -DCMAKE_BUILD_TYPE=Debug \ + -DBUILD_SHARED_LIBS=ON \ + -DWITH_SYMENGINE_THREAD_SAFE=ON \ + -DBUILD_TESTS=OFF \ + -DBUILD_BENCHMARKS=OFF \ + -DINTEGER_CLASS=boostmp + cmake --build build + cmake --build build --target test + sudo cmake --build build --target install + + - name: Build Haskell code + run: | + echo "package symengine" >> cabal.project.local + echo " extra-lib-dirs: /opt/symengine/lib" >> cabal.project.local + echo " extra-include-dirs: /opt/symengine/include" >> cabal.project.local + cabal build + + - name: Test + run: | + export LD_LIBRARY_PATH=/opt/symengine/lib:$LD_LIBRARY_PATH + export DYLD_LIBRARY_PATH=/opt/symengine/lib:$DYLD_LIBRARY_PATH + cabal test --test-show-details=direct diff --git a/.gitignore b/.gitignore index d01581c..6393707 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,6 @@ tags /*.iml /src/highlight.js /src/style.css -/_site/ \ No newline at end of file +/_site/.ghc.environment.* +result +result-1 diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index bf75be4..0000000 --- a/.travis.yml +++ /dev/null @@ -1,280 +0,0 @@ -# Copy these contents into the root directory of your Github project in a file -# named .travis.yml - -# Use new container infrastructure to enable caching -sudo: false - -# Choose a lightweight base image; we provide our own build tools. -language: c - -addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - libgmp-dev - - libmpfr-dev - - libmpc-dev - - binutils-dev - - g++-4.7 - - gcc - -# Caching so the next build will be fast too. -cache: - directories: - - $HOME/.ghc - - $HOME/.cabal - - $HOME/.stack - -# The different configurations we want to test. We have BUILD=cabal which uses -# cabal-install, and BUILD=stack which uses Stack. More documentation on each -# of those below. -# -# We set the compiler values here to tell Travis to use a different -# cache file per set of arguments. -# -# If you need to have different apt packages for each combination in the -# matrix, you can use a line such as: -# addons: {apt: {packages: [libfcgi-dev,libgmp-dev]}} -matrix: - include: - # We grab the appropriate GHC and cabal-install versions from hvr's PPA. See: - # https://github.com/hvr/multi-ghc-travis - #- env: BUILD=cabal GHCVER=7.0.4 CABALVER=1.16 HAPPYVER=1.19.5 ALEXVER=3.1.7 - # compiler: ": #GHC 7.0.4" - # addons: {apt: {packages: [cabal-install-1.16,ghc-7.0.4,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - #- env: BUILD=cabal GHCVER=7.2.2 CABALVER=1.16 HAPPYVER=1.19.5 ALEXVER=3.1.7 - # compiler: ": #GHC 7.2.2" - # addons: {apt: {packages: [cabal-install-1.16,ghc-7.2.2,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - env: BUILD=cabal GHCVER=7.4.2 CABALVER=1.16 HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC 7.4.2" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-1.16,ghc-7.4.2,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - env: BUILD=cabal GHCVER=7.6.3 CABALVER=1.16 HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC 7.6.3" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-1.16,ghc-7.6.3,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - env: BUILD=cabal GHCVER=7.8.4 CABALVER=1.18 HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC 7.8.4" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-1.18,ghc-7.8.4,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - env: BUILD=cabal GHCVER=7.10.3 CABALVER=1.22 HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC 7.10.3" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-1.22,ghc-7.10.3,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - # Build with the newest GHC and cabal-install. This is an accepted failure, - # see below. - - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC HEAD" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-head,ghc-head,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - # The Stack builds. We can pass in arbitrary Stack arguments via the ARGS - # variable, such as using --stack-yaml to point to a different file. - - env: BUILD=stack ARGS=" " - compiler: ": #stack default" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, ghc-7.10.3], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - - env: BUILD=stack ARGS="--resolver lts-2" - compiler: ": #stack 7.8.4" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, ghc-7.8.4], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - - env: BUILD=stack ARGS="--resolver lts-3" - compiler: ": #stack 7.10.2" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, ghc-7.10.2], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - - env: BUILD=stack ARGS="--resolver lts-5" - compiler: ": #stack 7.10.3" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, ghc-7.10.3], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - # Nightly builds are allowed to fail - - env: BUILD=stack ARGS="--resolver nightly" - compiler: ": #stack nightly" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, libgmp-dev]}} - - # Build on OS X in addition to Linux - - env: BUILD=stack ARGS=" " - compiler: ": #stack default osx" - os: osx - - - env: BUILD=stack ARGS="--resolver lts-2" - compiler: ": #stack 7.8.4 osx" - os: osx - - - env: BUILD=stack ARGS="--resolver lts-3" - compiler: ": #stack 7.10.2 osx" - os: osx - - - env: BUILD=stack ARGS="--resolver lts-5" - compiler: ": #stack 7.10.3 osx" - os: osx - - - env: BUILD=stack ARGS="--resolver nightly" - compiler: ": #stack nightly osx" - os: osx - - allow_failures: - - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7 - - env: BUILD=stack ARGS="--resolver nightly" - - -install: -# SYMENGINE INSTALL PHASE -# ----------------------- -# Download and install SymEngine -- cd $HOME && git clone https://github.com/symengine/symengine.git - -# Setup C compiler variables -# The reason we need to do this is because our build system is haskell, so none of the -# C variables are set. -- | - if [ `uname` = "Darwin" ] - then - export CC="clang" && export CXX="clang++" - else - export CC="gcc" && export CXX="g++-4.7" - fi - -- | - set -ex - export TEST_CPP="no" - cd $HOME/symengine - source bin/install_travis.sh - bin/test_travis.sh - -# EXPORT PHASE -# ------------ -# Export environment variables related to SymEngine's library and includes -- | - # $our_install_dir is exported by test_travis.sh from symengine - set -ex - export SYMENGINE_LIB_ARGS="--extra-lib-dirs=$our_install_dir/lib/" - export SYMENGINE_INCLUDE_ARGS="--extra-include-dirs=$our_install_dir/include/" - cd $TRAVIS_BUILD_DIR - -# GHC INSTALL PHASE -# ----------------- -# Install Stack if needed, then install GHC - -# Using compiler above sets CC to an invalid value, so unset it -- unset CC -# We want to always allow newer versions of packages when building on GHC HEAD -- CABALARGS="" -- if [ "x$GHCVER" = "xhead" ]; then CABALARGS=--allow-newer; fi - -# Download and unpack the stack executable -- export PATH=/opt/ghc/$GHCVER/bin:/opt/cabal/$CABALVER/bin:$HOME/.local/bin:/opt/alex/$ALEXVER/bin:/opt/happy/$HAPPYVER/bin:$HOME/.cabal/bin:$PATH -- mkdir -p ~/.local/bin -- | - if [ `uname` = "Darwin" ] - then - travis_retry curl --insecure -L https://www.stackage.org/stack/osx-x86_64 | tar xz --strip-components=1 --include '*/stack' -C ~/.local/bin - else - travis_retry curl -L https://www.stackage.org/stack/linux-x86_64 | tar xz --wildcards --strip-components=1 -C ~/.local/bin '*/stack' - fi - # Use the more reliable S3 mirror of Hackage - mkdir -p $HOME/.cabal - echo 'remote-repo: hackage.haskell.org:http://hackage.fpcomplete.com/' > $HOME/.cabal/config - echo 'remote-repo-cache: $HOME/.cabal/packages' >> $HOME/.cabal/config - - if [ "$CABALVER" != "1.16" ] - then - echo 'jobs: $ncpus' >> $HOME/.cabal/config - fi - -# Get the list of packages from the stack.yaml file -- PACKAGES=$(stack --install-ghc query locals | grep '^ *path' | sed 's@^ *path:@@') - -- echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo '?')]" -- if [ -f configure.ac ]; then autoreconf -i; fi -- | - set -ex - case "$BUILD" in - stack) - stack --no-terminal --install-ghc $ARGS test --only-dependencies - ;; - cabal) - cabal --version - travis_retry cabal update - cabal install --only-dependencies --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES $SYMENGINE_LIB_ARGS $SYMENGINE_INCLUDE_ARGS - ;; - esac - set +ex - - -script: -- | - set -ex - case "$BUILD" in - stack) - stack --no-terminal $ARGS $SYMENGINE_LIB_ARGS $SYMENGINE_INCLUDE_ARGS test --haddock --no-haddock-deps - ;; - cabal) - cabal update - cabal configure --enable-tests --enable-benchmarks --ghc-options=-O0 $CABALARGS $SYMENGINE_LIB_ARGS $SYMENGINE_INCLUDE_ARGS - cabal build - # run the test suite - cabal test --show-details=always - - # install after building the library - # cabal install --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES $SYMENGINE_LIB_ARGS $SYMENGINE_INCLUDE_ARGS - - ORIGDIR=$(pwd) - for dir in $PACKAGES - do - cd $dir - cabal check || [ "$CABALVER" == "1.16" ] - cabal sdist - SRC_TGZ=$(cabal info . | awk '{print $2;exit}').tar.gz && \ - (cd dist && cabal install --force-reinstalls "$SRC_TGZ") - cd $ORIGDIR - done - ;; - esac - set +ex diff --git a/Setup.hs b/Setup.hs deleted file mode 100644 index 9a994af..0000000 --- a/Setup.hs +++ /dev/null @@ -1,2 +0,0 @@ -import Distribution.Simple -main = defaultMain diff --git a/cabal.project.local b/cabal.project.local new file mode 100644 index 0000000..7f529ad --- /dev/null +++ b/cabal.project.local @@ -0,0 +1,5 @@ +ignore-project: False +write-ghc-environment-files: always +tests: True +test-options: "--color" +test-show-details: streaming diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..92c0573 --- /dev/null +++ b/flake.lock @@ -0,0 +1,77 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1689068808, + "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nix-filter": { + "locked": { + "lastModified": 1687178632, + "narHash": "sha256-HS7YR5erss0JCaUijPeyg2XrisEb959FIct3n2TMGbE=", + "owner": "numtide", + "repo": "nix-filter", + "rev": "d90c75e8319d0dd9be67d933d8eb9d0894ec9174", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "nix-filter", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1692174805, + "narHash": "sha256-xmNPFDi/AUMIxwgOH/IVom55Dks34u1g7sFKKebxUm0=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "caac0eb6bdcad0b32cb2522e03e4002c8975c62e", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nix-filter": "nix-filter", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..516f393 --- /dev/null +++ b/flake.nix @@ -0,0 +1,91 @@ +{ + description = "symengine/symengine.hs: SymEngine symbolic mathematics engine for Haskell"; + + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + nix-filter.url = "github:numtide/nix-filter"; + }; + + outputs = { nixpkgs, flake-utils, nix-filter, ... }: + let + src = nix-filter.lib { + root = ./.; + include = [ + "src" + "test" + "symengine.cabal" + "README.md" + "LICENSE" + ]; + }; + overlay = self: super: { + symengine = super.symengine.overrideAttrs (attrs: rec { + version = "0.10.1"; + src = self.fetchFromGitHub { + owner = attrs.pname; + repo = attrs.pname; + rev = "v${version}"; + sha256 = "sha256-qTu0vS9K6rrr/0SXKpGC9P1QSN/AN7hyO/4DrGvhxWM="; + }; + cmakeFlags = (attrs.cmakeFlags or [ ]) ++ [ + "-DCMAKE_BUILD_TYPE=Debug" + "-DBUILD_SHARED_LIBS=ON" + ]; + }); + + haskell = super.haskell // { + packageOverrides = hself: hsuper: { + symengine = (hself.callCabal2nix "symengine" src { + inherit (self) symengine; + mpc = self.libmpc; + }); + }; + }; + }; + + pkgsFor = system: import nixpkgs { + inherit system; + overlays = [ overlay ]; + config.allowBroken = true; + }; + in + { + packages = flake-utils.lib.eachDefaultSystemMap (system: + with (pkgsFor system); { + default = haskellPackages.symengine; + symengine = haskellPackages.symengine; + haskell = haskell.packages; + }); + + devShells = flake-utils.lib.eachDefaultSystemMap (system: + with (pkgsFor system); { + default = haskellPackages.shellFor { + packages = ps: with ps; [ symengine ]; + withHoogle = true; + nativeBuildInputs = with pkgs; with haskellPackages; [ + # Building and testing + cabal-install + # Language servers + haskell-language-server + nil + # Formatters + fourmolu + # cabal-fmt + nixpkgs-fmt + # Previewing markdown files + python3Packages.grip + ]; + shellHook = '' + LD_LIBRARY_PATH=${pkgs.symengine}/lib:${pkgs.flint}/lib:${pkgs.libmpc}/lib:${pkgs.mpfr}/lib:$LD_LIBRARY_PATH + SYMENGINE_PATH=${pkgs.symengine} + ''; + }; + # The formatter to use for .nix files (but not .hs files) + # Allows us to run `nix fmt` to reformat nix files. + formatter = pkgs.nixpkgs-fmt; + } + ); + overlays.default = overlay; + }; +} diff --git a/fourmolu.yaml b/fourmolu.yaml new file mode 100644 index 0000000..e9f82b2 --- /dev/null +++ b/fourmolu.yaml @@ -0,0 +1,14 @@ +indentation: 2 +function-arrows: leading +comma-style: leading +import-export-style: leading +indent-wheres: true +record-brace-space: true +newlines-between-decls: 1 +haddock-style: single-line +haddock-style-module: single-line +let-style: auto +in-style: right-align +unicode: never +respectful: true +fixities: [] diff --git a/src/Symengine.hs b/src/Symengine.hs index d0671ff..00bb99f 100644 --- a/src/Symengine.hs +++ b/src/Symengine.hs @@ -1,281 +1,826 @@ -{-# LANGUAGE RecordWildCards #-} - -{-| -Module : Symengine -Description : Symengine bindings to Haskell --} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -Wno-unused-matches #-} + +-- | +-- Module : Symengine +-- Description : Symengine bindings to Haskell module Symengine - ( - ascii_art_str, - zero, - one, - im, - Symengine.pi, - e, - minus_one, - rational, - complex, - symbol, - BasicSym, - ) where - + ( Basic (..) + , DenseMatrix (..) + , symbol + , parse + , e + , infinity + , nan + , diff + , evalf + , subs + , inverse + , identityMatrix + , zeroMatrix + , allocaCxxInteger + , peekCxxInteger + , withCxxInteger + , EvalDomain (..) + , InverseMethod (..) + , toAST + , fromAST + , AST (..) + , BasicKey (..) + ) where + +import Control.Exception (bracket, bracket_) +import Control.Monad +import Data.Hashable (Hashable (..)) +import Data.Text (Text, pack, unpack) +import Data.Text.Encoding qualified as T +import Data.Vector (Vector) +import Data.Vector qualified as V import Foreign.C.Types -import Foreign.Ptr -import Foreign.C.String -import Foreign.Storable -import Foreign.Marshal.Array -import Foreign.Marshal.Alloc import Foreign.ForeignPtr -import Control.Applicative +import Foreign.Marshal (allocaBytesAligned, toBool, withArrayLen) +import Foreign.Ptr +import GHC.Exts (IsString (..)) +import GHC.Int +import GHC.Num.BigNat +import GHC.Num.Integer +import GHC.Real (Ratio (..)) +import Language.C.Inline qualified as C +import Language.C.Inline.Unsafe qualified as CU +import Symengine.Context +import Symengine.Internal import System.IO.Unsafe -import Control.Monad -import GHC.Real - -data BasicStruct = BasicStruct { - data_ptr :: Ptr () -} - -instance Storable BasicStruct where - alignment _ = 8 - sizeOf _ = sizeOf nullPtr - peek basic_ptr = BasicStruct <$> peekByteOff basic_ptr 0 - poke basic_ptr BasicStruct{..} = pokeByteOff basic_ptr 0 data_ptr - - --- |represents a symbol exported by SymEngine. create this using the functions --- 'zero', 'one', 'minus_one', 'e', 'im', 'rational', 'complex', and also by --- constructing a number and converting it to a Symbol --- --- >>> 3.5 :: BasicSym --- 7/2 --- --- >>> rational 2 10 --- 1 /5 --- --- >>> complex 1 2 --- 1 + 2*I -data BasicSym = BasicSym { fptr :: ForeignPtr BasicStruct } - -withBasicSym :: BasicSym -> (Ptr BasicStruct -> IO a) -> IO a -withBasicSym p f = withForeignPtr (fptr p ) f - -withBasicSym2 :: BasicSym -> BasicSym -> (Ptr BasicStruct -> Ptr BasicStruct -> IO a) -> IO a -withBasicSym2 p1 p2 f = withBasicSym p1 (\p1 -> withBasicSym p2 (\p2 -> f p1 p2)) - -withBasicSym3 :: BasicSym -> BasicSym -> BasicSym -> (Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO a) -> IO a -withBasicSym3 p1 p2 p3 f = withBasicSym p1 (\p1 -> withBasicSym p2 (\p2 -> withBasicSym p3 (\p3 -> f p1 p2 p3))) - - --- | constructor for 0 -zero :: BasicSym -zero = basic_obj_constructor basic_const_zero_ffi --- | constructor for 1 -one :: BasicSym -one = basic_obj_constructor basic_const_one_ffi +-- | Basic building block of SymEngine expressions. +newtype Basic = Basic (ForeignPtr CxxBasic) --- | constructor for -1 -minus_one :: BasicSym -minus_one = basic_obj_constructor basic_const_minus_one_ffi +data DenseMatrix a = DenseMatrix {dmRows :: !Int, dmCols :: !Int, dmData :: !(Vector a)} --- | constructor for i = sqrt(-1) -im :: BasicSym -im = basic_obj_constructor basic_const_I_ffi +data CxxBasic --- | the ratio of the circumference of a circle to its radius -pi :: BasicSym -pi = basic_obj_constructor basic_const_pi_ffi +data CxxInteger --- | The base of the natural logarithm -e :: BasicSym -e = basic_obj_constructor basic_const_E_ffi +data CxxString -expand :: BasicSym -> BasicSym -expand = basic_unaryop basic_expand_ffi +data CxxMapBasicBasic +importSymengine -eulerGamma :: BasicSym -eulerGamma = basic_obj_constructor basic_const_EulerGamma_ffi - -basic_obj_constructor :: (Ptr BasicStruct -> IO ()) -> BasicSym -basic_obj_constructor init_fn = unsafePerformIO $ do - basic_ptr <- create_basic_ptr - withBasicSym basic_ptr init_fn - return basic_ptr - -basic_str :: BasicSym -> String -basic_str basic_ptr = unsafePerformIO $ withBasicSym basic_ptr (basic_str_ffi >=> peekCString) - -integerToCLong :: Integer -> CLong -integerToCLong i = CLong (fromInteger i) - +-- | Convert a pointer to @std::string@ into a string. +-- +-- It properly handles unicode characters. +-- peekCxxString :: Ptr CxxString -> IO Text +-- peekCxxString p = +-- fmap T.decodeUtf8 $ +-- packCString +-- =<< [CU.exp| char const* { $(const std::string* p)->c_str() } |] + +-- | Call 'peekCxxString' and @delete@ the pointer. +-- peekAndDeleteCxxString :: Ptr CxxString -> IO Text +-- peekAndDeleteCxxString p = do +-- s <- peekCxxString p +-- [CU.exp| void { delete $(const std::string* p) } |] +-- pure s +constructBasic :: (Ptr CxxBasic -> IO ()) -> IO Basic +constructBasic construct = + fmap Basic $ constructWithDeleter size deleter $ \ptr -> do + [CU.block| void { new ($(Object* ptr)) Object{}; } |] + construct ptr + where + size = fromIntegral [CU.pure| size_t { sizeof(Object) } |] + deleter = [C.funPtr| void deleteBasic(Object* ptr) { ptr->~Object(); } |] + +constructWithDeleter :: Int -> FinalizerPtr a -> (Ptr a -> IO ()) -> IO (ForeignPtr a) +constructWithDeleter size deleter constructor = do + fp <- mallocForeignPtrBytes size + withForeignPtr fp constructor + addForeignPtrFinalizer deleter fp + pure fp + +-- newtype VecBasic = VecBasic (ForeignPtr CxxVecBasic) + +-- constructVecBasic :: (Ptr CxxVecBasic -> IO ()) -> IO VecBasic +-- constructVecBasic construct = +-- fmap VecBasic $ constructWithDeleter size deleter $ \ptr -> do +-- [CU.block| void { new ($(SymEngine::vec_basic* ptr)) SymEngine::vec_basic{}; } |] +-- construct ptr +-- where +-- size = fromIntegral [CU.pure| size_t { sizeof(SymEngine::vec_basic) } |] +-- deleter = [C.funPtr| void deleteBasic(SymEngine::vec_basic* ptr) { ptr->~vector(); } |] + +withBasic :: Basic -> (Ptr CxxBasic -> IO a) -> IO a +withBasic (Basic fp) = withForeignPtr fp + +-- withVecBasic :: VecBasic -> (Ptr CxxVecBasic -> IO a) -> IO a +-- withVecBasic (VecBasic fp) = withForeignPtr fp + +-- vecBasicToList :: VecBasic -> [Basic] +-- vecBasicToList v = unsafePerformIO $ +-- withVecBasic v $ \v' -> do +-- size <- [CU.exp| size_t { $(const SymEngine::vec_basic* v')->size() } |] +-- forM [0 .. size - 1] $ \i -> +-- constructBasic $ \dest -> +-- [CU.exp| void { CONSTRUCT_BASIC($(Object* dest), +-- $(const SymEngine::vec_basic* v')->at($(size_t i))) } |] + +cxxVectorSize :: Ptr (Vector Basic) -> IO Int +cxxVectorSize ptr = fromIntegral <$> [CU.exp| size_t { $(const Vector* ptr)->size() } |] + +cxxVectorIndex :: Ptr (Vector Basic) -> Int -> IO Basic +cxxVectorIndex ptr (fromIntegral -> i) = + $(constructBasicFrom "$(const Vector* ptr)->at($(size_t i))") + +cxxVectorPushBack :: Ptr (Vector Basic) -> Basic -> IO () +cxxVectorPushBack ptr basic = + withBasic basic $ \x -> + [CU.exp| void { $(Vector* ptr)->push_back(*$(const Object* x)) } |] + +peekVector :: Ptr (Vector Basic) -> IO (Vector Basic) +peekVector ptr = do + size <- cxxVectorSize ptr + V.forM (V.enumFromStepN 0 1 size) (cxxVectorIndex ptr) + +allocaVector :: (Ptr (Vector Basic) -> IO a) -> IO a +allocaVector action = + allocaBytesAligned sizeBytes alignmentBytes $ \v -> + let construct = [CU.exp| void { new ($(Vector* v)) Vector{} } |] + destruct = [CU.exp| void { $(Vector* v)->~Vector() } |] + in bracket_ construct destruct (action v) + where + sizeBytes = fromIntegral [CU.pure| size_t { sizeof(Vector) } |] + alignmentBytes = fromIntegral [CU.pure| size_t { alignof(Vector) } |] + +withVector :: Vector Basic -> (Ptr (Vector Basic) -> IO a) -> IO a +withVector v action = do + allocaVector $ \ptr -> do + V.forM_ v $ cxxVectorPushBack ptr + action ptr + +-- \$ \dest -> +-- [CU.exp| void { CONSTRUCT_BASIC($(Object* dest), +-- $(const Vector* ptr)->at($(size_t i))) } |] + +allocaDenseMatrix :: Int -> Int -> (Ptr (DenseMatrix Basic) -> IO a) -> IO a +allocaDenseMatrix (fromIntegral -> nrows) (fromIntegral -> ncols) action = do + allocaBytesAligned sizeBytes alignmentBytes $ \v -> + let construct = + [CU.exp| void { new ($(DenseMatrix * v)) DenseMatrix{ + $(unsigned nrows), $(unsigned ncols)} } |] + destruct = [CU.exp| void { $(DenseMatrix* v)->~DenseMatrix() } |] + in bracket_ construct destruct (action v) + where + sizeBytes = fromIntegral [CU.pure| size_t { sizeof(DenseMatrix) } |] + alignmentBytes = fromIntegral [CU.pure| size_t { alignof(DenseMatrix) } |] + +withDenseMatrix :: DenseMatrix Basic -> (Ptr (DenseMatrix Basic) -> IO a) -> IO a +withDenseMatrix matrix action = + allocaDenseMatrix 0 0 $ \ptr -> + withVector (dmData matrix) $ \v -> do + let n = fromIntegral $ dmRows matrix + m = fromIntegral $ dmCols matrix + [CU.block| void { + *$(DenseMatrix* ptr) = DenseMatrix{$(unsigned n), $(unsigned m), *$(const Vector* v)}; + } |] + action ptr + +peekDenseMatrix :: Ptr (DenseMatrix Basic) -> IO (DenseMatrix Basic) +peekDenseMatrix ptr = do + n <- fromIntegral <$> [CU.exp| unsigned { $(const DenseMatrix* ptr)->nrows() } |] + m <- fromIntegral <$> [CU.exp| unsigned { $(const DenseMatrix* ptr)->ncols() } |] + v <- + allocaVector $ \v -> do + [CU.block| void { *$(Vector* v) = $(const DenseMatrix* ptr)->as_vec_basic(); } |] + peekVector v + pure $ DenseMatrix n m v + +instance Show Basic where + show basic = unpack . unsafePerformIO $ + withBasic basic $ \basic' -> + $(constructStringFrom "SymEngine::str(**$(Object* basic'))") + +deriving stock instance Show (DenseMatrix Basic) + +instance Eq Basic where + a == b = unsafePerformIO $ + withBasic a $ \a' -> + withBasic b $ \b' -> + toBool + <$> [CU.exp| bool { eq(**$(const Object* a'), **$(const Object* b')) } |] + +instance Hashable Basic where + hashWithSalt s = hashWithSalt s . hashInternal + where + hashInternal x = unsafePerformIO $ withBasic x $ \p -> [CU.exp| uint64_t { (*$(Object const* p))->hash() } |] + +newtype BasicKey = BasicKey {unBasicKey :: Basic} + +instance Eq BasicKey where + (BasicKey a) == (BasicKey b) + | hash a /= hash b = False + | otherwise = a == b + +instance Ord BasicKey where + compare (BasicKey a) (BasicKey b) = + case compare hashA hashB of + LT -> LT + GT -> GT + EQ -> case compareInternal of + -1 -> LT + 0 -> EQ + 1 -> GT + x -> error $ "__cmp__ returned invalid value: " <> show x + where + hashA = hash a + hashB = hash b + compareInternal = + unsafePerformIO $ withBasic a $ \aPtr -> withBasic b $ \bPtr -> + [CU.exp| int { (*$(Object const* aPtr))->__cmp__(**$(Object const* bPtr)) } |] + +parse :: Text -> Basic +parse (T.encodeUtf8 -> name) = + unsafePerformIO $ $(constructBasicFrom "parse($bs-cstr:name)") + +instance IsString Basic where + fromString = parse . pack + +symbol :: Text -> Basic +symbol (T.encodeUtf8 -> name) = + unsafePerformIO $ + $(constructBasicFrom "symbol(std::string{$bs-ptr:name, static_cast($bs-len:name)})") + +-- constructBasic $ \dest -> +-- [CU.exp| void { new ($(Object* dest)) Object{} } |] + +-- pureUnaryOp :: (Ptr CxxBasic -> Ptr CxxBasic -> IO ()) -> Basic -> Basic +-- pureUnaryOp f a = unsafePerformIO $ +-- withBasic a $ \a' -> +-- constructBasic $ \dest -> +-- f dest a' + +-- pureBinaryOp :: (Ptr CxxBasic -> Ptr CxxBasic -> Ptr CxxBasic -> IO ()) -> Basic -> Basic -> Basic +-- pureBinaryOp f a b = unsafePerformIO $ +-- withBasic a $ \a' -> +-- withBasic b $ \b' -> +-- constructBasic $ \dest -> +-- f dest a' b' + +allocaCxxInteger :: (Ptr CxxInteger -> IO a) -> IO a +allocaCxxInteger f = + allocaBytesAligned sizeBytes alignmentBytes $ \i -> + let construct = + [CU.exp| void { new ($(integer_class * i)) integer_class{} } |] + destruct = [CU.exp| void { $(integer_class * i)->~integer_class() } |] + in bracket_ construct destruct (f i) + where + sizeBytes = fromIntegral [CU.pure| size_t { sizeof(integer_class) } |] + alignmentBytes = fromIntegral [CU.pure| size_t { alignof(integer_class) } |] + +integerToWords :: Integer -> [Word] +integerToWords (IP b) = bigNatToWordList b +integerToWords (IN b) = bigNatToWordList b +integerToWords (IS n) = [fromIntegral (abs (I# n))] + +withCxxInteger :: Integer -> (Ptr CxxInteger -> IO a) -> IO a +withCxxInteger n action = + allocaCxxInteger $ \i -> + withArrayLen (fromIntegral <$> integerToWords n) $ + \(fromIntegral -> numWords) wordsPtr -> do + [CU.block| void { + auto const numWords = $(int numWords); + auto const* words = $(const uint64_t* wordsPtr); + if (numWords > 0) { + integer_class x{words[0]}; + for (int k = 1; k < numWords; ++k) { + x <<= 64; + x += words[k]; + } + *$(integer_class* i) = x; + } + } |] + when (n < 0) $ do + [CU.block| void { + auto& i = *$(integer_class* i); + i = -i; + } |] + action i + +peekCxxInteger :: Ptr CxxInteger -> IO Integer +peekCxxInteger i = do + allocaCxxInteger $ \j -> do + isNegative <- + toBool + <$> [CU.block| bool { + auto const& i = *$(integer_class const* i); + auto& j = *$(integer_class* j); + j = mp_abs(i); + return i < 0; + } |] + let go acc = do + w <- + [CU.block| uint64_t { + auto const& j = *$(integer_class const* j); + return mp_get_ui(j); + } |] + continue <- + toBool + <$> [CU.block| bool { + auto& j = *$(integer_class* j); + j >>= 64; + return j != 0; + } |] + if continue + then go $ w : acc + else pure $ w : acc + integerFromWordList isNegative . fmap fromIntegral <$> go [] + +instance Num Basic where + fromInteger n = unsafePerformIO $ + withCxxInteger n $ \i -> + $(constructBasicFrom "integer(*$(const integer_class* i))") + (+) = $(mkBinaryFunction "add(a, b)") + (-) = $(mkBinaryFunction "sub(a, b)") + (*) = $(mkBinaryFunction "mul(a, b)") + abs = $(mkUnaryFunction "abs(a)") + signum = $(mkUnaryFunction "sign(a)") + +instance Fractional Basic where + (/) = $(mkBinaryFunction "div(a, b)") + fromRational (numer :% denom) = + unsafePerformIO $ + withCxxInteger numer $ \numer' -> + withCxxInteger denom $ \denom' -> + $( constructBasicFrom + "Rational::from_two_ints(\ + \Integer(*$(const integer_class* numer')),\ + \Integer(*$(const integer_class* denom')))" + ) + +e :: Basic +e = unsafePerformIO $ $(constructBasicFrom "E") +{-# NOINLINE e #-} + +infinity :: Basic +infinity = unsafePerformIO $ $(constructBasicFrom "Inf") +{-# NOINLINE infinity #-} + +nan :: Basic +nan = unsafePerformIO $ $(constructBasicFrom "Nan") +{-# NOINLINE nan #-} + +instance Floating Basic where + pi = unsafePerformIO $ $(constructBasicFrom "pi") + exp = $(mkUnaryFunction "exp(a)") + log = $(mkUnaryFunction "log(a)") + sqrt = $(mkUnaryFunction "sqrt(a)") + (**) = $(mkBinaryFunction "pow(a, b)") + sin = $(mkUnaryFunction "sin(a)") + cos = $(mkUnaryFunction "cos(a)") + tan = $(mkUnaryFunction "tan(a)") + asin = $(mkUnaryFunction "asin(a)") + acos = $(mkUnaryFunction "acos(a)") + atan = $(mkUnaryFunction "atan(a)") + sinh = $(mkUnaryFunction "sinh(a)") + cosh = $(mkUnaryFunction "cosh(a)") + tanh = $(mkUnaryFunction "tanh(a)") + asinh = $(mkUnaryFunction "asinh(a)") + acosh = $(mkUnaryFunction "acosh(a)") + atanh = $(mkUnaryFunction "atanh(a)") + +diff :: Basic -> Basic -> Basic +diff f x + | basicTypeCode x == [CU.pure| int { static_cast(SYMENGINE_SYMBOL) } |] = + $(mkBinaryFunction "a->diff(rcp_static_cast(b))") f x + | otherwise = error "can only differentiate with respect to symbols" + +data EvalDomain = EvalComplex | EvalReal | EvalSymbolic + deriving stock (Show, Eq) + +evalDomainToCInt :: EvalDomain -> CInt +evalDomainToCInt EvalComplex = [CU.pure| int { static_cast(SymEngine::EvalfDomain::Complex) } |] +evalDomainToCInt EvalReal = [CU.pure| int { static_cast(SymEngine::EvalfDomain::Real) } |] +evalDomainToCInt EvalSymbolic = [CU.pure| int { static_cast(SymEngine::EvalfDomain::Symbolic) } |] + +evalf :: EvalDomain -> Int -> Basic -> Basic +evalf (evalDomainToCInt -> domain) (fromIntegral -> bits) x = unsafePerformIO $ + withBasic x $ \x' -> + $(constructBasicFrom "evalf(**$(const Object* x'), $(int bits), static_cast($(int domain)))") + +withCxxMapBasicBasic :: [(Basic, Basic)] -> (Ptr CxxMapBasicBasic -> IO a) -> IO a +withCxxMapBasicBasic pairs action = + bracket allocate destroy $ \p -> do + forM_ pairs $ \(from, to) -> + withBasic from $ \fromPtr -> withBasic to $ \toPtr -> + [CU.exp| void { + $(map_basic_basic* p)->emplace(*$(Object const* fromPtr), *$(Object const* toPtr)) } |] + action p + where + allocate = [CU.exp| map_basic_basic* { new map_basic_basic } |] + destroy p = [CU.exp| void { delete $(map_basic_basic* p) } |] + +subs :: [(Basic, Basic)] -> Basic -> Basic +subs replacements expr = + unsafePerformIO $ + withCxxMapBasicBasic replacements $ \replacementsPtr -> + withBasic expr $ \exprPtr -> + $(constructBasicFrom "subs(*$(Object const* exprPtr), *$(map_basic_basic const* replacementsPtr))") + +generateDenseMatrix :: Int -> Int -> (Int -> Int -> Basic) -> DenseMatrix Basic +generateDenseMatrix nrows ncols f = + DenseMatrix nrows ncols $ + V.generate (nrows * ncols) $ \i -> + let (!r, !c) = i `divMod` ncols + in f r c + +identityMatrix :: Int -> DenseMatrix Basic +identityMatrix n = generateDenseMatrix n n (\i j -> if i == j then 1 else 0) + +zeroMatrix :: Int -> Int -> DenseMatrix Basic +zeroMatrix n m = generateDenseMatrix n m (\_ _ -> 0) + +data InverseMethod + = InverseDefault + | InverseFractionFreeLU + | InverseLU + | InversePivotedLU + | InverseGaussJordan + deriving stock (Show, Eq) + +inverse :: InverseMethod -> DenseMatrix Basic -> DenseMatrix Basic +inverse InverseDefault m = unsafePerformIO $ withDenseMatrix m $ \a -> + $( createDenseMatrixVia + "auto const& a = *$(const DenseMatrix* a);\ + \out.resize(a.nrows(), a.ncols());\ + \a.inv(out);" + ) + +data AST + = SymengineInteger Integer + | SymengineRational Rational + | SymengineInfinity + | SymengineNaN + | SymengineConstant Basic + | SymengineSymbol Text + | SymengineMul (Vector Basic) + | SymengineAdd (Vector Basic) + | SymenginePow Basic Basic + | SymengineLog Basic + | SymengineSign Basic + | SymengineFunction Text (Vector Basic) + | SymengineDerivative Basic (Vector Basic) + deriving stock (Show, Eq) + +basicTypeCode :: Basic -> CInt +basicTypeCode x = unsafePerformIO $ + withBasic x $ + \x' -> [CU.exp| int { static_cast((*$(const Object* x'))->get_type_code()) } |] + +forceOneArg :: (Basic -> a) -> Vector Basic -> a +forceOneArg f v = case V.toList v of + [a] -> f a + _ -> error "expected a one-element vector" + +forceTwoArgs :: (Basic -> Basic -> a) -> Vector Basic -> a +forceTwoArgs f v = case V.toList v of + [a, b] -> f a b + _ -> error "expected a two-element vector" + +unsafeIntegerToAST :: Basic -> AST +unsafeIntegerToAST x = SymengineInteger n + where + n = unsafePerformIO $ + withBasic x $ \x' -> + allocaCxxInteger $ \i -> do + [CU.exp| void { + *$(integer_class* i) = + down_cast(**$(const Object* x')).as_integer_class() + } |] + peekCxxInteger i + +unsafeRationalToAST :: Basic -> AST +unsafeRationalToAST x = SymengineRational q + where + q = unsafePerformIO $ + withBasic x $ \x' -> + allocaCxxInteger $ \m -> + allocaCxxInteger $ \n -> do + [CU.block| void { + auto const& x = + down_cast(**$(const Object* x')).as_rational_class(); + *$(integer_class* m) = x.get_num(); + *$(integer_class* n) = x.get_den(); + } |] + (:%) <$> peekCxxInteger m <*> peekCxxInteger n + +unsafeSymbolToAST :: Basic -> AST +unsafeSymbolToAST x = SymengineSymbol . unsafePerformIO $ do + withBasic x $ \x' -> + $(constructStringFrom "down_cast(**$(const Object* x')).get_name()") + +toAST :: Basic -> AST +toAST x + | tp == [CU.pure| int { static_cast(SYMENGINE_INTEGER) } |] = unsafeIntegerToAST x + | tp == [CU.pure| int { static_cast(SYMENGINE_RATIONAL) } |] = unsafeRationalToAST x + | tp == [CU.pure| int { static_cast(SYMENGINE_INFTY) } |] = SymengineInfinity + | tp == [CU.pure| int { static_cast(SYMENGINE_NOT_A_NUMBER) } |] = SymengineNaN + | tp == [CU.pure| int { static_cast(SYMENGINE_CONSTANT) } |] = SymengineConstant x + | tp == [CU.pure| int { static_cast(SYMENGINE_SYMBOL) } |] = unsafeSymbolToAST x + | tp == [CU.pure| int { static_cast(SYMENGINE_ADD) } |] = + unsafePerformIO $ SymengineAdd . V.reverse <$> $(unpackFunction "Add") x + | tp == [CU.pure| int { static_cast(SYMENGINE_MUL) } |] = + unsafePerformIO $ SymengineMul <$> $(unpackFunction "Mul") x + | tp == [CU.pure| int { static_cast(SYMENGINE_POW) } |] = + unsafePerformIO $ forceTwoArgs SymenginePow <$> $(unpackFunction "Pow") x + | tp == [CU.pure| int { static_cast(SYMENGINE_LOG) } |] = + unsafePerformIO $ forceOneArg SymengineLog <$> $(unpackFunction "Log") x + | tp == [CU.pure| int { static_cast(SYMENGINE_SIGN) } |] = + unsafePerformIO $ forceOneArg SymengineSign <$> $(unpackFunction "Sign") x + | tp == [CU.pure| int { static_cast(SYMENGINE_FUNCTIONSYMBOL) } |] = + unsafePerformIO $ do + name <- withBasic x $ \x' -> + $(constructStringFrom "down_cast(**$(const Object* x')).get_name()") + args <- $(unpackFunction "FunctionSymbol") x + pure $ SymengineFunction name args + | tp == [CU.pure| int { static_cast(SYMENGINE_DERIVATIVE) } |] = + unsafePerformIO $ do + args <- $(unpackFunction "Derivative") x + pure $ SymengineDerivative (V.head args) (V.tail args) + | otherwise = error $ "unknown type code: " <> show tp + where + tp = basicTypeCode x + +fromAST :: AST -> Basic +fromAST = \case + SymengineInteger x -> fromInteger x + SymengineRational x -> fromRational x + SymengineInfinity -> infinity + SymengineNaN -> nan + SymengineConstant x -> x + SymengineSymbol x -> symbol x + SymengineAdd v -> V.foldl' (+) 0 v + SymengineMul v -> V.foldl' (*) 0 v + SymenginePow a b -> a ** b + SymengineLog x -> log x + SymengineSign x -> signum x + SymengineDerivative f v -> V.foldl' diff f v + SymengineFunction (T.encodeUtf8 -> s) v -> unsafePerformIO $ + withVector v $ \args -> + $(constructBasicFrom "function_symbol(std::string{$bs-ptr:s, static_cast($bs-len:s)}, *$(const Vector* args))") + +{- +-- | Convert a C string into a Haskell string properly handling unicode characters. +peekCString :: CString -> IO Text +peekCString = fmap T.decodeUtf8 . packCString + +withTempCString :: IO CString -> (CString -> IO a) -> IO a +withTempCString allocate = bracket allocate destroy + where + destroy p = [CU.exp| void { basic_str_free($(char* p)) } |] + +asciiArt :: IO Text +asciiArt = withTempCString [CU.exp| char* { ascii_art_str() } |] peekCString +-} -intToCLong :: Int -> CLong -intToCLong i = integerToCLong (toInteger i) - -basic_int_signed :: Int -> BasicSym -basic_int_signed i = unsafePerformIO $ do - iptr <- create_basic_ptr - withBasicSym iptr (\iptr -> integer_set_si_ffi iptr (intToCLong i) ) - return iptr - - -basic_from_integer :: Integer -> BasicSym -basic_from_integer i = unsafePerformIO $ do - iptr <- create_basic_ptr - withBasicSym iptr (\iptr -> integer_set_si_ffi iptr (fromInteger i)) - return iptr - --- |The `ascii_art_str` function prints SymEngine in ASCII art. --- this is useful as a sanity check -ascii_art_str :: IO String -ascii_art_str = ascii_art_str_ffi >>= peekCString - --- Unexported ffi functions------------------------ - --- |Create a basic object that represents all other objects through --- the FFI -create_basic_ptr :: IO BasicSym -create_basic_ptr = do - basic_ptr <- newArray [BasicStruct { data_ptr = nullPtr }] - basic_new_heap_ffi basic_ptr - finalized_ptr <- newForeignPtr ptr_basic_free_heap_ffi basic_ptr - return $ BasicSym { fptr = finalized_ptr } - -basic_binaryop :: (Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO ()) -> BasicSym -> BasicSym -> BasicSym -basic_binaryop f a b = unsafePerformIO $ do - s <- create_basic_ptr - withBasicSym3 s a b f - return s - -basic_unaryop :: (Ptr BasicStruct -> Ptr BasicStruct -> IO ()) -> BasicSym -> BasicSym -basic_unaryop f a = unsafePerformIO $ do - s <- create_basic_ptr - withBasicSym2 s a f - return s - - -basic_pow :: BasicSym -> BasicSym -> BasicSym -basic_pow = basic_binaryop basic_pow_ffi - --- |Create a rational number with numerator and denominator -rational :: BasicSym -> BasicSym -> BasicSym -rational = basic_binaryop rational_set_ffi - --- |Create a complex number a + b * im -complex :: BasicSym -> BasicSym -> BasicSym -complex a b = (basic_binaryop complex_set_ffi) a b - -basic_rational_from_integer :: Integer -> Integer -> BasicSym -basic_rational_from_integer i j = unsafePerformIO $ do - s <- create_basic_ptr - withBasicSym s (\s -> rational_set_si_ffi s (integerToCLong i) (integerToCLong j)) - return s - --- |Create a symbol with the given name -symbol :: String -> BasicSym -symbol name = unsafePerformIO $ do - s <- create_basic_ptr - cname <- newCString name - withBasicSym s (\s -> symbol_set_ffi s cname) - free cname - return s - --- |Differentiate an expression with respect to a symbol -diff :: BasicSym -> BasicSym -> BasicSym -diff expr symbol = (basic_binaryop basic_diff_ffi) expr symbol - -instance Show BasicSym where - show = basic_str - -instance Eq BasicSym where - (==) a b = unsafePerformIO $ do - i <- withBasicSym2 a b basic_eq_ffi - return $ i == 1 - - -instance Num BasicSym where - (+) = basic_binaryop basic_add_ffi - (-) = basic_binaryop basic_sub_ffi - (*) = basic_binaryop basic_mul_ffi - negate = basic_unaryop basic_neg_ffi - abs = basic_unaryop basic_abs_ffi - signum = undefined - fromInteger = basic_from_integer - -instance Fractional BasicSym where - (/) = basic_binaryop basic_div_ffi - fromRational (num :% denom) = basic_rational_from_integer num denom - recip r = one / r - -instance Floating BasicSym where - pi = Symengine.pi - exp x = e ** x - log = undefined - sqrt x = x ** 1/2 - (**) = basic_pow - logBase = undefined - sin = basic_unaryop basic_sin_ffi - cos = basic_unaryop basic_cos_ffi - tan = basic_unaryop basic_tan_ffi - asin = basic_unaryop basic_asin_ffi - acos = basic_unaryop basic_acos_ffi - atan = basic_unaryop basic_atan_ffi - sinh = basic_unaryop basic_sinh_ffi - cosh = basic_unaryop basic_cosh_ffi - tanh = basic_unaryop basic_tanh_ffi - asinh = basic_unaryop basic_asinh_ffi - acosh = basic_unaryop basic_acosh_ffi - atanh = basic_unaryop basic_atanh_ffi - -foreign import ccall "symengine/cwrapper.h ascii_art_str" ascii_art_str_ffi :: IO CString -foreign import ccall "symengine/cwrapper.h basic_new_heap" basic_new_heap_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h &basic_free_heap" ptr_basic_free_heap_ffi :: FunPtr(Ptr BasicStruct -> IO ()) - --- constants -foreign import ccall "symengine/cwrapper.h basic_const_zero" basic_const_zero_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_one" basic_const_one_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_minus_one" basic_const_minus_one_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_I" basic_const_I_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_pi" basic_const_pi_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_E" basic_const_E_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_EulerGamma" basic_const_EulerGamma_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_str" basic_str_ffi :: Ptr BasicStruct -> IO CString -foreign import ccall "symengine/cwrapper.h basic_eq" basic_eq_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO Int - -foreign import ccall "symengine/cwrapper.h symbol_set" symbol_set_ffi :: Ptr BasicStruct -> CString -> IO () -foreign import ccall "symengine/cwrapper.h basic_diff" basic_diff_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h integer_set_si" integer_set_si_ffi :: Ptr BasicStruct -> CLong -> IO () - -foreign import ccall "symengine/cwrapper.h rational_set" rational_set_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h rational_set_si" rational_set_si_ffi :: Ptr BasicStruct -> CLong -> CLong -> IO () - -foreign import ccall "symengine/cwrapper.h complex_set" complex_set_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_expand" basic_expand_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - - -foreign import ccall "symengine/cwrapper.h basic_add" basic_add_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_sub" basic_sub_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_mul" basic_mul_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_div" basic_div_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_pow" basic_pow_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_neg" basic_neg_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_abs" basic_abs_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_sin" basic_sin_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_cos" basic_cos_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_tan" basic_tan_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_asin" basic_asin_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_acos" basic_acos_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_atan" basic_atan_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_sinh" basic_sinh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_cosh" basic_cosh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_tanh" basic_tanh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_asinh" basic_asinh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_acosh" basic_acosh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_atanh" basic_atanh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- newtype BasicStruct = BasicStruct +-- { data_ptr :: Ptr () +-- } +-- +-- instance Storable BasicStruct where +-- alignment _ = 8 +-- sizeOf _ = sizeOf nullPtr +-- peek basic_ptr = BasicStruct <$> peekByteOff basic_ptr 0 +-- poke basic_ptr BasicStruct {..} = pokeByteOff basic_ptr 0 data_ptr +-- +-- -- |represents a symbol exported by SymEngine. create this using the functions +-- -- 'zero', 'one', 'minus_one', 'e', 'im', 'rational', 'complex', and also by +-- -- constructing a number and converting it to a Symbol +-- -- +-- -- >>> 3.5 :: BasicSym +-- -- 7/2 +-- -- +-- -- >>> rational 2 10 +-- -- 1 /5 +-- -- +-- -- >>> complex 1 2 +-- -- 1 + 2*I +-- data BasicSym = BasicSym {fptr :: ForeignPtr BasicStruct} +-- +-- withBasicSym :: BasicSym -> (Ptr BasicStruct -> IO a) -> IO a +-- withBasicSym p f = withForeignPtr (fptr p) f +-- +-- withBasicSym2 :: BasicSym -> BasicSym -> (Ptr BasicStruct -> Ptr BasicStruct -> IO a) -> IO a +-- withBasicSym2 p1 p2 f = withBasicSym p1 (\p1 -> withBasicSym p2 (\p2 -> f p1 p2)) +-- +-- withBasicSym3 :: BasicSym -> BasicSym -> BasicSym -> (Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO a) -> IO a +-- withBasicSym3 p1 p2 p3 f = withBasicSym p1 (\p1 -> withBasicSym p2 (\p2 -> withBasicSym p3 (\p3 -> f p1 p2 p3))) +-- +-- -- | constructor for 0 +-- zero :: BasicSym +-- zero = basic_obj_constructor basic_const_zero_ffi +-- +-- -- | constructor for 1 +-- one :: BasicSym +-- one = basic_obj_constructor basic_const_one_ffi +-- +-- -- | constructor for -1 +-- minus_one :: BasicSym +-- minus_one = basic_obj_constructor basic_const_minus_one_ffi +-- +-- -- | constructor for i = sqrt(-1) +-- im :: BasicSym +-- im = basic_obj_constructor basic_const_I_ffi +-- +-- -- | the ratio of the circumference of a circle to its radius +-- pi :: BasicSym +-- pi = basic_obj_constructor basic_const_pi_ffi +-- +-- -- | The base of the natural logarithm +-- e :: BasicSym +-- e = basic_obj_constructor basic_const_E_ffi +-- +-- expand :: BasicSym -> BasicSym +-- expand = basic_unaryop basic_expand_ffi +-- +-- eulerGamma :: BasicSym +-- eulerGamma = basic_obj_constructor basic_const_EulerGamma_ffi +-- +-- basic_obj_constructor :: (Ptr BasicStruct -> IO ()) -> BasicSym +-- basic_obj_constructor init_fn = unsafePerformIO $ do +-- basic_ptr <- create_basic_ptr +-- withBasicSym basic_ptr init_fn +-- return basic_ptr +-- +-- basic_str :: BasicSym -> String +-- basic_str basic_ptr = unsafePerformIO $ withBasicSym basic_ptr (basic_str_ffi >=> peekCString) +-- +-- integerToCLong :: Integer -> CLong +-- integerToCLong i = CLong (fromInteger i) +-- +-- intToCLong :: Int -> CLong +-- intToCLong i = integerToCLong (toInteger i) +-- +-- basic_int_signed :: Int -> BasicSym +-- basic_int_signed i = unsafePerformIO $ do +-- iptr <- create_basic_ptr +-- withBasicSym iptr (\iptr -> integer_set_si_ffi iptr (intToCLong i)) +-- return iptr +-- +-- basic_from_integer :: Integer -> BasicSym +-- basic_from_integer i = unsafePerformIO $ do +-- iptr <- create_basic_ptr +-- withBasicSym iptr (\iptr -> integer_set_si_ffi iptr (fromInteger i)) +-- return iptr +-- +-- -- |The `ascii_art_str` function prints SymEngine in ASCII art. +-- -- this is useful as a sanity check +-- ascii_art_str :: IO String +-- ascii_art_str = ascii_art_str_ffi >>= peekCString +-- +-- -- Unexported ffi functions------------------------ +-- +-- -- |Create a basic object that represents all other objects through +-- -- the FFI +-- create_basic_ptr :: IO BasicSym +-- create_basic_ptr = do +-- basic_ptr <- newArray [BasicStruct {data_ptr = nullPtr}] +-- basic_new_heap_ffi basic_ptr +-- finalized_ptr <- newForeignPtr ptr_basic_free_heap_ffi basic_ptr +-- return $ BasicSym {fptr = finalized_ptr} +-- +-- basic_binaryop :: (Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO ()) -> BasicSym -> BasicSym -> BasicSym +-- basic_binaryop f a b = unsafePerformIO $ do +-- s <- create_basic_ptr +-- withBasicSym3 s a b f +-- return s +-- +-- basic_unaryop :: (Ptr BasicStruct -> Ptr BasicStruct -> IO ()) -> BasicSym -> BasicSym +-- basic_unaryop f a = unsafePerformIO $ do +-- s <- create_basic_ptr +-- withBasicSym2 s a f +-- return s +-- +-- basic_pow :: BasicSym -> BasicSym -> BasicSym +-- basic_pow = basic_binaryop basic_pow_ffi +-- +-- -- |Create a rational number with numerator and denominator +-- rational :: BasicSym -> BasicSym -> BasicSym +-- rational = basic_binaryop rational_set_ffi +-- +-- -- |Create a complex number a + b * im +-- complex :: BasicSym -> BasicSym -> BasicSym +-- complex a b = (basic_binaryop complex_set_ffi) a b +-- +-- basic_rational_from_integer :: Integer -> Integer -> BasicSym +-- basic_rational_from_integer i j = unsafePerformIO $ do +-- s <- create_basic_ptr +-- withBasicSym s (\s -> rational_set_si_ffi s (integerToCLong i) (integerToCLong j)) +-- return s +-- +-- -- |Create a symbol with the given name +-- symbol :: String -> BasicSym +-- symbol name = unsafePerformIO $ do +-- s <- create_basic_ptr +-- cname <- newCString name +-- withBasicSym s (\s -> symbol_set_ffi s cname) +-- free cname +-- return s +-- +-- -- |Differentiate an expression with respect to a symbol +-- diff :: BasicSym -> BasicSym -> BasicSym +-- diff expr symbol = (basic_binaryop basic_diff_ffi) expr symbol +-- +-- instance Show BasicSym where +-- show = basic_str +-- +-- instance Eq BasicSym where +-- (==) a b = unsafePerformIO $ do +-- i <- withBasicSym2 a b basic_eq_ffi +-- return $ i == 1 +-- +-- instance Num BasicSym where +-- (+) = basic_binaryop basic_add_ffi +-- (-) = basic_binaryop basic_sub_ffi +-- (*) = basic_binaryop basic_mul_ffi +-- negate = basic_unaryop basic_neg_ffi +-- abs = basic_unaryop basic_abs_ffi +-- signum = undefined +-- fromInteger = basic_from_integer +-- +-- instance Fractional BasicSym where +-- (/) = basic_binaryop basic_div_ffi +-- fromRational (num :% denom) = basic_rational_from_integer num denom +-- recip r = one / r +-- +-- instance Floating BasicSym where +-- pi = Symengine.pi +-- exp x = e ** x +-- log = undefined +-- sqrt x = x ** 1 / 2 +-- (**) = basic_pow +-- logBase = undefined +-- sin = basic_unaryop basic_sin_ffi +-- cos = basic_unaryop basic_cos_ffi +-- tan = basic_unaryop basic_tan_ffi +-- asin = basic_unaryop basic_asin_ffi +-- acos = basic_unaryop basic_acos_ffi +-- atan = basic_unaryop basic_atan_ffi +-- sinh = basic_unaryop basic_sinh_ffi +-- cosh = basic_unaryop basic_cosh_ffi +-- tanh = basic_unaryop basic_tanh_ffi +-- asinh = basic_unaryop basic_asinh_ffi +-- acosh = basic_unaryop basic_acosh_ffi +-- atanh = basic_unaryop basic_atanh_ffi +-- +-- foreign import ccall "symengine/cwrapper.h ascii_art_str" ascii_art_str_ffi :: IO CString +-- foreign import ccall "symengine/cwrapper.h basic_new_heap" basic_new_heap_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h &basic_free_heap" ptr_basic_free_heap_ffi :: FunPtr (Ptr BasicStruct -> IO ()) +-- +-- -- constants +-- foreign import ccall "symengine/cwrapper.h basic_const_zero" basic_const_zero_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_one" basic_const_one_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_minus_one" basic_const_minus_one_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_I" basic_const_I_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_pi" basic_const_pi_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_E" basic_const_E_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_EulerGamma" basic_const_EulerGamma_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_str" basic_str_ffi :: Ptr BasicStruct -> IO CString +-- foreign import ccall "symengine/cwrapper.h basic_eq" basic_eq_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO Int +-- +-- foreign import ccall "symengine/cwrapper.h symbol_set" symbol_set_ffi :: Ptr BasicStruct -> CString -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_diff" basic_diff_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h integer_set_si" integer_set_si_ffi :: Ptr BasicStruct -> CLong -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h rational_set" rational_set_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h rational_set_si" rational_set_si_ffi :: Ptr BasicStruct -> CLong -> CLong -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h complex_set" complex_set_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_expand" basic_expand_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_add" basic_add_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_sub" basic_sub_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_mul" basic_mul_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_div" basic_div_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_pow" basic_pow_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_neg" basic_neg_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_abs" basic_abs_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_sin" basic_sin_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_cos" basic_cos_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_tan" basic_tan_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_asin" basic_asin_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_acos" basic_acos_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_atan" basic_atan_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_sinh" basic_sinh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_cosh" basic_cosh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_tanh" basic_tanh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_asinh" basic_asinh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_acosh" basic_acosh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_atanh" basic_atanh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () diff --git a/src/Symengine/Context.hs b/src/Symengine/Context.hs new file mode 100644 index 0000000..4958b66 --- /dev/null +++ b/src/Symengine/Context.hs @@ -0,0 +1,89 @@ +{-# LANGUAGE OverloadedStrings #-} + +-- | +-- Module : Symengine.Context +-- Description : Helpers to setup inline-c for Symengine +-- Copyright : (c) Tom Westerhout, 2023 +-- +-- This module defines a Template Haskell function 'importSymengine' that sets up everything you need +-- to call SymEngine functions from 'Language.C.Inline' quasiquotes. +module Symengine.Context + ( importSymengine + ) +where + +import Data.Map.Strict qualified as Map +import Language.C.Inline qualified as C +import Language.C.Inline.Context (Context (ctxTypesTable)) +import Language.C.Inline.Cpp qualified as Cpp +import Language.C.Types (CIdentifier, TypeSpecifier (..)) +import Language.Haskell.TH (DecsQ, Q, TypeQ, lookupTypeName) +import Language.Haskell.TH.Syntax (Type (..)) + +-- | One stop function to include all the neccessary machinery to call SymEngine functions via +-- inline-c. +-- +-- Put @importSymengine@ somewhere at the beginning of the file and enjoy using the C interface of +-- SymEngine via inline-c quasiquotes. +importSymengine :: DecsQ +importSymengine = + concat + <$> sequence + [ C.context =<< symengineCxt + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , defineCxxUtils + ] + +symengineCxt :: Q C.Context +symengineCxt = do + typePairs <- Map.fromList <$> symengineTypePairs + pure $ + C.funCtx <> C.fptrCtx <> C.bsCtx <> Cpp.cppCtx <> C.baseCtx <> mempty {ctxTypesTable = typePairs} + +symengineTypePairs :: Q [(TypeSpecifier, TypeQ)] +symengineTypePairs = + optionals + [ ("Object", "CxxBasic") + , ("Vector", "Vector Basic") + , ("DenseMatrix", "DenseMatrix Basic") + , ("integer_class", "CxxInteger") + , ("std::string", "CxxString") + , ("map_basic_basic", "CxxMapBasicBasic") + ] + where + optional :: (CIdentifier, String) -> Q [(TypeSpecifier, TypeQ)] + optional (cName, hsName) = do + hsType <- case words hsName of + [x] -> fmap ConT <$> lookupTypeName x + -- TODO: generalize to multiple arguments + [f, x] -> do + con <- fmap ConT <$> lookupTypeName f + arg <- fmap ConT <$> lookupTypeName x + pure $ AppT <$> con <*> arg + _ -> pure Nothing + pure $ maybe [] (\x -> [(TypeName cName, pure x)]) hsType + optionals :: [(CIdentifier, String)] -> Q [(TypeSpecifier, TypeQ)] + optionals pairs = concat <$> mapM optional pairs + +defineCxxUtils :: DecsQ +defineCxxUtils = + C.verbatim + "\ + \using Object = SymEngine::RCP; \n\ + \using Vector = SymEngine::vec_basic; \n\ + \using namespace SymEngine; \n\ + \ \n\ + \#define CONSTRUCT_BASIC(dest, expr) new (dest) Object{expr} \n\ + \ \n\ + \" diff --git a/src/Symengine/Internal.hs b/src/Symengine/Internal.hs new file mode 100644 index 0000000..bd2fdf4 --- /dev/null +++ b/src/Symengine/Internal.hs @@ -0,0 +1,124 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -Wno-unused-matches #-} + +-- | +-- Module : Symengine.Internal +-- Description : Symengine bindings to Haskell +module Symengine.Internal + ( constructBasicFrom + , constructStringFrom + , createDenseMatrixVia + , mkUnaryFunction + , mkBinaryFunction + , unpackFunction + ) where + +import Control.Exception (bracket_) +import Data.ByteString (packCString) +import Data.Text.Encoding qualified as T +import Foreign.Marshal (allocaBytes) +import Language.C.Inline qualified as C +import Language.C.Inline.Unsafe qualified as CU +import Language.Haskell.TH (Exp, Q) +import System.IO.Unsafe +import Language.C.Inline.Cpp.Exception qualified as C + +constructBasicFrom :: String -> Q Exp +constructBasicFrom expr = + C.substitute + [("expr", const expr)] + [| + constructBasic $ \dest -> + [C.throwBlock| void { + using namespace SymEngine; + new ($(Object* dest)) Object{@expr()}; + } |] + |] + +constructStringFrom :: String -> Q Exp +constructStringFrom expr = + C.substitute + [("expr", const expr)] + [| + let size = fromIntegral [CU.pure| size_t { sizeof(std::string) } |] + construct s = + [CU.block| void { + using namespace SymEngine; + new ($(std::string* s)) std::string{@expr()}; + } |] + destruct s = [CU.exp| void { $(std::string* s)->~basic_string() } |] + in allocaBytes size $ \s -> + bracket_ (construct s) (destruct s) $ + fmap T.decodeUtf8 $ + packCString + =<< [CU.exp| char const* { $(const std::string* s)->c_str() } |] + |] + +createDenseMatrixVia :: String -> Q Exp +createDenseMatrixVia expr = + C.substitute + [("expr", const expr)] + [| + allocaDenseMatrix 0 0 $ \ptr -> do + [CU.block| void { + auto& out = *$(DenseMatrix* ptr); + @expr() + } |] + peekDenseMatrix ptr + |] + +mkUnaryFunction :: String -> Q Exp +mkUnaryFunction expr = + C.substitute + [ ("expr", const expr) + ] + [| + \a' -> + unsafePerformIO $ + withBasic a' $ \a -> + constructBasic $ \dest -> + [CU.block| void { + using namespace SymEngine; + auto const& a = *$(const Object* a); + new ($(Object* dest)) Object{@expr()}; + } |] + |] + +mkBinaryFunction :: String -> Q Exp +mkBinaryFunction expr = + C.substitute + [ ("expr", const expr) + ] + [| + \a' b' -> + unsafePerformIO $ + withBasic a' $ \a -> + withBasic b' $ \b -> + constructBasic $ \dest -> + [CU.block| void { + using namespace SymEngine; + auto const& a = *$(const Object* a); + auto const& b = *$(const Object* b); + new ($(Object* dest)) Object{@expr()}; + } |] + |] + +unpackFunction :: String -> Q Exp +unpackFunction className = + C.substitute + [ ("class", const className) + ] + [| + \f' -> + withBasic f' $ \f -> + allocaVector $ \v -> do + [CU.block| void { + using namespace SymEngine; + auto const& f = down_cast<@class() const&>(**$(const Object* f)); + *$(Vector* v) = f.get_args(); + } |] + peekVector v + |] diff --git a/stack.yaml b/stack.yaml deleted file mode 100644 index 7b5a9de..0000000 --- a/stack.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# For more information, see: https://github.com/commercialhaskell/stack/blob/master/doc/yaml_configuration.md - -# Specifies the GHC version and set of packages available (e.g., lts-3.5, nightly-2015-09-21, ghc-7.10.2) -resolver: lts-3.2 - -# Local packages, usually specified by relative directory name -packages: -- '.' - -# Packages to be pulled from upstream that are not in the resolver (e.g., acme-missiles-0.3) -extra-deps: [] - -# Override default flag values for local packages and extra-deps -flags: {} - -# Control whether we use the GHC we find on the path -# system-ghc: true - -# Require a specific version of stack, using version ranges -# require-stack-version: -any # Default -# require-stack-version: >= 0.1.4.0 - -# Override the architecture used by stack, especially useful on Windows -# arch: i386 -# arch: x86_64 - -# Extra directories used by stack for building -# extra-include-dirs: [/path/to/dir] -# extra-lib-dirs: [/path/to/dir] diff --git a/symengine.cabal b/symengine.cabal index 0f33da5..ebe9573 100644 --- a/symengine.cabal +++ b/symengine.cabal @@ -1,40 +1,91 @@ -name: symengine -version: 0.1.2.0 -synopsis: SymEngine symbolic mathematics engine for Haskell -description: Please see README.md -homepage: http://github.com/symengine/symengine.hs#readme -license: MIT -license-file: LICENSE -author: Siddharth Bhat -maintainer: siddu.druid@gmail.com -copyright: 2016 Siddharth Bhat -category: FFI, Math, Symbolic Computation -build-type: Simple --- extra-source-files: -cabal-version: >=1.10 +cabal-version: 3.0 +name: symengine +version: 0.2.0.0 +synopsis: SymEngine symbolic mathematics engine for Haskell +description: Please see README.md +homepage: https://github.com/symengine/symengine.hs +license: MIT +license-file: LICENSE +author: Siddharth Bhat +maintainer: siddu.druid@gmail.com +copyright: + 2016 Siddharth Bhat + 2023 Tom Westerhout + +category: FFI, Math, Symbolic Computation +build-type: Simple +tested-with: GHC ==9.2.7 + +flag no-flint + description: disable linking with Flint + manual: True + default: False + +flag no-mpc + description: disable linking with MPC + manual: True + default: False + +flag no-mpfr + description: disable linking with MPFR + manual: True + default: False + +common common-options + build-depends: base >=4.16.0.0 && <5 + ghc-options: + -Weverything -Wno-unsafe -Wno-all-missed-specialisations + -Wno-missing-safe-haskell-mode -Wno-implicit-prelude + -Wno-missing-import-lists -Wno-missing-kind-signatures + -Wno-monomorphism-restriction + + default-language: GHC2021 + default-extensions: DerivingStrategies library - hs-source-dirs: src - exposed-modules: Symengine - build-depends: base >= 4.5.0 && <= 5 - default-language: Haskell2010 + import: common-options + hs-source-dirs: src + exposed-modules: Symengine + other-modules: + Symengine.Context + Symengine.Internal + + build-depends: + , bytestring + , containers + , ghc-bignum + , hashable + , inline-c + , inline-c-cpp + , template-haskell + , text + , vector + + extra-libraries: symengine + + if !flag(no-flint) + extra-libraries: flint + + if !flag(no-mpc) + extra-libraries: mpc + + if !flag(no-mpfr) + extra-libraries: mpfr + + if os(linux) + extra-libraries: stdc++ + + if os(osx) + extra-libraries: c++ test-suite symengine-test - type: exitcode-stdio-1.0 - hs-source-dirs: test, src - main-is: Spec.hs - build-depends: base >= 4.5.0 && <= 5 - , symengine >= 0.1.1 && <= 0.2 - , tasty >= 0.10.0 && <= 0.13 - , tasty-hunit >= 0.9.0 && <= 1.5 - , tasty-quickcheck >= 0.8.0 && <= 1.5 - ghc-options: -threaded -rtsopts -with-rtsopts=-N - extra-libraries: symengine stdc++ gmpxx gmp - - other-modules: Symengine - - default-language: Haskell2010 - -source-repository head - type: git - location: https://github.com/symengine/symengine.hs + import: common-options + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Spec.hs + build-depends: + , hspec + , symengine + , text + + ghc-options: -threaded -rtsopts -with-rtsopts=-N diff --git a/test/Spec.hs b/test/Spec.hs index e934667..513fe2d 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,53 +1,69 @@ -import Test.Tasty -import Test.Tasty.QuickCheck as QC -import Test.Tasty.HUnit as HU +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} -import Data.List -import Data.Ord -import Data.Monoid +import Control.Monad (unless) +import Data.Ratio +import Data.Text (pack) +import Symengine +import Test.Hspec +import Test.Hspec.QuickCheck -import Symengine as Sym -import Prelude hiding (pi) +main :: IO () +main = hspec $ do + describe "Num" $ do + prop "Integer" $ \(a :: Integer) (b :: Integer) -> do + show (fromIntegral @_ @Basic a) `shouldBe` show a + show (fromIntegral @_ @Basic a + fromIntegral b) `shouldBe` show (a + b) + show (fromIntegral @_ @Basic a - fromIntegral b) `shouldBe` show (a - b) + show (fromIntegral @_ @Basic a * fromIntegral b) `shouldBe` show (a * b) + show (abs (fromIntegral @_ @Basic a)) `shouldBe` show (abs a) + show (negate (fromIntegral @_ @Basic a)) `shouldBe` show (negate a) + show (signum (fromIntegral @_ @Basic a)) `shouldBe` show (signum a) + describe "AST" $ do + prop "SymengineInteger" $ \(x :: Integer) -> do + toAST (fromInteger x) `shouldBe` SymengineInteger x + prop "SymengineRational" $ \(x :: Rational) -> do + if denominator x == 1 + then toAST (fromRational x) `shouldBe` SymengineInteger (numerator x) + else toAST (fromRational x) `shouldBe` SymengineRational x + it "SymengineConstant" $ do + toAST (pi :: Basic) `shouldBe` SymengineConstant pi + toAST e `shouldBe` SymengineConstant e + prop "SymengineSymbol" $ \(x :: String) -> do + unless ('\NUL' `elem` x) $ + toAST (symbol (pack x)) `shouldBe` SymengineSymbol (pack x) + it "SymengineInfinity" $ do + toAST infinity `shouldBe` SymengineInfinity + it "SymengineNaN" $ do + toAST nan `shouldBe` SymengineNaN + it "SymengineAdd" $ do + toAST (symbol "x" + symbol "z" + symbol "y") `shouldBe` SymengineAdd [symbol "x", symbol "z", symbol "y"] + toAST (symbol "x" + symbol "y") `shouldBe` SymengineAdd [symbol "x", symbol "y"] + it "SymengineMul" $ do + toAST (2 * symbol "y") `shouldBe` SymengineMul [2, symbol "y"] + toAST (-symbol "x") `shouldBe` SymengineMul [-1, symbol "x"] + it "SymenginePow" $ do + toAST (sqrt (symbol "y")) `shouldBe` SymenginePow (symbol "y") 0.5 + toAST (exp (symbol "y")) `shouldBe` SymenginePow e (symbol "y") + it "SymengineLog" $ do + toAST (log (symbol "y")) `shouldBe` SymengineLog (symbol "y") + it "SymengineSign" $ do + toAST (signum (symbol "y")) `shouldBe` SymengineSign (symbol "y") + it "SymengineFunction" $ do + toAST (parse "f(1, x, y + 2)") `shouldBe` SymengineFunction "f" [1, symbol "x", 2 + symbol "y"] + show (fromAST (SymengineFunction "f" [1, symbol "x", 2 + symbol "y"])) `shouldBe` "f(1, x, 2 + y)" + it "SymengineDerivative" $ do + toAST (diff (parse "f(x)") (symbol "x")) `shouldBe` SymengineDerivative (parse "f(x)") [symbol "x"] + toAST (diff (symbol "x" ** 2) (symbol "x")) `shouldBe` SymengineMul [2, symbol "x"] -main = defaultMain tests - -tests :: TestTree -tests = testGroup "Tests" [unitTests] - - --- These are used to check invariants that can be tested by creating --- random members of the type and then checking invariants on them - --- properties :: TestTree --- properties = testGroup "Properties" [qcProps] - -unitTests = testGroup "Unit tests" - [ HU.testCase "FFI Sanity Check - ASCII Art should be non-empty" $ - do - ascii_art <- Sym.ascii_art_str - HU.assertBool "ASCII art from ascii_art_str is empty" (not . null $ ascii_art) - - - , HU.testCase "Basic Constructors" $ - do - "0" @?= (show zero) - "1" @?= (show one) - "-1" @?= (show minus_one) - , HU.testCase "Basic Trignometric Functions" $ - do - let pi_over_3 = pi / 3 :: BasicSym - let pi_over_2 = pi / 2 :: BasicSym - - sin zero @?= zero - cos zero @?= one - - sin (pi / 6) @?= 1 / 2 - sin (pi / 3) @?= (3 ** (1/2)) / 2 - - cos (pi / 6) @?= (3 ** (1/2)) / 2 - cos (pi / 3) @?= 1 / 2 - - sin pi_over_2 @?= one - cos pi_over_2 @?= zero - - ] + describe "subs" $ do + it "" $ do + subs [("x", 1)] "a + f(x) / x" `shouldBe` "a + f(1)" + subs [("k", "c")] "a + b" `shouldBe` "a + b" + subs [] "a + b" `shouldBe` "a + b" + subs [("a + b", "c")] "a + b" `shouldBe` "c" + describe "Misc" $ do + it "" $ do + print $ parse "a + f(x) / x - 4**2" + print $ evalf EvalSymbolic 20 $ parse "a + 8/3 * f(x) / x - 4**2" + print $ inverse InverseDefault (identityMatrix 3)