diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 58460db..f78a92c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,6 +37,8 @@ jobs: java-version: 17 distribution: temurin - uses: gradle/gradle-build-action@v3 + - name: Install roboRIO toolchain + run: ./gradlew installRoboRioToolchain - uses: actions/setup-python@v5 with: python-version: 3.8 diff --git a/.styleguide b/.styleguide index 18b500a..03ca580 100644 --- a/.styleguide +++ b/.styleguide @@ -16,4 +16,5 @@ modifiableFileExclude { gradlew.bat src/main/java/com src/main/java/frc/robot/URCL.java + sleipnir/ } diff --git a/build.gradle b/build.gradle index 14d168a..4decb3b 100644 --- a/build.gradle +++ b/build.gradle @@ -1,6 +1,8 @@ plugins { id "java" + id "cpp" id "edu.wpi.first.GradleRIO" version "2024.3.2" + id 'edu.wpi.first.GradleJni' version '0.10.1' id 'com.github.spotbugs' version '6.0.18' apply false id 'com.diffplug.spotless' version '6.25.0' apply false } @@ -12,6 +14,15 @@ java { def ROBOT_MAIN_CLASS = "frc.robot.Main" +// Set up exports properly +nativeUtils { + exportsConfigs { + // Only export explicit symbols from driver library + ShooterTrajoptJNI { + } + } +} + // Define my targets (RoboRIO) and artifacts (deployable files) // This is added by GradleRIO's backing project DeployUtils. deploy { @@ -88,6 +99,37 @@ if (!project.hasProperty('noSimGUI')) { wpi.sim.addDriverstation() +model { + components { + ShooterTrajoptJNI(JniNativeLibrarySpec) { + targetPlatform wpi.platforms.roborio + + enableCheckTask true + javaCompileTasks << compileJava + jniCrossCompileOptions << JniCrossCompileOptions(wpi.platforms.roborio) + + sources { + cpp { + source { + srcDirs 'src/main/native/cpp' + include '**/*.cpp' + } + + exportedHeaders { + srcDir 'src/main/native/include' + } + } + } + + binaries.all { + lib project: ':sleipnir', library: 'Sleipnir', linkage: 'shared' + } + + nativeUtils.useRequiredLibrary(it, "driver_shared") + } + } +} + // Setting up my Jar File. In this case, adding all libraries into the main jar ('fat jar') // in order to make them all available at runtime. Also adding the manifest so WPILib // knows where to look for our Robot Class. diff --git a/settings.gradle b/settings.gradle index d94f73c..a7c0f3e 100644 --- a/settings.gradle +++ b/settings.gradle @@ -28,3 +28,6 @@ pluginManagement { Properties props = System.getProperties(); props.setProperty("org.gradle.internal.native.headers.unresolved.dependencies.ignore", "true"); + +include 'sleipnir' +rootProject.name = '2024-Offseason' diff --git a/sleipnir/LICENSE_small_vector.txt b/sleipnir/LICENSE_small_vector.txt new file mode 100644 index 0000000..4cd586a --- /dev/null +++ b/sleipnir/LICENSE_small_vector.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Gene Harvey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sleipnir/SLEIPNIR_LICENSE.txt b/sleipnir/SLEIPNIR_LICENSE.txt new file mode 100644 index 0000000..91bc360 --- /dev/null +++ b/sleipnir/SLEIPNIR_LICENSE.txt @@ -0,0 +1,11 @@ +Copyright (c) Sleipnir contributors + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/sleipnir/build.gradle b/sleipnir/build.gradle new file mode 100644 index 0000000..ef4de5f --- /dev/null +++ b/sleipnir/build.gradle @@ -0,0 +1,31 @@ +import edu.wpi.first.toolchain.NativePlatforms + +apply plugin: 'cpp' +apply plugin: 'edu.wpi.first.GradleRIO' + +model { + components { + Sleipnir(NativeLibrarySpec) { + targetPlatform NativePlatforms.desktop + targetPlatform NativePlatforms.roborio + + sources { + cpp { + source { + srcDir 'src/cpp' + include '**/*.cpp', '**/*.cc' + } + + exportedHeaders { + srcDirs 'src/include', 'src/cpp/' + } + } + } + binaries.all { + cppCompiler.args "-I${file('src/main/cpp').absolutePath}" + } + + wpi.cpp.deps.wpilib(it) + } + } +} diff --git a/sleipnir/src/cpp/autodiff/VariableMatrix.cpp b/sleipnir/src/cpp/autodiff/VariableMatrix.cpp new file mode 100644 index 0000000..811062a --- /dev/null +++ b/sleipnir/src/cpp/autodiff/VariableMatrix.cpp @@ -0,0 +1,115 @@ +// Copyright (c) Sleipnir contributors + +#include "sleipnir/autodiff/VariableMatrix.hpp" + +#include + +namespace Eigen { + +template <> +struct NumTraits : NumTraits { + using Real = sleipnir::Variable; + using NonInteger = sleipnir::Variable; + using Nested = sleipnir::Variable; + + enum { + IsComplex = 0, + IsInteger = 0, + IsSigned = 1, + RequireInitialization = 1, + ReadCost = 1, + AddCost = 3, + MulCost = 3 + }; +}; + +} // namespace Eigen + +// For Variable equality operator +#include "sleipnir/optimization/Constraints.hpp" + +namespace sleipnir { + +VariableMatrix Solve(const VariableMatrix& A, const VariableMatrix& B) { + // m x n * n x p = m x p + Assert(A.Rows() == B.Rows()); + + if (A.Rows() == 1 && A.Cols() == 1) { + // Compute optimal inverse instead of using Eigen's general solver + return B(0, 0) / A(0, 0); + } else if (A.Rows() == 2 && A.Cols() == 2) { + // Compute optimal inverse instead of using Eigen's general solver + // + // [a b]⁻¹ ___1___ [ d −b] + // [c d] = ad − bc [−c a] + + const auto& a = A(0, 0); + const auto& b = A(0, 1); + const auto& c = A(1, 0); + const auto& d = A(1, 1); + + sleipnir::VariableMatrix Ainv{{d, -b}, {-c, a}}; + auto detA = a * d - b * c; + Ainv /= detA; + + return Ainv * B; + } else if (A.Rows() == 3 && A.Cols() == 3) { + // Compute optimal inverse instead of using Eigen's general solver + // + // [a b c]⁻¹ + // [d e f] + // [g h i] + // 1 [ei − fh ch − bi bf − ce] + // = --------------------------------- [fg − di ai − cg cd − af] + // aei − afh − bdi + bfg + cdh − ceg [dh − eg bg − ah ae − bd] + + const auto& a = A(0, 0); + const auto& b = A(0, 1); + const auto& c = A(0, 2); + const auto& d = A(1, 0); + const auto& e = A(1, 1); + const auto& f = A(1, 2); + const auto& g = A(2, 0); + const auto& h = A(2, 1); + const auto& i = A(2, 2); + + sleipnir::VariableMatrix Ainv{ + {e * i - f * h, c * h - b * i, b * f - c * e}, + {f * g - d * i, a * i - c * g, c * d - a * f}, + {d * h - e * g, b * g - a * h, a * e - b * d}}; + auto detA = + a * e * i - a * f * h - b * d * i + b * f * g + c * d * h - c * e * g; + Ainv /= detA; + + return Ainv * B; + } else { + using MatrixXv = Eigen::Matrix; + + MatrixXv eigenA{A.Rows(), A.Cols()}; + for (int row = 0; row < A.Rows(); ++row) { + for (int col = 0; col < A.Cols(); ++col) { + eigenA(row, col) = A(row, col); + } + } + + MatrixXv eigenB{B.Rows(), B.Cols()}; + for (int row = 0; row < B.Rows(); ++row) { + for (int col = 0; col < B.Cols(); ++col) { + eigenB(row, col) = B(row, col); + } + } + + MatrixXv eigenX = eigenA.householderQr().solve(eigenB); + + VariableMatrix X{A.Cols(), B.Cols()}; + for (int row = 0; row < X.Rows(); ++row) { + for (int col = 0; col < X.Cols(); ++col) { + X(row, col) = eigenX(row, col); + } + } + + return X; + } +} + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/Inertia.hpp b/sleipnir/src/cpp/optimization/Inertia.hpp new file mode 100644 index 0000000..8ef2ebe --- /dev/null +++ b/sleipnir/src/cpp/optimization/Inertia.hpp @@ -0,0 +1,58 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include "sleipnir/util/Concepts.hpp" + +namespace sleipnir { + +/** + * Represents the inertia of a matrix (the number of positive, negative, and + * zero eigenvalues). + */ +class Inertia { + public: + size_t positive = 0; + size_t negative = 0; + size_t zero = 0; + + constexpr Inertia() = default; + + /** + * Constructs the Inertia type with the given number of positive, negative, + * and zero eigenvalues. + * + * @param positive The number of positive eigenvalues. + * @param negative The number of negative eigenvalues. + * @param zero The number of zero eigenvalues. + */ + constexpr Inertia(size_t positive, size_t negative, size_t zero) + : positive{positive}, negative{negative}, zero{zero} {} + + /** + * Constructs the Inertia type with the inertia of the given LDLT + * decomposition. + * + * @tparam Solver Eigen sparse linear system solver. + * @param solver The LDLT decomposition of which to compute the inertia. + */ + template + explicit Inertia(const Solver& solver) { + const auto& D = solver.vectorD(); + for (int row = 0; row < D.rows(); ++row) { + if (D(row) > 0.0) { + ++positive; + } else if (D(row) < 0.0) { + ++negative; + } else { + ++zero; + } + } + } + + bool operator==(const Inertia&) const = default; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/RegularizedLDLT.hpp b/sleipnir/src/cpp/optimization/RegularizedLDLT.hpp new file mode 100644 index 0000000..d2488b7 --- /dev/null +++ b/sleipnir/src/cpp/optimization/RegularizedLDLT.hpp @@ -0,0 +1,179 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include + +#include +#include +#include + +#include "optimization/Inertia.hpp" + +// See docs/algorithms.md#Works_cited for citation definitions + +namespace sleipnir { + +/** + * Solves systems of linear equations using a regularized LDLT factorization. + */ +class RegularizedLDLT { + public: + using Solver = Eigen::SimplicialLDLT, + Eigen::Lower, Eigen::AMDOrdering>; + + /** + * Constructs a RegularizedLDLT instance. + */ + RegularizedLDLT() = default; + + /** + * Reports whether previous computation was successful. + */ + Eigen::ComputationInfo Info() { return m_info; } + + /** + * Computes the regularized LDLT factorization of a matrix. + * + * @param lhs Left-hand side of the system. + * @param numEqualityConstraints The number of equality constraints in the + * system. + * @param μ The barrier parameter for the current interior-point iteration. + */ + void Compute(const Eigen::SparseMatrix& lhs, + size_t numEqualityConstraints, double μ) { + // The regularization procedure is based on algorithm B.1 of [1] + m_numDecisionVariables = lhs.rows() - numEqualityConstraints; + m_numEqualityConstraints = numEqualityConstraints; + + const Inertia idealInertia{m_numDecisionVariables, m_numEqualityConstraints, + 0}; + Inertia inertia; + + double δ = 0.0; + double γ = 0.0; + + AnalyzePattern(lhs); + m_solver.factorize(lhs); + + if (m_solver.info() == Eigen::Success) { + inertia = Inertia{m_solver}; + + // If the inertia is ideal, don't regularize the system + if (inertia == idealInertia) { + m_info = Eigen::Success; + return; + } + } + + // If the decomposition succeeded and the inertia has some zero eigenvalues, + // or the decomposition failed, regularize the equality constraints + if ((m_solver.info() == Eigen::Success && inertia.zero > 0) || + m_solver.info() != Eigen::Success) { + γ = 1e-8 * std::pow(μ, 0.25); + } + + // Also regularize the Hessian. If the Hessian wasn't regularized in a + // previous run of Compute(), start at a small value of δ. Otherwise, + // attempt a δ half as big as the previous run so δ can trend downwards over + // time. + if (m_δOld == 0.0) { + δ = 1e-4; + } else { + δ = m_δOld / 2.0; + } + + while (true) { + // Regularize lhs by adding a multiple of the identity matrix + // + // lhs = [H + AᵢᵀΣAᵢ + δI Aₑᵀ] + // [ Aₑ −γI ] + Eigen::SparseMatrix lhsReg = lhs + Regularization(δ, γ); + AnalyzePattern(lhsReg); + m_solver.factorize(lhsReg); + inertia = Inertia{m_solver}; + + // If the inertia is ideal, store that value of δ and return. + // Otherwise, increase δ by an order of magnitude and try again. + if (inertia == idealInertia) { + m_δOld = δ; + m_info = Eigen::Success; + return; + } else { + δ *= 10.0; + + // If the Hessian perturbation is too high, report failure. This can + // happen due to a rank-deficient equality constraint Jacobian with + // linearly dependent constraints. + if (δ > 1e20) { + m_info = Eigen::NumericalIssue; + return; + } + } + } + } + + /** + * Solve the system of equations using a regularized LDLT factorization. + * + * @param rhs Right-hand side of the system. + */ + template + auto Solve(const Eigen::MatrixBase& rhs) { + return m_solver.solve(rhs); + } + + private: + Solver m_solver; + + Eigen::ComputationInfo m_info = Eigen::Success; + + /// The number of decision variables in the system. + size_t m_numDecisionVariables = 0; + + /// The number of equality constraints in the system. + size_t m_numEqualityConstraints = 0; + + /// The value of δ from the previous run of Compute(). + double m_δOld = 0.0; + + // Number of non-zeros in LHS. + int m_nonZeros = -1; + + /** + * Reanalize LHS matrix's sparsity pattern if it changed. + * + * @param lhs Matrix to analyze. + */ + void AnalyzePattern(const Eigen::SparseMatrix& lhs) { + int nonZeros = lhs.nonZeros(); + if (m_nonZeros != nonZeros) { + m_solver.analyzePattern(lhs); + m_nonZeros = nonZeros; + } + } + + /** + * Returns regularization matrix. + * + * @param δ The Hessian regularization factor. + * @param γ The equality constraint Jacobian regularization factor. + */ + Eigen::SparseMatrix Regularization(double δ, double γ) { + Eigen::VectorXd vec{m_numDecisionVariables + m_numEqualityConstraints}; + size_t row = 0; + while (row < m_numDecisionVariables) { + vec(row) = δ; + ++row; + } + while (row < m_numDecisionVariables + m_numEqualityConstraints) { + vec(row) = -γ; + ++row; + } + + return Eigen::SparseMatrix{vec.asDiagonal()}; + } +}; + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/solver/InteriorPoint.cpp b/sleipnir/src/cpp/optimization/solver/InteriorPoint.cpp new file mode 100644 index 0000000..817e3d1 --- /dev/null +++ b/sleipnir/src/cpp/optimization/solver/InteriorPoint.cpp @@ -0,0 +1,829 @@ +// Copyright (c) Sleipnir contributors + +#include "sleipnir/optimization/solver/InteriorPoint.hpp" + +#include +#include +#include +#include +#include + +#include + +#include "optimization/RegularizedLDLT.hpp" +#include "optimization/solver/util/ErrorEstimate.hpp" +#include "optimization/solver/util/FeasibilityRestoration.hpp" +#include "optimization/solver/util/Filter.hpp" +#include "optimization/solver/util/FractionToTheBoundaryRule.hpp" +#include "optimization/solver/util/IsLocallyInfeasible.hpp" +#include "optimization/solver/util/KKTError.hpp" +#include "sleipnir/autodiff/Gradient.hpp" +#include "sleipnir/autodiff/Hessian.hpp" +#include "sleipnir/autodiff/Jacobian.hpp" +#include "sleipnir/optimization/SolverExitCondition.hpp" +#include "sleipnir/util/Print.hpp" +#include "sleipnir/util/Spy.hpp" +#include "sleipnir/util/small_vector.hpp" +#include "util/ScopeExit.hpp" +#include "util/ToMilliseconds.hpp" + +// See docs/algorithms.md#Works_cited for citation definitions. +// +// See docs/algorithms.md#Interior-point_method for a derivation of the +// interior-point method formulation being used. + +namespace sleipnir { + +void InteriorPoint(std::span decisionVariables, + std::span equalityConstraints, + std::span inequalityConstraints, Variable& f, + function_ref callback, + const SolverConfig& config, bool feasibilityRestoration, + Eigen::VectorXd& x, Eigen::VectorXd& s, + SolverStatus* status) { + const auto solveStartTime = std::chrono::system_clock::now(); + + // Map decision variables and constraints to VariableMatrices for Lagrangian + VariableMatrix xAD{decisionVariables}; + xAD.SetValue(x); + VariableMatrix c_eAD{equalityConstraints}; + VariableMatrix c_iAD{inequalityConstraints}; + + // Create autodiff variables for s, y, and z for Lagrangian + VariableMatrix sAD(inequalityConstraints.size()); + sAD.SetValue(s); + VariableMatrix yAD(equalityConstraints.size()); + for (auto& y : yAD) { + y.SetValue(0.0); + } + VariableMatrix zAD(inequalityConstraints.size()); + for (auto& z : zAD) { + z.SetValue(1.0); + } + + // Lagrangian L + // + // L(xₖ, sₖ, yₖ, zₖ) = f(xₖ) − yₖᵀcₑ(xₖ) − zₖᵀ(cᵢ(xₖ) − sₖ) + auto L = f - (yAD.T() * c_eAD)(0) - (zAD.T() * (c_iAD - sAD))(0); + + // Equality constraint Jacobian Aₑ + // + // [∇ᵀcₑ₁(xₖ)] + // Aₑ(x) = [∇ᵀcₑ₂(xₖ)] + // [ ⋮ ] + // [∇ᵀcₑₘ(xₖ)] + Jacobian jacobianCe{c_eAD, xAD}; + Eigen::SparseMatrix A_e = jacobianCe.Value(); + + // Inequality constraint Jacobian Aᵢ + // + // [∇ᵀcᵢ₁(xₖ)] + // Aᵢ(x) = [∇ᵀcᵢ₂(xₖ)] + // [ ⋮ ] + // [∇ᵀcᵢₘ(xₖ)] + Jacobian jacobianCi{c_iAD, xAD}; + Eigen::SparseMatrix A_i = jacobianCi.Value(); + + // Gradient of f ∇f + Gradient gradientF{f, xAD}; + Eigen::SparseVector g = gradientF.Value(); + + // Hessian of the Lagrangian H + // + // Hₖ = ∇²ₓₓL(xₖ, sₖ, yₖ, zₖ) + Hessian hessianL{L, xAD}; + Eigen::SparseMatrix H = hessianL.Value(); + + Eigen::VectorXd y = yAD.Value(); + Eigen::VectorXd z = zAD.Value(); + Eigen::VectorXd c_e = c_eAD.Value(); + Eigen::VectorXd c_i = c_iAD.Value(); + + // Check for overconstrained problem + if (equalityConstraints.size() > decisionVariables.size()) { + if (config.diagnostics) { + sleipnir::println("The problem has too few degrees of freedom."); + sleipnir::println( + "Violated constraints (cₑ(x) = 0) in order of declaration:"); + for (int row = 0; row < c_e.rows(); ++row) { + if (c_e(row) < 0.0) { + sleipnir::println(" {}/{}: {} = 0", row + 1, c_e.rows(), c_e(row)); + } + } + } + + status->exitCondition = SolverExitCondition::kTooFewDOFs; + return; + } + + // Check whether initial guess has finite f(xₖ), cₑ(xₖ), and cᵢ(xₖ) + if (!std::isfinite(f.Value()) || !c_e.allFinite() || !c_i.allFinite()) { + status->exitCondition = + SolverExitCondition::kNonfiniteInitialCostOrConstraints; + return; + } + + // Sparsity pattern files written when spy flag is set in SolverConfig + std::ofstream H_spy; + std::ofstream A_e_spy; + std::ofstream A_i_spy; + if (config.spy) { + A_e_spy.open("A_e.spy"); + A_i_spy.open("A_i.spy"); + H_spy.open("H.spy"); + } + + if (config.diagnostics && !feasibilityRestoration) { + sleipnir::println("Error tolerance: {}\n", config.tolerance); + } + + std::chrono::system_clock::time_point iterationsStartTime; + + int iterations = 0; + + // Prints final diagnostics when the solver exits + scope_exit exit{[&] { + status->cost = f.Value(); + + if (config.diagnostics && !feasibilityRestoration) { + auto solveEndTime = std::chrono::system_clock::now(); + + sleipnir::println("\nSolve time: {:.3f} ms", + ToMilliseconds(solveEndTime - solveStartTime)); + sleipnir::println(" ↳ {:.3f} ms (solver setup)", + ToMilliseconds(iterationsStartTime - solveStartTime)); + if (iterations > 0) { + sleipnir::println( + " ↳ {:.3f} ms ({} solver iterations; {:.3f} ms average)", + ToMilliseconds(solveEndTime - iterationsStartTime), iterations, + ToMilliseconds((solveEndTime - iterationsStartTime) / iterations)); + } + sleipnir::println(""); + + sleipnir::println("{:^8} {:^10} {:^14} {:^6}", "autodiff", + "setup (ms)", "avg solve (ms)", "solves"); + sleipnir::println("{:=^47}", ""); + constexpr auto format = "{:^8} {:10.3f} {:14.3f} {:6}"; + sleipnir::println(format, "∇f(x)", + gradientF.GetProfiler().SetupDuration(), + gradientF.GetProfiler().AverageSolveDuration(), + gradientF.GetProfiler().SolveMeasurements()); + sleipnir::println(format, "∇²ₓₓL", hessianL.GetProfiler().SetupDuration(), + hessianL.GetProfiler().AverageSolveDuration(), + hessianL.GetProfiler().SolveMeasurements()); + sleipnir::println(format, "∂cₑ/∂x", + jacobianCe.GetProfiler().SetupDuration(), + jacobianCe.GetProfiler().AverageSolveDuration(), + jacobianCe.GetProfiler().SolveMeasurements()); + sleipnir::println(format, "∂cᵢ/∂x", + jacobianCi.GetProfiler().SetupDuration(), + jacobianCi.GetProfiler().AverageSolveDuration(), + jacobianCi.GetProfiler().SolveMeasurements()); + sleipnir::println(""); + } + }}; + + // Barrier parameter minimum + const double μ_min = config.tolerance / 10.0; + + // Barrier parameter μ + double μ = 0.1; + + // Fraction-to-the-boundary rule scale factor minimum + constexpr double τ_min = 0.99; + + // Fraction-to-the-boundary rule scale factor τ + double τ = τ_min; + + Filter filter{f, μ}; + + // This should be run when the error estimate is below a desired threshold for + // the current barrier parameter + auto UpdateBarrierParameterAndResetFilter = [&] { + // Barrier parameter linear decrease power in "κ_μ μ". Range of (0, 1). + constexpr double κ_μ = 0.2; + + // Barrier parameter superlinear decrease power in "μ^(θ_μ)". Range of (1, + // 2). + constexpr double θ_μ = 1.5; + + // Update the barrier parameter. + // + // μⱼ₊₁ = max(εₜₒₗ/10, min(κ_μ μⱼ, μⱼ^θ_μ)) + // + // See equation (7) of [2]. + μ = std::max(μ_min, std::min(κ_μ * μ, std::pow(μ, θ_μ))); + + // Update the fraction-to-the-boundary rule scaling factor. + // + // τⱼ = max(τₘᵢₙ, 1 − μⱼ) + // + // See equation (8) of [2]. + τ = std::max(τ_min, 1.0 - μ); + + // Reset the filter when the barrier parameter is updated + filter.Reset(μ); + }; + + // Kept outside the loop so its storage can be reused + small_vector> triplets; + + RegularizedLDLT solver; + + // Variables for determining when a step is acceptable + constexpr double α_red_factor = 0.5; + int acceptableIterCounter = 0; + + int fullStepRejectedCounter = 0; + int stepTooSmallCounter = 0; + + // Error estimate + double E_0 = std::numeric_limits::infinity(); + + iterationsStartTime = std::chrono::system_clock::now(); + + while (E_0 > config.tolerance && + acceptableIterCounter < config.maxAcceptableIterations) { + auto innerIterStartTime = std::chrono::system_clock::now(); + + // Check for local equality constraint infeasibility + if (IsEqualityLocallyInfeasible(A_e, c_e)) { + if (config.diagnostics) { + sleipnir::println( + "The problem is locally infeasible due to violated equality " + "constraints."); + sleipnir::println( + "Violated constraints (cₑ(x) = 0) in order of declaration:"); + for (int row = 0; row < c_e.rows(); ++row) { + if (c_e(row) < 0.0) { + sleipnir::println(" {}/{}: {} = 0", row + 1, c_e.rows(), c_e(row)); + } + } + } + + status->exitCondition = SolverExitCondition::kLocallyInfeasible; + return; + } + + // Check for local inequality constraint infeasibility + if (IsInequalityLocallyInfeasible(A_i, c_i)) { + if (config.diagnostics) { + sleipnir::println( + "The problem is infeasible due to violated inequality " + "constraints."); + sleipnir::println( + "Violated constraints (cᵢ(x) ≥ 0) in order of declaration:"); + for (int row = 0; row < c_i.rows(); ++row) { + if (c_i(row) < 0.0) { + sleipnir::println(" {}/{}: {} ≥ 0", row + 1, c_i.rows(), c_i(row)); + } + } + } + + status->exitCondition = SolverExitCondition::kLocallyInfeasible; + return; + } + + // Check for diverging iterates + if (x.lpNorm() > 1e20 || !x.allFinite() || + s.lpNorm() > 1e20 || !s.allFinite()) { + status->exitCondition = SolverExitCondition::kDivergingIterates; + return; + } + + // Write out spy file contents if that's enabled + if (config.spy) { + // Gap between sparsity patterns + if (iterations > 0) { + A_e_spy << "\n"; + A_i_spy << "\n"; + H_spy << "\n"; + } + + Spy(H_spy, H); + Spy(A_e_spy, A_e); + Spy(A_i_spy, A_i); + } + + // Call user callback + if (callback({iterations, x, s, g, H, A_e, A_i})) { + status->exitCondition = SolverExitCondition::kCallbackRequestedStop; + return; + } + + // [s₁ 0 ⋯ 0 ] + // S = [0 ⋱ ⋮ ] + // [⋮ ⋱ 0 ] + // [0 ⋯ 0 sₘ] + const auto S = s.asDiagonal(); + Eigen::SparseMatrix Sinv; + Sinv = s.cwiseInverse().asDiagonal(); + + // [z₁ 0 ⋯ 0 ] + // Z = [0 ⋱ ⋮ ] + // [⋮ ⋱ 0 ] + // [0 ⋯ 0 zₘ] + const auto Z = z.asDiagonal(); + Eigen::SparseMatrix Zinv; + Zinv = z.cwiseInverse().asDiagonal(); + + // Σ = S⁻¹Z + const Eigen::SparseMatrix Σ = Sinv * Z; + + // lhs = [H + AᵢᵀΣAᵢ Aₑᵀ] + // [ Aₑ 0 ] + // + // Don't assign upper triangle because solver only uses lower triangle. + const Eigen::SparseMatrix topLeft = + H.triangularView() + + (A_i.transpose() * Σ * A_i).triangularView(); + triplets.clear(); + triplets.reserve(topLeft.nonZeros() + A_e.nonZeros()); + for (int col = 0; col < H.cols(); ++col) { + // Append column of H + AᵢᵀΣAᵢ lower triangle in top-left quadrant + for (Eigen::SparseMatrix::InnerIterator it{topLeft, col}; it; + ++it) { + triplets.emplace_back(it.row(), it.col(), it.value()); + } + // Append column of Aₑ in bottom-left quadrant + for (Eigen::SparseMatrix::InnerIterator it{A_e, col}; it; ++it) { + triplets.emplace_back(H.rows() + it.row(), it.col(), it.value()); + } + } + Eigen::SparseMatrix lhs( + decisionVariables.size() + equalityConstraints.size(), + decisionVariables.size() + equalityConstraints.size()); + lhs.setFromSortedTriplets(triplets.begin(), triplets.end(), + [](const auto&, const auto& b) { return b; }); + + const Eigen::VectorXd e = Eigen::VectorXd::Ones(s.rows()); + + // rhs = −[∇f − Aₑᵀy + Aᵢᵀ(S⁻¹(Zcᵢ − μe) − z)] + // [ cₑ ] + Eigen::VectorXd rhs{x.rows() + y.rows()}; + rhs.segment(0, x.rows()) = + -(g - A_e.transpose() * y + + A_i.transpose() * (Sinv * (Z * c_i - μ * e) - z)); + rhs.segment(x.rows(), y.rows()) = -c_e; + + // Solve the Newton-KKT system + solver.Compute(lhs, equalityConstraints.size(), μ); + Eigen::VectorXd step{x.rows() + y.rows()}; + if (solver.Info() == Eigen::Success) { + step = solver.Solve(rhs); + } else { + // The regularization procedure failed due to a rank-deficient equality + // constraint Jacobian with linearly dependent constraints. Set the step + // length to zero and let second-order corrections attempt to restore + // feasibility. + step.setZero(); + } + + // step = [ pₖˣ] + // [−pₖʸ] + Eigen::VectorXd p_x = step.segment(0, x.rows()); + Eigen::VectorXd p_y = -step.segment(x.rows(), y.rows()); + + // pₖᶻ = −Σcᵢ + μS⁻¹e − ΣAᵢpₖˣ + Eigen::VectorXd p_z = -Σ * c_i + μ * Sinv * e - Σ * A_i * p_x; + + // pₖˢ = μZ⁻¹e − s − Z⁻¹Spₖᶻ + Eigen::VectorXd p_s = μ * Zinv * e - s - Zinv * S * p_z; + + // αᵐᵃˣ = max(α ∈ (0, 1] : sₖ + αpₖˢ ≥ (1−τⱼ)sₖ) + const double α_max = FractionToTheBoundaryRule(s, p_s, τ); + double α = α_max; + + // αₖᶻ = max(α ∈ (0, 1] : zₖ + αpₖᶻ ≥ (1−τⱼ)zₖ) + double α_z = FractionToTheBoundaryRule(z, p_z, τ); + + // Loop until a step is accepted. If a step becomes acceptable, the loop + // will exit early. + while (1) { + Eigen::VectorXd trial_x = x + α * p_x; + Eigen::VectorXd trial_y = y + α_z * p_y; + Eigen::VectorXd trial_z = z + α_z * p_z; + + xAD.SetValue(trial_x); + + Eigen::VectorXd trial_c_e = c_eAD.Value(); + Eigen::VectorXd trial_c_i = c_iAD.Value(); + + // If f(xₖ + αpₖˣ), cₑ(xₖ + αpₖˣ), or cᵢ(xₖ + αpₖˣ) aren't finite, reduce + // step size immediately + if (!std::isfinite(f.Value()) || !trial_c_e.allFinite() || + !trial_c_i.allFinite()) { + // Reduce step size + α *= α_red_factor; + continue; + } + + Eigen::VectorXd trial_s; + if (config.feasibleIPM && c_i.cwiseGreater(0.0).all()) { + // If the inequality constraints are all feasible, prevent them from + // becoming infeasible again. + // + // See equation (19.30) in [1]. + trial_s = trial_c_i; + } else { + trial_s = s + α * p_s; + } + + // Check whether filter accepts trial iterate + auto entry = filter.MakeEntry(trial_s, trial_c_e, trial_c_i); + if (filter.TryAdd(entry)) { + // Accept step + break; + } + + double prevConstraintViolation = c_e.lpNorm<1>() + (c_i - s).lpNorm<1>(); + double nextConstraintViolation = + trial_c_e.lpNorm<1>() + (trial_c_i - trial_s).lpNorm<1>(); + + // Second-order corrections + // + // If first trial point was rejected and constraint violation stayed the + // same or went up, apply second-order corrections + if (nextConstraintViolation >= prevConstraintViolation) { + // Apply second-order corrections. See section 2.4 of [2]. + Eigen::VectorXd p_x_cor = p_x; + Eigen::VectorXd p_y_soc = p_y; + Eigen::VectorXd p_z_soc = p_z; + Eigen::VectorXd p_s_soc = p_s; + + double α_soc = α; + Eigen::VectorXd c_e_soc = c_e; + + bool stepAcceptable = false; + for (int soc_iteration = 0; soc_iteration < 5 && !stepAcceptable; + ++soc_iteration) { + // Rebuild Newton-KKT rhs with updated constraint values. + // + // rhs = −[∇f − Aₑᵀy + Aᵢᵀ(S⁻¹(Zcᵢ − μe) − z)] + // [ cₑˢᵒᶜ ] + // + // where cₑˢᵒᶜ = αc(xₖ) + c(xₖ + αpₖˣ) + c_e_soc = α_soc * c_e_soc + trial_c_e; + rhs.bottomRows(y.rows()) = -c_e_soc; + + // Solve the Newton-KKT system + step = solver.Solve(rhs); + + p_x_cor = step.segment(0, x.rows()); + p_y_soc = -step.segment(x.rows(), y.rows()); + + // pₖᶻ = −Σcᵢ + μS⁻¹e − ΣAᵢpₖˣ + p_z_soc = -Σ * c_i + μ * Sinv * e - Σ * A_i * p_x_cor; + + // pₖˢ = μZ⁻¹e − s − Z⁻¹Spₖᶻ + p_s_soc = μ * Zinv * e - s - Zinv * S * p_z_soc; + + // αˢᵒᶜ = max(α ∈ (0, 1] : sₖ + αpₖˢ ≥ (1−τⱼ)sₖ) + α_soc = FractionToTheBoundaryRule(s, p_s_soc, τ); + trial_x = x + α_soc * p_x_cor; + trial_s = s + α_soc * p_s_soc; + + // αₖᶻ = max(α ∈ (0, 1] : zₖ + αpₖᶻ ≥ (1−τⱼ)zₖ) + double α_z_soc = FractionToTheBoundaryRule(z, p_z_soc, τ); + trial_y = y + α_z_soc * p_y_soc; + trial_z = z + α_z_soc * p_z_soc; + + xAD.SetValue(trial_x); + + trial_c_e = c_eAD.Value(); + trial_c_i = c_iAD.Value(); + + // Check whether filter accepts trial iterate + entry = filter.MakeEntry(trial_s, trial_c_e, trial_c_i); + if (filter.TryAdd(entry)) { + p_x = p_x_cor; + p_y = p_y_soc; + p_z = p_z_soc; + p_s = p_s_soc; + α = α_soc; + α_z = α_z_soc; + stepAcceptable = true; + } + } + + if (stepAcceptable) { + // Accept step + break; + } + } + + // If we got here and α is the full step, the full step was rejected. + // Increment the full-step rejected counter to keep track of how many full + // steps have been rejected in a row. + if (α == α_max) { + ++fullStepRejectedCounter; + } + + // If the full step was rejected enough times in a row, reset the filter + // because it may be impeding progress. + // + // See section 3.2 case I of [2]. + if (fullStepRejectedCounter >= 4 && + filter.maxConstraintViolation > entry.constraintViolation / 10.0) { + filter.maxConstraintViolation *= 0.1; + filter.Reset(μ); + continue; + } + + // Reduce step size + α *= α_red_factor; + + // Safety factor for the minimal step size + constexpr double α_min_frac = 0.05; + + // If step size hit a minimum, check if the KKT error was reduced. If it + // wasn't, invoke feasibility restoration. + if (α < α_min_frac * Filter::γConstraint) { + double currentKKTError = KKTError(g, A_e, c_e, A_i, c_i, s, y, z, μ); + + Eigen::VectorXd trial_x = x + α_max * p_x; + Eigen::VectorXd trial_s = s + α_max * p_s; + + Eigen::VectorXd trial_y = y + α_z * p_y; + Eigen::VectorXd trial_z = z + α_z * p_z; + + // Upate autodiff + xAD.SetValue(trial_x); + sAD.SetValue(trial_s); + yAD.SetValue(trial_y); + zAD.SetValue(trial_z); + + Eigen::VectorXd trial_c_e = c_eAD.Value(); + Eigen::VectorXd trial_c_i = c_iAD.Value(); + + double nextKKTError = KKTError(gradientF.Value(), jacobianCe.Value(), + trial_c_e, jacobianCi.Value(), trial_c_i, + trial_s, trial_y, trial_z, μ); + + // If the step using αᵐᵃˣ reduced the KKT error, accept it anyway + if (nextKKTError <= 0.999 * currentKKTError) { + α = α_max; + + // Accept step + break; + } + + // If the step direction was bad and feasibility restoration is + // already running, running it again won't help + if (feasibilityRestoration) { + status->exitCondition = SolverExitCondition::kLocallyInfeasible; + return; + } + + auto initialEntry = filter.MakeEntry(s, c_e, c_i); + + // Feasibility restoration phase + Eigen::VectorXd fr_x = x; + Eigen::VectorXd fr_s = s; + SolverStatus fr_status; + FeasibilityRestoration( + decisionVariables, equalityConstraints, inequalityConstraints, μ, + [&](const SolverIterationInfo& info) { + Eigen::VectorXd trial_x = + info.x.segment(0, decisionVariables.size()); + xAD.SetValue(trial_x); + + Eigen::VectorXd trial_s = + info.s.segment(0, inequalityConstraints.size()); + sAD.SetValue(trial_s); + + Eigen::VectorXd trial_c_e = c_eAD.Value(); + Eigen::VectorXd trial_c_i = c_iAD.Value(); + + // If current iterate is acceptable to normal filter and + // constraint violation has sufficiently reduced, stop + // feasibility restoration + auto entry = filter.MakeEntry(trial_s, trial_c_e, trial_c_i); + if (filter.IsAcceptable(entry) && + entry.constraintViolation < + 0.9 * initialEntry.constraintViolation) { + return true; + } + + return false; + }, + config, fr_x, fr_s, &fr_status); + + if (fr_status.exitCondition == + SolverExitCondition::kCallbackRequestedStop) { + p_x = fr_x - x; + p_s = fr_s - s; + + // Lagrange mutliplier estimates + // + // [y] = (ÂÂᵀ)⁻¹Â[ ∇f] + // [z] [−μe] + // + // where  = [Aₑ 0] + // [Aᵢ −S] + // + // See equation (19.37) of [1]. + { + xAD.SetValue(fr_x); + sAD.SetValue(c_iAD.Value()); + + A_e = jacobianCe.Value(); + A_i = jacobianCi.Value(); + g = gradientF.Value(); + + //  = [Aₑ 0] + // [Aᵢ −S] + triplets.clear(); + triplets.reserve(A_e.nonZeros() + A_i.nonZeros() + s.rows()); + for (int col = 0; col < A_e.cols(); ++col) { + // Append column of Aₑ in top-left quadrant + for (Eigen::SparseMatrix::InnerIterator it{A_e, col}; it; + ++it) { + triplets.emplace_back(it.row(), it.col(), it.value()); + } + // Append column of Aᵢ in bottom-left quadrant + for (Eigen::SparseMatrix::InnerIterator it{A_i, col}; it; + ++it) { + triplets.emplace_back(A_e.rows() + it.row(), it.col(), + it.value()); + } + } + // Append −S in bottom-right quadrant + for (int i = 0; i < s.rows(); ++i) { + triplets.emplace_back(A_e.rows() + i, A_e.cols() + i, -s(i)); + } + Eigen::SparseMatrix Ahat{A_e.rows() + A_i.rows(), + A_e.cols() + s.rows()}; + Ahat.setFromSortedTriplets( + triplets.begin(), triplets.end(), + [](const auto&, const auto& b) { return b; }); + + // lhs = ÂÂᵀ + Eigen::SparseMatrix lhs = Ahat * Ahat.transpose(); + + // rhs = Â[ ∇f] + // [−μe] + Eigen::VectorXd rhsTemp{g.rows() + e.rows()}; + rhsTemp.block(0, 0, g.rows(), 1) = g; + rhsTemp.block(g.rows(), 0, s.rows(), 1) = -μ * e; + Eigen::VectorXd rhs = Ahat * rhsTemp; + + Eigen::SimplicialLDLT> yzEstimator{lhs}; + Eigen::VectorXd sol = yzEstimator.solve(rhs); + + p_y = y - sol.block(0, 0, y.rows(), 1); + p_z = z - sol.block(y.rows(), 0, z.rows(), 1); + } + + α = 1.0; + α_z = 1.0; + + // Accept step + break; + } else if (fr_status.exitCondition == SolverExitCondition::kSuccess) { + status->exitCondition = SolverExitCondition::kLocallyInfeasible; + x = fr_x; + return; + } else { + status->exitCondition = + SolverExitCondition::kFeasibilityRestorationFailed; + x = fr_x; + return; + } + } + } + + // If full step was accepted, reset full-step rejected counter + if (α == α_max) { + fullStepRejectedCounter = 0; + } + + // Handle very small search directions by letting αₖ = αₖᵐᵃˣ when + // max(|pₖˣ(i)|/(1 + |xₖ(i)|)) < 10ε_mach. + // + // See section 3.9 of [2]. + double maxStepScaled = 0.0; + for (int row = 0; row < x.rows(); ++row) { + maxStepScaled = std::max(maxStepScaled, + std::abs(p_x(row)) / (1.0 + std::abs(x(row)))); + } + if (maxStepScaled < 10.0 * std::numeric_limits::epsilon()) { + α = α_max; + ++stepTooSmallCounter; + } else { + stepTooSmallCounter = 0; + } + + // xₖ₊₁ = xₖ + αₖpₖˣ + // sₖ₊₁ = sₖ + αₖpₖˢ + // yₖ₊₁ = yₖ + αₖᶻpₖʸ + // zₖ₊₁ = zₖ + αₖᶻpₖᶻ + x += α * p_x; + s += α * p_s; + y += α_z * p_y; + z += α_z * p_z; + + // A requirement for the convergence proof is that the "primal-dual barrier + // term Hessian" Σₖ does not deviate arbitrarily much from the "primal + // Hessian" μⱼSₖ⁻². We ensure this by resetting + // + // zₖ₊₁⁽ⁱ⁾ = max(min(zₖ₊₁⁽ⁱ⁾, κ_Σ μⱼ/sₖ₊₁⁽ⁱ⁾), μⱼ/(κ_Σ sₖ₊₁⁽ⁱ⁾)) + // + // for some fixed κ_Σ ≥ 1 after each step. See equation (16) of [2]. + { + // Barrier parameter scale factor for inequality constraint Lagrange + // multiplier safeguard + constexpr double κ_Σ = 1e10; + + for (int row = 0; row < z.rows(); ++row) { + z(row) = + std::max(std::min(z(row), κ_Σ * μ / s(row)), μ / (κ_Σ * s(row))); + } + } + + // Update autodiff for Jacobians and Hessian + xAD.SetValue(x); + sAD.SetValue(s); + yAD.SetValue(y); + zAD.SetValue(z); + A_e = jacobianCe.Value(); + A_i = jacobianCi.Value(); + g = gradientF.Value(); + H = hessianL.Value(); + + c_e = c_eAD.Value(); + c_i = c_iAD.Value(); + + // Update the error estimate + E_0 = ErrorEstimate(g, A_e, c_e, A_i, c_i, s, y, z, 0.0); + if (E_0 < config.acceptableTolerance) { + ++acceptableIterCounter; + } else { + acceptableIterCounter = 0; + } + + // Update the barrier parameter if necessary + if (E_0 > config.tolerance) { + // Barrier parameter scale factor for tolerance checks + constexpr double κ_ε = 10.0; + + // While the error estimate is below the desired threshold for this + // barrier parameter value, decrease the barrier parameter further + double E_μ = ErrorEstimate(g, A_e, c_e, A_i, c_i, s, y, z, μ); + while (μ > μ_min && E_μ <= κ_ε * μ) { + UpdateBarrierParameterAndResetFilter(); + E_μ = ErrorEstimate(g, A_e, c_e, A_i, c_i, s, y, z, μ); + } + } + + const auto innerIterEndTime = std::chrono::system_clock::now(); + + // Diagnostics for current iteration + if (config.diagnostics) { + if (iterations % 20 == 0) { + sleipnir::println("{:^4} {:^9} {:^13} {:^13} {:^13}", "iter", + "time (ms)", "error", "cost", "infeasibility"); + sleipnir::println("{:=^61}", ""); + } + + sleipnir::println("{:4}{} {:9.3f} {:13e} {:13e} {:13e}", iterations, + feasibilityRestoration ? "r" : " ", + ToMilliseconds(innerIterEndTime - innerIterStartTime), + E_0, f.Value(), + c_e.lpNorm<1>() + (c_i - s).lpNorm<1>()); + } + + ++iterations; + + // Check for max iterations + if (iterations >= config.maxIterations) { + status->exitCondition = SolverExitCondition::kMaxIterationsExceeded; + return; + } + + // Check for max wall clock time + if (innerIterEndTime - solveStartTime > config.timeout) { + status->exitCondition = SolverExitCondition::kTimeout; + return; + } + + // Check for solve to acceptable tolerance + if (E_0 > config.tolerance && + acceptableIterCounter == config.maxAcceptableIterations) { + status->exitCondition = SolverExitCondition::kSolvedToAcceptableTolerance; + return; + } + + // The search direction has been very small twice, so assume the problem has + // been solved as well as possible given finite precision and reduce the + // barrier parameter. + // + // See section 3.9 of [2]. + if (stepTooSmallCounter >= 2 && μ > μ_min) { + UpdateBarrierParameterAndResetFilter(); + continue; + } + } +} // NOLINT(readability/fn_size) + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/solver/util/ErrorEstimate.hpp b/sleipnir/src/cpp/optimization/solver/util/ErrorEstimate.hpp new file mode 100644 index 0000000..6b757df --- /dev/null +++ b/sleipnir/src/cpp/optimization/solver/util/ErrorEstimate.hpp @@ -0,0 +1,80 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include +#include + +// See docs/algorithms.md#Works_cited for citation definitions + +namespace sleipnir { + +/** + * Returns the error estimate using the KKT conditions for the interior-point + * method. + * + * @param g Gradient of the cost function ∇f. + * @param A_e The problem's equality constraint Jacobian Aₑ(x) evaluated at the + * current iterate. + * @param c_e The problem's equality constraints cₑ(x) evaluated at the current + * iterate. + * @param A_i The problem's inequality constraint Jacobian Aᵢ(x) evaluated at + * the current iterate. + * @param c_i The problem's inequality constraints cᵢ(x) evaluated at the + * current iterate. + * @param s Inequality constraint slack variables. + * @param y Equality constraint dual variables. + * @param z Inequality constraint dual variables. + * @param μ Barrier parameter. + */ +inline double ErrorEstimate(const Eigen::VectorXd& g, + const Eigen::SparseMatrix& A_e, + const Eigen::VectorXd& c_e, + const Eigen::SparseMatrix& A_i, + const Eigen::VectorXd& c_i, + const Eigen::VectorXd& s, const Eigen::VectorXd& y, + const Eigen::VectorXd& z, double μ) { + int numEqualityConstraints = A_e.rows(); + int numInequalityConstraints = A_i.rows(); + + // Update the error estimate using the KKT conditions from equations (19.5a) + // through (19.5d) of [1]. + // + // ∇f − Aₑᵀy − Aᵢᵀz = 0 + // Sz − μe = 0 + // cₑ = 0 + // cᵢ − s = 0 + // + // The error tolerance is the max of the following infinity norms scaled by + // s_d and s_c (see equation (5) of [2]). + // + // ‖∇f − Aₑᵀy − Aᵢᵀz‖_∞ / s_d + // ‖Sz − μe‖_∞ / s_c + // ‖cₑ‖_∞ + // ‖cᵢ − s‖_∞ + + // s_d = max(sₘₐₓ, (‖y‖₁ + ‖z‖₁) / (m + n)) / sₘₐₓ + constexpr double s_max = 100.0; + double s_d = + std::max(s_max, (y.lpNorm<1>() + z.lpNorm<1>()) / + (numEqualityConstraints + numInequalityConstraints)) / + s_max; + + // s_c = max(sₘₐₓ, ‖z‖₁ / n) / sₘₐₓ + double s_c = + std::max(s_max, z.lpNorm<1>() / numInequalityConstraints) / s_max; + + const auto S = s.asDiagonal(); + const Eigen::VectorXd e = Eigen::VectorXd::Ones(s.rows()); + + return std::max({(g - A_e.transpose() * y - A_i.transpose() * z) + .lpNorm() / + s_d, + (S * z - μ * e).lpNorm() / s_c, + c_e.lpNorm(), + (c_i - s).lpNorm()}); +} + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/solver/util/FeasibilityRestoration.hpp b/sleipnir/src/cpp/optimization/solver/util/FeasibilityRestoration.hpp new file mode 100644 index 0000000..c324ef0 --- /dev/null +++ b/sleipnir/src/cpp/optimization/solver/util/FeasibilityRestoration.hpp @@ -0,0 +1,238 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include +#include + +#include + +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/autodiff/VariableMatrix.hpp" +#include "sleipnir/optimization/SolverConfig.hpp" +#include "sleipnir/optimization/SolverIterationInfo.hpp" +#include "sleipnir/optimization/SolverStatus.hpp" +#include "sleipnir/optimization/solver/InteriorPoint.hpp" +#include "sleipnir/util/FunctionRef.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * Finds the iterate that minimizes the constraint violation while not deviating + * too far from the starting point. This is a fallback procedure when the normal + * interior-point method fails to converge to a feasible point. + * + * @param[in] decisionVariables The list of decision variables. + * @param[in] equalityConstraints The list of equality constraints. + * @param[in] inequalityConstraints The list of inequality constraints. + * @param[in] μ Barrier parameter. + * @param[in] callback The user callback. + * @param[in] config Configuration options for the solver. + * @param[in,out] x The current iterate from the normal solve. + * @param[in,out] s The current inequality constraint slack variables from the + * normal solve. + * @param[out] status The solver status. + */ +inline void FeasibilityRestoration( + std::span decisionVariables, + std::span equalityConstraints, + std::span inequalityConstraints, double μ, + function_ref callback, + const SolverConfig& config, Eigen::VectorXd& x, Eigen::VectorXd& s, + SolverStatus* status) { + // Feasibility restoration + // + // min ρ Σ (pₑ + nₑ + pᵢ + nᵢ) + ζ/2 (x - x_R)ᵀD_R(x - x_R) + // x + // pₑ,nₑ + // pᵢ,nᵢ + // + // s.t. cₑ(x) - pₑ + nₑ = 0 + // cᵢ(x) - s - pᵢ + nᵢ = 0 + // pₑ ≥ 0 + // nₑ ≥ 0 + // pᵢ ≥ 0 + // nᵢ ≥ 0 + // + // where ρ = 1000, ζ = √μ where μ is the barrier parameter, x_R is original + // iterate before feasibility restoration, and D_R is a scaling matrix defined + // by + // + // D_R = diag(min(1, 1/|x_R⁽¹⁾|), …, min(1, 1/|x_R|⁽ⁿ⁾) + + constexpr double ρ = 1000.0; + + small_vector fr_decisionVariables; + fr_decisionVariables.reserve(decisionVariables.size() + + 2 * equalityConstraints.size() + + 2 * inequalityConstraints.size()); + + // Assign x + fr_decisionVariables.assign(decisionVariables.begin(), + decisionVariables.end()); + + // Allocate pₑ, nₑ, pᵢ, and nᵢ + for (size_t row = 0; + row < 2 * equalityConstraints.size() + 2 * inequalityConstraints.size(); + ++row) { + fr_decisionVariables.emplace_back(); + } + + auto it = fr_decisionVariables.cbegin(); + + VariableMatrix xAD{std::span{it, it + decisionVariables.size()}}; + it += decisionVariables.size(); + + VariableMatrix p_e{std::span{it, it + equalityConstraints.size()}}; + it += equalityConstraints.size(); + + VariableMatrix n_e{std::span{it, it + equalityConstraints.size()}}; + it += equalityConstraints.size(); + + VariableMatrix p_i{std::span{it, it + inequalityConstraints.size()}}; + it += inequalityConstraints.size(); + + VariableMatrix n_i{std::span{it, it + inequalityConstraints.size()}}; + + // Set initial values for pₑ, nₑ, pᵢ, and nᵢ. + // + // + // From equation (33) of [2]: + // ______________________ + // μ − ρ c(x) /(μ − ρ c(x))² μ c(x) + // n = −−−−−−−−−− + / (−−−−−−−−−−) + −−−−−− (1) + // 2ρ √ ( 2ρ ) 2ρ + // + // The quadratic formula: + // ________ + // -b + √b² - 4ac + // x = −−−−−−−−−−−−−− (2) + // 2a + // + // Rearrange (1) to fit the quadratic formula better: + // _________________________ + // μ - ρ c(x) + √(μ - ρ c(x))² + 2ρ μ c(x) + // n = −−−−−−−−−−−−−−−−−−−−−−−−−−−−−−−−−−−−−−− + // 2ρ + // + // Solve for coefficients: + // + // a = ρ (3) + // b = ρ c(x) - μ (4) + // + // -4ac = μ c(x) 2ρ + // -4(ρ)c = 2ρ μ c(x) + // -4c = 2μ c(x) + // c = -μ c(x)/2 (5) + // + // p = c(x) + n (6) + for (int row = 0; row < p_e.Rows(); ++row) { + double c_e = equalityConstraints[row].Value(); + + constexpr double a = 2 * ρ; + double b = ρ * c_e - μ; + double c = -μ * c_e / 2.0; + + double n = -b * std::sqrt(b * b - 4.0 * a * c) / (2.0 * a); + double p = c_e + n; + + p_e(row).SetValue(p); + n_e(row).SetValue(n); + } + for (int row = 0; row < p_i.Rows(); ++row) { + double c_i = inequalityConstraints[row].Value() - s(row); + + constexpr double a = 2 * ρ; + double b = ρ * c_i - μ; + double c = -μ * c_i / 2.0; + + double n = -b * std::sqrt(b * b - 4.0 * a * c) / (2.0 * a); + double p = c_i + n; + + p_i(row).SetValue(p); + n_i(row).SetValue(n); + } + + // cₑ(x) - pₑ + nₑ = 0 + small_vector fr_equalityConstraints; + fr_equalityConstraints.assign(equalityConstraints.begin(), + equalityConstraints.end()); + for (size_t row = 0; row < fr_equalityConstraints.size(); ++row) { + auto& constraint = fr_equalityConstraints[row]; + constraint = constraint - p_e(row) + n_e(row); + } + + // cᵢ(x) - s - pᵢ + nᵢ = 0 + small_vector fr_inequalityConstraints; + fr_inequalityConstraints.assign(inequalityConstraints.begin(), + inequalityConstraints.end()); + for (size_t row = 0; row < fr_inequalityConstraints.size(); ++row) { + auto& constraint = fr_inequalityConstraints[row]; + constraint = constraint - s(row) - p_i(row) + n_i(row); + } + + // pₑ ≥ 0 + std::copy(p_e.begin(), p_e.end(), + std::back_inserter(fr_inequalityConstraints)); + + // pᵢ ≥ 0 + std::copy(p_i.begin(), p_i.end(), + std::back_inserter(fr_inequalityConstraints)); + + // nₑ ≥ 0 + std::copy(n_e.begin(), n_e.end(), + std::back_inserter(fr_inequalityConstraints)); + + // nᵢ ≥ 0 + std::copy(n_i.begin(), n_i.end(), + std::back_inserter(fr_inequalityConstraints)); + + Variable J = 0.0; + + // J += ρ Σ (pₑ + nₑ + pᵢ + nᵢ) + for (auto& elem : p_e) { + J += elem; + } + for (auto& elem : p_i) { + J += elem; + } + for (auto& elem : n_e) { + J += elem; + } + for (auto& elem : n_i) { + J += elem; + } + J *= ρ; + + // D_R = diag(min(1, 1/|x_R⁽¹⁾|), …, min(1, 1/|x_R|⁽ⁿ⁾) + Eigen::VectorXd D_R{x.rows()}; + for (int row = 0; row < D_R.rows(); ++row) { + D_R(row) = std::min(1.0, 1.0 / std::abs(x(row))); + } + + // J += ζ/2 (x - x_R)ᵀD_R(x - x_R) + for (int row = 0; row < x.rows(); ++row) { + J += std::sqrt(μ) / 2.0 * D_R(row) * sleipnir::pow(xAD(row) - x(row), 2); + } + + Eigen::VectorXd fr_x = VariableMatrix{fr_decisionVariables}.Value(); + + // Set up initial value for inequality constraint slack variables + Eigen::VectorXd fr_s{fr_inequalityConstraints.size()}; + fr_s.segment(0, inequalityConstraints.size()) = s; + fr_s.segment(inequalityConstraints.size(), + fr_s.size() - inequalityConstraints.size()) + .setOnes(); + + InteriorPoint(fr_decisionVariables, fr_equalityConstraints, + fr_inequalityConstraints, J, callback, config, true, fr_x, fr_s, + status); + + x = fr_x.segment(0, decisionVariables.size()); + s = fr_s.segment(0, inequalityConstraints.size()); +} + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/solver/util/Filter.hpp b/sleipnir/src/cpp/optimization/solver/util/Filter.hpp new file mode 100644 index 0000000..3fbb849 --- /dev/null +++ b/sleipnir/src/cpp/optimization/solver/util/Filter.hpp @@ -0,0 +1,188 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include +#include + +#include + +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * Filter entry consisting of cost and constraint violation. + */ +struct FilterEntry { + /// The cost function's value + double cost = 0.0; + + /// The constraint violation + double constraintViolation = 0.0; + + constexpr FilterEntry() = default; + + /** + * Constructs a FilterEntry. + * + * @param cost The cost function's value. + * @param constraintViolation The constraint violation. + */ + FilterEntry(double cost, double constraintViolation) + : cost{cost}, constraintViolation{constraintViolation} {} + + /** + * Constructs a FilterEntry. + * + * @param f The cost function. + * @param μ The barrier parameter. + * @param s The inequality constraint slack variables. + * @param c_e The equality constraint values (nonzero means violation). + * @param c_i The inequality constraint values (negative means violation). + */ + FilterEntry(Variable& f, double μ, const Eigen::VectorXd& s, + const Eigen::VectorXd& c_e, const Eigen::VectorXd& c_i) + : cost{f.Value() - μ * s.array().log().sum()}, + constraintViolation{c_e.lpNorm<1>() + (c_i - s).lpNorm<1>()} {} +}; + +/** + * Interior-point step filter. + */ +class Filter { + public: + static constexpr double γCost = 1e-8; + static constexpr double γConstraint = 1e-5; + + double maxConstraintViolation = 1e4; + + /** + * Construct an empty filter. + * + * @param f The cost function. + * @param μ The barrier parameter. + */ + explicit Filter(Variable& f, double μ) { + m_f = &f; + m_μ = μ; + + // Initial filter entry rejects constraint violations above max + m_filter.emplace_back(std::numeric_limits::infinity(), + maxConstraintViolation); + } + + /** + * Reset the filter. + * + * @param μ The new barrier parameter. + */ + void Reset(double μ) { + m_μ = μ; + m_filter.clear(); + + // Initial filter entry rejects constraint violations above max + m_filter.emplace_back(std::numeric_limits::infinity(), + maxConstraintViolation); + } + + /** + * Creates a new filter entry. + * + * @param s The inequality constraint slack variables. + * @param c_e The equality constraint values (nonzero means violation). + * @param c_i The inequality constraint values (negative means violation). + */ + FilterEntry MakeEntry(Eigen::VectorXd& s, const Eigen::VectorXd& c_e, + const Eigen::VectorXd& c_i) { + return FilterEntry{*m_f, m_μ, s, c_e, c_i}; + } + + /** + * Add a new entry to the filter. + * + * @param entry The entry to add to the filter. + */ + void Add(const FilterEntry& entry) { + // Remove dominated entries + erase_if(m_filter, [&](const auto& elem) { + return entry.cost <= elem.cost && + entry.constraintViolation <= elem.constraintViolation; + }); + + m_filter.push_back(entry); + } + + /** + * Add a new entry to the filter. + * + * @param entry The entry to add to the filter. + */ + void Add(FilterEntry&& entry) { + // Remove dominated entries + erase_if(m_filter, [&](const auto& elem) { + return entry.cost <= elem.cost && + entry.constraintViolation <= elem.constraintViolation; + }); + + m_filter.push_back(entry); + } + + /** + * Returns true if the given iterate is accepted by the filter. + * + * @param entry The entry to attempt adding to the filter. + */ + bool TryAdd(const FilterEntry& entry) { + if (IsAcceptable(entry)) { + Add(entry); + return true; + } else { + return false; + } + } + + /** + * Returns true if the given iterate is accepted by the filter. + * + * @param entry The entry to attempt adding to the filter. + */ + bool TryAdd(FilterEntry&& entry) { + if (IsAcceptable(entry)) { + Add(std::move(entry)); + return true; + } else { + return false; + } + } + + /** + * Returns true if the given entry is acceptable to the filter. + * + * @param entry The entry to check. + */ + bool IsAcceptable(const FilterEntry& entry) { + if (!std::isfinite(entry.cost) || + !std::isfinite(entry.constraintViolation)) { + return false; + } + + // If current filter entry is better than all prior ones in some respect, + // accept it + return std::all_of(m_filter.begin(), m_filter.end(), [&](const auto& elem) { + return entry.cost <= elem.cost - γCost * elem.constraintViolation || + entry.constraintViolation <= + (1.0 - γConstraint) * elem.constraintViolation; + }); + } + + private: + Variable* m_f = nullptr; + double m_μ = 0.0; + small_vector m_filter; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/solver/util/FractionToTheBoundaryRule.hpp b/sleipnir/src/cpp/optimization/solver/util/FractionToTheBoundaryRule.hpp new file mode 100644 index 0000000..98a117f --- /dev/null +++ b/sleipnir/src/cpp/optimization/solver/util/FractionToTheBoundaryRule.hpp @@ -0,0 +1,45 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +// See docs/algorithms.md#Works_cited for citation definitions + +namespace sleipnir { + +/** + * Applies fraction-to-the-boundary rule to a variable and its iterate, then + * returns a fraction of the iterate step size within (0, 1]. + * + * @param x The variable. + * @param p The iterate on the variable. + * @param τ Fraction-to-the-boundary rule scaling factor within (0, 1]. + * @return Fraction of the iterate step size within (0, 1]. + */ +inline double FractionToTheBoundaryRule( + const Eigen::Ref& x, + const Eigen::Ref& p, double τ) { + // α = max(α ∈ (0, 1] : x + αp ≥ (1 − τ)x) + // + // where x and τ are positive. + // + // x + αp ≥ (1 − τ)x + // x + αp ≥ x − τx + // αp ≥ −τx + // + // If the inequality is false, p < 0 and α is too big. Find the largest value + // of α that makes the inequality true. + // + // α = −τ/p x + double α = 1.0; + for (int i = 0; i < x.rows(); ++i) { + if (α * p(i) < -τ * x(i)) { + α = -τ / p(i) * x(i); + } + } + + return α; +} + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/solver/util/IsLocallyInfeasible.hpp b/sleipnir/src/cpp/optimization/solver/util/IsLocallyInfeasible.hpp new file mode 100644 index 0000000..dc3992e --- /dev/null +++ b/sleipnir/src/cpp/optimization/solver/util/IsLocallyInfeasible.hpp @@ -0,0 +1,63 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include + +// See docs/algorithms.md#Works_cited for citation definitions + +namespace sleipnir { + +/** + * Returns true if the problem's equality constraints are locally infeasible. + * + * @param A_e The problem's equality constraint Jacobian Aₑ(x) evaluated at the + * current iterate. + * @param c_e The problem's equality constraints cₑ(x) evaluated at the current + * iterate. + */ +inline bool IsEqualityLocallyInfeasible(const Eigen::SparseMatrix& A_e, + const Eigen::VectorXd& c_e) { + // The equality constraints are locally infeasible if + // + // Aₑᵀcₑ → 0 + // ‖cₑ‖ > ε + // + // See "Infeasibility detection" in section 6 of [3]. + return A_e.rows() > 0 && (A_e.transpose() * c_e).norm() < 1e-6 && + c_e.norm() > 1e-2; +} + +/** + * Returns true if the problem's inequality constraints are locally infeasible. + * + * @param A_i The problem's inequality constraint Jacobian Aᵢ(x) evaluated at + * the current iterate. + * @param c_i The problem's inequality constraints cᵢ(x) evaluated at the + * current iterate. + */ +inline bool IsInequalityLocallyInfeasible( + const Eigen::SparseMatrix& A_i, const Eigen::VectorXd& c_i) { + // The inequality constraints are locally infeasible if + // + // Aᵢᵀcᵢ⁺ → 0 + // ‖cᵢ⁺‖ > ε + // + // where cᵢ⁺ = min(cᵢ, 0). + // + // See "Infeasibility detection" in section 6 of [3]. + // + // cᵢ⁺ is used instead of cᵢ⁻ from the paper to follow the convention that + // feasible inequality constraints are ≥ 0. + if (A_i.rows() > 0) { + Eigen::VectorXd c_i_plus = c_i.cwiseMin(0.0); + if ((A_i.transpose() * c_i_plus).norm() < 1e-6 && c_i_plus.norm() > 1e-6) { + return true; + } + } + + return false; +} + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/optimization/solver/util/KKTError.hpp b/sleipnir/src/cpp/optimization/solver/util/KKTError.hpp new file mode 100644 index 0000000..03fa112 --- /dev/null +++ b/sleipnir/src/cpp/optimization/solver/util/KKTError.hpp @@ -0,0 +1,51 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include + +// See docs/algorithms.md#Works_cited for citation definitions + +namespace sleipnir { + +/** + * Returns the KKT error for the interior-point method. + * + * @param g Gradient of the cost function ∇f. + * @param A_e The problem's equality constraint Jacobian Aₑ(x) evaluated at the + * current iterate. + * @param c_e The problem's equality constraints cₑ(x) evaluated at the current + * iterate. + * @param A_i The problem's inequality constraint Jacobian Aᵢ(x) evaluated at + * the current iterate. + * @param c_i The problem's inequality constraints cᵢ(x) evaluated at the + * current iterate. + * @param s Inequality constraint slack variables. + * @param y Equality constraint dual variables. + * @param z Inequality constraint dual variables. + * @param μ Barrier parameter. + */ +inline double KKTError(const Eigen::VectorXd& g, + const Eigen::SparseMatrix& A_e, + const Eigen::VectorXd& c_e, + const Eigen::SparseMatrix& A_i, + const Eigen::VectorXd& c_i, const Eigen::VectorXd& s, + const Eigen::VectorXd& y, const Eigen::VectorXd& z, + double μ) { + // Compute the KKT error as the 1-norm of the KKT conditions from equations + // (19.5a) through (19.5d) of [1]. + // + // ∇f − Aₑᵀy − Aᵢᵀz = 0 + // Sz − μe = 0 + // cₑ = 0 + // cᵢ − s = 0 + + const auto S = s.asDiagonal(); + const Eigen::VectorXd e = Eigen::VectorXd::Ones(s.rows()); + + return (g - A_e.transpose() * y - A_i.transpose() * z).lpNorm<1>() + + (S * z - μ * e).lpNorm<1>() + c_e.lpNorm<1>() + (c_i - s).lpNorm<1>(); +} + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/util/Pool.cpp b/sleipnir/src/cpp/util/Pool.cpp new file mode 100644 index 0000000..33b137b --- /dev/null +++ b/sleipnir/src/cpp/util/Pool.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Sleipnir contributors + +#include "sleipnir/util/Pool.hpp" + +namespace sleipnir { + +PoolResource& GlobalPoolResource() { + thread_local PoolResource pool{16384}; + return pool; +} + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/util/ScopeExit.hpp b/sleipnir/src/cpp/util/ScopeExit.hpp new file mode 100644 index 0000000..dedc7ae --- /dev/null +++ b/sleipnir/src/cpp/util/ScopeExit.hpp @@ -0,0 +1,35 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +namespace sleipnir { + +template +class scope_exit { + public: + explicit scope_exit(F&& f) noexcept : m_f{std::forward(f)} {} + + ~scope_exit() { + if (m_active) { + m_f(); + } + } + + scope_exit(scope_exit&& rhs) noexcept + : m_f{std::move(rhs.m_f)}, m_active{rhs.m_active} { + rhs.release(); + } + + scope_exit(const scope_exit&) = delete; + scope_exit& operator=(const scope_exit&) = delete; + + void release() noexcept { m_active = false; } + + private: + F m_f; + bool m_active = true; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/cpp/util/ToMilliseconds.hpp b/sleipnir/src/cpp/util/ToMilliseconds.hpp new file mode 100644 index 0000000..0e00418 --- /dev/null +++ b/sleipnir/src/cpp/util/ToMilliseconds.hpp @@ -0,0 +1,21 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +namespace sleipnir { + +/** + * Converts std::chrono::duration to a number of milliseconds rounded to three + * decimals. + */ +template > +constexpr double ToMilliseconds( + const std::chrono::duration& duration) { + using std::chrono::duration_cast; + using std::chrono::microseconds; + return duration_cast(duration).count() / 1e3; +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/autodiff/Expression.hpp b/sleipnir/src/include/sleipnir/autodiff/Expression.hpp new file mode 100644 index 0000000..6c4ae62 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/Expression.hpp @@ -0,0 +1,1129 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "sleipnir/autodiff/ExpressionType.hpp" +#include "sleipnir/util/IntrusiveSharedPtr.hpp" +#include "sleipnir/util/Pool.hpp" +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir::detail { + +struct SLEIPNIR_DLLEXPORT Expression; + +inline constexpr void IntrusiveSharedPtrIncRefCount(Expression* expr); +inline constexpr void IntrusiveSharedPtrDecRefCount(Expression* expr); + +/** + * Typedef for intrusive shared pointer to Expression. + */ +using ExpressionPtr = IntrusiveSharedPtr; + +/** + * Creates an intrusive shared pointer to an expression from the global pool + * allocator. + * + * @param args Constructor arguments for Expression. + */ +template +static ExpressionPtr MakeExpressionPtr(Args&&... args) { + return AllocateIntrusiveShared(GlobalPoolAllocator(), + std::forward(args)...); +} + +/** + * An autodiff expression node. + */ +struct SLEIPNIR_DLLEXPORT Expression { + /** + * Binary function taking two doubles and returning a double. + */ + using BinaryFuncDouble = double (*)(double, double); + + /** + * Trinary function taking three doubles and returning a double. + */ + using TrinaryFuncDouble = double (*)(double, double, double); + + /** + * Trinary function taking three expressions and returning an expression. + */ + using TrinaryFuncExpr = ExpressionPtr (*)(const ExpressionPtr&, + const ExpressionPtr&, + const ExpressionPtr&); + + /// The value of the expression node. + double value = 0.0; + + /// The adjoint of the expression node used during autodiff. + double adjoint = 0.0; + + /// Tracks the number of instances of this expression yet to be encountered in + /// an expression tree. + uint32_t duplications = 0; + + /// This expression's row in wrt for autodiff gradient, Jacobian, or Hessian. + /// This is -1 if the expression isn't in wrt. + int32_t row = -1; + + /// The adjoint of the expression node used during gradient expression tree + /// generation. + ExpressionPtr adjointExpr; + + /// Expression argument type. + ExpressionType type = ExpressionType::kConstant; + + /// Reference count for intrusive shared pointer. + uint32_t refCount = 0; + + /// Either nullary operator with no arguments, unary operator with one + /// argument, or binary operator with two arguments. This operator is + /// used to update the node's value. + BinaryFuncDouble valueFunc = nullptr; + + /// Functions returning double adjoints of the children expressions. + /// + /// Parameters: + ///
    + ///
  • lhs: Left argument to binary operator.
  • + ///
  • rhs: Right argument to binary operator.
  • + ///
  • parentAdjoint: Adjoint of parent expression.
  • + ///
+ std::array gradientValueFuncs{nullptr, nullptr}; + + /// Functions returning Variable adjoints of the children expressions. + /// + /// Parameters: + ///
    + ///
  • lhs: Left argument to binary operator.
  • + ///
  • rhs: Right argument to binary operator.
  • + ///
  • parentAdjoint: Adjoint of parent expression.
  • + ///
+ std::array gradientFuncs{nullptr, nullptr}; + + /// Expression arguments. + std::array args{nullptr, nullptr}; + + /** + * Constructs a constant expression with a value of zero. + */ + constexpr Expression() = default; + + /** + * Constructs a nullary expression (an operator with no arguments). + * + * @param value The expression value. + * @param type The expression type. It should be either constant (the default) + * or linear. + */ + explicit constexpr Expression(double value, + ExpressionType type = ExpressionType::kConstant) + : value{value}, type{type} {} + + /** + * Constructs an unary expression (an operator with one argument). + * + * @param type The expression's type. + * @param valueFunc Unary operator that produces this expression's value. + * @param lhsGradientValueFunc Gradient with respect to the operand. + * @param lhsGradientFunc Gradient with respect to the operand. + * @param lhs Unary operator's operand. + */ + constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc, + TrinaryFuncDouble lhsGradientValueFunc, + TrinaryFuncExpr lhsGradientFunc, ExpressionPtr lhs) + : value{valueFunc(lhs->value, 0.0)}, + type{type}, + valueFunc{valueFunc}, + gradientValueFuncs{lhsGradientValueFunc, nullptr}, + gradientFuncs{lhsGradientFunc, nullptr}, + args{lhs, nullptr} {} + + /** + * Constructs a binary expression (an operator with two arguments). + * + * @param type The expression's type. + * @param valueFunc Unary operator that produces this expression's value. + * @param lhsGradientValueFunc Gradient with respect to the left operand. + * @param rhsGradientValueFunc Gradient with respect to the right operand. + * @param lhsGradientFunc Gradient with respect to the left operand. + * @param rhsGradientFunc Gradient with respect to the right operand. + * @param lhs Binary operator's left operand. + * @param rhs Binary operator's right operand. + */ + constexpr Expression(ExpressionType type, BinaryFuncDouble valueFunc, + TrinaryFuncDouble lhsGradientValueFunc, + TrinaryFuncDouble rhsGradientValueFunc, + TrinaryFuncExpr lhsGradientFunc, + TrinaryFuncExpr rhsGradientFunc, ExpressionPtr lhs, + ExpressionPtr rhs) + : value{valueFunc(lhs->value, rhs->value)}, + type{type}, + valueFunc{valueFunc}, + gradientValueFuncs{lhsGradientValueFunc, rhsGradientValueFunc}, + gradientFuncs{lhsGradientFunc, rhsGradientFunc}, + args{lhs, rhs} {} + + /** + * Returns true if the expression is the given constant. + * + * @param constant The constant. + */ + constexpr bool IsConstant(double constant) const { + return type == ExpressionType::kConstant && value == constant; + } + + /** + * Expression-Expression multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT ExpressionPtr operator*(const ExpressionPtr& lhs, + const ExpressionPtr& rhs) { + using enum ExpressionType; + + // Prune expression + if (lhs->IsConstant(0.0)) { + // Return zero + return lhs; + } else if (rhs->IsConstant(0.0)) { + // Return zero + return rhs; + } else if (lhs->IsConstant(1.0)) { + return rhs; + } else if (rhs->IsConstant(1.0)) { + return lhs; + } + + // Evaluate constant + if (lhs->type == kConstant && rhs->type == kConstant) { + return MakeExpressionPtr(lhs->value * rhs->value); + } + + // Evaluate expression type + ExpressionType type; + if (lhs->type == kConstant) { + type = rhs->type; + } else if (rhs->type == kConstant) { + type = lhs->type; + } else if (lhs->type == kLinear && rhs->type == kLinear) { + type = kQuadratic; + } else { + type = kNonlinear; + } + + return MakeExpressionPtr( + type, [](double lhs, double rhs) { return lhs * rhs; }, + [](double, double rhs, double parentAdjoint) { + return parentAdjoint * rhs; + }, + [](double lhs, double, double parentAdjoint) { + return parentAdjoint * lhs; + }, + [](const ExpressionPtr&, const ExpressionPtr& rhs, + const ExpressionPtr& parentAdjoint) { return parentAdjoint * rhs; }, + [](const ExpressionPtr& lhs, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { return parentAdjoint * lhs; }, + lhs, rhs); + } + + /** + * Expression-Expression division operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT ExpressionPtr operator/(const ExpressionPtr& lhs, + const ExpressionPtr& rhs) { + using enum ExpressionType; + + // Prune expression + if (lhs->IsConstant(0.0)) { + // Return zero + return lhs; + } else if (rhs->IsConstant(1.0)) { + return lhs; + } + + // Evaluate constant + if (lhs->type == kConstant && rhs->type == kConstant) { + return MakeExpressionPtr(lhs->value / rhs->value); + } + + // Evaluate expression type + ExpressionType type; + if (rhs->type == kConstant) { + type = lhs->type; + } else { + type = kNonlinear; + } + + return MakeExpressionPtr( + type, [](double lhs, double rhs) { return lhs / rhs; }, + [](double, double rhs, double parentAdjoint) { + return parentAdjoint / rhs; + }, + [](double lhs, double rhs, double parentAdjoint) { + return parentAdjoint * -lhs / (rhs * rhs); + }, + [](const ExpressionPtr&, const ExpressionPtr& rhs, + const ExpressionPtr& parentAdjoint) { return parentAdjoint / rhs; }, + [](const ExpressionPtr& lhs, const ExpressionPtr& rhs, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * -lhs / (rhs * rhs); + }, + lhs, rhs); + } + + /** + * Expression-Expression addition operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT ExpressionPtr operator+(const ExpressionPtr& lhs, + const ExpressionPtr& rhs) { + using enum ExpressionType; + + // Prune expression + if (lhs == nullptr || lhs->IsConstant(0.0)) { + return rhs; + } else if (rhs == nullptr || rhs->IsConstant(0.0)) { + return lhs; + } + + // Evaluate constant + if (lhs->type == kConstant && rhs->type == kConstant) { + return MakeExpressionPtr(lhs->value + rhs->value); + } + + return MakeExpressionPtr( + std::max(lhs->type, rhs->type), + [](double lhs, double rhs) { return lhs + rhs; }, + [](double, double, double parentAdjoint) { return parentAdjoint; }, + [](double, double, double parentAdjoint) { return parentAdjoint; }, + [](const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { return parentAdjoint; }, + [](const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { return parentAdjoint; }, + lhs, rhs); + } + + /** + * Expression-Expression subtraction operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT ExpressionPtr operator-(const ExpressionPtr& lhs, + const ExpressionPtr& rhs) { + using enum ExpressionType; + + // Prune expression + if (lhs->IsConstant(0.0)) { + if (rhs->IsConstant(0.0)) { + // Return zero + return rhs; + } else { + return -rhs; + } + } else if (rhs->IsConstant(0.0)) { + return lhs; + } + + // Evaluate constant + if (lhs->type == kConstant && rhs->type == kConstant) { + return MakeExpressionPtr(lhs->value - rhs->value); + } + + return MakeExpressionPtr( + std::max(lhs->type, rhs->type), + [](double lhs, double rhs) { return lhs - rhs; }, + [](double, double, double parentAdjoint) { return parentAdjoint; }, + [](double, double, double parentAdjoint) { return -parentAdjoint; }, + [](const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { return parentAdjoint; }, + [](const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { return -parentAdjoint; }, + lhs, rhs); + } + + /** + * Unary minus operator. + * + * @param lhs Operand of unary minus. + */ + friend SLEIPNIR_DLLEXPORT ExpressionPtr operator-(const ExpressionPtr& lhs) { + using enum ExpressionType; + + // Prune expression + if (lhs->IsConstant(0.0)) { + // Return zero + return lhs; + } + + // Evaluate constant + if (lhs->type == kConstant) { + return MakeExpressionPtr(-lhs->value); + } + + return MakeExpressionPtr( + lhs->type, [](double lhs, double) { return -lhs; }, + [](double, double, double parentAdjoint) { return -parentAdjoint; }, + [](const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { return -parentAdjoint; }, + lhs); + } + + /** + * Unary plus operator. + * + * @param lhs Operand of unary plus. + */ + friend SLEIPNIR_DLLEXPORT ExpressionPtr operator+(const ExpressionPtr& lhs) { + return lhs; + } +}; + +SLEIPNIR_DLLEXPORT inline ExpressionPtr exp(const ExpressionPtr& x); +SLEIPNIR_DLLEXPORT inline ExpressionPtr sin(const ExpressionPtr& x); +SLEIPNIR_DLLEXPORT inline ExpressionPtr sinh(const ExpressionPtr& x); +SLEIPNIR_DLLEXPORT inline ExpressionPtr sqrt(const ExpressionPtr& x); + +/** + * Refcount increment for intrusive shared pointer. + * + * @param expr The shared pointer's managed object. + */ +inline constexpr void IntrusiveSharedPtrIncRefCount(Expression* expr) { + ++expr->refCount; +} + +/** + * Refcount decrement for intrusive shared pointer. + * + * @param expr The shared pointer's managed object. + */ +inline constexpr void IntrusiveSharedPtrDecRefCount(Expression* expr) { + // If a deeply nested tree is being deallocated all at once, calling the + // Expression destructor when expr's refcount reaches zero can cause a stack + // overflow. Instead, we iterate over its children to decrement their + // refcounts and deallocate them. + small_vector stack; + stack.emplace_back(expr); + + while (!stack.empty()) { + auto elem = stack.back(); + stack.pop_back(); + + // Decrement the current node's refcount. If it reaches zero, deallocate the + // node and enqueue its children so their refcounts are decremented too. + if (--elem->refCount == 0) { + if (elem->adjointExpr != nullptr) { + stack.emplace_back(elem->adjointExpr.Get()); + } + for (auto&& arg : elem->args) { + if (arg != nullptr) { + stack.emplace_back(arg.Get()); + } + } + + // Not calling the destructor here is safe because it only decrements + // refcounts, which was already done above. + auto alloc = GlobalPoolAllocator(); + std::allocator_traits::deallocate(alloc, elem, + sizeof(Expression)); + } + } +} + +/** + * std::abs() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr abs( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::abs(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::abs(x); }, + [](double x, double, double parentAdjoint) { + if (x < 0.0) { + return -parentAdjoint; + } else if (x > 0.0) { + return parentAdjoint; + } else { + return 0.0; + } + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + if (x->value < 0.0) { + return -parentAdjoint; + } else if (x->value > 0.0) { + return parentAdjoint; + } else { + // Return zero + return MakeExpressionPtr(); + } + }, + x); +} + +/** + * std::acos() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr acos( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + return MakeExpressionPtr(std::numbers::pi / 2.0); + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::acos(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::acos(x); }, + [](double x, double, double parentAdjoint) { + return -parentAdjoint / std::sqrt(1.0 - x * x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return -parentAdjoint / + sleipnir::detail::sqrt(MakeExpressionPtr(1.0) - x * x); + }, + x); +} + +/** + * std::asin() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr asin( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::asin(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::asin(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint / std::sqrt(1.0 - x * x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint / + sleipnir::detail::sqrt(MakeExpressionPtr(1.0) - x * x); + }, + x); +} + +/** + * std::atan() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr atan( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::atan(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::atan(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint / (1.0 + x * x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint / (MakeExpressionPtr(1.0) + x * x); + }, + x); +} + +/** + * std::atan2() for Expressions. + * + * @param y The y argument. + * @param x The x argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr atan2( // NOLINT + const ExpressionPtr& y, const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (y->IsConstant(0.0)) { + // Return zero + return y; + } else if (x->IsConstant(0.0)) { + return MakeExpressionPtr(std::numbers::pi / 2.0); + } + + // Evaluate constant + if (y->type == kConstant && x->type == kConstant) { + return MakeExpressionPtr(std::atan2(y->value, x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double y, double x) { return std::atan2(y, x); }, + [](double y, double x, double parentAdjoint) { + return parentAdjoint * x / (y * y + x * x); + }, + [](double y, double x, double parentAdjoint) { + return parentAdjoint * -y / (y * y + x * x); + }, + [](const ExpressionPtr& y, const ExpressionPtr& x, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * x / (y * y + x * x); + }, + [](const ExpressionPtr& y, const ExpressionPtr& x, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * -y / (y * y + x * x); + }, + y, x); +} + +/** + * std::cos() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr cos( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + return MakeExpressionPtr(1.0); + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::cos(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::cos(x); }, + [](double x, double, double parentAdjoint) { + return -parentAdjoint * std::sin(x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * -sleipnir::detail::sin(x); + }, + x); +} + +/** + * std::cosh() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr cosh( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + return MakeExpressionPtr(1.0); + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::cosh(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::cosh(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint * std::sinh(x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * sleipnir::detail::sinh(x); + }, + x); +} + +/** + * std::erf() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr erf( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::erf(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::erf(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint * 2.0 * std::numbers::inv_sqrtpi * + std::exp(-x * x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * + MakeExpressionPtr(2.0 * std::numbers::inv_sqrtpi) * + sleipnir::detail::exp(-x * x); + }, + x); +} + +/** + * std::exp() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr exp( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + return MakeExpressionPtr(1.0); + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::exp(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::exp(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint * std::exp(x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * sleipnir::detail::exp(x); + }, + x); +} + +/** + * std::hypot() for Expressions. + * + * @param x The x argument. + * @param y The y argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr hypot( // NOLINT + const ExpressionPtr& x, const ExpressionPtr& y) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + return y; + } else if (y->IsConstant(0.0)) { + return x; + } + + // Evaluate constant + if (x->type == kConstant && y->type == kConstant) { + return MakeExpressionPtr(std::hypot(x->value, y->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double y) { return std::hypot(x, y); }, + [](double x, double y, double parentAdjoint) { + return parentAdjoint * x / std::hypot(x, y); + }, + [](double x, double y, double parentAdjoint) { + return parentAdjoint * y / std::hypot(x, y); + }, + [](const ExpressionPtr& x, const ExpressionPtr& y, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * x / sleipnir::detail::hypot(x, y); + }, + [](const ExpressionPtr& x, const ExpressionPtr& y, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * y / sleipnir::detail::hypot(x, y); + }, + x, y); +} + +/** + * std::log() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr log( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::log(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::log(x); }, + [](double x, double, double parentAdjoint) { return parentAdjoint / x; }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { return parentAdjoint / x; }, + x); +} + +/** + * std::log10() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr log10( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::log10(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::log10(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint / (std::numbers::ln10 * x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint / (MakeExpressionPtr(std::numbers::ln10) * x); + }, + x); +} + +/** + * std::pow() for Expressions. + * + * @param base The base. + * @param power The power. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr pow( // NOLINT + const ExpressionPtr& base, const ExpressionPtr& power) { + using enum ExpressionType; + + // Prune expression + if (base->IsConstant(0.0)) { + // Return zero + return base; + } else if (base->IsConstant(1.0)) { + return base; + } + if (power->IsConstant(0.0)) { + return MakeExpressionPtr(1.0); + } else if (power->IsConstant(1.0)) { + return base; + } + + // Evaluate constant + if (base->type == kConstant && power->type == kConstant) { + return MakeExpressionPtr(std::pow(base->value, power->value)); + } + + return MakeExpressionPtr( + base->type == kLinear && power->IsConstant(2.0) ? kQuadratic : kNonlinear, + [](double base, double power) { return std::pow(base, power); }, + [](double base, double power, double parentAdjoint) { + return parentAdjoint * std::pow(base, power - 1) * power; + }, + [](double base, double power, double parentAdjoint) { + // Since x * std::log(x) -> 0 as x -> 0 + if (base == 0.0) { + return 0.0; + } else { + return parentAdjoint * std::pow(base, power - 1) * base * + std::log(base); + } + }, + [](const ExpressionPtr& base, const ExpressionPtr& power, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * + sleipnir::detail::pow(base, power - MakeExpressionPtr(1.0)) * + power; + }, + [](const ExpressionPtr& base, const ExpressionPtr& power, + const ExpressionPtr& parentAdjoint) { + // Since x * std::log(x) -> 0 as x -> 0 + if (base->value == 0.0) { + // Return zero + return base; + } else { + return parentAdjoint * + sleipnir::detail::pow(base, power - MakeExpressionPtr(1.0)) * + base * sleipnir::detail::log(base); + } + }, + base, power); +} + +/** + * sign() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr sign(const ExpressionPtr& x) { + using enum ExpressionType; + + // Evaluate constant + if (x->type == kConstant) { + if (x->value < 0.0) { + return MakeExpressionPtr(-1.0); + } else if (x->value == 0.0) { + // Return zero + return x; + } else { + return MakeExpressionPtr(1.0); + } + } + + return MakeExpressionPtr( + kNonlinear, + [](double x, double) { + if (x < 0.0) { + return -1.0; + } else if (x == 0.0) { + return 0.0; + } else { + return 1.0; + } + }, + [](double, double, double) { return 0.0; }, + [](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr&) { + // Return zero + return MakeExpressionPtr(); + }, + x); +} + +/** + * std::sin() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr sin( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::sin(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::sin(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint * std::cos(x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * sleipnir::detail::cos(x); + }, + x); +} + +/** + * std::sinh() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr sinh(const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::sinh(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::sinh(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint * std::cosh(x); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint * sleipnir::detail::cosh(x); + }, + x); +} + +/** + * std::sqrt() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr sqrt( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Evaluate constant + if (x->type == kConstant) { + if (x->value == 0.0) { + // Return zero + return x; + } else if (x->value == 1.0) { + return x; + } else { + return MakeExpressionPtr(std::sqrt(x->value)); + } + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::sqrt(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint / (2.0 * std::sqrt(x)); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint / + (MakeExpressionPtr(2.0) * sleipnir::detail::sqrt(x)); + }, + x); +} + +/** + * std::tan() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr tan( // NOLINT + const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::tan(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::tan(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint / (std::cos(x) * std::cos(x)); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint / + (sleipnir::detail::cos(x) * sleipnir::detail::cos(x)); + }, + x); +} + +/** + * std::tanh() for Expressions. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline ExpressionPtr tanh(const ExpressionPtr& x) { + using enum ExpressionType; + + // Prune expression + if (x->IsConstant(0.0)) { + // Return zero + return x; + } + + // Evaluate constant + if (x->type == kConstant) { + return MakeExpressionPtr(std::tanh(x->value)); + } + + return MakeExpressionPtr( + kNonlinear, [](double x, double) { return std::tanh(x); }, + [](double x, double, double parentAdjoint) { + return parentAdjoint / (std::cosh(x) * std::cosh(x)); + }, + [](const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parentAdjoint) { + return parentAdjoint / + (sleipnir::detail::cosh(x) * sleipnir::detail::cosh(x)); + }, + x); +} + +} // namespace sleipnir::detail diff --git a/sleipnir/src/include/sleipnir/autodiff/ExpressionGraph.hpp b/sleipnir/src/include/sleipnir/autodiff/ExpressionGraph.hpp new file mode 100644 index 0000000..c614195 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/ExpressionGraph.hpp @@ -0,0 +1,243 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include "sleipnir/autodiff/Expression.hpp" +#include "sleipnir/util/FunctionRef.hpp" +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir::detail { + +/** + * This class is an adaptor type that performs value updates of an expression's + * computational graph in a way that skips duplicates. + */ +class SLEIPNIR_DLLEXPORT ExpressionGraph { + public: + /** + * Generates the deduplicated computational graph for the given expression. + * + * @param root The root node of the expression. + */ + explicit ExpressionGraph(ExpressionPtr& root) { + // If the root type is a constant, Update() is a no-op, so there's no work + // to do + if (root == nullptr || root->type == ExpressionType::kConstant) { + return; + } + + // Breadth-first search (BFS) is used as opposed to a depth-first search + // (DFS) to avoid counting duplicate nodes multiple times. A list of nodes + // ordered from parent to child with no duplicates is generated. + // + // https://en.wikipedia.org/wiki/Breadth-first_search + + // BFS list sorted from parent to child. + small_vector stack; + + stack.emplace_back(root.Get()); + + // Initialize the number of instances of each node in the tree + // (Expression::duplications) + while (!stack.empty()) { + auto currentNode = stack.back(); + stack.pop_back(); + + for (auto&& arg : currentNode->args) { + // Only continue if the node is not a constant and hasn't already been + // explored. + if (arg != nullptr && arg->type != ExpressionType::kConstant) { + // If this is the first instance of the node encountered (it hasn't + // been explored yet), add it to stack so it's recursed upon + if (arg->duplications == 0) { + stack.push_back(arg.Get()); + } + ++arg->duplications; + } + } + } + + stack.emplace_back(root.Get()); + + while (!stack.empty()) { + auto currentNode = stack.back(); + stack.pop_back(); + + // BFS lists sorted from parent to child. + m_rowList.emplace_back(currentNode->row); + m_adjointList.emplace_back(currentNode); + if (currentNode->valueFunc != nullptr) { + // Constants are skipped because they have no valueFunc and don't need + // to be updated + m_valueList.emplace_back(currentNode); + } + + for (auto&& arg : currentNode->args) { + // Only add node if it's not a constant and doesn't already exist in the + // tape. + if (arg != nullptr && arg->type != ExpressionType::kConstant) { + // Once the number of node visitations equals the number of + // duplications (the counter hits zero), add it to the stack. Note + // that this means the node is only enqueued once. + --arg->duplications; + if (arg->duplications == 0) { + stack.push_back(arg.Get()); + } + } + } + } + } + + /** + * Update the values of all nodes in this computational tree based on the + * values of their dependent nodes. + */ + void Update() { + // Traverse the BFS list backward from child to parent and update the value + // of each node. + for (auto it = m_valueList.rbegin(); it != m_valueList.rend(); ++it) { + auto& node = *it; + + auto& lhs = node->args[0]; + auto& rhs = node->args[1]; + + if (lhs != nullptr) { + if (rhs != nullptr) { + node->value = node->valueFunc(lhs->value, rhs->value); + } else { + node->value = node->valueFunc(lhs->value, 0.0); + } + } + } + } + + /** + * Returns the variable's gradient tree. + * + * @param wrt Variables with respect to which to compute the gradient. + */ + small_vector GenerateGradientTree( + std::span wrt) const { + // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation + // for background on reverse accumulation automatic differentiation. + + for (size_t row = 0; row < wrt.size(); ++row) { + wrt[row]->row = row; + } + + small_vector grad; + grad.reserve(wrt.size()); + for (size_t row = 0; row < wrt.size(); ++row) { + grad.emplace_back(MakeExpressionPtr()); + } + + // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1. + if (m_adjointList.size() > 0) { + m_adjointList[0]->adjointExpr = MakeExpressionPtr(1.0); + for (auto it = m_adjointList.begin() + 1; it != m_adjointList.end(); + ++it) { + auto& node = *it; + node->adjointExpr = MakeExpressionPtr(); + } + } + + // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y + // multiplied by dy/dx. If there are multiple "paths" from the root node to + // variable; the variable's adjoint is the sum of each path's adjoint + // contribution. + for (auto node : m_adjointList) { + auto& lhs = node->args[0]; + auto& rhs = node->args[1]; + + if (lhs != nullptr && !lhs->IsConstant(0.0)) { + lhs->adjointExpr = lhs->adjointExpr + + node->gradientFuncs[0](lhs, rhs, node->adjointExpr); + } + if (rhs != nullptr && !rhs->IsConstant(0.0)) { + rhs->adjointExpr = rhs->adjointExpr + + node->gradientFuncs[1](lhs, rhs, node->adjointExpr); + } + + // If variable is a leaf node, assign its adjoint to the gradient. + if (node->row != -1) { + grad[node->row] = node->adjointExpr; + } + } + + // Unlink adjoints to avoid circular references between them and their + // parent expressions. This ensures all expressions are returned to the free + // list. + for (auto node : m_adjointList) { + for (auto& arg : node->args) { + if (arg != nullptr) { + arg->adjointExpr = nullptr; + } + } + } + + for (size_t row = 0; row < wrt.size(); ++row) { + wrt[row]->row = -1; + } + + return grad; + } + + /** + * Updates the adjoints in the expression graph, effectively computing the + * gradient. + * + * @param func A function that takes two arguments: an int for the gradient + * row, and a double for the adjoint (gradient value). + */ + void ComputeAdjoints(function_ref func) { + // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1. + m_adjointList[0]->adjoint = 1.0; + for (auto it = m_adjointList.begin() + 1; it != m_adjointList.end(); ++it) { + auto& node = *it; + node->adjoint = 0.0; + } + + // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y + // multiplied by dy/dx. If there are multiple "paths" from the root node to + // variable; the variable's adjoint is the sum of each path's adjoint + // contribution. + for (size_t col = 0; col < m_adjointList.size(); ++col) { + auto& node = m_adjointList[col]; + auto& lhs = node->args[0]; + auto& rhs = node->args[1]; + + if (lhs != nullptr) { + if (rhs != nullptr) { + lhs->adjoint += node->gradientValueFuncs[0](lhs->value, rhs->value, + node->adjoint); + rhs->adjoint += node->gradientValueFuncs[1](lhs->value, rhs->value, + node->adjoint); + } else { + lhs->adjoint += + node->gradientValueFuncs[0](lhs->value, 0.0, node->adjoint); + } + } + + // If variable is a leaf node, assign its adjoint to the gradient. + int row = m_rowList[col]; + if (row != -1) { + func(row, node->adjoint); + } + } + } + + private: + // List that maps nodes to their respective row. + small_vector m_rowList; + + // List for updating adjoints + small_vector m_adjointList; + + // List for updating values + small_vector m_valueList; +}; + +} // namespace sleipnir::detail diff --git a/sleipnir/src/include/sleipnir/autodiff/ExpressionType.hpp b/sleipnir/src/include/sleipnir/autodiff/ExpressionType.hpp new file mode 100644 index 0000000..37825e3 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/ExpressionType.hpp @@ -0,0 +1,27 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +namespace sleipnir { + +/** + * Expression type. + * + * Used for autodiff caching. + */ +enum class ExpressionType : uint8_t { + /// There is no expression. + kNone, + /// The expression is a constant. + kConstant, + /// The expression is composed of linear and lower-order operators. + kLinear, + /// The expression is composed of quadratic and lower-order operators. + kQuadratic, + /// The expression is composed of nonlinear and lower-order operators. + kNonlinear +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/autodiff/Gradient.hpp b/sleipnir/src/include/sleipnir/autodiff/Gradient.hpp new file mode 100644 index 0000000..cf6a417 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/Gradient.hpp @@ -0,0 +1,73 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include + +#include "sleipnir/autodiff/Jacobian.hpp" +#include "sleipnir/autodiff/Profiler.hpp" +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/autodiff/VariableMatrix.hpp" +#include "sleipnir/util/SymbolExports.hpp" + +namespace sleipnir { + +/** + * This class calculates the gradient of a a variable with respect to a vector + * of variables. + * + * The gradient is only recomputed if the variable expression is quadratic or + * higher order. + */ +class SLEIPNIR_DLLEXPORT Gradient { + public: + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the gradient. + * @param wrt Variable with respect to which to compute the gradient. + */ + Gradient(Variable variable, Variable wrt) noexcept + : Gradient{std::move(variable), VariableMatrix{wrt}} {} + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the gradient. + * @param wrt Vector of variables with respect to which to compute the + * gradient. + */ + Gradient(Variable variable, const VariableMatrix& wrt) noexcept + : m_jacobian{variable, wrt} {} + + /** + * Returns the gradient as a VariableMatrix. + * + * This is useful when constructing optimization problems with derivatives in + * them. + */ + VariableMatrix Get() const { return m_jacobian.Get().T(); } + + /** + * Evaluates the gradient at wrt's value. + */ + const Eigen::SparseVector& Value() { + m_g = m_jacobian.Value(); + + return m_g; + } + + /** + * Returns the profiler. + */ + Profiler& GetProfiler() { return m_jacobian.GetProfiler(); } + + private: + Eigen::SparseVector m_g; + + Jacobian m_jacobian; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/autodiff/Hessian.hpp b/sleipnir/src/include/sleipnir/autodiff/Hessian.hpp new file mode 100644 index 0000000..4563aa2 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/Hessian.hpp @@ -0,0 +1,79 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include +#include + +#include "sleipnir/autodiff/ExpressionGraph.hpp" +#include "sleipnir/autodiff/Jacobian.hpp" +#include "sleipnir/autodiff/Profiler.hpp" +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/autodiff/VariableMatrix.hpp" +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * This class calculates the Hessian of a variable with respect to a vector of + * variables. + * + * The gradient tree is cached so subsequent Hessian calculations are faster, + * and the Hessian is only recomputed if the variable expression is nonlinear. + */ +class SLEIPNIR_DLLEXPORT Hessian { + public: + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Vector of variables with respect to which to compute the + * Hessian. + */ + Hessian(Variable variable, const VariableMatrix& wrt) noexcept + : m_jacobian{ + [&] { + small_vector wrtVec; + wrtVec.reserve(wrt.size()); + for (auto& elem : wrt) { + wrtVec.emplace_back(elem.expr); + } + + auto grad = + detail::ExpressionGraph{variable.expr}.GenerateGradientTree( + wrtVec); + + VariableMatrix ret{wrt.Rows()}; + for (int row = 0; row < ret.Rows(); ++row) { + ret(row) = Variable{std::move(grad[row])}; + } + return ret; + }(), + wrt} {} + + /** + * Returns the Hessian as a VariableMatrix. + * + * This is useful when constructing optimization problems with derivatives in + * them. + */ + VariableMatrix Get() const { return m_jacobian.Get(); } + + /** + * Evaluates the Hessian at wrt's value. + */ + const Eigen::SparseMatrix& Value() { return m_jacobian.Value(); } + + /** + * Returns the profiler. + */ + Profiler& GetProfiler() { return m_jacobian.GetProfiler(); } + + private: + Jacobian m_jacobian; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/autodiff/Jacobian.hpp b/sleipnir/src/include/sleipnir/autodiff/Jacobian.hpp new file mode 100644 index 0000000..ac00c11 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/Jacobian.hpp @@ -0,0 +1,155 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include + +#include "sleipnir/autodiff/ExpressionGraph.hpp" +#include "sleipnir/autodiff/Profiler.hpp" +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/autodiff/VariableMatrix.hpp" +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * This class calculates the Jacobian of a vector of variables with respect to a + * vector of variables. + * + * The Jacobian is only recomputed if the variable expression is quadratic or + * higher order. + */ +class SLEIPNIR_DLLEXPORT Jacobian { + public: + /** + * Constructs a Jacobian object. + * + * @param variables Vector of variables of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the + * Jacobian. + */ + Jacobian(const VariableMatrix& variables, const VariableMatrix& wrt) noexcept + : m_variables{std::move(variables)}, m_wrt{std::move(wrt)} { + m_profiler.StartSetup(); + + for (int row = 0; row < m_wrt.Rows(); ++row) { + m_wrt(row).expr->row = row; + } + + for (Variable variable : m_variables) { + m_graphs.emplace_back(variable.expr); + } + + // Reserve triplet space for 99% sparsity + m_cachedTriplets.reserve(m_variables.Rows() * m_wrt.Rows() * 0.01); + + for (int row = 0; row < m_variables.Rows(); ++row) { + if (m_variables(row).Type() == ExpressionType::kLinear) { + // If the row is linear, compute its gradient once here and cache its + // triplets. Constant rows are ignored because their gradients have no + // nonzero triplets. + m_graphs[row].ComputeAdjoints([&](int col, double adjoint) { + m_cachedTriplets.emplace_back(row, col, adjoint); + }); + } else if (m_variables(row).Type() > ExpressionType::kLinear) { + // If the row is quadratic or nonlinear, add it to the list of nonlinear + // rows to be recomputed in Value(). + m_nonlinearRows.emplace_back(row); + } + } + + for (int row = 0; row < m_wrt.Rows(); ++row) { + m_wrt(row).expr->row = -1; + } + + if (m_nonlinearRows.empty()) { + m_J.setFromTriplets(m_cachedTriplets.begin(), m_cachedTriplets.end()); + } + + m_profiler.StopSetup(); + } + + /** + * Returns the Jacobian as a VariableMatrix. + * + * This is useful when constructing optimization problems with derivatives in + * them. + */ + VariableMatrix Get() const { + VariableMatrix result{m_variables.Rows(), m_wrt.Rows()}; + + small_vector wrtVec; + wrtVec.reserve(m_wrt.size()); + for (auto& elem : m_wrt) { + wrtVec.emplace_back(elem.expr); + } + + for (int row = 0; row < m_variables.Rows(); ++row) { + auto grad = m_graphs[row].GenerateGradientTree(wrtVec); + for (int col = 0; col < m_wrt.Rows(); ++col) { + result(row, col) = Variable{std::move(grad[col])}; + } + } + + return result; + } + + /** + * Evaluates the Jacobian at wrt's value. + */ + const Eigen::SparseMatrix& Value() { + if (m_nonlinearRows.empty()) { + return m_J; + } + + m_profiler.StartSolve(); + + for (auto& graph : m_graphs) { + graph.Update(); + } + + // Copy the cached triplets so triplets added for the nonlinear rows are + // thrown away at the end of the function + auto triplets = m_cachedTriplets; + + // Compute each nonlinear row of the Jacobian + for (int row : m_nonlinearRows) { + m_graphs[row].ComputeAdjoints([&](int col, double adjoint) { + triplets.emplace_back(row, col, adjoint); + }); + } + + m_J.setFromTriplets(triplets.begin(), triplets.end()); + + m_profiler.StopSolve(); + + return m_J; + } + + /** + * Returns the profiler. + */ + Profiler& GetProfiler() { return m_profiler; } + + private: + VariableMatrix m_variables; + VariableMatrix m_wrt; + + small_vector m_graphs; + + Eigen::SparseMatrix m_J{m_variables.Rows(), m_wrt.Rows()}; + + // Cached triplets for gradients of linear rows + small_vector> m_cachedTriplets; + + // List of row indices for nonlinear rows whose graients will be computed in + // Value() + small_vector m_nonlinearRows; + + Profiler m_profiler; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/autodiff/Profiler.hpp b/sleipnir/src/include/sleipnir/autodiff/Profiler.hpp new file mode 100644 index 0000000..5c22d14 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/Profiler.hpp @@ -0,0 +1,79 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include "sleipnir/util/SymbolExports.hpp" + +namespace sleipnir { + +/** + * Records the number of profiler measurements (start/stop pairs) and the + * average duration between each start and stop call. + */ +class SLEIPNIR_DLLEXPORT Profiler { + public: + /** + * Tell the profiler to start measuring setup time. + */ + void StartSetup() { m_setupStartTime = std::chrono::system_clock::now(); } + + /** + * Tell the profiler to stop measuring setup time. + */ + void StopSetup() { + m_setupDuration = std::chrono::system_clock::now() - m_setupStartTime; + } + + /** + * Tell the profiler to start measuring solve time. + */ + void StartSolve() { m_solveStartTime = std::chrono::system_clock::now(); } + + /** + * Tell the profiler to stop measuring solve time, increment the number of + * averages, and incorporate the latest measurement into the average. + */ + void StopSolve() { + auto now = std::chrono::system_clock::now(); + ++m_solveMeasurements; + m_averageSolveDuration = + (m_solveMeasurements - 1.0) / m_solveMeasurements * + m_averageSolveDuration + + 1.0 / m_solveMeasurements * (now - m_solveStartTime); + } + + /** + * The setup duration in milliseconds as a double. + */ + double SetupDuration() const { + using std::chrono::duration_cast; + using std::chrono::nanoseconds; + return duration_cast(m_setupDuration).count() / 1e6; + } + + /** + * The number of solve measurements taken. + */ + int SolveMeasurements() const { return m_solveMeasurements; } + + /** + * The average solve duration in milliseconds as a double. + */ + double AverageSolveDuration() const { + using std::chrono::duration_cast; + using std::chrono::nanoseconds; + return duration_cast(m_averageSolveDuration).count() / 1e6; + } + + private: + std::chrono::system_clock::time_point m_setupStartTime; + std::chrono::duration m_setupDuration{0.0}; + + int m_solveMeasurements = 0; + std::chrono::duration m_averageSolveDuration{0.0}; + std::chrono::system_clock::time_point m_solveStartTime; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/autodiff/Variable.hpp b/sleipnir/src/include/sleipnir/autodiff/Variable.hpp new file mode 100644 index 0000000..b176d91 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/Variable.hpp @@ -0,0 +1,422 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include "sleipnir/autodiff/Expression.hpp" +#include "sleipnir/autodiff/ExpressionGraph.hpp" +#include "sleipnir/util/Print.hpp" +#include "sleipnir/util/SymbolExports.hpp" + +namespace sleipnir { + +// Forward declarations for friend declarations in Variable +class SLEIPNIR_DLLEXPORT Hessian; +class SLEIPNIR_DLLEXPORT Jacobian; + +/** + * An autodiff variable pointing to an expression node. + */ +class SLEIPNIR_DLLEXPORT Variable { + public: + /** + * Constructs a linear Variable with a value of zero. + */ + Variable() = default; + + /** + * Constructs a Variable from a double. + * + * @param value The value of the Variable. + */ + Variable(double value) : expr{detail::MakeExpressionPtr(value)} {} // NOLINT + + /** + * Constructs a Variable pointing to the specified expression. + * + * @param expr The autodiff variable. + */ + explicit Variable(const detail::ExpressionPtr& expr) : expr{expr} {} + + /** + * Constructs a Variable pointing to the specified expression. + * + * @param expr The autodiff variable. + */ + explicit Variable(detail::ExpressionPtr&& expr) : expr{std::move(expr)} {} + + /** + * Assignment operator for double. + * + * @param value The value of the Variable. + */ + Variable& operator=(double value) { + expr = detail::MakeExpressionPtr(value); + + return *this; + } + + /** + * Sets Variable's internal value. + * + * @param value The value of the Variable. + */ + void SetValue(double value) { + if (expr->IsConstant(0.0)) { + expr = detail::MakeExpressionPtr(value); + } else { + // We only need to check the first argument since unary and binary + // operators both use it + if (expr->args[0] != nullptr && !expr->args[0]->IsConstant(0.0)) { + sleipnir::println( + stderr, + "WARNING: {}:{}: Modified the value of a dependent variable", + __FILE__, __LINE__); + } + expr->value = value; + } + } + + /** + * Variable-Variable multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT Variable operator*(const Variable& lhs, + const Variable& rhs) { + return Variable{lhs.expr * rhs.expr}; + } + + /** + * Variable-Variable compound multiplication operator. + * + * @param rhs Operator right-hand side. + */ + Variable& operator*=(const Variable& rhs) { + *this = *this * rhs; + return *this; + } + + /** + * Variable-Variable division operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT Variable operator/(const Variable& lhs, + const Variable& rhs) { + return Variable{lhs.expr / rhs.expr}; + } + + /** + * Variable-Variable compound division operator. + * + * @param rhs Operator right-hand side. + */ + Variable& operator/=(const Variable& rhs) { + *this = *this / rhs; + return *this; + } + + /** + * Variable-Variable addition operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs, + const Variable& rhs) { + return Variable{lhs.expr + rhs.expr}; + } + + /** + * Variable-Variable compound addition operator. + * + * @param rhs Operator right-hand side. + */ + Variable& operator+=(const Variable& rhs) { + *this = *this + rhs; + return *this; + } + + /** + * Variable-Variable subtraction operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs, + const Variable& rhs) { + return Variable{lhs.expr - rhs.expr}; + } + + /** + * Variable-Variable compound subtraction operator. + * + * @param rhs Operator right-hand side. + */ + Variable& operator-=(const Variable& rhs) { + *this = *this - rhs; + return *this; + } + + /** + * Unary minus operator. + * + * @param lhs Operand for unary minus. + */ + friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs) { + return Variable{-lhs.expr}; + } + + /** + * Unary plus operator. + * + * @param lhs Operand for unary plus. + */ + friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs) { + return Variable{+lhs.expr}; + } + + /** + * Returns the value of this variable. + */ + double Value() { + // Updates the value of this variable based on the values of its dependent + // variables + detail::ExpressionGraph{expr}.Update(); + + return expr->value; + } + + /** + * Returns the type of this expression (constant, linear, quadratic, or + * nonlinear). + */ + ExpressionType Type() const { return expr->type; } + + private: + /// The expression node. + detail::ExpressionPtr expr = + detail::MakeExpressionPtr(0.0, ExpressionType::kLinear); + + friend SLEIPNIR_DLLEXPORT Variable abs(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable acos(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable asin(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable atan(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable atan2(const Variable& y, + const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable cos(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable cosh(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable erf(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable exp(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, + const Variable& y); + friend SLEIPNIR_DLLEXPORT Variable log(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable log10(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable pow(const Variable& base, + const Variable& power); + friend SLEIPNIR_DLLEXPORT Variable sign(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable sin(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable sinh(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable sqrt(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable tan(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable tanh(const Variable& x); + friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, const Variable& y, + const Variable& z); + + friend class SLEIPNIR_DLLEXPORT Hessian; + friend class SLEIPNIR_DLLEXPORT Jacobian; +}; + +/** + * std::abs() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable abs(const Variable& x) { + return Variable{detail::abs(x.expr)}; +} + +/** + * std::acos() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable acos(const Variable& x) { + return Variable{detail::acos(x.expr)}; +} + +/** + * std::asin() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable asin(const Variable& x) { + return Variable{detail::asin(x.expr)}; +} + +/** + * std::atan() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable atan(const Variable& x) { + return Variable{detail::atan(x.expr)}; +} + +/** + * std::atan2() for Variables. + * + * @param y The y argument. + * @param x The x argument. + */ +SLEIPNIR_DLLEXPORT inline Variable atan2(const Variable& y, const Variable& x) { + return Variable{detail::atan2(y.expr, x.expr)}; +} + +/** + * std::cos() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable cos(const Variable& x) { + return Variable{detail::cos(x.expr)}; +} + +/** + * std::cosh() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable cosh(const Variable& x) { + return Variable{detail::cosh(x.expr)}; +} + +/** + * std::erf() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable erf(const Variable& x) { + return Variable{detail::erf(x.expr)}; +} + +/** + * std::exp() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable exp(const Variable& x) { + return Variable{detail::exp(x.expr)}; +} + +/** + * std::hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + */ +SLEIPNIR_DLLEXPORT inline Variable hypot(const Variable& x, const Variable& y) { + return Variable{detail::hypot(x.expr, y.expr)}; +} + +/** + * std::pow() for Variables. + * + * @param base The base. + * @param power The power. + */ +SLEIPNIR_DLLEXPORT inline Variable pow(const Variable& base, + const Variable& power) { + return Variable{detail::pow(base.expr, power.expr)}; +} + +/** + * std::log() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable log(const Variable& x) { + return Variable{detail::log(x.expr)}; +} + +/** + * std::log10() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable log10(const Variable& x) { + return Variable{detail::log10(x.expr)}; +} + +/** + * sign() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable sign(const Variable& x) { + return Variable{detail::sign(x.expr)}; +} + +/** + * std::sin() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable sin(const Variable& x) { + return Variable{detail::sin(x.expr)}; +} + +/** + * std::sinh() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable sinh(const Variable& x) { + return Variable{detail::sinh(x.expr)}; +} + +/** + * std::sqrt() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable sqrt(const Variable& x) { + return Variable{detail::sqrt(x.expr)}; +} + +/** + * std::tan() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable tan(const Variable& x) { + return Variable{detail::tan(x.expr)}; +} + +/** + * std::tanh() for Variables. + * + * @param x The argument. + */ +SLEIPNIR_DLLEXPORT inline Variable tanh(const Variable& x) { + return Variable{detail::tanh(x.expr)}; +} + +/** + * std::hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @param z The z argument. + */ +SLEIPNIR_DLLEXPORT inline Variable hypot(const Variable& x, const Variable& y, + const Variable& z) { + return Variable{sleipnir::sqrt(sleipnir::pow(x, 2) + sleipnir::pow(y, 2) + + sleipnir::pow(z, 2))}; +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/autodiff/VariableBlock.hpp b/sleipnir/src/include/sleipnir/autodiff/VariableBlock.hpp new file mode 100644 index 0000000..03a1eb6 --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/VariableBlock.hpp @@ -0,0 +1,627 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include + +#include + +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/util/Assert.hpp" +#include "sleipnir/util/FunctionRef.hpp" + +namespace sleipnir { + +/** + * A submatrix of autodiff variables with reference semantics. + * + * @tparam Mat The type of the matrix whose storage this class points to. + */ +template +class VariableBlock { + public: + VariableBlock(const VariableBlock& values) = default; + + /** + * Assigns a VariableBlock to the block. + * + * @param values VariableBlock of values. + */ + VariableBlock& operator=(const VariableBlock& values) { + if (this == &values) { + return *this; + } + + if (m_mat == nullptr) { + m_mat = values.m_mat; + m_rowOffset = values.m_rowOffset; + m_colOffset = values.m_colOffset; + m_blockRows = values.m_blockRows; + m_blockCols = values.m_blockCols; + } else { + Assert(Rows() == values.Rows()); + Assert(Cols() == values.Cols()); + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) = values(row, col); + } + } + } + + return *this; + } + + VariableBlock(VariableBlock&&) = default; + + /** + * Assigns a VariableBlock to the block. + * + * @param values VariableBlock of values. + */ + VariableBlock& operator=(VariableBlock&& values) { + if (this == &values) { + return *this; + } + + if (m_mat == nullptr) { + m_mat = values.m_mat; + m_rowOffset = values.m_rowOffset; + m_colOffset = values.m_colOffset; + m_blockRows = values.m_blockRows; + m_blockCols = values.m_blockCols; + } else { + Assert(Rows() == values.Rows()); + Assert(Cols() == values.Cols()); + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) = values(row, col); + } + } + } + + return *this; + } + + /** + * Constructs a Variable block pointing to all of the given matrix. + * + * @param mat The matrix to which to point. + */ + VariableBlock(Mat& mat) // NOLINT + : m_mat{&mat}, m_blockRows{mat.Rows()}, m_blockCols{mat.Cols()} {} + + /** + * Constructs a Variable block pointing to a subset of the given matrix. + * + * @param mat The matrix to which to point. + * @param rowOffset The block's row offset. + * @param colOffset The block's column offset. + * @param blockRows The number of rows in the block. + * @param blockCols The number of columns in the block. + */ + VariableBlock(Mat& mat, int rowOffset, int colOffset, int blockRows, + int blockCols) + : m_mat{&mat}, + m_rowOffset{rowOffset}, + m_colOffset{colOffset}, + m_blockRows{blockRows}, + m_blockCols{blockCols} {} + + /** + * Assigns a double to the block. + * + * This only works for blocks with one row and one column. + */ + VariableBlock& operator=(double value) { + Assert(Rows() == 1 && Cols() == 1); + + (*this)(0, 0) = value; + + return *this; + } + + /** + * Assigns a double to the block. + * + * This only works for blocks with one row and one column. + * + * @param value Value to assign. + */ + void SetValue(double value) { + Assert(Rows() == 1 && Cols() == 1); + + (*this)(0, 0).SetValue(value); + } + + /** + * Assigns an Eigen matrix to the block. + * + * @param values Eigen matrix of values to assign. + */ + template + VariableBlock& operator=(const Eigen::MatrixBase& values) { + Assert(Rows() == values.rows()); + Assert(Cols() == values.cols()); + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) = values(row, col); + } + } + + return *this; + } + + /** + * Sets block's internal values. + * + * @param values Eigen matrix of values. + */ + template + requires std::same_as + void SetValue(const Eigen::MatrixBase& values) { + Assert(Rows() == values.rows()); + Assert(Cols() == values.cols()); + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col).SetValue(values(row, col)); + } + } + } + + /** + * Assigns a VariableMatrix to the block. + * + * @param values VariableMatrix of values. + */ + VariableBlock& operator=(const Mat& values) { + Assert(Rows() == values.Rows()); + Assert(Cols() == values.Cols()); + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) = values(row, col); + } + } + return *this; + } + + /** + * Assigns a VariableMatrix to the block. + * + * @param values VariableMatrix of values. + */ + VariableBlock& operator=(Mat&& values) { + Assert(Rows() == values.Rows()); + Assert(Cols() == values.Cols()); + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) = std::move(values(row, col)); + } + } + return *this; + } + + /** + * Returns a scalar subblock at the given row and column. + * + * @param row The scalar subblock's row. + * @param col The scalar subblock's column. + */ + Variable& operator()(int row, int col) + requires(!std::is_const_v) + { + Assert(row >= 0 && row < Rows()); + Assert(col >= 0 && col < Cols()); + return (*m_mat)(m_rowOffset + row, m_colOffset + col); + } + + /** + * Returns a scalar subblock at the given row and column. + * + * @param row The scalar subblock's row. + * @param col The scalar subblock's column. + */ + const Variable& operator()(int row, int col) const { + Assert(row >= 0 && row < Rows()); + Assert(col >= 0 && col < Cols()); + return (*m_mat)(m_rowOffset + row, m_colOffset + col); + } + + /** + * Returns a scalar subblock at the given row. + * + * @param row The scalar subblock's row. + */ + Variable& operator()(int row) + requires(!std::is_const_v) + { + Assert(row >= 0 && row < Rows() * Cols()); + return (*this)(row / Cols(), row % Cols()); + } + + /** + * Returns a scalar subblock at the given row. + * + * @param row The scalar subblock's row. + */ + const Variable& operator()(int row) const { + Assert(row >= 0 && row < Rows() * Cols()); + return (*this)(row / Cols(), row % Cols()); + } + + /** + * Returns a block slice of the variable matrix. + * + * @param rowOffset The row offset of the block selection. + * @param colOffset The column offset of the block selection. + * @param blockRows The number of rows in the block selection. + * @param blockCols The number of columns in the block selection. + */ + VariableBlock Block(int rowOffset, int colOffset, int blockRows, + int blockCols) { + Assert(rowOffset >= 0 && rowOffset <= Rows()); + Assert(colOffset >= 0 && colOffset <= Cols()); + Assert(blockRows >= 0 && blockRows <= Rows() - rowOffset); + Assert(blockCols >= 0 && blockCols <= Cols() - colOffset); + return VariableBlock{*m_mat, m_rowOffset + rowOffset, + m_colOffset + colOffset, blockRows, blockCols}; + } + + /** + * Returns a block slice of the variable matrix. + * + * @param rowOffset The row offset of the block selection. + * @param colOffset The column offset of the block selection. + * @param blockRows The number of rows in the block selection. + * @param blockCols The number of columns in the block selection. + */ + const VariableBlock Block(int rowOffset, int colOffset, + int blockRows, int blockCols) const { + Assert(rowOffset >= 0 && rowOffset <= Rows()); + Assert(colOffset >= 0 && colOffset <= Cols()); + Assert(blockRows >= 0 && blockRows <= Rows() - rowOffset); + Assert(blockCols >= 0 && blockCols <= Cols() - colOffset); + return VariableBlock{*m_mat, m_rowOffset + rowOffset, + m_colOffset + colOffset, blockRows, blockCols}; + } + + /** + * Returns a row slice of the variable matrix. + * + * @param row The row to slice. + */ + VariableBlock Row(int row) { + Assert(row >= 0 && row < Rows()); + return Block(row, 0, 1, Cols()); + } + + /** + * Returns a row slice of the variable matrix. + * + * @param row The row to slice. + */ + VariableBlock Row(int row) const { + Assert(row >= 0 && row < Rows()); + return Block(row, 0, 1, Cols()); + } + + /** + * Returns a column slice of the variable matrix. + * + * @param col The column to slice. + */ + VariableBlock Col(int col) { + Assert(col >= 0 && col < Cols()); + return Block(0, col, Rows(), 1); + } + + /** + * Returns a column slice of the variable matrix. + * + * @param col The column to slice. + */ + VariableBlock Col(int col) const { + Assert(col >= 0 && col < Cols()); + return Block(0, col, Rows(), 1); + } + + /** + * Compound matrix multiplication-assignment operator. + * + * @param rhs Variable to multiply. + */ + VariableBlock& operator*=(const VariableBlock& rhs) { + Assert(Cols() == rhs.Rows() && Cols() == rhs.Cols()); + + for (int i = 0; i < Rows(); ++i) { + for (int j = 0; j < rhs.Cols(); ++j) { + Variable sum; + for (int k = 0; k < Cols(); ++k) { + sum += (*this)(i, k) * rhs(k, j); + } + (*this)(i, j) = sum; + } + } + + return *this; + } + + /** + * Compound matrix multiplication-assignment operator (only enabled when lhs + * is a scalar). + * + * @param rhs Variable to multiply. + */ + VariableBlock& operator*=(double rhs) { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) *= rhs; + } + } + + return *this; + } + + /** + * Compound matrix division-assignment operator (only enabled when rhs + * is a scalar). + * + * @param rhs Variable to divide. + */ + VariableBlock& operator/=(const VariableBlock& rhs) { + Assert(rhs.Rows() == 1 && rhs.Cols() == 1); + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) /= rhs(0, 0); + } + } + + return *this; + } + + /** + * Compound matrix division-assignment operator (only enabled when rhs + * is a scalar). + * + * @param rhs Variable to divide. + */ + VariableBlock& operator/=(double rhs) { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) /= rhs; + } + } + + return *this; + } + + /** + * Compound addition-assignment operator. + * + * @param rhs Variable to add. + */ + VariableBlock& operator+=(const VariableBlock& rhs) { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) += rhs(row, col); + } + } + + return *this; + } + + /** + * Compound subtraction-assignment operator. + * + * @param rhs Variable to subtract. + */ + VariableBlock& operator-=(const VariableBlock& rhs) { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) -= rhs(row, col); + } + } + + return *this; + } + + /** + * Returns the transpose of the variable matrix. + */ + std::remove_cv_t T() const { + std::remove_cv_t result{Cols(), Rows()}; + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + result(col, row) = (*this)(row, col); + } + } + + return result; + } + + /** + * Returns number of rows in the matrix. + */ + int Rows() const { return m_blockRows; } + + /** + * Returns number of columns in the matrix. + */ + int Cols() const { return m_blockCols; } + + /** + * Returns an element of the variable matrix. + * + * @param row The row of the element to return. + * @param col The column of the element to return. + */ + double Value(int row, int col) { + Assert(row >= 0 && row < Rows()); + Assert(col >= 0 && col < Cols()); + return (*m_mat)(m_rowOffset + row, m_colOffset + col).Value(); + } + + /** + * Returns a row of the variable column vector. + * + * @param index The index of the element to return. + */ + double Value(int index) { + Assert(index >= 0 && index < Rows() * Cols()); + return (*m_mat)(m_rowOffset + index / m_blockCols, + m_colOffset + index % m_blockCols) + .Value(); + } + + /** + * Returns the contents of the variable matrix. + */ + Eigen::MatrixXd Value() { + Eigen::MatrixXd result{Rows(), Cols()}; + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + result(row, col) = Value(row, col); + } + } + + return result; + } + + /** + * Transforms the matrix coefficient-wise with an unary operator. + * + * @param unaryOp The unary operator to use for the transform operation. + */ + std::remove_cv_t CwiseTransform( + function_ref unaryOp) const { + std::remove_cv_t result{Rows(), Cols()}; + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + result(row, col) = unaryOp((*this)(row, col)); + } + } + + return result; + } + + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = Variable; + using difference_type = std::ptrdiff_t; + using pointer = Variable*; + using reference = Variable&; + + iterator(VariableBlock* mat, int row, int col) + : m_mat{mat}, m_row{row}, m_col{col} {} + + iterator& operator++() { + ++m_col; + if (m_col == m_mat->Cols()) { + m_col = 0; + ++m_row; + } + return *this; + } + iterator operator++(int) { + iterator retval = *this; + ++(*this); + return retval; + } + bool operator==(const iterator&) const = default; + reference operator*() { return (*m_mat)(m_row, m_col); } + + private: + VariableBlock* m_mat; + int m_row; + int m_col; + }; + + class const_iterator { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = Variable; + using difference_type = std::ptrdiff_t; + using pointer = Variable*; + using const_reference = const Variable&; + + const_iterator(const VariableBlock* mat, int row, int col) + : m_mat{mat}, m_row{row}, m_col{col} {} + + const_iterator& operator++() { + ++m_col; + if (m_col == m_mat->Cols()) { + m_col = 0; + ++m_row; + } + return *this; + } + const_iterator operator++(int) { + const_iterator retval = *this; + ++(*this); + return retval; + } + bool operator==(const const_iterator&) const = default; + const_reference operator*() const { return (*m_mat)(m_row, m_col); } + + private: + const VariableBlock* m_mat; + int m_row; + int m_col; + }; + + /** + * Returns begin iterator. + */ + iterator begin() { return iterator(this, 0, 0); } + + /** + * Returns end iterator. + */ + iterator end() { return iterator(this, Rows(), 0); } + + /** + * Returns begin iterator. + */ + const_iterator begin() const { return const_iterator(this, 0, 0); } + + /** + * Returns end iterator. + */ + const_iterator end() const { return const_iterator(this, Rows(), 0); } + + /** + * Returns begin iterator. + */ + const_iterator cbegin() const { return const_iterator(this, 0, 0); } + + /** + * Returns end iterator. + */ + const_iterator cend() const { return const_iterator(this, Rows(), 0); } + + /** + * Returns number of elements in matrix. + */ + size_t size() const { return m_blockRows * m_blockCols; } + + private: + Mat* m_mat = nullptr; + int m_rowOffset = 0; + int m_colOffset = 0; + int m_blockRows = 0; + int m_blockCols = 0; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/autodiff/VariableMatrix.hpp b/sleipnir/src/include/sleipnir/autodiff/VariableMatrix.hpp new file mode 100644 index 0000000..5fa7e9f --- /dev/null +++ b/sleipnir/src/include/sleipnir/autodiff/VariableMatrix.hpp @@ -0,0 +1,1032 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/autodiff/VariableBlock.hpp" +#include "sleipnir/util/Assert.hpp" +#include "sleipnir/util/FunctionRef.hpp" +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * A matrix of autodiff variables. + */ +class SLEIPNIR_DLLEXPORT VariableMatrix { + public: + /** + * Constructs an empty VariableMatrix. + */ + VariableMatrix() = default; + + /** + * Constructs a VariableMatrix column vector with the given rows. + * + * @param rows The number of matrix rows. + */ + explicit VariableMatrix(int rows) : m_rows{rows}, m_cols{1} { + for (int row = 0; row < rows; ++row) { + m_storage.emplace_back(); + } + } + + /** + * Constructs a VariableMatrix with the given dimensions. + * + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + */ + VariableMatrix(int rows, int cols) : m_rows{rows}, m_cols{cols} { + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + m_storage.emplace_back(); + } + } + } + + /** + * Constructs a scalar VariableMatrix from a nested list of Variables. + * + * @param list The nested list of Variables. + */ + VariableMatrix( // NOLINT + std::initializer_list> list) { + // Get row and column counts for destination matrix + m_rows = list.size(); + m_cols = 0; + if (list.size() > 0) { + m_cols = list.begin()->size(); + } + + // Assert the first and latest column counts are the same + for ([[maybe_unused]] + const auto& row : list) { + Assert(list.begin()->size() == row.size()); + } + + m_storage.reserve(Rows() * Cols()); + for (const auto& row : list) { + std::copy(row.begin(), row.end(), std::back_inserter(m_storage)); + } + } + + /** + * Constructs a scalar VariableMatrix from a nested list of doubles. + * + * This overload is for Python bindings only. + * + * @param list The nested list of Variables. + */ + VariableMatrix(const std::vector>& list) { // NOLINT + // Get row and column counts for destination matrix + m_rows = list.size(); + m_cols = 0; + if (list.size() > 0) { + m_cols = list.begin()->size(); + } + + // Assert the first and latest column counts are the same + for ([[maybe_unused]] + const auto& row : list) { + Assert(list.begin()->size() == row.size()); + } + + m_storage.reserve(Rows() * Cols()); + for (const auto& row : list) { + std::copy(row.begin(), row.end(), std::back_inserter(m_storage)); + } + } + + /** + * Constructs a scalar VariableMatrix from a nested list of Variables. + * + * This overload is for Python bindings only. + * + * @param list The nested list of Variables. + */ + VariableMatrix(const std::vector>& list) { // NOLINT + // Get row and column counts for destination matrix + m_rows = list.size(); + m_cols = 0; + if (list.size() > 0) { + m_cols = list.begin()->size(); + } + + // Assert the first and latest column counts are the same + for ([[maybe_unused]] + const auto& row : list) { + Assert(list.begin()->size() == row.size()); + } + + m_storage.reserve(Rows() * Cols()); + for (const auto& row : list) { + std::copy(row.begin(), row.end(), std::back_inserter(m_storage)); + } + } + + /** + * Constructs a VariableMatrix from an Eigen matrix. + * + * @param values Eigen matrix of values. + */ + template + VariableMatrix(const Eigen::MatrixBase& values) // NOLINT + : m_rows{static_cast(values.rows())}, + m_cols{static_cast(values.cols())} { + m_storage.reserve(values.rows() * values.cols()); + for (int row = 0; row < values.rows(); ++row) { + for (int col = 0; col < values.cols(); ++col) { + m_storage.emplace_back(values(row, col)); + } + } + } + + /** + * Constructs a VariableMatrix from an Eigen diagonal matrix. + * + * @param values Diagonal matrix of values. + */ + template + VariableMatrix(const Eigen::DiagonalBase& values) // NOLINT + : m_rows{static_cast(values.rows())}, + m_cols{static_cast(values.cols())} { + m_storage.reserve(values.rows() * values.cols()); + for (int row = 0; row < values.rows(); ++row) { + for (int col = 0; col < values.cols(); ++col) { + if (row == col) { + m_storage.emplace_back(values.diagonal()(row)); + } else { + m_storage.emplace_back(0.0); + } + } + } + } + + /** + * Assigns an Eigen matrix to a VariableMatrix. + * + * @param values Eigen matrix of values. + */ + template + VariableMatrix& operator=(const Eigen::MatrixBase& values) { + Assert(Rows() == values.rows()); + Assert(Cols() == values.cols()); + + for (int row = 0; row < values.rows(); ++row) { + for (int col = 0; col < values.cols(); ++col) { + (*this)(row, col) = values(row, col); + } + } + + return *this; + } + + /** + * Sets the VariableMatrix's internal values. + * + * @param values Eigen matrix of values. + */ + template + requires std::same_as + void SetValue(const Eigen::MatrixBase& values) { + Assert(Rows() == values.rows()); + Assert(Cols() == values.cols()); + + for (int row = 0; row < values.rows(); ++row) { + for (int col = 0; col < values.cols(); ++col) { + (*this)(row, col).SetValue(values(row, col)); + } + } + } + + /** + * Constructs a scalar VariableMatrix from a Variable. + * + * @param variable Variable. + */ + VariableMatrix(const Variable& variable) // NOLINT + : m_rows{1}, m_cols{1} { + m_storage.emplace_back(variable); + } + + /** + * Constructs a scalar VariableMatrix from a Variable. + * + * @param variable Variable. + */ + VariableMatrix(Variable&& variable) : m_rows{1}, m_cols{1} { // NOLINT + m_storage.emplace_back(std::move(variable)); + } + + /** + * Constructs a VariableMatrix from a VariableBlock. + * + * @param values VariableBlock of values. + */ + VariableMatrix(const VariableBlock& values) // NOLINT + : m_rows{values.Rows()}, m_cols{values.Cols()} { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + m_storage.emplace_back(values(row, col)); + } + } + } + + /** + * Constructs a VariableMatrix from a VariableBlock. + * + * @param values VariableBlock of values. + */ + VariableMatrix(const VariableBlock& values) // NOLINT + : m_rows{values.Rows()}, m_cols{values.Cols()} { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + m_storage.emplace_back(values(row, col)); + } + } + } + + /** + * Constructs a column vector wrapper around a Variable array. + * + * @param values Variable array to wrap. + */ + explicit VariableMatrix(std::span values) + : m_rows{static_cast(values.size())}, m_cols{1} { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + m_storage.emplace_back(values[row * Cols() + col]); + } + } + } + + /** + * Constructs a matrix wrapper around a Variable array. + * + * @param values Variable array to wrap. + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + */ + VariableMatrix(std::span values, int rows, int cols) + : m_rows{rows}, m_cols{cols} { + Assert(static_cast(values.size()) == Rows() * Cols()); + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + m_storage.emplace_back(values[row * Cols() + col]); + } + } + } + + /** + * Returns a block pointing to the given row and column. + * + * @param row The block row. + * @param col The block column. + */ + Variable& operator()(int row, int col) { + Assert(row >= 0 && row < Rows()); + Assert(col >= 0 && col < Cols()); + return m_storage[row * Cols() + col]; + } + + /** + * Returns a block pointing to the given row and column. + * + * @param row The block row. + * @param col The block column. + */ + const Variable& operator()(int row, int col) const { + Assert(row >= 0 && row < Rows()); + Assert(col >= 0 && col < Cols()); + return m_storage[row * Cols() + col]; + } + + /** + * Returns a block pointing to the given row. + * + * @param row The block row. + */ + Variable& operator()(int row) { + Assert(row >= 0 && row < Rows() * Cols()); + return m_storage[row]; + } + + /** + * Returns a block pointing to the given row. + * + * @param row The block row. + */ + const Variable& operator()(int row) const { + Assert(row >= 0 && row < Rows() * Cols()); + return m_storage[row]; + } + + /** + * Returns a block slice of the variable matrix. + * + * @param rowOffset The row offset of the block selection. + * @param colOffset The column offset of the block selection. + * @param blockRows The number of rows in the block selection. + * @param blockCols The number of columns in the block selection. + */ + VariableBlock Block(int rowOffset, int colOffset, + int blockRows, int blockCols) { + Assert(rowOffset >= 0 && rowOffset <= Rows()); + Assert(colOffset >= 0 && colOffset <= Cols()); + Assert(blockRows >= 0 && blockRows <= Rows() - rowOffset); + Assert(blockCols >= 0 && blockCols <= Cols() - colOffset); + return VariableBlock{*this, rowOffset, colOffset, blockRows, blockCols}; + } + + /** + * Returns a block slice of the variable matrix. + * + * @param rowOffset The row offset of the block selection. + * @param colOffset The column offset of the block selection. + * @param blockRows The number of rows in the block selection. + * @param blockCols The number of columns in the block selection. + */ + const VariableBlock Block(int rowOffset, int colOffset, + int blockRows, + int blockCols) const { + Assert(rowOffset >= 0 && rowOffset <= Rows()); + Assert(colOffset >= 0 && colOffset <= Cols()); + Assert(blockRows >= 0 && blockRows <= Rows() - rowOffset); + Assert(blockCols >= 0 && blockCols <= Cols() - colOffset); + return VariableBlock{*this, rowOffset, colOffset, blockRows, blockCols}; + } + + /** + * Returns a segment of the variable vector. + * + * @param offset The offset of the segment. + * @param length The length of the segment. + */ + VariableBlock Segment(int offset, int length) { + Assert(offset >= 0 && offset < Rows() * Cols()); + Assert(length >= 0 && length <= Rows() * Cols() - offset); + return Block(offset, 0, length, 1); + } + + /** + * Returns a segment of the variable vector. + * + * @param offset The offset of the segment. + * @param length The length of the segment. + */ + const VariableBlock Segment(int offset, + int length) const { + Assert(offset >= 0 && offset < Rows() * Cols()); + Assert(length >= 0 && length <= Rows() * Cols() - offset); + return Block(offset, 0, length, 1); + } + + /** + * Returns a row slice of the variable matrix. + * + * @param row The row to slice. + */ + VariableBlock Row(int row) { + Assert(row >= 0 && row < Rows()); + return Block(row, 0, 1, Cols()); + } + + /** + * Returns a row slice of the variable matrix. + * + * @param row The row to slice. + */ + const VariableBlock Row(int row) const { + Assert(row >= 0 && row < Rows()); + return Block(row, 0, 1, Cols()); + } + + /** + * Returns a column slice of the variable matrix. + * + * @param col The column to slice. + */ + VariableBlock Col(int col) { + Assert(col >= 0 && col < Cols()); + return Block(0, col, Rows(), 1); + } + + /** + * Returns a column slice of the variable matrix. + * + * @param col The column to slice. + */ + const VariableBlock Col(int col) const { + Assert(col >= 0 && col < Cols()); + return Block(0, col, Rows(), 1); + } + + /** + * Matrix multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix + operator*(const VariableMatrix& lhs, const VariableMatrix& rhs) { + Assert(lhs.Cols() == rhs.Rows()); + + VariableMatrix result{lhs.Rows(), rhs.Cols()}; + + for (int i = 0; i < lhs.Rows(); ++i) { + for (int j = 0; j < rhs.Cols(); ++j) { + Variable sum; + for (int k = 0; k < lhs.Cols(); ++k) { + sum += lhs(i, k) * rhs(k, j); + } + result(i, j) = sum; + } + } + + return result; + } + + /** + * Matrix-scalar multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const VariableMatrix& lhs, + const Variable& rhs) { + VariableMatrix result{lhs.Rows(), lhs.Cols()}; + + for (int row = 0; row < result.Rows(); ++row) { + for (int col = 0; col < result.Cols(); ++col) { + result(row, col) = lhs(row, col) * rhs; + } + } + + return result; + } + + /** + * Matrix-scalar multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const VariableMatrix& lhs, + double rhs) { + return lhs * Variable{rhs}; + } + + /** + * Scalar-matrix multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix + operator*(const Variable& lhs, const VariableMatrix& rhs) { + VariableMatrix result{rhs.Rows(), rhs.Cols()}; + + for (int row = 0; row < result.Rows(); ++row) { + for (int col = 0; col < result.Cols(); ++col) { + result(row, col) = rhs(row, col) * lhs; + } + } + + return result; + } + + /** + * Scalar-matrix multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix + operator*(double lhs, const VariableMatrix& rhs) { + return Variable{lhs} * rhs; + } + + /** + * Compound matrix multiplication-assignment operator. + * + * @param rhs Variable to multiply. + */ + VariableMatrix& operator*=(const VariableMatrix& rhs) { + Assert(Cols() == rhs.Rows() && Cols() == rhs.Cols()); + + for (int i = 0; i < Rows(); ++i) { + for (int j = 0; j < rhs.Cols(); ++j) { + Variable sum; + for (int k = 0; k < Cols(); ++k) { + sum += (*this)(i, k) * rhs(k, j); + } + (*this)(i, j) = sum; + } + } + + return *this; + } + + /** + * Binary division operator (only enabled when rhs is a scalar). + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix operator/(const VariableMatrix& lhs, + const Variable& rhs) { + VariableMatrix result{lhs.Rows(), lhs.Cols()}; + + for (int row = 0; row < result.Rows(); ++row) { + for (int col = 0; col < result.Cols(); ++col) { + result(row, col) = lhs(row, col) / rhs; + } + } + + return result; + } + + /** + * Compound matrix division-assignment operator (only enabled when rhs + * is a scalar). + * + * @param rhs Variable to divide. + */ + VariableMatrix& operator/=(const Variable& rhs) { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) /= rhs; + } + } + + return *this; + } + + /** + * Binary addition operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix + operator+(const VariableMatrix& lhs, const VariableMatrix& rhs) { + VariableMatrix result{lhs.Rows(), lhs.Cols()}; + + for (int row = 0; row < result.Rows(); ++row) { + for (int col = 0; col < result.Cols(); ++col) { + result(row, col) = lhs(row, col) + rhs(row, col); + } + } + + return result; + } + + /** + * Compound addition-assignment operator. + * + * @param rhs Variable to add. + */ + VariableMatrix& operator+=(const VariableMatrix& rhs) { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) += rhs(row, col); + } + } + + return *this; + } + + /** + * Binary subtraction operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix + operator-(const VariableMatrix& lhs, const VariableMatrix& rhs) { + VariableMatrix result{lhs.Rows(), lhs.Cols()}; + + for (int row = 0; row < result.Rows(); ++row) { + for (int col = 0; col < result.Cols(); ++col) { + result(row, col) = lhs(row, col) - rhs(row, col); + } + } + + return result; + } + + /** + * Compound subtraction-assignment operator. + * + * @param rhs Variable to subtract. + */ + VariableMatrix& operator-=(const VariableMatrix& rhs) { + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + (*this)(row, col) -= rhs(row, col); + } + } + + return *this; + } + + /** + * Unary minus operator. + * + * @param lhs Operand for unary minus. + */ + friend SLEIPNIR_DLLEXPORT VariableMatrix + operator-(const VariableMatrix& lhs) { + VariableMatrix result{lhs.Rows(), lhs.Cols()}; + + for (int row = 0; row < result.Rows(); ++row) { + for (int col = 0; col < result.Cols(); ++col) { + result(row, col) = -lhs(row, col); + } + } + + return result; + } + + /** + * Implicit conversion operator from 1x1 VariableMatrix to Variable. + */ + operator Variable() const { // NOLINT + Assert(Rows() == 1 && Cols() == 1); + return (*this)(0, 0); + } + + /** + * Returns the transpose of the variable matrix. + */ + VariableMatrix T() const { + VariableMatrix result{Cols(), Rows()}; + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + result(col, row) = (*this)(row, col); + } + } + + return result; + } + + /** + * Returns number of rows in the matrix. + */ + int Rows() const { return m_rows; } + + /** + * Returns number of columns in the matrix. + */ + int Cols() const { return m_cols; } + + /** + * Returns an element of the variable matrix. + * + * @param row The row of the element to return. + * @param col The column of the element to return. + */ + double Value(int row, int col) { + Assert(row >= 0 && row < Rows()); + Assert(col >= 0 && col < Cols()); + return m_storage[row * Cols() + col].Value(); + } + + /** + * Returns a row of the variable column vector. + * + * @param index The index of the element to return. + */ + double Value(int index) { + Assert(index >= 0 && index < Rows() * Cols()); + return m_storage[index].Value(); + } + + /** + * Returns the contents of the variable matrix. + */ + Eigen::MatrixXd Value() { + Eigen::MatrixXd result{Rows(), Cols()}; + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + result(row, col) = Value(row, col); + } + } + + return result; + } + + /** + * Transforms the matrix coefficient-wise with an unary operator. + * + * @param unaryOp The unary operator to use for the transform operation. + */ + VariableMatrix CwiseTransform( + function_ref unaryOp) const { + VariableMatrix result{Rows(), Cols()}; + + for (int row = 0; row < Rows(); ++row) { + for (int col = 0; col < Cols(); ++col) { + result(row, col) = unaryOp((*this)(row, col)); + } + } + + return result; + } + + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = Variable; + using difference_type = std::ptrdiff_t; + using pointer = Variable*; + using reference = Variable&; + + iterator(VariableMatrix* mat, int row, int col) + : m_mat{mat}, m_row{row}, m_col{col} {} + + iterator& operator++() { + ++m_col; + if (m_col == m_mat->Cols()) { + m_col = 0; + ++m_row; + } + return *this; + } + iterator operator++(int) { + iterator retval = *this; + ++(*this); + return retval; + } + bool operator==(const iterator&) const = default; + reference operator*() { return (*m_mat)(m_row, m_col); } + + private: + VariableMatrix* m_mat; + int m_row; + int m_col; + }; + + class const_iterator { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = Variable; + using difference_type = std::ptrdiff_t; + using pointer = Variable*; + using const_reference = const Variable&; + + const_iterator(const VariableMatrix* mat, int row, int col) + : m_mat{mat}, m_row{row}, m_col{col} {} + + const_iterator& operator++() { + ++m_col; + if (m_col == m_mat->Cols()) { + m_col = 0; + ++m_row; + } + return *this; + } + const_iterator operator++(int) { + const_iterator retval = *this; + ++(*this); + return retval; + } + bool operator==(const const_iterator&) const = default; + const_reference operator*() const { return (*m_mat)(m_row, m_col); } + + private: + const VariableMatrix* m_mat; + int m_row; + int m_col; + }; + + /** + * Returns begin iterator. + */ + iterator begin() { return iterator(this, 0, 0); } + + /** + * Returns end iterator. + */ + iterator end() { return iterator(this, Rows(), 0); } + + /** + * Returns begin iterator. + */ + const_iterator begin() const { return const_iterator(this, 0, 0); } + + /** + * Returns end iterator. + */ + const_iterator end() const { return const_iterator(this, Rows(), 0); } + + /** + * Returns begin iterator. + */ + const_iterator cbegin() const { return const_iterator(this, 0, 0); } + + /** + * Returns end iterator. + */ + const_iterator cend() const { return const_iterator(this, Rows(), 0); } + + /** + * Returns number of elements in matrix. + */ + size_t size() const { return m_rows * m_cols; } + + /** + * Returns a variable matrix filled with zeroes. + * + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + */ + static VariableMatrix Zero(int rows, int cols) { + VariableMatrix result{rows, cols}; + + for (auto& elem : result) { + elem = 0.0; + } + + return result; + } + + /** + * Returns a variable matrix filled with ones. + * + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + */ + static VariableMatrix Ones(int rows, int cols) { + VariableMatrix result{rows, cols}; + + for (auto& elem : result) { + elem = 1.0; + } + + return result; + } + + private: + small_vector m_storage; + int m_rows = 0; + int m_cols = 0; +}; + +/** + * Applies a coefficient-wise reduce operation to two matrices. + * + * @param lhs The left-hand side of the binary operator. + * @param rhs The right-hand side of the binary operator. + * @param binaryOp The binary operator to use for the reduce operation. + */ +SLEIPNIR_DLLEXPORT inline VariableMatrix CwiseReduce( + const VariableMatrix& lhs, const VariableMatrix& rhs, + function_ref binaryOp) { + Assert(lhs.Rows() == rhs.Rows()); + Assert(lhs.Rows() == rhs.Rows()); + + VariableMatrix result{lhs.Rows(), lhs.Cols()}; + + for (int row = 0; row < lhs.Rows(); ++row) { + for (int col = 0; col < lhs.Cols(); ++col) { + result(row, col) = binaryOp(lhs(row, col), rhs(row, col)); + } + } + + return result; +} + +/** + * Assemble a VariableMatrix from a nested list of blocks. + * + * Each row's blocks must have the same height, and the assembled block rows + * must have the same width. For example, for the block matrix [[A, B], [C]] to + * be constructible, the number of rows in A and B must match, and the number of + * columns in [A, B] and [C] must match. + * + * @param list The nested list of blocks. + */ +SLEIPNIR_DLLEXPORT inline VariableMatrix Block( + std::initializer_list> list) { + // Get row and column counts for destination matrix + int rows = 0; + int cols = -1; + for (const auto& row : list) { + if (row.size() > 0) { + rows += row.begin()->Rows(); + } + + // Get number of columns in this row + int latestCols = 0; + for (const auto& elem : row) { + // Assert the first and latest row have the same height + Assert(row.begin()->Rows() == elem.Rows()); + + latestCols += elem.Cols(); + } + + // If this is the first row, record the column count. Otherwise, assert the + // first and latest column counts are the same. + if (cols == -1) { + cols = latestCols; + } else { + Assert(cols == latestCols); + } + } + + VariableMatrix result{rows, cols}; + + int rowOffset = 0; + for (const auto& row : list) { + int colOffset = 0; + for (const auto& elem : row) { + result.Block(rowOffset, colOffset, elem.Rows(), elem.Cols()) = elem; + colOffset += elem.Cols(); + } + rowOffset += row.begin()->Rows(); + } + + return result; +} + +/** + * Assemble a VariableMatrix from a nested list of blocks. + * + * Each row's blocks must have the same height, and the assembled block rows + * must have the same width. For example, for the block matrix [[A, B], [C]] to + * be constructible, the number of rows in A and B must match, and the number of + * columns in [A, B] and [C] must match. + * + * This overload is for Python bindings only. + * + * @param list The nested list of blocks. + */ +SLEIPNIR_DLLEXPORT inline VariableMatrix Block( + const std::vector>& list) { + // Get row and column counts for destination matrix + int rows = 0; + int cols = -1; + for (const auto& row : list) { + if (row.size() > 0) { + rows += row.begin()->Rows(); + } + + // Get number of columns in this row + int latestCols = 0; + for (const auto& elem : row) { + // Assert the first and latest row have the same height + Assert(row.begin()->Rows() == elem.Rows()); + + latestCols += elem.Cols(); + } + + // If this is the first row, record the column count. Otherwise, assert the + // first and latest column counts are the same. + if (cols == -1) { + cols = latestCols; + } else { + Assert(cols == latestCols); + } + } + + VariableMatrix result{rows, cols}; + + int rowOffset = 0; + for (const auto& row : list) { + int colOffset = 0; + for (const auto& elem : row) { + result.Block(rowOffset, colOffset, elem.Rows(), elem.Cols()) = elem; + colOffset += elem.Cols(); + } + rowOffset += row.begin()->Rows(); + } + + return result; +} + +/** + * Solves the VariableMatrix equation AX = B for X. + * + * @param A The left-hand side. + * @param B The right-hand side. + * @return The solution X. + */ +SLEIPNIR_DLLEXPORT VariableMatrix Solve(const VariableMatrix& A, + const VariableMatrix& B); + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/control/OCPSolver.hpp b/sleipnir/src/include/sleipnir/control/OCPSolver.hpp new file mode 100644 index 0000000..adfca44 --- /dev/null +++ b/sleipnir/src/include/sleipnir/control/OCPSolver.hpp @@ -0,0 +1,471 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include +#include + +#include "sleipnir/autodiff/VariableMatrix.hpp" +#include "sleipnir/optimization/OptimizationProblem.hpp" +#include "sleipnir/util/Assert.hpp" +#include "sleipnir/util/Concepts.hpp" +#include "sleipnir/util/FunctionRef.hpp" +#include "sleipnir/util/SymbolExports.hpp" + +namespace sleipnir { + +/** + * Performs 4th order Runge-Kutta integration of dx/dt = f(t, x, u) for dt. + * + * @param f The function to integrate. It must take two arguments x and u. + * @param x The initial value of x. + * @param u The value u held constant over the integration period. + * @param t0 The initial time. + * @param dt The time over which to integrate. + */ +template +State RK4(F&& f, State x, Input u, Time t0, Time dt) { + auto halfdt = dt * 0.5; + State k1 = f(t0, x, u, dt); + State k2 = f(t0 + halfdt, x + k1 * halfdt, u, dt); + State k3 = f(t0 + halfdt, x + k2 * halfdt, u, dt); + State k4 = f(t0 + dt, x + k3 * dt, u, dt); + + return x + (k1 + k2 * 2.0 + k3 * 2.0 + k4) * (dt / 6.0); +} + +/** + * Enum describing an OCP transcription method. + */ +enum class TranscriptionMethod : uint8_t { + /// Each state is a decision variable constrained to the integrated dynamics + /// of the previous state. + kDirectTranscription, + /// The trajectory is modeled as a series of cubic polynomials where the + /// centerpoint slope is constrained. + kDirectCollocation, + /// States depend explicitly as a function of all previous states and all + /// previous inputs. + kSingleShooting +}; + +/** + * Enum describing a type of system dynamics constraints. + */ +enum class DynamicsType : uint8_t { + /// The dynamics are a function in the form dx/dt = f(t, x, u). + kExplicitODE, + /// The dynamics are a function in the form xₖ₊₁ = f(t, xₖ, uₖ). + kDiscrete +}; + +/** + * Enum describing the type of system timestep. + */ +enum class TimestepMethod : uint8_t { + /// The timestep is a fixed constant. + kFixed, + /// The timesteps are allowed to vary as independent decision variables. + kVariable, + /// The timesteps are equal length but allowed to vary as a single decision + /// variable. + kVariableSingle +}; + +/** + * This class allows the user to pose and solve a constrained optimal control + * problem (OCP) in a variety of ways. + * + * The system is transcripted by one of three methods (direct transcription, + * direct collocation, or single-shooting) and additional constraints can be + * added. + * + * In direct transcription, each state is a decision variable constrained to the + * integrated dynamics of the previous state. In direct collocation, the + * trajectory is modeled as a series of cubic polynomials where the centerpoint + * slope is constrained. In single-shooting, states depend explicitly as a + * function of all previous states and all previous inputs. + * + * Explicit ODEs are integrated using RK4. + * + * For explicit ODEs, the function must be in the form dx/dt = f(t, x, u). + * For discrete state transition functions, the function must be in the form + * xₖ₊₁ = f(t, xₖ, uₖ). + * + * Direct collocation requires an explicit ODE. Direct transcription and + * single-shooting can use either an ODE or state transition function. + * + * https://underactuated.mit.edu/trajopt.html goes into more detail on each + * transcription method. + */ +class SLEIPNIR_DLLEXPORT OCPSolver : public OptimizationProblem { + public: + /** + * Build an optimization problem using a system evolution function (explicit + * ODE or discrete state transition function). + * + * @param numStates The number of system states. + * @param numInputs The number of system inputs. + * @param dt The timestep for fixed-step integration. + * @param numSteps The number of control points. + * @param dynamics Function representing an explicit or implicit ODE, or a + * discrete state transition function. + * - Explicit: dx/dt = f(x, u, *) + * - Implicit: f([x dx/dt]', u, *) = 0 + * - State transition: xₖ₊₁ = f(xₖ, uₖ) + * @param dynamicsType The type of system evolution function. + * @param timestepMethod The timestep method. + * @param method The transcription method. + */ + OCPSolver( + int numStates, int numInputs, std::chrono::duration dt, + int numSteps, + function_ref + dynamics, + DynamicsType dynamicsType = DynamicsType::kExplicitODE, + TimestepMethod timestepMethod = TimestepMethod::kFixed, + TranscriptionMethod method = TranscriptionMethod::kDirectTranscription) + : OCPSolver{numStates, + numInputs, + dt, + numSteps, + [=]([[maybe_unused]] const VariableMatrix& t, + const VariableMatrix& x, const VariableMatrix& u, + [[maybe_unused]] + const VariableMatrix& dt) -> VariableMatrix { + return dynamics(x, u); + }, + dynamicsType, + timestepMethod, + method} {} + + /** + * Build an optimization problem using a system evolution function (explicit + * ODE or discrete state transition function). + * + * @param numStates The number of system states. + * @param numInputs The number of system inputs. + * @param dt The timestep for fixed-step integration. + * @param numSteps The number of control points. + * @param dynamics Function representing an explicit or implicit ODE, or a + * discrete state transition function. + * - Explicit: dx/dt = f(t, x, u, *) + * - Implicit: f(t, [x dx/dt]', u, *) = 0 + * - State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt) + * @param dynamicsType The type of system evolution function. + * @param timestepMethod The timestep method. + * @param method The transcription method. + */ + OCPSolver( + int numStates, int numInputs, std::chrono::duration dt, + int numSteps, + function_ref + dynamics, + DynamicsType dynamicsType = DynamicsType::kExplicitODE, + TimestepMethod timestepMethod = TimestepMethod::kFixed, + TranscriptionMethod method = TranscriptionMethod::kDirectTranscription) + : m_numStates{numStates}, + m_numInputs{numInputs}, + m_dt{dt}, + m_numSteps{numSteps}, + m_transcriptionMethod{method}, + m_dynamicsType{dynamicsType}, + m_dynamicsFunction{std::move(dynamics)}, + m_timestepMethod{timestepMethod} { + // u is numSteps + 1 so that the final constraintFunction evaluation works + m_U = DecisionVariable(m_numInputs, m_numSteps + 1); + + if (m_timestepMethod == TimestepMethod::kFixed) { + m_DT = VariableMatrix{1, m_numSteps + 1}; + for (int i = 0; i < numSteps + 1; ++i) { + m_DT(0, i) = m_dt.count(); + } + } else if (m_timestepMethod == TimestepMethod::kVariableSingle) { + Variable DT = DecisionVariable(); + DT.SetValue(m_dt.count()); + + // Set the member variable matrix to track the decision variable + m_DT = VariableMatrix{1, m_numSteps + 1}; + for (int i = 0; i < numSteps + 1; ++i) { + m_DT(0, i) = DT; + } + } else if (m_timestepMethod == TimestepMethod::kVariable) { + m_DT = DecisionVariable(1, m_numSteps + 1); + for (int i = 0; i < numSteps + 1; ++i) { + m_DT(0, i).SetValue(m_dt.count()); + } + } + + if (m_transcriptionMethod == TranscriptionMethod::kDirectTranscription) { + m_X = DecisionVariable(m_numStates, m_numSteps + 1); + ConstrainDirectTranscription(); + } else if (m_transcriptionMethod == + TranscriptionMethod::kDirectCollocation) { + m_X = DecisionVariable(m_numStates, m_numSteps + 1); + ConstrainDirectCollocation(); + } else if (m_transcriptionMethod == TranscriptionMethod::kSingleShooting) { + // In single-shooting the states aren't decision variables, but instead + // depend on the input and previous states + m_X = VariableMatrix{m_numStates, m_numSteps + 1}; + ConstrainSingleShooting(); + } + } + + /** + * Utility function to constrain the initial state. + * + * @param initialState the initial state to constrain to. + */ + template + requires ScalarLike || MatrixLike + void ConstrainInitialState(const T& initialState) { + SubjectTo(InitialState() == initialState); + } + + /** + * Utility function to constrain the final state. + * + * @param finalState the final state to constrain to. + */ + template + requires ScalarLike || MatrixLike + void ConstrainFinalState(const T& finalState) { + SubjectTo(FinalState() == finalState); + } + + /** + * Set the constraint evaluation function. This function is called + * `numSteps+1` times, with the corresponding state and input + * VariableMatrices. + * + * @param callback The callback f(x, u) where x is the state and u is the + * input vector. + */ + void ForEachStep( + const function_ref + callback) { + for (int i = 0; i < m_numSteps + 1; ++i) { + auto x = X().Col(i); + auto u = U().Col(i); + callback(x, u); + } + } + + /** + * Set the constraint evaluation function. This function is called + * `numSteps+1` times, with the corresponding state and input + * VariableMatrices. + * + * @param callback The callback f(t, x, u, dt) where t is time, x is the state + * vector, u is the input vector, and dt is the timestep duration. + */ + void ForEachStep( + const function_ref + callback) { + Variable time = 0.0; + + for (int i = 0; i < m_numSteps + 1; ++i) { + auto x = X().Col(i); + auto u = U().Col(i); + auto dt = DT()(0, i); + callback(time, x, u, dt); + + time += dt; + } + } + + /** + * Convenience function to set a lower bound on the input. + * + * @param lowerBound The lower bound that inputs must always be above. Must be + * shaped (numInputs)x1. + */ + template + requires ScalarLike || MatrixLike + void SetLowerInputBound(const T& lowerBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + SubjectTo(U().Col(i) >= lowerBound); + } + } + + /** + * Convenience function to set an upper bound on the input. + * + * @param upperBound The upper bound that inputs must always be below. Must be + * shaped (numInputs)x1. + */ + template + requires ScalarLike || MatrixLike + void SetUpperInputBound(const T& upperBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + SubjectTo(U().Col(i) <= upperBound); + } + } + + /** + * Convenience function to set an upper bound on the timestep. + * + * @param maxTimestep The maximum timestep. + */ + void SetMaxTimestep(std::chrono::duration maxTimestep) { + SubjectTo(DT() <= maxTimestep.count()); + } + + /** + * Convenience function to set a lower bound on the timestep. + * + * @param minTimestep The minimum timestep. + */ + void SetMinTimestep(std::chrono::duration minTimestep) { + SubjectTo(DT() >= minTimestep.count()); + } + + /** + * Get the state variables. After the problem is solved, this will contain the + * optimized trajectory. + * + * Shaped (numStates)x(numSteps+1). + * + * @returns The state variable matrix. + */ + VariableMatrix& X() { return m_X; }; + + /** + * Get the input variables. After the problem is solved, this will contain the + * inputs corresponding to the optimized trajectory. + * + * Shaped (numInputs)x(numSteps+1), although the last input step is unused in + * the trajectory. + * + * @returns The input variable matrix. + */ + VariableMatrix& U() { return m_U; }; + + /** + * Get the timestep variables. After the problem is solved, this will contain + * the timesteps corresponding to the optimized trajectory. + * + * Shaped 1x(numSteps+1), although the last timestep is unused in + * the trajectory. + * + * @returns The timestep variable matrix. + */ + VariableMatrix& DT() { return m_DT; }; + + /** + * Convenience function to get the initial state in the trajectory. + * + * @returns The initial state of the trajectory. + */ + VariableMatrix InitialState() { return m_X.Col(0); } + + /** + * Convenience function to get the final state in the trajectory. + * + * @returns The final state of the trajectory. + */ + VariableMatrix FinalState() { return m_X.Col(m_numSteps); } + + private: + void ConstrainDirectCollocation() { + Assert(m_dynamicsType == DynamicsType::kExplicitODE); + + Variable time = 0.0; + + // Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/ + for (int i = 0; i < m_numSteps; ++i) { + Variable h = DT()(0, i); + + auto& f = m_dynamicsFunction; + + auto t_begin = time; + auto t_end = t_begin + h; + + auto x_begin = X().Col(i); + auto x_end = X().Col(i + 1); + + auto u_begin = U().Col(i); + auto u_end = U().Col(i + 1); + + auto xdot_begin = f(t_begin, x_begin, u_begin, h); + auto xdot_end = f(t_end, x_end, u_end, h); + auto xdot_c = + -3 / (2 * h) * (x_begin - x_end) - 0.25 * (xdot_begin + xdot_end); + + auto t_c = t_begin + 0.5 * h; + auto x_c = 0.5 * (x_begin + x_end) + h / 8 * (xdot_begin - xdot_end); + auto u_c = 0.5 * (u_begin + u_end); + + SubjectTo(xdot_c == f(t_c, x_c, u_c, h)); + + time += h; + } + } + + void ConstrainDirectTranscription() { + Variable time = 0.0; + + for (int i = 0; i < m_numSteps; ++i) { + auto x_begin = X().Col(i); + auto x_end = X().Col(i + 1); + auto u = U().Col(i); + Variable dt = DT()(0, i); + + if (m_dynamicsType == DynamicsType::kExplicitODE) { + SubjectTo(x_end == RK4( + m_dynamicsFunction, x_begin, u, time, dt)); + } else if (m_dynamicsType == DynamicsType::kDiscrete) { + SubjectTo(x_end == m_dynamicsFunction(time, x_begin, u, dt)); + } + + time += dt; + } + } + + void ConstrainSingleShooting() { + Variable time = 0.0; + + for (int i = 0; i < m_numSteps; ++i) { + auto x_begin = X().Col(i); + auto x_end = X().Col(i + 1); + auto u = U().Col(i); + Variable dt = DT()(0, i); + + if (m_dynamicsType == DynamicsType::kExplicitODE) { + x_end = RK4(m_dynamicsFunction, x_begin, u, + time, dt); + } else if (m_dynamicsType == DynamicsType::kDiscrete) { + x_end = m_dynamicsFunction(time, x_begin, u, dt); + } + + time += dt; + } + } + + int m_numStates; + int m_numInputs; + std::chrono::duration m_dt; + int m_numSteps; + TranscriptionMethod m_transcriptionMethod; + + DynamicsType m_dynamicsType; + + function_ref + m_dynamicsFunction; + + TimestepMethod m_timestepMethod; + + VariableMatrix m_X; + VariableMatrix m_U; + VariableMatrix m_DT; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/optimization/Constraints.hpp b/sleipnir/src/include/sleipnir/optimization/Constraints.hpp new file mode 100644 index 0000000..80da66b --- /dev/null +++ b/sleipnir/src/include/sleipnir/optimization/Constraints.hpp @@ -0,0 +1,324 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include +#include +#include + +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/util/Assert.hpp" +#include "sleipnir/util/Concepts.hpp" +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * Make a list of constraints. + * + * The standard form for equality constraints is c(x) = 0, and the standard form + * for inequality constraints is c(x) ≥ 0. This function takes constraints of + * the form lhs = rhs or lhs ≥ rhs and converts them to lhs - rhs = 0 or + * lhs - rhs ≥ 0. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + */ +template + requires(ScalarLike> || MatrixLike>) && + (ScalarLike> || MatrixLike>) && + (!std::same_as, double> || + !std::same_as, double>) +small_vector MakeConstraints(LHS&& lhs, RHS&& rhs) { + small_vector constraints; + + if constexpr (ScalarLike> && + ScalarLike>) { + constraints.emplace_back(lhs - rhs); + } else if constexpr (ScalarLike> && + MatrixLike>) { + int rows; + int cols; + if constexpr (EigenMatrixLike>) { + rows = rhs.rows(); + cols = rhs.cols(); + } else { + rows = rhs.Rows(); + cols = rhs.Cols(); + } + + constraints.reserve(rows * cols); + + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + // Make right-hand side zero + constraints.emplace_back(lhs - rhs(row, col)); + } + } + } else if constexpr (MatrixLike> && + ScalarLike>) { + int rows; + int cols; + if constexpr (EigenMatrixLike>) { + rows = lhs.rows(); + cols = lhs.cols(); + } else { + rows = lhs.Rows(); + cols = lhs.Cols(); + } + + constraints.reserve(rows * cols); + + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + // Make right-hand side zero + constraints.emplace_back(lhs(row, col) - rhs); + } + } + } else if constexpr (MatrixLike> && + MatrixLike>) { + int lhsRows; + int lhsCols; + if constexpr (EigenMatrixLike>) { + lhsRows = lhs.rows(); + lhsCols = lhs.cols(); + } else { + lhsRows = lhs.Rows(); + lhsCols = lhs.Cols(); + } + + [[maybe_unused]] + int rhsRows; + [[maybe_unused]] + int rhsCols; + if constexpr (EigenMatrixLike>) { + rhsRows = rhs.rows(); + rhsCols = rhs.cols(); + } else { + rhsRows = rhs.Rows(); + rhsCols = rhs.Cols(); + } + + Assert(lhsRows == rhsRows && lhsCols == rhsCols); + constraints.reserve(lhsRows * lhsCols); + + for (int row = 0; row < lhsRows; ++row) { + for (int col = 0; col < lhsCols; ++col) { + // Make right-hand side zero + constraints.emplace_back(lhs(row, col) - rhs(row, col)); + } + } + } + + return constraints; +} + +/** + * A vector of equality constraints of the form cₑ(x) = 0. + */ +struct SLEIPNIR_DLLEXPORT EqualityConstraints { + /// A vector of scalar equality constraints. + small_vector constraints; + + /** + * Concatenates multiple equality constraints. + * + * @param equalityConstraints The list of EqualityConstraints to concatenate. + */ + EqualityConstraints( + std::initializer_list equalityConstraints) { + for (const auto& elem : equalityConstraints) { + constraints.insert(constraints.end(), elem.constraints.begin(), + elem.constraints.end()); + } + } + + /** + * Concatenates multiple equality constraints. + * + * This overload is for Python bindings only. + * + * @param equalityConstraints The list of EqualityConstraints to concatenate. + */ + explicit EqualityConstraints( + const std::vector& equalityConstraints) { + for (const auto& elem : equalityConstraints) { + constraints.insert(constraints.end(), elem.constraints.begin(), + elem.constraints.end()); + } + } + + /** + * Constructs an equality constraint from a left and right side. + * + * The standard form for equality constraints is c(x) = 0. This function takes + * a constraint of the form lhs = rhs and converts it to lhs - rhs = 0. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + */ + template + requires(ScalarLike> || MatrixLike>) && + (ScalarLike> || MatrixLike>) && + (!std::same_as, double> || + !std::same_as, double>) + EqualityConstraints(LHS&& lhs, RHS&& rhs) + : constraints{MakeConstraints(lhs, rhs)} {} + + /** + * Implicit conversion operator to bool. + */ + operator bool() { // NOLINT + return std::all_of( + constraints.begin(), constraints.end(), + [](auto& constraint) { return constraint.Value() == 0.0; }); + } +}; + +/** + * A vector of inequality constraints of the form cᵢ(x) ≥ 0. + */ +struct SLEIPNIR_DLLEXPORT InequalityConstraints { + /// A vector of scalar inequality constraints. + small_vector constraints; + + /** + * Concatenates multiple inequality constraints. + * + * @param inequalityConstraints The list of InequalityConstraints to + * concatenate. + */ + InequalityConstraints( + std::initializer_list inequalityConstraints) { + for (const auto& elem : inequalityConstraints) { + constraints.insert(constraints.end(), elem.constraints.begin(), + elem.constraints.end()); + } + } + + /** + * Concatenates multiple inequality constraints. + * + * This overload is for Python bindings only. + * + * @param inequalityConstraints The list of InequalityConstraints to + * concatenate. + */ + explicit InequalityConstraints( + const std::vector& inequalityConstraints) { + for (const auto& elem : inequalityConstraints) { + constraints.insert(constraints.end(), elem.constraints.begin(), + elem.constraints.end()); + } + } + + /** + * Constructs an inequality constraint from a left and right side. + * + * The standard form for inequality constraints is c(x) ≥ 0. This function + * takes a constraints of the form lhs ≥ rhs and converts it to lhs - rhs ≥ 0. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + */ + template + requires(ScalarLike> || MatrixLike>) && + (ScalarLike> || MatrixLike>) && + (!std::same_as, double> || + !std::same_as, double>) + InequalityConstraints(LHS&& lhs, RHS&& rhs) + : constraints{MakeConstraints(lhs, rhs)} {} + + /** + * Implicit conversion operator to bool. + */ + operator bool() { // NOLINT + return std::all_of( + constraints.begin(), constraints.end(), + [](auto& constraint) { return constraint.Value() >= 0.0; }); + } +}; + +/** + * Equality operator that returns an equality constraint for two Variables. + * + * @param lhs Left-hand side. + * @param rhs Left-hand side. + */ +template + requires(ScalarLike> || MatrixLike>) && + (ScalarLike> || MatrixLike>) && + (!std::same_as, double> || + !std::same_as, double>) +EqualityConstraints operator==(LHS&& lhs, RHS&& rhs) { + return EqualityConstraints{lhs, rhs}; +} + +/** + * Less-than comparison operator that returns an inequality constraint for two + * Variables. + * + * @param lhs Left-hand side. + * @param rhs Left-hand side. + */ +template + requires(ScalarLike> || MatrixLike>) && + (ScalarLike> || MatrixLike>) && + (!std::same_as, double> || + !std::same_as, double>) +InequalityConstraints operator<(LHS&& lhs, RHS&& rhs) { + return rhs >= lhs; +} + +/** + * Less-than-or-equal-to comparison operator that returns an inequality + * constraint for two Variables. + * + * @param lhs Left-hand side. + * @param rhs Left-hand side. + */ +template + requires(ScalarLike> || MatrixLike>) && + (ScalarLike> || MatrixLike>) && + (!std::same_as, double> || + !std::same_as, double>) +InequalityConstraints operator<=(LHS&& lhs, RHS&& rhs) { + return rhs >= lhs; +} + +/** + * Greater-than comparison operator that returns an inequality constraint for + * two Variables. + * + * @param lhs Left-hand side. + * @param rhs Left-hand side. + */ +template + requires(ScalarLike> || MatrixLike>) && + (ScalarLike> || MatrixLike>) && + (!std::same_as, double> || + !std::same_as, double>) +InequalityConstraints operator>(LHS&& lhs, RHS&& rhs) { + return lhs >= rhs; +} + +/** + * Greater-than-or-equal-to comparison operator that returns an inequality + * constraint for two Variables. + * + * @param lhs Left-hand side. + * @param rhs Left-hand side. + */ +template + requires(ScalarLike> || MatrixLike>) && + (ScalarLike> || MatrixLike>) && + (!std::same_as, double> || + !std::same_as, double>) +InequalityConstraints operator>=(LHS&& lhs, RHS&& rhs) { + return InequalityConstraints{lhs, rhs}; +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/optimization/Multistart.hpp b/sleipnir/src/include/sleipnir/optimization/Multistart.hpp new file mode 100644 index 0000000..8055713 --- /dev/null +++ b/sleipnir/src/include/sleipnir/optimization/Multistart.hpp @@ -0,0 +1,74 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include + +#include "sleipnir/optimization/SolverStatus.hpp" +#include "sleipnir/util/FunctionRef.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * The result of a multistart solve. + * + * @tparam DecisionVariables The type containing the decision variable initial + * guess. + */ +template +struct MultistartResult { + SolverStatus status; + DecisionVariables variables; +}; + +/** + * Solves an optimization problem from different starting points in parallel, + * then returns the solution with the lowest cost. + * + * Each solve is performed on a separate thread. Solutions from successful + * solves are always preferred over solutions from unsuccessful solves, and cost + * (lower is better) is the tiebreaker between successful solves. + * + * @tparam DecisionVariables The type containing the decision variable initial + * guess. + * @param solve A user-provided function that takes a decision variable initial + * guess and returns a MultistartResult. + * @param initialGuesses A list of decision variable initial guesses to try. + */ +template +MultistartResult Multistart( + function_ref( + const DecisionVariables& initialGuess)> + solve, + std::span initialGuesses) { + small_vector>> futures; + futures.reserve(initialGuesses.size()); + + for (const auto& initialGuess : initialGuesses) { + futures.emplace_back(std::async(std::launch::async, solve, initialGuess)); + } + + small_vector> results; + results.reserve(futures.size()); + + for (auto& future : futures) { + results.emplace_back(future.get()); + } + + return *std::min_element( + results.cbegin(), results.cend(), [](const auto& a, const auto& b) { + // Prioritize successful solve + if (a.status.exitCondition == SolverExitCondition::kSuccess && + b.status.exitCondition != SolverExitCondition::kSuccess) { + return true; + } + + // Otherwise prioritize solution with lower cost + return a.status.cost < b.status.cost; + }); +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/optimization/OptimizationProblem.hpp b/sleipnir/src/include/sleipnir/optimization/OptimizationProblem.hpp new file mode 100644 index 0000000..7d387e0 --- /dev/null +++ b/sleipnir/src/include/sleipnir/optimization/OptimizationProblem.hpp @@ -0,0 +1,381 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/autodiff/VariableMatrix.hpp" +#include "sleipnir/optimization/Constraints.hpp" +#include "sleipnir/optimization/SolverConfig.hpp" +#include "sleipnir/optimization/SolverExitCondition.hpp" +#include "sleipnir/optimization/SolverIterationInfo.hpp" +#include "sleipnir/optimization/SolverStatus.hpp" +#include "sleipnir/optimization/solver/InteriorPoint.hpp" +#include "sleipnir/util/Print.hpp" +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * This class allows the user to pose a constrained nonlinear optimization + * problem in natural mathematical notation and solve it. + * + * This class supports problems of the form: +@verbatim + minₓ f(x) +subject to cₑ(x) = 0 + cᵢ(x) ≥ 0 +@endverbatim + * + * where f(x) is the scalar cost function, x is the vector of decision variables + * (variables the solver can tweak to minimize the cost function), cᵢ(x) are the + * inequality constraints, and cₑ(x) are the equality constraints. Constraints + * are equations or inequalities of the decision variables that constrain what + * values the solver is allowed to use when searching for an optimal solution. + * + * The nice thing about this class is users don't have to put their system in + * the form shown above manually; they can write it in natural mathematical form + * and it'll be converted for them. + */ +class SLEIPNIR_DLLEXPORT OptimizationProblem { + public: + /** + * Construct the optimization problem. + */ + OptimizationProblem() noexcept = default; + + /** + * Create a decision variable in the optimization problem. + */ + [[nodiscard]] + Variable DecisionVariable() { + m_decisionVariables.emplace_back(); + return m_decisionVariables.back(); + } + + /** + * Create a matrix of decision variables in the optimization problem. + * + * @param rows Number of matrix rows. + * @param cols Number of matrix columns. + */ + [[nodiscard]] + VariableMatrix DecisionVariable(int rows, int cols = 1) { + m_decisionVariables.reserve(m_decisionVariables.size() + rows * cols); + + VariableMatrix vars{rows, cols}; + + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + m_decisionVariables.emplace_back(); + vars(row, col) = m_decisionVariables.back(); + } + } + + return vars; + } + + /** + * Create a symmetric matrix of decision variables in the optimization + * problem. + * + * Variable instances are reused across the diagonal, which helps reduce + * problem dimensionality. + * + * @param rows Number of matrix rows. + */ + [[nodiscard]] + VariableMatrix SymmetricDecisionVariable(int rows) { + // We only need to store the lower triangle of an n x n symmetric matrix; + // the other elements are duplicates. The lower triangle has (n² + n)/2 + // elements. + // + // n + // Σ k = (n² + n)/2 + // k=1 + m_decisionVariables.reserve(m_decisionVariables.size() + + (rows * rows + rows) / 2); + + VariableMatrix vars{rows, rows}; + + for (int row = 0; row < rows; ++row) { + for (int col = 0; col <= row; ++col) { + m_decisionVariables.emplace_back(); + vars(row, col) = m_decisionVariables.back(); + vars(col, row) = m_decisionVariables.back(); + } + } + + return vars; + } + + /** + * Tells the solver to minimize the output of the given cost function. + * + * Note that this is optional. If only constraints are specified, the solver + * will find the closest solution to the initial conditions that's in the + * feasible set. + * + * @param cost The cost function to minimize. + */ + void Minimize(const Variable& cost) { + m_f = cost; + status.costFunctionType = m_f.value().Type(); + } + + /** + * Tells the solver to minimize the output of the given cost function. + * + * Note that this is optional. If only constraints are specified, the solver + * will find the closest solution to the initial conditions that's in the + * feasible set. + * + * @param cost The cost function to minimize. + */ + void Minimize(Variable&& cost) { + m_f = std::move(cost); + status.costFunctionType = m_f.value().Type(); + } + + /** + * Tells the solver to maximize the output of the given objective function. + * + * Note that this is optional. If only constraints are specified, the solver + * will find the closest solution to the initial conditions that's in the + * feasible set. + * + * @param objective The objective function to maximize. + */ + void Maximize(const Variable& objective) { + // Maximizing a cost function is the same as minimizing its negative + m_f = -objective; + status.costFunctionType = m_f.value().Type(); + } + + /** + * Tells the solver to maximize the output of the given objective function. + * + * Note that this is optional. If only constraints are specified, the solver + * will find the closest solution to the initial conditions that's in the + * feasible set. + * + * @param objective The objective function to maximize. + */ + void Maximize(Variable&& objective) { + // Maximizing a cost function is the same as minimizing its negative + m_f = -std::move(objective); + status.costFunctionType = m_f.value().Type(); + } + + /** + * Tells the solver to solve the problem while satisfying the given equality + * constraint. + * + * @param constraint The constraint to satisfy. + */ + void SubjectTo(const EqualityConstraints& constraint) { + // Get the highest order equality constraint expression type + for (const auto& c : constraint.constraints) { + status.equalityConstraintType = + std::max(status.equalityConstraintType, c.Type()); + } + + m_equalityConstraints.reserve(m_equalityConstraints.size() + + constraint.constraints.size()); + std::copy(constraint.constraints.begin(), constraint.constraints.end(), + std::back_inserter(m_equalityConstraints)); + } + + /** + * Tells the solver to solve the problem while satisfying the given equality + * constraint. + * + * @param constraint The constraint to satisfy. + */ + void SubjectTo(EqualityConstraints&& constraint) { + // Get the highest order equality constraint expression type + for (const auto& c : constraint.constraints) { + status.equalityConstraintType = + std::max(status.equalityConstraintType, c.Type()); + } + + m_equalityConstraints.reserve(m_equalityConstraints.size() + + constraint.constraints.size()); + std::copy(constraint.constraints.begin(), constraint.constraints.end(), + std::back_inserter(m_equalityConstraints)); + } + + /** + * Tells the solver to solve the problem while satisfying the given inequality + * constraint. + * + * @param constraint The constraint to satisfy. + */ + void SubjectTo(const InequalityConstraints& constraint) { + // Get the highest order inequality constraint expression type + for (const auto& c : constraint.constraints) { + status.inequalityConstraintType = + std::max(status.inequalityConstraintType, c.Type()); + } + + m_inequalityConstraints.reserve(m_inequalityConstraints.size() + + constraint.constraints.size()); + std::copy(constraint.constraints.begin(), constraint.constraints.end(), + std::back_inserter(m_inequalityConstraints)); + } + + /** + * Tells the solver to solve the problem while satisfying the given inequality + * constraint. + * + * @param constraint The constraint to satisfy. + */ + void SubjectTo(InequalityConstraints&& constraint) { + // Get the highest order inequality constraint expression type + for (const auto& c : constraint.constraints) { + status.inequalityConstraintType = + std::max(status.inequalityConstraintType, c.Type()); + } + + m_inequalityConstraints.reserve(m_inequalityConstraints.size() + + constraint.constraints.size()); + std::copy(constraint.constraints.begin(), constraint.constraints.end(), + std::back_inserter(m_inequalityConstraints)); + } + + /** + * Solve the optimization problem. The solution will be stored in the original + * variables used to construct the problem. + * + * @param config Configuration options for the solver. + */ + SolverStatus Solve(const SolverConfig& config = SolverConfig{}) { + // Create the initial value column vector + Eigen::VectorXd x{m_decisionVariables.size()}; + for (size_t i = 0; i < m_decisionVariables.size(); ++i) { + x(i) = m_decisionVariables[i].Value(); + } + + status.exitCondition = SolverExitCondition::kSuccess; + + // If there's no cost function, make it zero and continue + if (!m_f.has_value()) { + m_f = Variable(); + } + + if (config.diagnostics) { + constexpr std::array kExprTypeToName{"empty", "constant", "linear", + "quadratic", "nonlinear"}; + + // Print cost function and constraint expression types + sleipnir::println( + "The cost function is {}.", + kExprTypeToName[static_cast(status.costFunctionType)]); + sleipnir::println( + "The equality constraints are {}.", + kExprTypeToName[static_cast(status.equalityConstraintType)]); + sleipnir::println( + "The inequality constraints are {}.", + kExprTypeToName[static_cast(status.inequalityConstraintType)]); + sleipnir::println(""); + + // Print problem dimensionality + sleipnir::println("Number of decision variables: {}", + m_decisionVariables.size()); + sleipnir::println("Number of equality constraints: {}", + m_equalityConstraints.size()); + sleipnir::println("Number of inequality constraints: {}\n", + m_inequalityConstraints.size()); + } + + // If the problem is empty or constant, there's nothing to do + if (status.costFunctionType <= ExpressionType::kConstant && + status.equalityConstraintType <= ExpressionType::kConstant && + status.inequalityConstraintType <= ExpressionType::kConstant) { + return status; + } + + // Solve the optimization problem + Eigen::VectorXd s = Eigen::VectorXd::Ones(m_inequalityConstraints.size()); + InteriorPoint(m_decisionVariables, m_equalityConstraints, + m_inequalityConstraints, m_f.value(), m_callback, config, + false, x, s, &status); + + if (config.diagnostics) { + sleipnir::println("Exit condition: {}", ToMessage(status.exitCondition)); + } + + // Assign the solution to the original Variable instances + VariableMatrix{m_decisionVariables}.SetValue(x); + + return status; + } + + /** + * Sets a callback to be called at each solver iteration. + * + * The callback for this overload should return void. + * + * @param callback The callback. + */ + template + requires requires(F callback, const SolverIterationInfo& info) { + { callback(info) } -> std::same_as; + } + void Callback(F&& callback) { + m_callback = [=, callback = std::forward(callback)]( + const SolverIterationInfo& info) { + callback(info); + return false; + }; + } + + /** + * Sets a callback to be called at each solver iteration. + * + * The callback for this overload should return bool. + * + * @param callback The callback. Returning true from the callback causes the + * solver to exit early with the solution it has so far. + */ + template + requires requires(F callback, const SolverIterationInfo& info) { + { callback(info) } -> std::same_as; + } + void Callback(F&& callback) { + m_callback = std::forward(callback); + } + + private: + // The list of decision variables, which are the root of the problem's + // expression tree + small_vector m_decisionVariables; + + // The cost function: f(x) + std::optional m_f; + + // The list of equality constraints: cₑ(x) = 0 + small_vector m_equalityConstraints; + + // The list of inequality constraints: cᵢ(x) ≥ 0 + small_vector m_inequalityConstraints; + + // The user callback + std::function m_callback = + [](const SolverIterationInfo&) { return false; }; + + // The solver status + SolverStatus status; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/optimization/SolverConfig.hpp b/sleipnir/src/include/sleipnir/optimization/SolverConfig.hpp new file mode 100644 index 0000000..f7323f7 --- /dev/null +++ b/sleipnir/src/include/sleipnir/optimization/SolverConfig.hpp @@ -0,0 +1,54 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include + +#include "sleipnir/util/SymbolExports.hpp" + +namespace sleipnir { + +/** + * Solver configuration. + */ +struct SLEIPNIR_DLLEXPORT SolverConfig { + /// The solver will stop once the error is below this tolerance. + double tolerance = 1e-8; + + /// The maximum number of solver iterations before returning a solution. + int maxIterations = 5000; + + /// The solver will stop once the error is below this tolerance for + /// `acceptableIterations` iterations. This is useful in cases where the + /// solver might not be able to achieve the desired level of accuracy due to + /// floating-point round-off. + double acceptableTolerance = 1e-6; + + /// The solver will stop once the error is below `acceptableTolerance` for + /// this many iterations. + int maxAcceptableIterations = 15; + + /// The maximum elapsed wall clock time before returning a solution. + std::chrono::duration timeout{ + std::numeric_limits::infinity()}; + + /// Enables the feasible interior-point method. When the inequality + /// constraints are all feasible, step sizes are reduced when necessary to + /// prevent them becoming infeasible again. This is useful when parts of the + /// problem are ill-conditioned in infeasible regions (e.g., square root of a + /// negative value). This can slow or prevent progress toward a solution + /// though, so only enable it if necessary. + bool feasibleIPM = false; + + /// Enables diagnostic prints. + bool diagnostics = false; + + /// Enables writing sparsity patterns of H, Aₑ, and Aᵢ to files named H.spy, + /// A_e.spy, and A_i.spy respectively during solve. + /// + /// Use tools/spy.py to plot them. + bool spy = false; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/optimization/SolverExitCondition.hpp b/sleipnir/src/include/sleipnir/optimization/SolverExitCondition.hpp new file mode 100644 index 0000000..7d14452 --- /dev/null +++ b/sleipnir/src/include/sleipnir/optimization/SolverExitCondition.hpp @@ -0,0 +1,80 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include + +#include "sleipnir/util/SymbolExports.hpp" + +namespace sleipnir { + +/** + * Solver exit condition. + */ +enum class SolverExitCondition : int8_t { + /// Solved the problem to the desired tolerance. + kSuccess = 0, + /// Solved the problem to an acceptable tolerance, but not the desired one. + kSolvedToAcceptableTolerance = 1, + /// The solver returned its solution so far after the user requested a stop. + kCallbackRequestedStop = 2, + /// The solver determined the problem to be overconstrained and gave up. + kTooFewDOFs = -1, + /// The solver determined the problem to be locally infeasible and gave up. + kLocallyInfeasible = -2, + /// The solver failed to reach the desired tolerance, and feasibility + /// restoration failed to converge. + kFeasibilityRestorationFailed = -3, + /// The solver encountered nonfinite initial cost or constraints and gave up. + kNonfiniteInitialCostOrConstraints = -4, + /// The solver encountered diverging primal iterates xₖ and/or sₖ and gave up. + kDivergingIterates = -5, + /// The solver returned its solution so far after exceeding the maximum number + /// of iterations. + kMaxIterationsExceeded = -6, + /// The solver returned its solution so far after exceeding the maximum + /// elapsed wall clock time. + kTimeout = -7 +}; + +/** + * Returns user-readable message corresponding to the exit condition. + * + * @param exitCondition Solver exit condition. + */ +SLEIPNIR_DLLEXPORT constexpr std::string_view ToMessage( + const SolverExitCondition& exitCondition) { + using enum SolverExitCondition; + + switch (exitCondition) { + case kSuccess: + return "solved to desired tolerance"; + case kSolvedToAcceptableTolerance: + return "solved to acceptable tolerance"; + case kCallbackRequestedStop: + return "callback requested stop"; + case kTooFewDOFs: + return "problem has too few degrees of freedom"; + case kLocallyInfeasible: + return "problem is locally infeasible"; + case kFeasibilityRestorationFailed: + return "solver failed to reach the desired tolerance, and feasibility " + "restoration failed to converge"; + case kNonfiniteInitialCostOrConstraints: + return "solver encountered nonfinite initial cost or constraints and " + "gave up"; + case kDivergingIterates: + return "solver encountered diverging primal iterates xₖ and/or sₖ and " + "gave up"; + case kMaxIterationsExceeded: + return "solution returned after maximum iterations exceeded"; + case kTimeout: + return "solution returned after maximum wall clock time exceeded"; + default: + return "unknown"; + } +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/optimization/SolverIterationInfo.hpp b/sleipnir/src/include/sleipnir/optimization/SolverIterationInfo.hpp new file mode 100644 index 0000000..bb915b8 --- /dev/null +++ b/sleipnir/src/include/sleipnir/optimization/SolverIterationInfo.hpp @@ -0,0 +1,36 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include + +namespace sleipnir { + +/** + * Solver iteration information exposed to a user callback. + */ +struct SolverIterationInfo { + /// The solver iteration. + int iteration; + + /// The decision variables. + const Eigen::VectorXd& x; + + /// The inequality constraint slack variables. + const Eigen::VectorXd& s; + + /// The gradient of the cost function. + const Eigen::SparseVector& g; + + /// The Hessian of the Lagrangian. + const Eigen::SparseMatrix& H; + + /// The equality constraint Jacobian. + const Eigen::SparseMatrix& A_e; + + /// The inequality constraint Jacobian. + const Eigen::SparseMatrix& A_i; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/optimization/SolverStatus.hpp b/sleipnir/src/include/sleipnir/optimization/SolverStatus.hpp new file mode 100644 index 0000000..122941c --- /dev/null +++ b/sleipnir/src/include/sleipnir/optimization/SolverStatus.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include "sleipnir/autodiff/ExpressionType.hpp" +#include "sleipnir/optimization/SolverExitCondition.hpp" +#include "sleipnir/util/SymbolExports.hpp" + +namespace sleipnir { + +/** + * Return value of OptimizationProblem::Solve() containing the cost function and + * constraint types and solver's exit condition. + */ +struct SLEIPNIR_DLLEXPORT SolverStatus { + /// The cost function type detected by the solver. + ExpressionType costFunctionType = ExpressionType::kNone; + + /// The equality constraint type detected by the solver. + ExpressionType equalityConstraintType = ExpressionType::kNone; + + /// The inequality constraint type detected by the solver. + ExpressionType inequalityConstraintType = ExpressionType::kNone; + + /// The solver's exit condition. + SolverExitCondition exitCondition = SolverExitCondition::kSuccess; + + /// The solution's cost. + double cost = 0.0; +}; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/optimization/solver/InteriorPoint.hpp b/sleipnir/src/include/sleipnir/optimization/solver/InteriorPoint.hpp new file mode 100644 index 0000000..51d8f97 --- /dev/null +++ b/sleipnir/src/include/sleipnir/optimization/solver/InteriorPoint.hpp @@ -0,0 +1,55 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include + +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/optimization/SolverConfig.hpp" +#include "sleipnir/optimization/SolverIterationInfo.hpp" +#include "sleipnir/optimization/SolverStatus.hpp" +#include "sleipnir/util/FunctionRef.hpp" +#include "sleipnir/util/SymbolExports.hpp" + +namespace sleipnir { + +/** +Finds the optimal solution to a nonlinear program using the interior-point +method. + +A nonlinear program has the form: + +@verbatim + min_x f(x) +subject to cₑ(x) = 0 + cᵢ(x) ≥ 0 +@endverbatim + +where f(x) is the cost function, cₑ(x) are the equality constraints, and cᵢ(x) +are the inequality constraints. + +@param[in] decisionVariables The list of decision variables. +@param[in] equalityConstraints The list of equality constraints. +@param[in] inequalityConstraints The list of inequality constraints. +@param[in] f The cost function. +@param[in] callback The user callback. +@param[in] config Configuration options for the solver. +@param[in] feasibilityRestoration Whether to use feasibility restoration instead + of the normal algorithm. +@param[in,out] x The initial guess and output location for the decision + variables. +@param[in,out] s The initial guess and output location for the inequality + constraint slack variables. +@param[out] status The solver status. +*/ +SLEIPNIR_DLLEXPORT void InteriorPoint( + std::span decisionVariables, + std::span equalityConstraints, + std::span inequalityConstraints, Variable& f, + function_ref callback, + const SolverConfig& config, bool feasibilityRestoration, Eigen::VectorXd& x, + Eigen::VectorXd& s, SolverStatus* status); + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/util/Assert.hpp b/sleipnir/src/include/sleipnir/util/Assert.hpp new file mode 100644 index 0000000..ba381ef --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/Assert.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#ifdef JORMUNGANDR +#include +#include +/** + * Throw an exception in Python. + */ +#define Assert(condition) \ + do { \ + if (!(condition)) { \ + throw std::invalid_argument( \ + std::format("{}:{}: {}: Assertion `{}' failed.", __FILE__, __LINE__, \ + __func__, #condition)); \ + } \ + } while (0); +#else +#include +/** + * Abort in C++. + */ +#define Assert(condition) assert(condition) +#endif diff --git a/sleipnir/src/include/sleipnir/util/Concepts.hpp b/sleipnir/src/include/sleipnir/util/Concepts.hpp new file mode 100644 index 0000000..653200e --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/Concepts.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include + +#include + +#include "sleipnir/autodiff/Variable.hpp" +#include "sleipnir/autodiff/VariableMatrix.hpp" + +namespace sleipnir { + +template +concept ScalarLike = std::same_as || std::same_as || + std::same_as; + +template +concept SleipnirMatrixLike = std::same_as || + std::same_as>; + +template +concept EigenMatrixLike = + std::derived_from>; + +template +concept EigenSolver = requires(T t) { t.solve(Eigen::VectorXd{}); }; + +template +concept MatrixLike = SleipnirMatrixLike || EigenMatrixLike; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/util/FunctionRef.hpp b/sleipnir/src/include/sleipnir/util/FunctionRef.hpp new file mode 100644 index 0000000..14a4690 --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/FunctionRef.hpp @@ -0,0 +1,100 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include +#include + +namespace sleipnir { + +/** + * An implementation of std::function_ref, a lightweight non-owning reference to + * a callable. + */ +template +class function_ref; + +template +class function_ref { + public: + constexpr function_ref() noexcept = delete; + + /** + * Creates a `function_ref` which refers to the same callable as `rhs`. + */ + constexpr function_ref(const function_ref& rhs) noexcept = + default; + + /** + * Constructs a `function_ref` referring to `f`. + */ + template + requires(!std::is_same_v, function_ref> && + std::is_invocable_r_v) + constexpr function_ref(F&& f) noexcept // NOLINT(google-explicit-constructor) + : obj_(const_cast( + reinterpret_cast(std::addressof(f)))) { + callback_ = [](void* obj, Args... args) -> R { + return std::invoke( + *reinterpret_cast::type>(obj), + std::forward(args)...); + }; + } + + /** + * Makes `*this` refer to the same callable as `rhs`. + */ + constexpr function_ref& operator=( + const function_ref& rhs) noexcept = default; + + /** + * Makes `*this` refer to `f`. + */ + template + requires std::is_invocable_r_v + constexpr function_ref& operator=(F&& f) noexcept { + obj_ = reinterpret_cast(std::addressof(f)); + callback_ = [](void* obj, Args... args) { + return std::invoke( + *reinterpret_cast::type>(obj), + std::forward(args)...); + }; + + return *this; + } + + /** + * Swaps the referred callables of `*this` and `rhs`. + */ + constexpr void swap(function_ref& rhs) noexcept { + std::swap(obj_, rhs.obj_); + std::swap(callback_, rhs.callback_); + } + + /** + * Call the stored callable with the given arguments. + */ + R operator()(Args... args) const { + return callback_(obj_, std::forward(args)...); + } + + private: + void* obj_ = nullptr; + R (*callback_)(void*, Args...) = nullptr; +}; + +/** + * Swaps the referred callables of `lhs` and `rhs`. + */ +template +constexpr void swap(function_ref& lhs, + function_ref& rhs) noexcept { + lhs.swap(rhs); +} + +template +function_ref(R (*)(Args...)) -> function_ref; + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/util/IntrusiveSharedPtr.hpp b/sleipnir/src/include/sleipnir/util/IntrusiveSharedPtr.hpp new file mode 100644 index 0000000..f1290e5 --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/IntrusiveSharedPtr.hpp @@ -0,0 +1,216 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include + +namespace sleipnir { + +/** + * A custom intrusive shared pointer implementation without thread + * synchronization overhead. + * + * Types used with this class should have three things: + * + * 1. A zero-initialized public counter variable that serves as the shared + * pointer's reference count. + * 2. A free function `void IntrusiveSharedPtrIncRefCount(T*)` that increments + * the reference count. + * 3. A free function `void IntrusiveSharedPtrDecRefCount(T*)` that decrements + * the reference count and deallocates the pointed to object if the reference + * count reaches zero. + * + * @tparam T The type of the object to be reference counted. + */ +template +class IntrusiveSharedPtr { + public: + /** + * Constructs an empty intrusive shared pointer. + */ + constexpr IntrusiveSharedPtr() noexcept = default; + + /** + * Constructs an empty intrusive shared pointer. + */ + constexpr IntrusiveSharedPtr(std::nullptr_t) noexcept {} // NOLINT + + /** + * Constructs an intrusive shared pointer from the given pointer and takes + * ownership. + */ + explicit constexpr IntrusiveSharedPtr(T* ptr) noexcept : m_ptr{ptr} { + if (m_ptr != nullptr) { + IntrusiveSharedPtrIncRefCount(m_ptr); + } + } + + constexpr ~IntrusiveSharedPtr() { + if (m_ptr != nullptr) { + IntrusiveSharedPtrDecRefCount(m_ptr); + } + } + + /** + * Copy constructs from the given intrusive shared pointer. + */ + constexpr IntrusiveSharedPtr(const IntrusiveSharedPtr& rhs) noexcept + : m_ptr{rhs.m_ptr} { + if (m_ptr != nullptr) { + IntrusiveSharedPtrIncRefCount(m_ptr); + } + } + + /** + * Makes a copy of the given intrusive shared pointer. + */ + constexpr IntrusiveSharedPtr& operator=( // NOLINT + const IntrusiveSharedPtr& rhs) noexcept { + if (m_ptr == rhs.m_ptr) { + return *this; + } + + if (m_ptr != nullptr) { + IntrusiveSharedPtrDecRefCount(m_ptr); + } + + m_ptr = rhs.m_ptr; + + if (m_ptr != nullptr) { + IntrusiveSharedPtrIncRefCount(m_ptr); + } + + return *this; + } + + /** + * Move constructs from the given intrusive shared pointer. + */ + constexpr IntrusiveSharedPtr(IntrusiveSharedPtr&& rhs) noexcept + : m_ptr{std::exchange(rhs.m_ptr, nullptr)} {} + + /** + * Move assigns from the given intrusive shared pointer. + */ + constexpr IntrusiveSharedPtr& operator=( + IntrusiveSharedPtr&& rhs) noexcept { + if (m_ptr == rhs.m_ptr) { + return *this; + } + + std::swap(m_ptr, rhs.m_ptr); + + return *this; + } + + /** + * Returns the internal pointer. + */ + constexpr T* Get() const noexcept { return m_ptr; } + + /** + * Returns the object pointed to by the internal pointer. + */ + constexpr T& operator*() const noexcept { return *m_ptr; } + + /** + * Returns the internal pointer. + */ + constexpr T* operator->() const noexcept { return m_ptr; } + + /** + * Returns true if the internal pointer isn't nullptr. + */ + explicit constexpr operator bool() const noexcept { return m_ptr != nullptr; } + + /** + * Returns true if the given intrusive shared pointers point to the same + * object. + */ + friend constexpr bool operator==(const IntrusiveSharedPtr& lhs, + const IntrusiveSharedPtr& rhs) noexcept { + return lhs.m_ptr == rhs.m_ptr; + } + + /** + * Returns true if the given intrusive shared pointers point to different + * objects. + */ + friend constexpr bool operator!=(const IntrusiveSharedPtr& lhs, + const IntrusiveSharedPtr& rhs) noexcept { + return lhs.m_ptr != rhs.m_ptr; + } + + /** + * Returns true if the left-hand intrusive shared pointer points to nullptr. + */ + friend constexpr bool operator==(const IntrusiveSharedPtr& lhs, + std::nullptr_t) noexcept { + return lhs.m_ptr == nullptr; + } + + /** + * Returns true if the right-hand intrusive shared pointer points to nullptr. + */ + friend constexpr bool operator==(std::nullptr_t, + const IntrusiveSharedPtr& rhs) noexcept { + return nullptr == rhs.m_ptr; + } + + /** + * Returns true if the left-hand intrusive shared pointer doesn't point to + * nullptr. + */ + friend constexpr bool operator!=(const IntrusiveSharedPtr& lhs, + std::nullptr_t) noexcept { + return lhs.m_ptr != nullptr; + } + + /** + * Returns true if the right-hand intrusive shared pointer doesn't point to + * nullptr. + */ + friend constexpr bool operator!=(std::nullptr_t, + const IntrusiveSharedPtr& rhs) noexcept { + return nullptr != rhs.m_ptr; + } + + private: + T* m_ptr = nullptr; +}; + +/** + * Constructs an object of type T and wraps it in an intrusive shared pointer + * using args as the parameter list for the constructor of T. + * + * @tparam T Type of object for intrusive shared pointer. + * @tparam Args Types of constructor arguments. + * @param args Constructor arguments for T. + */ +template +IntrusiveSharedPtr MakeIntrusiveShared(Args&&... args) { + return IntrusiveSharedPtr{new T(std::forward(args)...)}; +} + +/** + * Constructs an object of type T and wraps it in an intrusive shared pointer + * using alloc as the storage allocator of T and args as the parameter list for + * the constructor of T. + * + * @tparam T Type of object for intrusive shared pointer. + * @tparam Alloc Type of allocator for T. + * @tparam Args Types of constructor arguments. + * @param alloc The allocator for T. + * @param args Constructor arguments for T. + */ +template +IntrusiveSharedPtr AllocateIntrusiveShared(Alloc alloc, Args&&... args) { + auto ptr = std::allocator_traits::allocate(alloc, sizeof(T)); + std::allocator_traits::construct(alloc, ptr, + std::forward(args)...); + return IntrusiveSharedPtr{ptr}; +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/util/Pool.hpp b/sleipnir/src/include/sleipnir/util/Pool.hpp new file mode 100644 index 0000000..441fa70 --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/Pool.hpp @@ -0,0 +1,161 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include + +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * This class implements a pool memory resource. + * + * The pool allocates chunks of memory and splits them into blocks managed by a + * free list. Allocations return pointers from the free list, and deallocations + * return pointers to the free list. + */ +class SLEIPNIR_DLLEXPORT PoolResource { + public: + /** + * Constructs a default PoolResource. + * + * @param blocksPerChunk Number of blocks per chunk of memory. + */ + explicit PoolResource(size_t blocksPerChunk) + : blocksPerChunk{blocksPerChunk} {} + + PoolResource(const PoolResource&) = delete; + PoolResource& operator=(const PoolResource&) = delete; + PoolResource(PoolResource&&) = default; + PoolResource& operator=(PoolResource&&) = default; + + /** + * Returns a block of memory from the pool. + * + * @param bytes Number of bytes in the block. + * @param alignment Alignment of the block (unused). + */ + [[nodiscard]] + void* allocate(size_t bytes, [[maybe_unused]] size_t alignment = + alignof(std::max_align_t)) { + if (m_freeList.empty()) { + AddChunk(bytes); + } + + auto ptr = m_freeList.back(); + m_freeList.pop_back(); + return ptr; + } + + /** + * Gives a block of memory back to the pool. + * + * @param p A pointer to the block of memory. + * @param bytes Number of bytes in the block (unused). + * @param alignment Alignment of the block (unused). + */ + void deallocate( + void* p, [[maybe_unused]] size_t bytes, + [[maybe_unused]] size_t alignment = alignof(std::max_align_t)) { + m_freeList.emplace_back(p); + } + + /** + * Returns true if this pool resource has the same backing storage as another. + */ + bool is_equal(const PoolResource& other) const noexcept { + return this == &other; + } + + /** + * Returns the number of blocks from this pool resource that are in use. + */ + size_t blocks_in_use() const noexcept { + return m_buffer.size() * blocksPerChunk - m_freeList.size(); + } + + private: + small_vector> m_buffer; + small_vector m_freeList; + size_t blocksPerChunk; + + /** + * Adds a memory chunk to the pool, partitions it into blocks with the given + * number of bytes, and appends pointers to them to the free list. + * + * @param bytesPerBlock Number of bytes in the block. + */ + void AddChunk(size_t bytesPerBlock) { + m_buffer.emplace_back(new std::byte[bytesPerBlock * blocksPerChunk]); + for (int i = blocksPerChunk - 1; i >= 0; --i) { + m_freeList.emplace_back(m_buffer.back().get() + bytesPerBlock * i); + } + } +}; + +/** + * This class is an allocator for the pool resource. + * + * @tparam T The type of object in the pool. + */ +template +class PoolAllocator { + public: + /** + * The type of object in the pool. + */ + using value_type = T; + + /** + * Constructs a pool allocator with the given pool memory resource. + * + * @param r The pool resource. + */ + explicit constexpr PoolAllocator(PoolResource* r) : m_memoryResource{r} {} + + constexpr PoolAllocator(const PoolAllocator& other) = default; + constexpr PoolAllocator& operator=(const PoolAllocator&) = default; + + /** + * Returns a block of memory from the pool. + * + * @param n Number of bytes in the block. + */ + [[nodiscard]] + constexpr T* allocate(size_t n) { + return static_cast(m_memoryResource->allocate(n)); + } + + /** + * Gives a block of memory back to the pool. + * + * @param p A pointer to the block of memory. + * @param n Number of bytes in the block. + */ + constexpr void deallocate(T* p, size_t n) { + m_memoryResource->deallocate(p, n); + } + + private: + PoolResource* m_memoryResource; +}; + +/** + * Returns a global pool memory resource. + */ +SLEIPNIR_DLLEXPORT PoolResource& GlobalPoolResource(); + +/** + * Returns an allocator for a global pool memory resource. + * + * @tparam T The type of object in the pool. + */ +template +PoolAllocator GlobalPoolAllocator() { + return PoolAllocator{&GlobalPoolResource()}; +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/util/Print.hpp b/sleipnir/src/include/sleipnir/util/Print.hpp new file mode 100644 index 0000000..8541ddf --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/Print.hpp @@ -0,0 +1,55 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include + +namespace sleipnir { + +/** + * Wrapper around std::print() that squelches write failure exceptions. + */ +template +inline void print(fmt::format_string fmt, T&&... args) { + try { + fmt::print(fmt, std::forward(args)...); + } catch (const std::system_error&) { + } +} + +/** + * Wrapper around std::print() that squelches write failure exceptions. + */ +template +inline void print(std::FILE* f, fmt::format_string fmt, T&&... args) { + try { + fmt::print(f, fmt, std::forward(args)...); + } catch (const std::system_error&) { + } +} + +/** + * Wrapper around std::println() that squelches write failure exceptions. + */ +template +inline void println(fmt::format_string fmt, T&&... args) { + try { + fmt::println(fmt, std::forward(args)...); + } catch (const std::system_error&) { + } +} + +/** + * Wrapper around std::println() that squelches write failure exceptions. + */ +template +inline void println(std::FILE* f, fmt::format_string fmt, T&&... args) { + try { + fmt::println(f, fmt, std::forward(args)...); + } catch (const std::system_error&) { + } +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/util/Spy.hpp b/sleipnir/src/include/sleipnir/util/Spy.hpp new file mode 100644 index 0000000..cb9b4e1 --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/Spy.hpp @@ -0,0 +1,89 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#include +#include +#include + +#include + +#include "sleipnir/util/SymbolExports.hpp" +#include "sleipnir/util/small_vector.hpp" + +namespace sleipnir { + +/** + * Write the sparsity pattern of a sparse matrix to a file. + * + * Each character represents an element with '.' representing zero, '+' + * representing positive, and '-' representing negative. Here's an example for a + * 3x3 identity matrix. + * + * "+.." + * ".+." + * "..+" + * + * @param[out] file A file stream. + * @param[in] mat The sparse matrix. + */ +SLEIPNIR_DLLEXPORT inline void Spy(std::ostream& file, + const Eigen::SparseMatrix& mat) { + const int cells_width = mat.cols() + 1; + const int cells_height = mat.rows(); + + small_vector cells; + + // Allocate space for matrix of characters plus trailing newlines + cells.reserve(cells_width * cells_height); + + // Initialize cell array + for (int row = 0; row < mat.rows(); ++row) { + for (int col = 0; col < mat.cols(); ++col) { + cells.emplace_back('.'); + } + cells.emplace_back('\n'); + } + + // Fill in non-sparse entries + for (int k = 0; k < mat.outerSize(); ++k) { + for (Eigen::SparseMatrix::InnerIterator it{mat, k}; it; ++it) { + if (it.value() < 0.0) { + cells[it.row() * cells_width + it.col()] = '-'; + } else if (it.value() > 0.0) { + cells[it.row() * cells_width + it.col()] = '+'; + } + } + } + + // Write cell array to file + for (const auto& c : cells) { + file << c; + } +} + +/** + * Write the sparsity pattern of a sparse matrix to a file. + * + * Each character represents an element with "." representing zero, "+" + * representing positive, and "-" representing negative. Here's an example for a + * 3x3 identity matrix. + * + * "+.." + * ".+." + * "..+" + * + * @param[in] filename The filename. + * @param[in] mat The sparse matrix. + */ +SLEIPNIR_DLLEXPORT inline void Spy(std::string_view filename, + const Eigen::SparseMatrix& mat) { + std::ofstream file{std::string{filename}}; + if (!file.is_open()) { + return; + } + + Spy(file, mat); +} + +} // namespace sleipnir diff --git a/sleipnir/src/include/sleipnir/util/SymbolExports.hpp b/sleipnir/src/include/sleipnir/util/SymbolExports.hpp new file mode 100644 index 0000000..f826663 --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/SymbolExports.hpp @@ -0,0 +1,192 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +#ifdef _WIN32 +#ifdef _MSC_VER +#pragma warning(disable : 4251) +#endif + +#ifdef SLEIPNIR_EXPORTS +#ifdef __GNUC__ +#define SLEIPNIR_DLLEXPORT __attribute__((dllexport)) +#else +#define SLEIPNIR_DLLEXPORT __declspec(dllexport) +#endif + +#elif defined(SLEIPNIR_IMPORTS) + +#ifdef __GNUC__ +#define SLEIPNIR_DLLEXPORT __attribute__((dllimport)) +#else +#define SLEIPNIR_DLLEXPORT __declspec(dllimport) +#endif + +#else +#define SLEIPNIR_DLLEXPORT +#endif + +#else // _WIN32 + +#ifdef SLEIPNIR_EXPORTS +#define SLEIPNIR_DLLEXPORT __attribute__((visibility("default"))) +#else +#define SLEIPNIR_DLLEXPORT +#endif + +#endif // _WIN32 + +// Synopsis +// +// This header provides macros for using FOO_EXPORT macros with explicit +// template instantiation declarations and definitions. +// Generally, the FOO_EXPORT macros are used at declarations, +// and GCC requires them to be used at explicit instantiation declarations, +// but MSVC requires __declspec(dllexport) to be used at the explicit +// instantiation definitions instead. + +// Usage +// +// In a header file, write: +// +// extern template class EXPORT_TEMPLATE_DECLARE(FOO_EXPORT) foo; +// +// In a source file, write: +// +// template class EXPORT_TEMPLATE_DEFINE(FOO_EXPORT) foo; + +// Implementation notes +// +// The implementation of this header uses some subtle macro semantics to +// detect what the provided FOO_EXPORT value was defined as and then +// to dispatch to appropriate macro definitions. Unfortunately, +// MSVC's C preprocessor is rather non-compliant and requires special +// care to make it work. +// +// Issue 1. +// +// #define F(x) +// F() +// +// MSVC emits warning C4003 ("not enough actual parameters for macro +// 'F'), even though it's a valid macro invocation. This affects the +// macros below that take just an "export" parameter, because export +// may be empty. +// +// As a workaround, we can add a dummy parameter and arguments: +// +// #define F(x,_) +// F(,) +// +// Issue 2. +// +// #define F(x) G##x +// #define Gj() ok +// F(j()) +// +// The correct replacement for "F(j())" is "ok", but MSVC replaces it +// with "Gj()". As a workaround, we can pass the result to an +// identity macro to force MSVC to look for replacements again. (This +// is why EXPORT_TEMPLATE_STYLE_3 exists.) + +#define EXPORT_TEMPLATE_DECLARE(export) \ + EXPORT_TEMPLATE_INVOKE(DECLARE, EXPORT_TEMPLATE_STYLE(export, ), export) +#define EXPORT_TEMPLATE_DEFINE(export) \ + EXPORT_TEMPLATE_INVOKE(DEFINE, EXPORT_TEMPLATE_STYLE(export, ), export) + +// INVOKE is an internal helper macro to perform parameter replacements +// and token pasting to chain invoke another macro. E.g., +// EXPORT_TEMPLATE_INVOKE(DECLARE, DEFAULT, FOO_EXPORT) +// will export to call +// EXPORT_TEMPLATE_DECLARE_DEFAULT(FOO_EXPORT, ) +// (but with FOO_EXPORT expanded too). +#define EXPORT_TEMPLATE_INVOKE(which, style, export) \ + EXPORT_TEMPLATE_INVOKE_2(which, style, export) +#define EXPORT_TEMPLATE_INVOKE_2(which, style, export) \ + EXPORT_TEMPLATE_##which##_##style(export, ) + +// Default style is to apply the FOO_EXPORT macro at declaration sites. +#define EXPORT_TEMPLATE_DECLARE_DEFAULT(export, _) export +#define EXPORT_TEMPLATE_DEFINE_DEFAULT(export, _) + +// The "MSVC hack" style is used when FOO_EXPORT is defined +// as __declspec(dllexport), which MSVC requires to be used at +// definition sites instead. +#define EXPORT_TEMPLATE_DECLARE_MSVC_HACK(export, _) +#define EXPORT_TEMPLATE_DEFINE_MSVC_HACK(export, _) export + +// EXPORT_TEMPLATE_STYLE is an internal helper macro that identifies which +// export style needs to be used for the provided FOO_EXPORT macro definition. +// "", "__attribute__(...)", and "__declspec(dllimport)" are mapped +// to "DEFAULT"; while "__declspec(dllexport)" is mapped to "MSVC_HACK". +// +// It's implemented with token pasting to transform the __attribute__ and +// __declspec annotations into macro invocations. E.g., if FOO_EXPORT is +// defined as "__declspec(dllimport)", it undergoes the following sequence of +// macro substitutions: +// EXPORT_TEMPLATE_STYLE(FOO_EXPORT, ) +// EXPORT_TEMPLATE_STYLE_2(__declspec(dllimport), ) +// EXPORT_TEMPLATE_STYLE_3(EXPORT_TEMPLATE_STYLE_MATCH__declspec(dllimport)) +// EXPORT_TEMPLATE_STYLE_MATCH__declspec(dllimport) +// EXPORT_TEMPLATE_STYLE_MATCH_DECLSPEC_dllimport +// DEFAULT +#define EXPORT_TEMPLATE_STYLE(export, _) EXPORT_TEMPLATE_STYLE_2(export, ) +#define EXPORT_TEMPLATE_STYLE_2(export, _) \ + EXPORT_TEMPLATE_STYLE_3( \ + EXPORT_TEMPLATE_STYLE_MATCH_foj3FJo5StF0OvIzl7oMxA##export) +#define EXPORT_TEMPLATE_STYLE_3(style) style + +// Internal helper macros for EXPORT_TEMPLATE_STYLE. +// +// XXX: C++ reserves all identifiers containing "__" for the implementation, +// but "__attribute__" and "__declspec" already contain "__" and the token-paste +// operator can only add characters; not remove them. To minimize the risk of +// conflict with implementations, we include "foj3FJo5StF0OvIzl7oMxA" (a random +// 128-bit string, encoded in Base64) in the macro name. +#define EXPORT_TEMPLATE_STYLE_MATCH_foj3FJo5StF0OvIzl7oMxA DEFAULT +#define EXPORT_TEMPLATE_STYLE_MATCH_foj3FJo5StF0OvIzl7oMxA__attribute__(...) \ + DEFAULT +#define EXPORT_TEMPLATE_STYLE_MATCH_foj3FJo5StF0OvIzl7oMxA__declspec(arg) \ + EXPORT_TEMPLATE_STYLE_MATCH_DECLSPEC_##arg + +// Internal helper macros for EXPORT_TEMPLATE_STYLE. +#define EXPORT_TEMPLATE_STYLE_MATCH_DECLSPEC_dllexport MSVC_HACK +#define EXPORT_TEMPLATE_STYLE_MATCH_DECLSPEC_dllimport DEFAULT + +// Sanity checks. +// +// EXPORT_TEMPLATE_TEST uses the same macro invocation pattern as +// EXPORT_TEMPLATE_DECLARE and EXPORT_TEMPLATE_DEFINE do to check that they're +// working correctly. When they're working correctly, the sequence of macro +// replacements should go something like: +// +// EXPORT_TEMPLATE_TEST(DEFAULT, __declspec(dllimport)); +// +// static_assert(EXPORT_TEMPLATE_INVOKE(TEST_DEFAULT, +// EXPORT_TEMPLATE_STYLE(__declspec(dllimport), ), +// __declspec(dllimport)), "__declspec(dllimport)"); +// +// static_assert(EXPORT_TEMPLATE_INVOKE(TEST_DEFAULT, +// DEFAULT, __declspec(dllimport)), "__declspec(dllimport)"); +// +// static_assert(EXPORT_TEMPLATE_TEST_DEFAULT_DEFAULT( +// __declspec(dllimport)), "__declspec(dllimport)"); +// +// static_assert(true, "__declspec(dllimport)"); +// +// When they're not working correctly, a syntax error should occur instead. +#define EXPORT_TEMPLATE_TEST(want, export) \ + static_assert(EXPORT_TEMPLATE_INVOKE( \ + TEST_##want, EXPORT_TEMPLATE_STYLE(export, ), export), \ + #export) +#define EXPORT_TEMPLATE_TEST_DEFAULT_DEFAULT(...) true +#define EXPORT_TEMPLATE_TEST_MSVC_HACK_MSVC_HACK(...) true + +EXPORT_TEMPLATE_TEST(DEFAULT, ); +EXPORT_TEMPLATE_TEST(DEFAULT, __attribute__((visibility("default")))); +EXPORT_TEMPLATE_TEST(MSVC_HACK, __declspec(dllexport)); +EXPORT_TEMPLATE_TEST(DEFAULT, __declspec(dllimport)); + +#undef EXPORT_TEMPLATE_TEST +#undef EXPORT_TEMPLATE_TEST_DEFAULT_DEFAULT +#undef EXPORT_TEMPLATE_TEST_MSVC_HACK_MSVC_HACK diff --git a/sleipnir/src/include/sleipnir/util/small_vector.hpp b/sleipnir/src/include/sleipnir/util/small_vector.hpp new file mode 100644 index 0000000..fa10479 --- /dev/null +++ b/sleipnir/src/include/sleipnir/util/small_vector.hpp @@ -0,0 +1,4380 @@ +/** small_vector.hpp + * An implementation of `small_vector` (a vector with a small + * buffer optimization). + * + * Copyright © 2020-2021 Gene Harvey + * + * This software may be modified and distributed under the terms + * of the MIT license. See LICENSE_small_vector.txt for details. + * + * Source: https://github.com/gharveymn/small_vector + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sleipnir { + +namespace concepts { + +template +concept Complete = requires { sizeof(T); }; + +// Note: this mirrors the named requirements, not the standard concepts, so we +// don't require the destructor to be noexcept for Destructible. +template +concept Destructible = std::is_destructible_v; + +template +concept TriviallyDestructible = std::is_trivially_destructible_v; + +template +concept NoThrowDestructible = std::is_nothrow_destructible_v; + +// Note: this mirrors the named requirements, not the standard library concepts, +// so we don't require Destructible here. + +template +concept ConstructibleFrom = std::is_constructible_v; + +template +concept NoThrowConstructibleFrom = std::is_nothrow_constructible_v; + +template +concept ConvertibleTo = + std::is_convertible_v && + requires(typename std::add_rvalue_reference_t (&f)()) { + static_cast(f()); + }; + +template +concept NoThrowConvertibleTo = + std::is_nothrow_convertible_v && + requires(typename std::add_rvalue_reference_t (&f)() noexcept) { + { static_cast(f()) } noexcept; + }; + +// Note: std::default_initializable requires std::destructible. +template +concept DefaultConstructible = ConstructibleFrom && requires { + T{}; +} && requires { ::new (static_cast(nullptr)) T; }; + +template +concept MoveAssignable = std::assignable_from; + +template +concept CopyAssignable = + MoveAssignable && std::assignable_from && + std::assignable_from && std::assignable_from; + +template +concept MoveConstructible = ConstructibleFrom && ConvertibleTo; + +template +concept NoThrowMoveConstructible = + NoThrowConstructibleFrom && NoThrowConvertibleTo; + +template +concept CopyConstructible = + MoveConstructible && ConstructibleFrom && ConvertibleTo && + ConstructibleFrom && ConvertibleTo && + ConstructibleFrom && ConvertibleTo; + +template +concept NoThrowCopyConstructible = + NoThrowMoveConstructible && NoThrowConstructibleFrom && + NoThrowConvertibleTo && NoThrowConstructibleFrom && + NoThrowConvertibleTo && NoThrowConstructibleFrom && + NoThrowConvertibleTo; + +template +concept Swappable = std::swappable; + +template +concept EqualityComparable = std::equality_comparable; + +// T is a type +// X is a Container +// A is an Allocator +// if X::allocator_type then +// std::same_as::template rebind_alloc> +// otherwise +// no condition; we use std::allocator regardless of A +// +// see [22.2.1].16 +template +concept EmplaceConstructible = + std::same_as && + // only perform this check if X is allocator-aware + ( + ( + requires { typename X::allocator_type; } && + std::same_as::template rebind_alloc> && + ( + requires (A m, T* p, Args&&... args) { + m.construct(p, std::forward(args)...); + } || + requires (T* p, Args&&... args) { + { std::construct_at(p, std::forward(args)...) } -> std::same_as; + } + ) + ) || + ( + !requires { typename X::allocator_type; } && + requires (T* p, Args&&... args) { + { std::construct_at(p, std::forward (args)...) } -> std::same_as; + } + ) + ); + +template >> +concept DefaultInsertable = EmplaceConstructible; + +template >> +concept MoveInsertable = EmplaceConstructible; + +template >> +concept CopyInsertable = + MoveInsertable && EmplaceConstructible && + EmplaceConstructible; + +// same method as with EmplaceConstructible +template >> +concept Erasable = + std::same_as && + ((requires { typename X::allocator_type; } // if X is allocator aware + && std::same_as< + typename X::allocator_type, + typename std::allocator_traits::template rebind_alloc> && + (requires(A m, T* p) { m.destroy(p); } || std::is_destructible_v)) || + (!requires { typename X::allocator_type; } && std::is_destructible_v)); + +template +concept ContextuallyConvertibleToBool = std::constructible_from; + +template +concept BoolConstant = std::derived_from || + std::derived_from; + +template +concept NullablePointer = + EqualityComparable && DefaultConstructible && CopyConstructible && + CopyAssignable && Destructible && + ConstructibleFrom && ConvertibleTo && + requires(T p, T q, std::nullptr_t np) { + T(np); + { p = np } -> std::same_as; + { p != q } -> ContextuallyConvertibleToBool; + { p == np } -> ContextuallyConvertibleToBool; + { np == p } -> ContextuallyConvertibleToBool; + { p != np } -> ContextuallyConvertibleToBool; + { np != p } -> ContextuallyConvertibleToBool; + }; + +static_assert(NullablePointer); +static_assert(!NullablePointer); + +template +concept AllocatorFor = + NoThrowCopyConstructible && + requires(A a, typename std::allocator_traits::template rebind_alloc b, + U xp, typename std::allocator_traits::pointer p, + typename std::allocator_traits::const_pointer cp, + typename std::allocator_traits::void_pointer vp, + typename std::allocator_traits::const_void_pointer cvp, + typename std::allocator_traits::value_type& r, + typename std::allocator_traits::size_type n) { + /** Inner types **/ + // A::pointer + requires NullablePointer::pointer>; + requires std::random_access_iterator< + typename std::allocator_traits::pointer>; + requires std::contiguous_iterator< + typename std::allocator_traits::pointer>; + + // A::const_pointer + requires NullablePointer< + typename std::allocator_traits::const_pointer>; + requires std::random_access_iterator< + typename std::allocator_traits::const_pointer>; + requires std::contiguous_iterator< + typename std::allocator_traits::const_pointer>; + + requires std::convertible_to< + typename std::allocator_traits::pointer, + typename std::allocator_traits::const_pointer>; + + // A::void_pointer + requires NullablePointer::void_pointer>; + + requires std::convertible_to< + typename std::allocator_traits::pointer, + typename std::allocator_traits::void_pointer>; + + requires std::same_as< + typename std::allocator_traits::void_pointer, + typename std::allocator_traits::void_pointer>; + + // A::const_void_pointer + requires NullablePointer< + typename std::allocator_traits::const_void_pointer>; + + requires std::convertible_to< + typename std::allocator_traits::pointer, + typename std::allocator_traits::const_void_pointer>; + + requires std::convertible_to< + typename std::allocator_traits::const_pointer, + typename std::allocator_traits::const_void_pointer>; + + requires std::convertible_to< + typename std::allocator_traits::void_pointer, + typename std::allocator_traits::const_void_pointer>; + + requires std::same_as< + typename std::allocator_traits::const_void_pointer, + typename std::allocator_traits::const_void_pointer>; + + // A::value_type + typename A::value_type; + requires std::same_as; + requires std::same_as::value_type>; + + // A::size_type + requires std::unsigned_integral< + typename std::allocator_traits::size_type>; + + // A::difference_type + requires std::signed_integral< + typename std::allocator_traits::difference_type>; + + // A::template rebind::other [optional] + requires !requires { + typename A::template rebind::other; + } || requires { + requires std::same_as::other>; + requires std::same_as::other>; + }; + + /** Operations on pointers **/ + { *p } -> std::same_as; + { *cp } -> std::same_as; + + // Language in the standard implies that `decltype (p)` must either + // be a raw pointer or implement `operator->`. There is no mention + // of `std::to_address` or `std::pointer_traits::to_address`. + requires std::same_as || requires { + { p.operator->() } -> std::same_as; + }; + + requires std::same_as || + requires { + { + cp.operator->() + } -> std::same_as; + }; + + { static_cast(vp) } -> std::same_as; + { static_cast(cvp) } -> std::same_as; + + { + std::pointer_traits::pointer_to(r) + } -> std::same_as; + + /** Storage and lifetime operations **/ + // a.allocate (n) + { a.allocate(n) } -> std::same_as; + + // a.allocate (n, cvp) [optional] + requires !requires { a.allocate(n, cvp); } || requires { + { a.allocate(n, cvp) } -> std::same_as; + }; + + // a.deallocate (p, n) + { a.deallocate(p, n) } -> std::convertible_to; + + // a.max_size () [optional] + requires !requires { a.max_size(); } || requires { + { a.max_size() } -> std::same_as; + }; + + // a.construct (xp, args) [optional] + requires !requires { a.construct(xp); } || requires { + { a.construct(xp) } -> std::convertible_to; + }; + + // a.destroy (xp) [optional] + requires !requires { a.destroy(xp); } || requires { + { a.destroy(xp) } -> std::convertible_to; + }; + + /** Relationship between instances **/ + requires NoThrowConstructibleFrom; + requires NoThrowConstructibleFrom; + + requires BoolConstant::is_always_equal>; + + /** Influence on container operations **/ + // a.select_on_container_copy_construction () [optional] + requires !requires { a.select_on_container_copy_construction(); } || + requires { + { + a.select_on_container_copy_construction() + } -> std::same_as; + }; + + requires BoolConstant::propagate_on_container_copy_assignment>; + + requires BoolConstant::propagate_on_container_move_assignment>; + + requires BoolConstant< + typename std::allocator_traits::propagate_on_container_swap>; + + { a == b } -> std::same_as; + { a != b } -> std::same_as; + } && + requires(A a1, A a2) { + { a1 == a2 } -> std::same_as; + { a1 != a2 } -> std::same_as; + }; + +static_assert( + AllocatorFor, int>, + "std::allocator failed to meet Allocator concept requirements."); + +template +concept Allocator = AllocatorFor; + +namespace small_vector { + +// Basically, these shut off the concepts if we have an incomplete type. +// This namespace is only needed because of issues on Clang +// preventing us from short-circuiting for incomplete types. + +template +concept Destructible = !concepts::Complete || concepts::Destructible; + +template +concept MoveAssignable = !concepts::Complete || concepts::MoveAssignable; + +template +concept CopyAssignable = !concepts::Complete || concepts::CopyAssignable; + +template +concept MoveConstructible = + !concepts::Complete || concepts::MoveConstructible; + +template +concept CopyConstructible = + !concepts::Complete || concepts::CopyConstructible; + +template +concept Swappable = !concepts::Complete || concepts::Swappable; + +template +concept DefaultInsertable = !concepts::Complete || + concepts::DefaultInsertable; + +template +concept MoveInsertable = + !concepts::Complete || concepts::MoveInsertable; + +template +concept CopyInsertable = + !concepts::Complete || concepts::CopyInsertable; + +template +concept Erasable = + !concepts::Complete || concepts::Erasable; + +template +concept EmplaceConstructible = + !concepts::Complete || + concepts::EmplaceConstructible; + +template +concept AllocatorFor = + !concepts::Complete || concepts::AllocatorFor; + +template +concept Allocator = AllocatorFor; + +} // namespace small_vector + +} // namespace concepts + +template + requires concepts::small_vector::Allocator +struct default_buffer_size; + +template >::value, + typename Allocator = std::allocator> + requires concepts::small_vector::AllocatorFor +class small_vector; + +template + requires concepts::small_vector::Allocator +struct default_buffer_size { + private: + template + struct is_complete : std::false_type {}; + + template + struct is_complete(sizeof(U)))> + : std::true_type {}; + + template + inline static constexpr bool is_complete_v = is_complete::value; + + public: + using allocator_type = Allocator; + using value_type = typename std::allocator_traits::value_type; + using empty_small_vector = small_vector; + + static_assert(is_complete_v, + "Calculation of a default number of elements requires that `T` " + "be complete."); + + static constexpr unsigned buffer_max = 256; + + static constexpr unsigned ideal_total = 64; + + static constexpr unsigned ideal_buffer = + ideal_total - sizeof(empty_small_vector); + + static_assert(sizeof(empty_small_vector) != 0, + "Empty `small_vector` should not have size 0."); + + static_assert(ideal_buffer < ideal_total, + "Empty `small_vector` is larger than ideal_total."); + + static constexpr unsigned value = (sizeof(value_type) <= ideal_buffer) + ? (ideal_buffer / sizeof(value_type)) + : 1; +}; + +template +inline constexpr unsigned default_buffer_size_v = + default_buffer_size::value; + +template +class small_vector_iterator { + public: + using difference_type = DifferenceType; + using value_type = typename std::iterator_traits::value_type; + using pointer = typename std::iterator_traits::pointer; + using reference = typename std::iterator_traits::reference; + using iterator_category = + typename std::iterator_traits::iterator_category; + using iterator_concept = std::contiguous_iterator_tag; + + small_vector_iterator(const small_vector_iterator&) = default; + small_vector_iterator(small_vector_iterator&&) noexcept = default; + small_vector_iterator& operator=(const small_vector_iterator&) = default; + small_vector_iterator& operator=(small_vector_iterator&&) noexcept = default; + ~small_vector_iterator() = default; + +#ifdef NDEBUG + small_vector_iterator() = default; +#else + constexpr small_vector_iterator() noexcept : m_ptr() {} +#endif + + constexpr explicit small_vector_iterator(const Pointer& p) noexcept + : m_ptr(p) {} + + template + requires std::is_convertible_v + constexpr small_vector_iterator( // NOLINT + const small_vector_iterator& other) noexcept + : m_ptr(other.base()) {} + + constexpr small_vector_iterator& operator++() noexcept { + ++m_ptr; + return *this; + } + + constexpr small_vector_iterator operator++(int) noexcept { + return small_vector_iterator(m_ptr++); + } + + constexpr small_vector_iterator& operator--() noexcept { + --m_ptr; + return *this; + } + + constexpr small_vector_iterator operator--(int) noexcept { + return small_vector_iterator(m_ptr--); + } + + constexpr small_vector_iterator& operator+=(difference_type n) noexcept { + m_ptr += n; + return *this; + } + + constexpr small_vector_iterator operator+(difference_type n) const noexcept { + return small_vector_iterator(m_ptr + n); + } + + constexpr small_vector_iterator& operator-=(difference_type n) noexcept { + m_ptr -= n; + return *this; + } + + constexpr small_vector_iterator operator-(difference_type n) const noexcept { + return small_vector_iterator(m_ptr - n); + } + + constexpr reference operator*() const noexcept { + return launder_and_dereference(m_ptr); + } + + constexpr pointer operator->() const noexcept { return get_pointer(m_ptr); } + + constexpr reference operator[](difference_type n) const noexcept { + return launder_and_dereference(m_ptr + n); + } + + constexpr const Pointer& base() const noexcept { return m_ptr; } + + private: + static constexpr pointer get_pointer(Pointer ptr) noexcept + requires std::is_pointer_v + { + return ptr; + } + + static constexpr pointer get_pointer(Pointer ptr) noexcept + requires(!std::is_pointer_v) + { + // Given the requirements for Allocator, Pointer must either be a raw + // pointer, or have a defined operator-> which returns a raw pointer. + return ptr.operator->(); + } + + static constexpr reference launder_and_dereference(Pointer ptr) noexcept + requires std::is_pointer_v + { + return *std::launder(ptr); + } + + static constexpr reference launder_and_dereference(Pointer ptr) noexcept + requires(!std::is_pointer_v) + { + return *ptr; + } + + Pointer m_ptr; +}; + +template +constexpr bool operator==( + const small_vector_iterator& lhs, + const small_vector_iterator& + rhs) noexcept(noexcept(lhs.base() == rhs.base())) + requires requires { + { lhs.base() == rhs.base() } -> std::convertible_to; + } +{ + return lhs.base() == rhs.base(); +} + +template +constexpr bool operator==( + const small_vector_iterator& lhs, + const small_vector_iterator& + rhs) noexcept(noexcept(lhs.base() == rhs.base())) + requires requires { + { lhs.base() == rhs.base() } -> std::convertible_to; + } +{ + return lhs.base() == rhs.base(); +} + +template + requires std::three_way_comparable_with +constexpr auto operator<=>( + const small_vector_iterator& lhs, + const small_vector_iterator& + rhs) noexcept(noexcept(lhs.base() <=> rhs.base())) { + return lhs.base() <=> rhs.base(); +} + +template + requires std::three_way_comparable +constexpr auto operator<=>( + const small_vector_iterator& lhs, + const small_vector_iterator& + rhs) noexcept(noexcept(lhs.base() <=> rhs.base())) { + return lhs.base() <=> rhs.base(); +} + +template +constexpr auto operator<=>( + const small_vector_iterator& lhs, + const small_vector_iterator& + rhs) noexcept(noexcept(lhs.base() < rhs.base()) && + noexcept(rhs.base() < lhs.base())) { + using ordering = std::weak_ordering; + return (lhs.base() < rhs.base()) ? ordering::less + : (rhs.base() < lhs.base()) ? ordering::greater + : ordering::equivalent; +} + +template +constexpr auto operator<=>( + const small_vector_iterator& lhs, + const small_vector_iterator& + rhs) noexcept(noexcept(lhs.base() < rhs.base()) && + noexcept(rhs.base() < lhs.base())) { + using ordering = std::weak_ordering; + return (lhs.base() < rhs.base()) ? ordering::less + : (rhs.base() < lhs.base()) ? ordering::greater + : ordering::equivalent; +} + +template +constexpr DifferenceType operator-( + const small_vector_iterator& lhs, + const small_vector_iterator& rhs) noexcept { + return static_cast(lhs.base() - rhs.base()); +} + +template +constexpr DifferenceType operator-( + const small_vector_iterator& lhs, + const small_vector_iterator& rhs) noexcept { + return static_cast(lhs.base() - rhs.base()); +} + +template +constexpr small_vector_iterator operator+( + DifferenceType n, + const small_vector_iterator& it) noexcept { + return it + n; +} + +namespace detail { + +template +class inline_storage { + public: + using value_ty = T; + + inline_storage() = default; + inline_storage(const inline_storage&) = delete; + inline_storage(inline_storage&&) noexcept = delete; + inline_storage& operator=(const inline_storage&) = delete; + inline_storage& operator=(inline_storage&&) noexcept = delete; + ~inline_storage() = default; + + [[nodiscard]] + constexpr value_ty* get_inline_ptr() noexcept { + return static_cast(static_cast(std::addressof(*m_data))); + } + + [[nodiscard]] + constexpr const value_ty* get_inline_ptr() const noexcept { + return static_cast( + static_cast(std::addressof(*m_data))); + } + + static constexpr size_t element_size() noexcept { return sizeof(value_ty); } + + static constexpr size_t alignment() noexcept { return alignof(value_ty); } + + static constexpr unsigned num_elements() noexcept { return InlineCapacity; } + + static constexpr size_t num_bytes() noexcept { + return num_elements() * element_size(); + } + + private: + alignas(alignment()) std::byte m_data[element_size()][num_elements()]; +}; + +template +class inline_storage { + public: + using value_ty = T; + + inline_storage() = default; + inline_storage(const inline_storage&) = delete; + inline_storage(inline_storage&&) noexcept = delete; + inline_storage& operator=(const inline_storage&) = delete; + inline_storage& operator=(inline_storage&&) noexcept = delete; + ~inline_storage() = default; + + [[nodiscard]] + constexpr value_ty* get_inline_ptr() noexcept { + return nullptr; + } + + [[nodiscard]] + constexpr const value_ty* get_inline_ptr() const noexcept { + return nullptr; + } + + static constexpr size_t element_size() noexcept { return sizeof(value_ty); } + + static constexpr size_t alignment() noexcept { return alignof(value_ty); } + + static constexpr unsigned num_elements() noexcept { return 0; } + + static constexpr size_t num_bytes() noexcept { return 0; } +}; + +template && !std::is_final_v> +class allocator_inliner; + +template +class allocator_inliner : private Allocator { + using alloc_traits = std::allocator_traits; + + static constexpr bool copy_assign_is_noop = + !alloc_traits::propagate_on_container_copy_assignment::value; + + static constexpr bool move_assign_is_noop = + !alloc_traits::propagate_on_container_move_assignment::value; + + static constexpr bool swap_is_noop = + !alloc_traits::propagate_on_container_swap::value; + + template + requires IsNoOp + constexpr void maybe_assign(const allocator_inliner&) noexcept {} + + template + requires(!IsNoOp) + constexpr void maybe_assign(const allocator_inliner& other) noexcept( + noexcept(std::declval().operator=(other))) { + Allocator::operator=(other); + } + + template + requires IsNoOp + constexpr void maybe_assign(allocator_inliner&&) noexcept {} + + template + requires(!IsNoOp) + constexpr void maybe_assign(allocator_inliner&& other) noexcept( + noexcept(std::declval().operator=(std::move(other)))) { + Allocator::operator=(std::move(other)); + } + + public: + allocator_inliner() = default; + allocator_inliner(const allocator_inliner&) = default; + allocator_inliner(allocator_inliner&&) noexcept = default; + ~allocator_inliner() = default; + + constexpr explicit allocator_inliner(const Allocator& alloc) noexcept + : Allocator(alloc) {} + + constexpr allocator_inliner& + operator=(const allocator_inliner& other) noexcept( + noexcept(std::declval().maybe_assign(other))) { + assert( + &other != this && + "`allocator_inliner` should not participate in self-copy-assignment."); + maybe_assign(other); + return *this; + } + + constexpr allocator_inliner& operator=(allocator_inliner&& other) noexcept( + noexcept( + std::declval().maybe_assign(std::move(other)))) { + assert( + &other != this && + "`allocator_inliner` should not participate in self-move-assignment."); + maybe_assign(std::move(other)); + return *this; + } + + constexpr Allocator& allocator_ref() noexcept { return *this; } + + constexpr const Allocator& allocator_ref() const noexcept { return *this; } + + template + requires IsNoOp + constexpr void swap(allocator_inliner&) {} + + template + requires(!IsNoOp) + constexpr void swap(allocator_inliner& other) { + using std::swap; + swap(static_cast(*this), static_cast(other)); + } +}; + +template +class allocator_inliner { + using alloc_traits = std::allocator_traits; + + static constexpr bool copy_assign_is_noop = + !alloc_traits::propagate_on_container_copy_assignment::value; + + static constexpr bool move_assign_is_noop = + !alloc_traits::propagate_on_container_move_assignment::value; + + static constexpr bool swap_is_noop = + !alloc_traits::propagate_on_container_swap::value; + + template + requires IsNoOp + constexpr void maybe_assign(const allocator_inliner&) noexcept {} + + template + requires(!IsNoOp) + constexpr void maybe_assign(const allocator_inliner& other) noexcept( + noexcept(std::declval() = other.m_alloc)) { + m_alloc = other.m_alloc; + } + + template + requires IsNoOp + constexpr void maybe_assign(allocator_inliner&&) noexcept {} + + template + requires(!IsNoOp) + constexpr void maybe_assign(allocator_inliner&& other) noexcept(noexcept( + std::declval() = std::move(other.m_alloc))) { + m_alloc = std::move(other.m_alloc); + } + + public: + allocator_inliner() = default; + allocator_inliner(const allocator_inliner&) = default; + allocator_inliner(allocator_inliner&&) noexcept = default; + ~allocator_inliner() = default; + + constexpr explicit allocator_inliner(const Allocator& alloc) noexcept + : m_alloc(alloc) {} + + constexpr allocator_inliner& + operator=(const allocator_inliner& other) noexcept( + noexcept(std::declval().maybe_assign(other))) { + assert( + &other != this && + "`allocator_inliner` should not participate in self-copy-assignment."); + maybe_assign(other); + return *this; + } + + constexpr allocator_inliner& operator=(allocator_inliner&& other) noexcept( + noexcept( + std::declval().maybe_assign(std::move(other)))) { + assert( + &other != this && + "`allocator_inliner` should not participate in self-move-assignment."); + maybe_assign(std::move(other)); + return *this; + } + + constexpr Allocator& allocator_ref() noexcept { return m_alloc; } + + constexpr const Allocator& allocator_ref() const noexcept { return m_alloc; } + + template + requires IsNoOp + constexpr void swap(allocator_inliner&) {} + + template + requires(!IsNoOp) + constexpr void swap(allocator_inliner& other) { + using std::swap; + swap(m_alloc, other.m_alloc); + } + + private: + Allocator m_alloc; +}; + +template +class allocator_interface : public allocator_inliner { + public: + template + inline static constexpr bool is_complete_v = requires { sizeof(U); }; + + using size_type = typename std::allocator_traits::size_type; + + // If difference_type is larger than size_type then we need + // to rectify that problem. + using difference_type = typename std::conditional_t< + (static_cast( + (std::numeric_limits::max)()) < // less-than + static_cast((std::numeric_limits::difference_type>::max)())), + typename std::make_signed_t, + typename std::allocator_traits::difference_type>; + + private: + using alloc_base = allocator_inliner; + + protected: + using alloc_ty = Allocator; + using alloc_traits = std::allocator_traits; + using value_ty = typename alloc_traits::value_type; + using ptr = typename alloc_traits::pointer; + using cptr = typename alloc_traits::const_pointer; + using vptr = typename alloc_traits::void_pointer; + using cvptr = typename alloc_traits::const_void_pointer; + + // Select the fastest types larger than the user-facing types. These are only + // intended for internal computations, and should not have any memory + // footprint visible to consumers. + using size_ty = typename std::conditional_t< + (sizeof(size_type) <= sizeof(uint8_t)), uint_fast8_t, + typename std::conditional_t< + (sizeof(size_type) <= sizeof(uint16_t)), uint_fast16_t, + typename std::conditional_t< + (sizeof(size_type) <= sizeof(uint32_t)), uint_fast32_t, + typename std::conditional_t<(sizeof(size_type) <= + sizeof(uint64_t)), + uint_fast64_t, size_type>>>>; + + using diff_ty = typename std::conditional_t< + (sizeof(difference_type) <= sizeof(int8_t)), int_fast8_t, + typename std::conditional_t< + (sizeof(difference_type) <= sizeof(int16_t)), int_fast16_t, + typename std::conditional_t< + (sizeof(difference_type) <= sizeof(int32_t)), int_fast32_t, + typename std::conditional_t<(sizeof(difference_type) <= + sizeof(int64_t)), + int_fast64_t, difference_type>>>>; + + using alloc_base::allocator_ref; + + private: + template + struct underlying_if_enum { + using type = T; + }; + + template + requires std::is_enum_v + struct underlying_if_enum : std::underlying_type {}; + + template + using underlying_if_enum_t = typename underlying_if_enum::type; + + template + inline static constexpr bool has_ptr_traits_to_address_v = + requires { std::pointer_traits

::to_address(std::declval

()); }; + + template + inline static constexpr bool has_alloc_construct_v = + is_complete_v && requires { + std::declval().construct(std::declval(), + std::declval()...); + }; + + template + inline static constexpr bool must_use_alloc_construct_v = + !std::is_same_v> && + has_alloc_construct_v; + + template + inline static constexpr bool has_alloc_destroy_v = + is_complete_v && + requires { std::declval().destroy(std::declval()); }; + + template + inline static constexpr bool must_use_alloc_destroy_v = + !std::is_same_v> && has_alloc_destroy_v; + + public: + allocator_interface() = default; + allocator_interface(allocator_interface&&) noexcept = default; + + constexpr allocator_interface& operator=(const allocator_interface&) = + default; + + constexpr allocator_interface& operator=(allocator_interface&&) noexcept = + default; + + ~allocator_interface() = default; + + constexpr allocator_interface(const allocator_interface& other) noexcept + : alloc_base(alloc_traits::select_on_container_copy_construction( + other.allocator_ref())) {} + + constexpr explicit allocator_interface(const alloc_ty& alloc) noexcept + : alloc_base(alloc) {} + + template + constexpr explicit allocator_interface(T&&, const alloc_ty& alloc) noexcept + : allocator_interface(alloc) {} + + template + inline static constexpr bool is_memcpyable_integral_v = + is_complete_v && + (sizeof(underlying_if_enum_t) == + sizeof(underlying_if_enum_t)) && + (std::is_same_v> == + std::is_same_v>) && + std::is_integral_v> && + std::is_integral_v>; + + template + inline static constexpr bool is_convertible_pointer_v = + std::is_pointer_v && std::is_pointer_v && + std::is_convertible_v; + + // Memcpyable assignment. + template + inline static constexpr bool is_memcpyable_v = + is_complete_v && !std::is_reference_v && + std::is_trivially_assignable_v && + std::is_trivially_copyable_v> && + (std::is_same_v>>, + std::remove_cv_t> || + is_memcpyable_integral_v< + std::remove_reference_t>, + std::remove_cv_t> || + is_convertible_pointer_v< + std::remove_reference_t>, + std::remove_cv_t>); + + // Memcpyable construction. + template + inline static constexpr bool is_uninitialized_memcpyable_v = + !std::is_reference_v && std::is_trivially_constructible_v && + std::is_trivially_copyable_v> && + (std::is_same_v< + std::remove_cv_t>>, + std::remove_cv_t> || + is_memcpyable_integral_v>, + std::remove_cv_t> || + is_convertible_pointer_v>, + std::remove_cv_t>) && + (!must_use_alloc_construct_v< + alloc_ty, value_ty, + std::remove_reference_t>> && + !must_use_alloc_destroy_v); + + template + struct is_small_vector_iterator : std::false_type {}; + + template + struct is_small_vector_iterator> + : std::true_type {}; + + template + inline static constexpr bool is_small_vector_iterator_v = + is_small_vector_iterator::value; + + template + inline static constexpr bool is_contiguous_iterator_v = + std::is_same_v || std::is_same_v || + is_small_vector_iterator_v || std::contiguous_iterator; + + template + struct is_memcpyable_iterator { + inline static constexpr bool value = + is_memcpyable_v())> && + is_contiguous_iterator_v; + }; + + // Unwrap move_iterators + template + struct is_memcpyable_iterator> + : is_memcpyable_iterator {}; + + template + inline static constexpr bool is_memcpyable_iterator_v = + is_memcpyable_iterator::value; + + template + struct is_uninitialized_memcpyable_iterator { + inline static constexpr bool value = + is_uninitialized_memcpyable_v())> && + is_contiguous_iterator_v; + }; + + // Unwrap move_iterators + template + struct is_uninitialized_memcpyable_iterator, V> + : is_uninitialized_memcpyable_iterator {}; + + template + inline static constexpr bool is_uninitialized_memcpyable_iterator_v = + is_uninitialized_memcpyable_iterator::value; + + [[noreturn]] + static constexpr void throw_range_length_error() { + throw std::length_error("The specified range is too long."); + } + + static constexpr value_ty* to_address(value_ty* p) noexcept { + static_assert(!std::is_function_v, + "value_ty is a function pointer."); + return p; + } + + static constexpr const value_ty* to_address(const value_ty* p) noexcept { + static_assert(!std::is_function_v, + "value_ty is a function pointer."); + return p; + } + + template + requires has_ptr_traits_to_address_v + static constexpr auto to_address(const Pointer& p) noexcept + -> decltype(std::pointer_traits::to_address(p)) { + return std::pointer_traits::to_address(p); + } + + template + requires(!has_ptr_traits_to_address_v) + static constexpr auto to_address(const Pointer& p) noexcept + -> decltype(to_address(p.operator->())) { + return to_address(p.operator->()); + } + + template + [[nodiscard]] + static consteval size_t numeric_max() noexcept { + static_assert(0 <= (std::numeric_limits::max)(), + "Integer is nonpositive."); + return static_cast((std::numeric_limits::max)()); + } + + [[nodiscard]] + static constexpr size_ty internal_range_length(cptr first, + cptr last) noexcept { + // This is guaranteed to be less than or equal to max size_ty. + return static_cast(last - first); + } + + template + [[nodiscard]] + static constexpr size_ty external_range_length_impl( + RandomIt first, RandomIt last, std::random_access_iterator_tag) { + assert(0 <= (last - first) && "Invalid range."); + const auto len = static_cast(last - first); +#ifndef NDEBUG + if (numeric_max() < len) + throw_range_length_error(); +#endif + return static_cast(len); + } + + template + [[nodiscard]] + static constexpr size_ty external_range_length_impl( + ForwardIt first, ForwardIt last, std::forward_iterator_tag) { + if (std::is_constant_evaluated()) { + // Make sure constexpr doesn't get broken by `using namespace + // std::rel_ops`. + typename std::iterator_traits::difference_type len = 0; + for (; !(first == last); ++first) { + ++len; + } + assert(static_cast(len) <= numeric_max()); + return static_cast(len); + } + + const auto len = static_cast(std::distance(first, last)); +#ifndef NDEBUG + if (numeric_max() < len) + throw_range_length_error(); +#endif + return static_cast(len); + } + + template ::difference_type> + requires(numeric_max() < numeric_max()) + [[nodiscard]] + static constexpr size_ty external_range_length(ForwardIt first, + ForwardIt last) { + using iterator_cat = + typename std::iterator_traits::iterator_category; + return external_range_length_impl(first, last, iterator_cat{}); + } + + template ::difference_type> + requires(!(numeric_max() < numeric_max())) + [[nodiscard]] + static constexpr size_ty external_range_length(ForwardIt first, + ForwardIt last) noexcept { + if (std::is_constant_evaluated()) { + // Make sure constexpr doesn't get broken by `using namespace + // std::rel_ops`. + size_ty len = 0; + for (; !(first == last); ++first) { + ++len; + } + return len; + } + + return static_cast(std::distance(first, last)); + } + + template ::difference_type, + typename Integer = IteratorDiffT> + [[nodiscard]] + static constexpr Iterator unchecked_next(Iterator pos, + Integer n = 1) noexcept { + unchecked_advance(pos, static_cast(n)); + return pos; + } + + template ::difference_type, + typename Integer = IteratorDiffT> + [[nodiscard]] + static constexpr Iterator unchecked_prev(Iterator pos, + Integer n = 1) noexcept { + unchecked_advance(pos, -static_cast(n)); + return pos; + } + + template ::difference_type, + typename Integer = IteratorDiffT> + static constexpr void unchecked_advance(Iterator& pos, Integer n) noexcept { + std::advance(pos, static_cast(n)); + } + + [[nodiscard]] + constexpr size_ty get_max_size() const noexcept { + // This is protected from max/min macros. + return (std::min)( + static_cast(alloc_traits::max_size(allocator_ref())), + static_cast(numeric_max())); + } + + [[nodiscard]] + constexpr ptr allocate(size_ty n) { + return alloc_traits::allocate(allocator_ref(), static_cast(n)); + } + + [[nodiscard]] + constexpr ptr allocate_with_hint(size_ty n, cptr hint) { + return alloc_traits::allocate(allocator_ref(), static_cast(n), + hint); + } + + constexpr void deallocate(ptr p, size_ty n) { + alloc_traits::deallocate(allocator_ref(), to_address(p), + static_cast(n)); + } + + template + requires is_uninitialized_memcpyable_v + constexpr void construct(ptr p, U&& val) noexcept { + if (std::is_constant_evaluated()) { + alloc_traits::construct(allocator_ref(), to_address(p), + std::forward(val)); + return; + } + std::memcpy(to_address(p), &val, sizeof(value_ty)); + } + + // basically alloc_traits::construct + // all this is so we can replicate C++20 behavior in the other overload + template + requires(sizeof...(Args) != 1 || + !is_uninitialized_memcpyable_v) && + has_alloc_construct_v + constexpr void construct(ptr p, Args&&... args) noexcept( + noexcept(alloc_traits::construct(std::declval(), + std::declval(), + std::forward(args)...))) { + alloc_traits::construct(allocator_ref(), to_address(p), + std::forward(args)...); + } + + template + requires(sizeof...(Args) != 1 || + !is_uninitialized_memcpyable_v) && + (!has_alloc_construct_v) && requires { + ::new (std::declval()) V(std::declval()...); + } + constexpr void construct(ptr p, Args&&... args) noexcept(noexcept( + ::new(std::declval()) value_ty(std::declval()...))) { + construct_at(to_address(p), std::forward(args)...); + } + + template + requires std::is_trivially_destructible_v && + (!must_use_alloc_destroy_v) + constexpr void destroy(ptr) const noexcept {} + + template + requires(!std::is_trivially_destructible_v || + must_use_alloc_destroy_v) && + has_alloc_destroy_v + constexpr void destroy(ptr p) noexcept { + alloc_traits::destroy(allocator_ref(), to_address(p)); + } + + // defined so we match C++20 behavior in all cases. + template + requires(!std::is_trivially_destructible_v || + must_use_alloc_destroy_v) && + (!has_alloc_destroy_v) + constexpr void destroy(ptr p) noexcept { + destroy_at(to_address(p)); + } + + template + requires std::is_trivially_destructible_v && + (!must_use_alloc_destroy_v) + constexpr void destroy_range(ptr, ptr) const noexcept {} + + template + requires(!std::is_trivially_destructible_v || + must_use_alloc_destroy_v) + constexpr void destroy_range(ptr first, ptr last) noexcept { + for (; !(first == last); ++first) { + destroy(first); + } + } + + // allowed if trivially copyable and we use the standard allocator + // and InputIt is a contiguous iterator + template + requires is_uninitialized_memcpyable_iterator_v + constexpr ptr uninitialized_copy(ForwardIt first, ForwardIt last, + ptr dest) noexcept { + static_assert(std::is_constructible_v, + "`value_type` must be copy constructible."); + + if (std::is_constant_evaluated()) { + return default_uninitialized_copy(first, last, dest); + } + + const size_ty num_copy = external_range_length(first, last); + if (num_copy != 0) { + std::memcpy(to_address(dest), to_address(first), + num_copy * sizeof(value_ty)); + } + return unchecked_next(dest, num_copy); + } + + template + requires is_uninitialized_memcpyable_iterator_v + constexpr ptr uninitialized_copy(std::move_iterator first, + std::move_iterator last, + ptr dest) noexcept { + return uninitialized_copy(first.base(), last.base(), dest); + } + + template + requires(!is_uninitialized_memcpyable_iterator_v) + constexpr ptr uninitialized_copy(InputIt first, InputIt last, ptr d_first) { + return default_uninitialized_copy(first, last, d_first); + } + + template + constexpr ptr default_uninitialized_copy(InputIt first, InputIt last, + ptr d_first) { + ptr d_last = d_first; + try { + for (; !(first == last); ++first, static_cast(++d_last)) { + construct(d_last, *first); + } + return d_last; + } catch (...) { + destroy_range(d_first, d_last); + throw; + } + } + + template + requires(std::is_trivially_constructible_v && + !must_use_alloc_construct_v) + constexpr ptr uninitialized_value_construct(ptr first, ptr last) { + if (std::is_constant_evaluated()) { + return default_uninitialized_value_construct(first, last); + } + std::fill(first, last, value_ty()); + return last; + } + + template + requires(!std::is_trivially_constructible_v || + must_use_alloc_construct_v) + constexpr ptr uninitialized_value_construct(ptr first, ptr last) { + return default_uninitialized_value_construct(first, last); + } + + constexpr ptr default_uninitialized_value_construct(ptr first, ptr last) { + ptr curr = first; + try { + for (; !(curr == last); ++curr) { + construct(curr); + } + return curr; + } catch (...) { + destroy_range(first, curr); + throw; + } + } + + constexpr ptr uninitialized_fill(ptr first, ptr last) { + return uninitialized_value_construct(first, last); + } + + constexpr ptr uninitialized_fill(ptr first, ptr last, const value_ty& val) { + ptr curr = first; + try { + for (; !(curr == last); ++curr) { + construct(curr, val); + } + return curr; + } catch (...) { + destroy_range(first, curr); + throw; + } + } + + private: + // If value_ty is an array, replicate C++20 behavior (I don't think that + // value_ty can actually be an array because of the Erasable requirement, but + // there shouldn't be any runtime cost for being defensive here). + template + requires std::is_array_v + static constexpr void destroy_at(value_ty* p) noexcept { + for (auto& e : *p) { + destroy_at(std::addressof(e)); + } + } + + template + requires(!std::is_array_v) + static constexpr void destroy_at(value_ty* p) noexcept { + p->~value_ty(); + } + + template + static constexpr auto construct_at(value_ty* p, Args&&... args) noexcept( + noexcept(::new(std::declval()) V(std::declval()...))) + -> decltype(::new(std::declval()) V(std::declval()...)) { + if (std::is_constant_evaluated()) { + return std::construct_at(p, std::forward(args)...); + } + void* vp = const_cast(static_cast(p)); + return ::new (vp) value_ty(std::forward(args)...); + } +}; + +template +class small_vector_data_base { + public: + using ptr = Pointer; + using size_ty = SizeT; + + small_vector_data_base() = default; + small_vector_data_base(const small_vector_data_base&) = default; + small_vector_data_base(small_vector_data_base&&) noexcept = default; + small_vector_data_base& operator=(const small_vector_data_base&) = default; + small_vector_data_base& operator=(small_vector_data_base&&) noexcept = + default; + ~small_vector_data_base() = default; + + constexpr ptr data_ptr() const noexcept { return m_data_ptr; } + constexpr size_ty capacity() const noexcept { return m_capacity; } + constexpr size_ty size() const noexcept { return m_size; } + + constexpr void set_data_ptr(ptr data_ptr) noexcept { m_data_ptr = data_ptr; } + constexpr void set_capacity(size_ty capacity) noexcept { + m_capacity = capacity; + } + constexpr void set_size(size_ty size) noexcept { m_size = size; } + + constexpr void set(ptr data_ptr, size_ty capacity, size_ty size) { + m_data_ptr = data_ptr; + m_capacity = capacity; + m_size = size; + } + + constexpr void swap_data_ptr(small_vector_data_base& other) noexcept { + using std::swap; + swap(m_data_ptr, other.m_data_ptr); + } + + constexpr void swap_capacity(small_vector_data_base& other) noexcept { + using std::swap; + swap(m_capacity, other.m_capacity); + } + + constexpr void swap_size(small_vector_data_base& other) noexcept { + using std::swap; + swap(m_size, other.m_size); + } + + constexpr void swap(small_vector_data_base& other) noexcept { + using std::swap; + swap(m_data_ptr, other.m_data_ptr); + swap(m_capacity, other.m_capacity); + swap(m_size, other.m_size); + } + + private: + ptr m_data_ptr; + size_ty m_capacity; + size_ty m_size; +}; + +template +class small_vector_data : public small_vector_data_base { + public: + using value_ty = T; + + small_vector_data() = default; + small_vector_data(const small_vector_data&) = delete; + small_vector_data(small_vector_data&&) noexcept = delete; + small_vector_data& operator=(const small_vector_data&) = delete; + small_vector_data& operator=(small_vector_data&&) noexcept = delete; + ~small_vector_data() = default; + + constexpr value_ty* storage() noexcept { return m_storage.get_inline_ptr(); } + + constexpr const value_ty* storage() const noexcept { + return m_storage.get_inline_ptr(); + } + + private: + inline_storage m_storage; +}; + +template +class small_vector_data + : public small_vector_data_base, + private inline_storage { + using base = inline_storage; + + public: + using value_ty = T; + + small_vector_data() = default; + small_vector_data(const small_vector_data&) = delete; + small_vector_data(small_vector_data&&) noexcept = delete; + small_vector_data& operator=(const small_vector_data&) = delete; + small_vector_data& operator=(small_vector_data&&) noexcept = delete; + ~small_vector_data() = default; + + constexpr value_ty* storage() noexcept { return base::get_inline_ptr(); } + + constexpr const value_ty* storage() const noexcept { + return base::get_inline_ptr(); + } +}; + +template +class small_vector_base : public allocator_interface { + public: + using size_type = typename allocator_interface::size_type; + using difference_type = + typename allocator_interface::difference_type; + + template + friend class small_vector_base; + + protected: + using alloc_interface = allocator_interface; + using alloc_traits = typename alloc_interface::alloc_traits; + using alloc_ty = Allocator; + + using value_ty = typename alloc_interface::value_ty; + using ptr = typename alloc_interface::ptr; + using cptr = typename alloc_interface::cptr; + using size_ty = typename alloc_interface::size_ty; + using diff_ty = typename alloc_interface::diff_ty; + + static_assert( + alloc_interface::template is_complete_v || InlineCapacity == 0, + "`value_type` must be complete for instantiation of a non-zero number " + "of inline elements."); + + template + inline static constexpr bool is_complete_v = + alloc_interface::template is_complete_v; + + using alloc_interface::allocator_ref; + using alloc_interface::construct; + using alloc_interface::deallocate; + using alloc_interface::destroy; + using alloc_interface::destroy_range; + using alloc_interface::external_range_length; + using alloc_interface::get_max_size; + using alloc_interface::internal_range_length; + using alloc_interface::to_address; + using alloc_interface::unchecked_advance; + using alloc_interface::unchecked_next; + using alloc_interface::unchecked_prev; + using alloc_interface::uninitialized_copy; + using alloc_interface::uninitialized_fill; + using alloc_interface::uninitialized_value_construct; + + template + [[nodiscard]] + static consteval size_t numeric_max() noexcept { + return alloc_interface::template numeric_max(); + } + + [[nodiscard]] + static consteval size_ty get_inline_capacity() noexcept { + return static_cast(InlineCapacity); + } + + template + inline static constexpr bool is_emplace_constructible_v = + is_complete_v && requires { + std::declval().construct(std::declval(), + std::declval()...); + }; + + template + inline static constexpr bool is_nothrow_emplace_constructible_v = + is_complete_v && requires { + noexcept(std::declval().construct( + std::declval(), std::declval()...)); + }; + + template + inline static constexpr bool is_explicitly_move_insertable_v = + is_emplace_constructible_v; + + template + inline static constexpr bool is_explicitly_nothrow_move_insertable_v = + is_nothrow_emplace_constructible_v; + + template + inline static constexpr bool is_explicitly_copy_insertable_v = + is_emplace_constructible_v && is_emplace_constructible_v; + + template + inline static constexpr bool is_explicitly_nothrow_copy_insertable_v = + is_nothrow_emplace_constructible_v && + is_nothrow_emplace_constructible_v; + + template + inline static constexpr bool relocate_with_move_v = + std::is_nothrow_move_constructible_v || + !is_explicitly_copy_insertable_v; + + template + inline static constexpr bool allocations_are_movable_v = + std::is_same_v, A> || + std::allocator_traits::propagate_on_container_move_assignment::value || + std::allocator_traits::is_always_equal::value; + + template + inline static constexpr bool allocations_are_swappable_v = + std::is_same_v, A> || + std::allocator_traits::propagate_on_container_swap::value || + std::allocator_traits::is_always_equal::value; + + template + inline static constexpr bool is_memcpyable_v = + alloc_interface::template is_memcpyable_v; + + template + inline static constexpr bool is_memcpyable_iterator_v = + alloc_interface::template is_memcpyable_iterator_v; + + [[noreturn]] + static constexpr void throw_overflow_error() { + throw std::overflow_error("The requested conversion would overflow."); + } + + [[noreturn]] + static constexpr void throw_index_error() { + throw std::out_of_range("The requested index was out of range."); + } + + [[noreturn]] + static constexpr void throw_increment_error() { + throw std::domain_error( + "The requested increment was outside of the allowed range."); + } + + [[noreturn]] + static constexpr void throw_allocation_size_error() { + throw std::length_error( + "The required allocation exceeds the maximum size."); + } + + [[nodiscard]] + constexpr ptr ptr_cast( + const small_vector_iterator& it) noexcept { + return unchecked_next(begin_ptr(), it.base() - begin_ptr()); + } + + private: + class stack_temporary { + public: + stack_temporary() = delete; + stack_temporary(const stack_temporary&) = delete; + stack_temporary(stack_temporary&&) noexcept = delete; + stack_temporary& operator=(const stack_temporary&) = delete; + stack_temporary& operator=(stack_temporary&&) noexcept = delete; + + template + constexpr explicit stack_temporary(alloc_interface& alloc_iface, + Args&&... args) + : m_interface(alloc_iface) { + m_interface.construct(get_pointer(), std::forward(args)...); + } + + constexpr ~stack_temporary() { m_interface.destroy(get_pointer()); } + + [[nodiscard]] + constexpr const value_ty& get() const noexcept { + return *get_pointer(); + } + + [[nodiscard]] + constexpr value_ty&& release() noexcept { + return std::move(*get_pointer()); + } + + private: + [[nodiscard]] + constexpr cptr get_pointer() const noexcept { + return static_cast( + static_cast(std::addressof(m_data))); + } + + [[nodiscard]] + constexpr ptr get_pointer() noexcept { + return static_cast(static_cast(std::addressof(m_data))); + } + + alloc_interface& m_interface; + alignas(value_ty) std::byte m_data[sizeof(value_ty)]; + }; + + class heap_temporary { + public: + heap_temporary() = delete; + heap_temporary(const heap_temporary&) = delete; + heap_temporary(heap_temporary&&) noexcept = delete; + heap_temporary& operator=(const heap_temporary&) = delete; + heap_temporary& operator=(heap_temporary&&) noexcept = delete; + + template + constexpr explicit heap_temporary(alloc_interface& alloc_iface, + Args&&... args) + : m_interface(alloc_iface), + m_data_ptr(alloc_iface.allocate(sizeof(value_ty))) { + try { + m_interface.construct(m_data_ptr, std::forward(args)...); + } catch (...) { + m_interface.deallocate(m_data_ptr, sizeof(value_ty)); + throw; + } + } + + constexpr ~heap_temporary() { + m_interface.destroy(m_data_ptr); + m_interface.deallocate(m_data_ptr, sizeof(value_ty)); + } + + [[nodiscard]] + constexpr const value_ty& get() const noexcept { + return *m_data_ptr; + } + + [[nodiscard]] + constexpr value_ty&& release() noexcept { + return std::move(*m_data_ptr); + } + + private: + alloc_interface& m_interface; + ptr m_data_ptr; + }; + + constexpr void wipe() { + destroy_range(begin_ptr(), end_ptr()); + if (has_allocation()) { + deallocate(data_ptr(), get_capacity()); + } + } + + constexpr void set_data_ptr(ptr data_ptr) noexcept { + m_data.set_data_ptr(data_ptr); + } + + constexpr void set_capacity(size_ty capacity) noexcept { + m_data.set_capacity(static_cast(capacity)); + } + + constexpr void set_size(size_ty size) noexcept { + m_data.set_size(static_cast(size)); + } + + constexpr void set_data(ptr data_ptr, size_ty capacity, + size_ty size) noexcept { + m_data.set(data_ptr, static_cast(capacity), + static_cast(size)); + } + + constexpr void swap_data_ptr(small_vector_base& other) noexcept { + m_data.swap_data_ptr(other.m_data); + } + + constexpr void swap_capacity(small_vector_base& other) noexcept { + m_data.swap_capacity(other.m_data); + } + + constexpr void swap_size(small_vector_base& other) noexcept { + m_data.swap_size(other.m_data); + } + + constexpr void swap_allocation(small_vector_base& other) noexcept { + m_data.swap(other.m_data); + } + + constexpr void reset_data(ptr data_ptr, size_ty capacity, size_ty size) { + wipe(); + m_data.set(data_ptr, static_cast(capacity), + static_cast(size)); + } + + constexpr void increase_size(size_ty n) noexcept { + m_data.set_size(get_size() + n); + } + + constexpr void decrease_size(size_ty n) noexcept { + m_data.set_size(get_size() - n); + } + + constexpr ptr unchecked_allocate(size_ty n) { + assert(InlineCapacity < n && + "Allocated capacity should be greater than InlineCapacity."); + return alloc_interface::allocate(n); + } + + constexpr ptr unchecked_allocate(size_ty n, cptr hint) { + assert(InlineCapacity < n && + "Allocated capacity should be greater than InlineCapacity."); + return alloc_interface::allocate_with_hint(n, hint); + } + + constexpr ptr checked_allocate(size_ty n) { + if (get_max_size() < n) { + throw_allocation_size_error(); + } + return unchecked_allocate(n); + } + + protected: + [[nodiscard]] + constexpr size_ty unchecked_calculate_new_capacity( + const size_ty minimum_required_capacity) const noexcept { + const size_ty current_capacity = get_capacity(); + + assert(current_capacity < minimum_required_capacity); + + if (get_max_size() - current_capacity <= current_capacity) { + return get_max_size(); + } + + // Note: This growth factor might be theoretically superior, but in testing + // it falls flat: size_ty new_capacity = current_capacity + + // (current_capacity / 2); + + const size_ty new_capacity = 2 * current_capacity; + if (new_capacity < minimum_required_capacity) { + return minimum_required_capacity; + } + return new_capacity; + } + + [[nodiscard]] + constexpr size_ty checked_calculate_new_capacity( + const size_ty minimum_required_capacity) const { + if (get_max_size() < minimum_required_capacity) { + throw_allocation_size_error(); + } + return unchecked_calculate_new_capacity(minimum_required_capacity); + } + + template + constexpr small_vector_base& copy_assign_default( + const small_vector_base& other) { + if (get_capacity() < other.get_size()) { + // Reallocate. + size_ty new_capacity = unchecked_calculate_new_capacity(other.get_size()); + ptr new_data_ptr = + unchecked_allocate(new_capacity, other.allocation_end_ptr()); + + try { + uninitialized_copy(other.begin_ptr(), other.end_ptr(), new_data_ptr); + } catch (...) { + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, other.get_size()); + } else { + if (get_size() < other.get_size()) { + // No reallocation, partially in uninitialized space. + std::copy_n(other.begin_ptr(), get_size(), begin_ptr()); + uninitialized_copy(unchecked_next(other.begin_ptr(), get_size()), + other.end_ptr(), end_ptr()); + } else { + destroy_range( + copy_range(other.begin_ptr(), other.end_ptr(), begin_ptr()), + end_ptr()); + } + + // data_ptr and capacity do not change in this case. + set_size(other.get_size()); + } + + alloc_interface::operator=(other); + return *this; + } + + template + requires(AT::propagate_on_container_copy_assignment::value && + !AT::is_always_equal::value) + constexpr small_vector_base& copy_assign( + const small_vector_base& other) { + if (other.allocator_ref() == allocator_ref()) { + return copy_assign_default(other); + } + + if (InlineCapacity < other.get_size()) { + alloc_interface new_alloc(other); + + const size_ty new_capacity = other.get_size(); + const ptr new_data_ptr = new_alloc.allocate_with_hint( + new_capacity, other.allocation_end_ptr()); + + try { + uninitialized_copy(other.begin_ptr(), other.end_ptr(), new_data_ptr); + } catch (...) { + new_alloc.deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, other.get_size()); + alloc_interface::operator=(new_alloc); + } else { + if (has_allocation()) { + ptr new_data_ptr; + if (std::is_constant_evaluated()) { + alloc_interface new_alloc(other); + new_data_ptr = new_alloc.allocate(InlineCapacity); + } else { + new_data_ptr = storage_ptr(); + } + + uninitialized_copy(other.begin_ptr(), other.end_ptr(), new_data_ptr); + destroy_range(begin_ptr(), end_ptr()); + deallocate(data_ptr(), get_capacity()); + set_data_ptr(new_data_ptr); + set_capacity(InlineCapacity); + } else if (get_size() < other.get_size()) { + std::copy_n(other.begin_ptr(), get_size(), begin_ptr()); + uninitialized_copy(unchecked_next(other.begin_ptr(), get_size()), + other.end_ptr(), end_ptr()); + } else { + destroy_range( + copy_range(other.begin_ptr(), other.end_ptr(), begin_ptr()), + end_ptr()); + } + set_size(other.get_size()); + alloc_interface::operator=(other); + } + + return *this; + } + + template + requires(!AT::propagate_on_container_copy_assignment::value || + AT::is_always_equal::value) + constexpr small_vector_base& copy_assign( + const small_vector_base& other) { + return copy_assign_default(other); + } + + template + constexpr void move_allocation_pointer( + small_vector_base&& other) noexcept { + reset_data(other.data_ptr(), other.get_capacity(), other.get_size()); + other.set_default(); + } + + template + requires(N == 0) + constexpr small_vector_base& move_assign_default( + small_vector_base&& other) noexcept { + move_allocation_pointer(std::move(other)); + alloc_interface::operator=(std::move(other)); + return *this; + } + + template + requires(LessEqualI <= InlineCapacity) + constexpr small_vector_base& move_assign_default( + small_vector_base&& + other) noexcept(std::is_nothrow_move_assignable_v && + std::is_nothrow_move_constructible_v) { + // We only move the allocation pointer over if it has strictly greater + // capacity than the inline capacity of `*this` because allocations can + // never have a smaller capacity than the inline capacity. + if (InlineCapacity < other.get_capacity()) { + move_allocation_pointer(std::move(other)); + } else { + // We are guaranteed to have sufficient capacity to store the elements. + if (InlineCapacity < get_capacity()) { + ptr new_data_ptr; + if (std::is_constant_evaluated()) { + new_data_ptr = other.allocate(InlineCapacity); + } else { + new_data_ptr = storage_ptr(); + } + + uninitialized_move(other.begin_ptr(), other.end_ptr(), new_data_ptr); + destroy_range(begin_ptr(), end_ptr()); + deallocate(data_ptr(), get_capacity()); + set_data_ptr(new_data_ptr); + set_capacity(InlineCapacity); + } else if (get_size() < other.get_size()) { + // There are more elements in `other`. + // Overwrite the existing range and uninitialized move the rest. + ptr other_pivot = unchecked_next(other.begin_ptr(), get_size()); + std::move(other.begin_ptr(), other_pivot, begin_ptr()); + uninitialized_move(other_pivot, other.end_ptr(), end_ptr()); + } else { + // There are the same number or fewer elements in `other`. + // Overwrite part of the existing range and destroy the rest. + ptr new_end = + std::move(other.begin_ptr(), other.end_ptr(), begin_ptr()); + destroy_range(new_end, end_ptr()); + } + + set_size(other.get_size()); + + // Note: We do not need to deallocate any allocations in `other` because + // the value of + // an object meeting the Allocator named requirements does not + // change value after a move. + } + + alloc_interface::operator=(std::move(other)); + return *this; + } + + template + requires(InlineCapacity < GreaterI) + constexpr small_vector_base& move_assign_default( + small_vector_base&& other) { + if (other.has_allocation()) { + move_allocation_pointer(std::move(other)); + } else if (get_capacity() < other.get_size() || + (has_allocation() && + !(other.allocator_ref() == allocator_ref()))) { + // Reallocate. + + // The compiler should be able to optimize this. + size_ty new_capacity = + get_capacity() < other.get_size() + ? unchecked_calculate_new_capacity(other.get_size()) + : get_capacity(); + + ptr new_data_ptr = + other.allocate_with_hint(new_capacity, other.allocation_end_ptr()); + + try { + uninitialized_move(other.begin_ptr(), other.end_ptr(), new_data_ptr); + } catch (...) { + other.deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, other.get_size()); + } else { + if (get_size() < other.get_size()) { + // There are more elements in `other`. + // Overwrite the existing range and uninitialized move the rest. + ptr other_pivot = unchecked_next(other.begin_ptr(), get_size()); + std::move(other.begin_ptr(), other_pivot, begin_ptr()); + uninitialized_move(other_pivot, other.end_ptr(), end_ptr()); + } else { + // fewer elements in other + // overwrite part of the existing range and destroy the rest + ptr new_end = + std::move(other.begin_ptr(), other.end_ptr(), begin_ptr()); + destroy_range(new_end, end_ptr()); + } + + // `data_ptr` and `capacity` do not change in this case. + set_size(other.get_size()); + } + + alloc_interface::operator=(std::move(other)); + return *this; + } + + template + constexpr small_vector_base& move_assign_unequal_no_propagate( + small_vector_base&& other) { + if (get_capacity() < other.get_size()) { + // Reallocate. + size_ty new_capacity = unchecked_calculate_new_capacity(other.get_size()); + ptr new_data_ptr = + unchecked_allocate(new_capacity, other.allocation_end_ptr()); + + try { + uninitialized_move(other.begin_ptr(), other.end_ptr(), new_data_ptr); + } catch (...) { + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, other.get_size()); + } else { + if (get_size() < other.get_size()) { + // There are more elements in `other`. + // Overwrite the existing range and uninitialized move the rest. + ptr other_pivot = unchecked_next(other.begin_ptr(), get_size()); + std::move(other.begin_ptr(), other_pivot, begin_ptr()); + uninitialized_move(other_pivot, other.end_ptr(), end_ptr()); + } else { + // There are fewer elements in `other`. + // Overwrite part of the existing range and destroy the rest. + destroy_range( + std::move(other.begin_ptr(), other.end_ptr(), begin_ptr()), + end_ptr()); + } + + // data_ptr and capacity do not change in this case + set_size(other.get_size()); + } + + alloc_interface::operator=(std::move(other)); + return *this; + } + + template + requires allocations_are_movable_v + constexpr small_vector_base& + move_assign(small_vector_base&& other) noexcept( + noexcept(std::declval().move_assign_default( + std::move(other)))) { + return move_assign_default(std::move(other)); + } + + template + requires(!allocations_are_movable_v) + constexpr small_vector_base& move_assign( + small_vector_base&& other) { + if (other.allocator_ref() == allocator_ref()) { + return move_assign_default(std::move(other)); + } + return move_assign_unequal_no_propagate(std::move(other)); + } + + template + requires(I == 0) + constexpr void move_initialize(small_vector_base&& other) noexcept { + set_data(other.data_ptr(), other.get_capacity(), other.get_size()); + other.set_default(); + } + + template + requires(LessEqualI <= InlineCapacity) + constexpr void + move_initialize(small_vector_base&& other) noexcept( + std::is_nothrow_move_constructible_v) { + if (InlineCapacity < other.get_capacity()) { + set_data(other.data_ptr(), other.get_capacity(), other.get_size()); + other.set_default(); + } else { + set_to_inline_storage(); + uninitialized_move(other.begin_ptr(), other.end_ptr(), data_ptr()); + set_size(other.get_size()); + } + } + + template + requires(InlineCapacity < GreaterI) + constexpr void move_initialize( + small_vector_base&& other) { + if (other.has_allocation()) { + set_data(other.data_ptr(), other.get_capacity(), other.get_size()); + other.set_default(); + } else { + if (InlineCapacity < other.get_size()) { + // We may throw in this case. + set_data_ptr( + unchecked_allocate(other.get_size(), other.allocation_end_ptr())); + set_capacity(other.get_size()); + + try { + uninitialized_move(other.begin_ptr(), other.end_ptr(), data_ptr()); + } catch (...) { + deallocate(data_ptr(), get_capacity()); + throw; + } + } else { + set_to_inline_storage(); + uninitialized_move(other.begin_ptr(), other.end_ptr(), data_ptr()); + } + + set_size(other.get_size()); + } + } + + public: + small_vector_base(const small_vector_base&) = delete; + small_vector_base(small_vector_base&&) noexcept = delete; + small_vector_base& operator=(const small_vector_base&) = delete; + small_vector_base& operator=(small_vector_base&&) noexcept = delete; + + constexpr small_vector_base() noexcept { set_default(); } + + static constexpr struct bypass_tag { + } bypass{}; + + template + constexpr small_vector_base(bypass_tag, + const small_vector_base& other, + const MaybeAlloc&... alloc) + : alloc_interface(other, alloc...) { + if (InlineCapacity < other.get_size()) { + set_data_ptr( + unchecked_allocate(other.get_size(), other.allocation_end_ptr())); + set_capacity(other.get_size()); + + try { + uninitialized_copy(other.begin_ptr(), other.end_ptr(), data_ptr()); + } catch (...) { + deallocate(data_ptr(), get_capacity()); + throw; + } + } else { + set_to_inline_storage(); + uninitialized_copy(other.begin_ptr(), other.end_ptr(), data_ptr()); + } + + set_size(other.get_size()); + } + + template + constexpr small_vector_base( + bypass_tag, + small_vector_base&& + other) noexcept(std::is_nothrow_move_constructible_v || + (I == 0 && I == InlineCapacity)) + : alloc_interface(std::move(other)) { + move_initialize(std::move(other)); + } + + template + requires std::same_as> || + std::allocator_traits::is_always_equal::value + constexpr small_vector_base( + bypass_tag, small_vector_base&& other, + const alloc_ty&) noexcept(noexcept(small_vector_base(bypass, + std::move(other)))) + : small_vector_base(bypass, std::move(other)) {} + + template + requires(!(std::same_as> || + std::allocator_traits::is_always_equal::value)) + constexpr small_vector_base(bypass_tag, + small_vector_base&& other, + const alloc_ty& alloc) + : alloc_interface(alloc) { + if (other.allocator_ref() == alloc) { + move_initialize(std::move(other)); + return; + } + + if (InlineCapacity < other.get_size()) { + // We may throw in this case. + set_data_ptr( + unchecked_allocate(other.get_size(), other.allocation_end_ptr())); + set_capacity(other.get_size()); + + try { + uninitialized_move(other.begin_ptr(), other.end_ptr(), data_ptr()); + } catch (...) { + deallocate(data_ptr(), get_capacity()); + throw; + } + } else { + set_to_inline_storage(); + uninitialized_move(other.begin_ptr(), other.end_ptr(), data_ptr()); + } + + set_size(other.get_size()); + } + + constexpr explicit small_vector_base(const alloc_ty& alloc) noexcept + : alloc_interface(alloc) { + set_default(); + } + + constexpr small_vector_base(size_ty count, const alloc_ty& alloc) + : alloc_interface(alloc) { + if (InlineCapacity < count) { + set_data_ptr(checked_allocate(count)); + set_capacity(count); + } else { + set_to_inline_storage(); + } + + try { + uninitialized_value_construct(begin_ptr(), + unchecked_next(begin_ptr(), count)); + } catch (...) { + if (has_allocation()) { + deallocate(data_ptr(), get_capacity()); + } + throw; + } + set_size(count); + } + + constexpr small_vector_base(size_ty count, const value_ty& val, + const alloc_ty& alloc) + : alloc_interface(alloc) { + if (InlineCapacity < count) { + set_data_ptr(checked_allocate(count)); + set_capacity(count); + } else { + set_to_inline_storage(); + } + + try { + uninitialized_fill(begin_ptr(), unchecked_next(begin_ptr(), count), val); + } catch (...) { + if (has_allocation()) { + deallocate(data_ptr(), get_capacity()); + } + throw; + } + set_size(count); + } + + template + constexpr small_vector_base(size_ty count, Generator& g, + const alloc_ty& alloc) + : alloc_interface(alloc) { + if (InlineCapacity < count) { + set_data_ptr(checked_allocate(count)); + set_capacity(count); + } else { + set_to_inline_storage(); + } + + ptr curr = begin_ptr(); + const ptr new_end = unchecked_next(begin_ptr(), count); + try { + for (; !(curr == new_end); ++curr) { + construct(curr, g()); + } + } catch (...) { + destroy_range(begin_ptr(), curr); + if (has_allocation()) { + deallocate(data_ptr(), get_capacity()); + } + throw; + } + set_size(count); + } + + template + constexpr small_vector_base(InputIt first, InputIt last, + std::input_iterator_tag, const alloc_ty& alloc) + : small_vector_base(alloc) { + using iterator_cat = + typename std::iterator_traits::iterator_category; + append_range(first, last, iterator_cat{}); + } + + template + constexpr small_vector_base(ForwardIt first, ForwardIt last, + std::forward_iterator_tag, const alloc_ty& alloc) + : alloc_interface(alloc) { + size_ty count = external_range_length(first, last); + if (InlineCapacity < count) { + set_data_ptr(unchecked_allocate(count)); + set_capacity(count); + try { + uninitialized_copy(first, last, begin_ptr()); + } catch (...) { + deallocate(data_ptr(), get_capacity()); + throw; + } + } else { + set_to_inline_storage(); + uninitialized_copy(first, last, begin_ptr()); + } + + set_size(count); + } + + constexpr ~small_vector_base() noexcept { + assert(InlineCapacity <= get_capacity() && "Invalid capacity."); + wipe(); + } + + protected: + constexpr void set_to_inline_storage() { + set_capacity(InlineCapacity); + if (std::is_constant_evaluated()) { + return set_data_ptr(alloc_interface::allocate(InlineCapacity)); + } + set_data_ptr(storage_ptr()); + } + + constexpr void assign_with_copies(size_ty count, const value_ty& val) { + if (get_capacity() < count) { + size_ty new_capacity = checked_calculate_new_capacity(count); + ptr new_begin = unchecked_allocate(new_capacity); + + try { + uninitialized_fill(new_begin, unchecked_next(new_begin, count), val); + } catch (...) { + deallocate(new_begin, new_capacity); + throw; + } + + reset_data(new_begin, new_capacity, count); + } else if (get_size() < count) { + std::fill(begin_ptr(), end_ptr(), val); + uninitialized_fill(end_ptr(), unchecked_next(begin_ptr(), count), val); + set_size(count); + } else { + erase_range(std::fill_n(begin_ptr(), count, val), end_ptr()); + } + } + + template + requires std::is_assignable_v())> + constexpr void assign_with_range(InputIt first, InputIt last, + std::input_iterator_tag) { + using iterator_cat = + typename std::iterator_traits::iterator_category; + + ptr curr = begin_ptr(); + for (; !(end_ptr() == curr || first == last); + ++curr, static_cast(++first)) { + *curr = *first; + } + + if (first == last) { + erase_to_end(curr); + } else { + append_range(first, last, iterator_cat{}); + } + } + + template + requires std::is_assignable_v())> + constexpr void assign_with_range(ForwardIt first, ForwardIt last, + std::forward_iterator_tag) { + const size_ty count = external_range_length(first, last); + if (get_capacity() < count) { + size_ty new_capacity = checked_calculate_new_capacity(count); + ptr new_begin = unchecked_allocate(new_capacity); + + try { + uninitialized_copy(first, last, new_begin); + } catch (...) { + deallocate(new_begin, new_capacity); + throw; + } + + reset_data(new_begin, new_capacity, count); + } else if (get_size() < count) { + ForwardIt pivot = copy_n_return_in(first, get_size(), begin_ptr()); + uninitialized_copy(pivot, last, end_ptr()); + set_size(count); + } else { + erase_range(copy_range(first, last, begin_ptr()), end_ptr()); + } + } + + template + requires( + !std::is_assignable_v())>) + constexpr void assign_with_range(InputIt first, InputIt last, + std::input_iterator_tag) { + using iterator_cat = + typename std::iterator_traits::iterator_category; + + // If not assignable then destroy all elements and append. + erase_all(); + append_range(first, last, iterator_cat{}); + } + + // Ie. move-if-noexcept. + struct strong_exception_policy {}; + + template + requires is_explicitly_move_insertable_v && + (!std::same_as || + relocate_with_move_v) + constexpr ptr uninitialized_move(ptr first, ptr last, ptr d_first) noexcept( + std::is_nothrow_move_constructible_v) { + return uninitialized_copy(std::make_move_iterator(first), + std::make_move_iterator(last), d_first); + } + + template + requires(!is_explicitly_move_insertable_v || + (std::same_as && + !relocate_with_move_v)) + constexpr ptr uninitialized_move(ptr first, ptr last, ptr d_first) noexcept( + alloc_interface::template is_uninitialized_memcpyable_iterator_v) { + return uninitialized_copy(first, last, d_first); + } + + constexpr ptr shift_into_uninitialized(ptr pos, size_ty n_shift) { + // Shift elements over to the right into uninitialized space. + // Returns the start of the shifted range. + // Precondition: shift < end_ptr () - pos + assert(n_shift != 0 && "The value of `n_shift` should not be 0."); + + const ptr original_end = end_ptr(); + const ptr pivot = unchecked_prev(original_end, n_shift); + + uninitialized_move(pivot, original_end, original_end); + increase_size(n_shift); + return move_right(pos, pivot, original_end); + } + + template + constexpr ptr append_element(Args&&... args) { + if (get_size() < get_capacity()) { + return emplace_into_current_end(std::forward(args)...); + } + return emplace_into_reallocation_end(std::forward(args)...); + } + + constexpr ptr append_copies(size_ty count, const value_ty& val) { + if (num_uninitialized() < count) { + // Reallocate. + if (get_max_size() - get_size() < count) { + throw_allocation_size_error(); + } + + size_ty original_size = get_size(); + size_ty new_size = get_size() + count; + + // The check is handled by the if-guard. + size_ty new_capacity = unchecked_calculate_new_capacity(new_size); + ptr new_data_ptr = unchecked_allocate(new_capacity, allocation_end_ptr()); + ptr new_last = unchecked_next(new_data_ptr, original_size); + + try { + new_last = + uninitialized_fill(new_last, unchecked_next(new_last, count), val); + uninitialized_move(begin_ptr(), end_ptr(), new_data_ptr); + } catch (...) { + destroy_range(unchecked_next(new_data_ptr, original_size), new_last); + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, new_size); + return unchecked_next(new_data_ptr, original_size); + } else { + const ptr ret = end_ptr(); + uninitialized_fill(ret, unchecked_next(ret, count), val); + increase_size(count); + return ret; + } + } + + template MovePolicy, typename InputIt> + constexpr ptr append_range(InputIt first, InputIt last, + std::input_iterator_tag) { + // Append with a strong exception guarantee. + size_ty original_size = get_size(); + for (; !(first == last); ++first) { + try { + append_element(*first); + } catch (...) { + erase_range(unchecked_next(begin_ptr(), original_size), end_ptr()); + throw; + } + } + return unchecked_next(begin_ptr(), original_size); + } + + template + requires(!std::same_as) + constexpr ptr append_range(InputIt first, InputIt last, + std::input_iterator_tag) { + size_ty original_size = get_size(); + for (; !(first == last); ++first) { + append_element(*first); + } + return unchecked_next(begin_ptr(), original_size); + } + + template + constexpr ptr append_range(ForwardIt first, ForwardIt last, + std::forward_iterator_tag) { + const size_ty num_insert = external_range_length(first, last); + + if (num_uninitialized() < num_insert) { + // Reallocate. + if (get_max_size() - get_size() < num_insert) { + throw_allocation_size_error(); + } + + size_ty original_size = get_size(); + size_ty new_size = get_size() + num_insert; + + // The check is handled by the if-guard. + size_ty new_capacity = unchecked_calculate_new_capacity(new_size); + ptr new_data_ptr = unchecked_allocate(new_capacity, allocation_end_ptr()); + ptr new_last = unchecked_next(new_data_ptr, original_size); + + try { + new_last = uninitialized_copy(first, last, new_last); + uninitialized_move(begin_ptr(), end_ptr(), new_data_ptr); + } catch (...) { + destroy_range(unchecked_next(new_data_ptr, original_size), new_last); + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, new_size); + return unchecked_next(new_data_ptr, original_size); + } else { + ptr ret = end_ptr(); + uninitialized_copy(first, last, ret); + increase_size(num_insert); + return ret; + } + } + + template + constexpr ptr emplace_at(ptr pos, Args&&... args) { + assert(get_size() <= get_capacity() && "size was greater than capacity"); + + if (get_size() < get_capacity()) { + return emplace_into_current(pos, std::forward(args)...); + } + return emplace_into_reallocation(pos, std::forward(args)...); + } + + constexpr ptr insert_copies(ptr pos, size_ty count, const value_ty& val) { + if (0 == count) { + return pos; + } + + if (end_ptr() == pos) { + if (1 == count) { + return append_element(val); + } + return append_copies(count, val); + } + + if (num_uninitialized() < count) { + // Reallocate. + if (get_max_size() - get_size() < count) { + throw_allocation_size_error(); + } + + const size_ty offset = internal_range_length(begin_ptr(), pos); + + const size_ty new_size = get_size() + count; + + // The check is handled by the if-guard. + const size_ty new_capacity = unchecked_calculate_new_capacity(new_size); + ptr new_data_ptr = unchecked_allocate(new_capacity, allocation_end_ptr()); + ptr new_first = unchecked_next(new_data_ptr, offset); + ptr new_last = new_first; + + try { + uninitialized_fill(new_first, unchecked_next(new_first, count), val); + unchecked_advance(new_last, count); + + uninitialized_move(begin_ptr(), pos, new_data_ptr); + new_first = new_data_ptr; + uninitialized_move(pos, end_ptr(), new_last); + } catch (...) { + destroy_range(new_first, new_last); + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, new_size); + return unchecked_next(begin_ptr(), offset); + } else { + // If we have fewer to insert than tailing elements after `pos`, we shift + // into uninitialized and then copy over. + + const size_ty tail_size = internal_range_length(pos, end_ptr()); + if (tail_size < count) { + // The number inserted is larger than the number after `pos`, + // so part of the input will be used to construct new elements, + // and another part of it will assign existing ones. + // In order: + // Construct new elements immediately after end_ptr () using the + // input. Move-construct existing elements over to the tail. Assign + // existing elements using the input. + + ptr original_end = end_ptr(); + + // Place a portion of the input into the uninitialized section. + size_ty num_val_tail = count - tail_size; + + if (std::is_constant_evaluated()) { + uninitialized_fill(end_ptr(), unchecked_next(end_ptr(), num_val_tail), + val); + increase_size(num_val_tail); + + const heap_temporary tmp(*this, val); + + uninitialized_move(pos, original_end, end_ptr()); + increase_size(tail_size); + + std::fill_n(pos, tail_size, tmp.get()); + + return pos; + } + + uninitialized_fill(end_ptr(), unchecked_next(end_ptr(), num_val_tail), + val); + increase_size(num_val_tail); + + try { + // We need to handle possible aliasing here. + const stack_temporary tmp(*this, val); + + // Now, move the tail to the end. + uninitialized_move(pos, original_end, end_ptr()); + increase_size(tail_size); + + try { + // Finally, try to copy the rest of the elements over. + std::fill_n(pos, tail_size, tmp.get()); + } catch (...) { + // Attempt to roll back and destroy the tail if we fail. + ptr inserted_end = unchecked_prev(end_ptr(), tail_size); + move_left(inserted_end, end_ptr(), pos); + destroy_range(inserted_end, end_ptr()); + decrease_size(tail_size); + throw; + } + } catch (...) { + // Destroy the elements constructed from the input. + destroy_range(original_end, end_ptr()); + decrease_size(internal_range_length(original_end, end_ptr())); + throw; + } + } else { + if (std::is_constant_evaluated()) { + const heap_temporary tmp(*this, val); + + ptr inserted_end = shift_into_uninitialized(pos, count); + std::fill(pos, inserted_end, tmp.get()); + + return pos; + } + const stack_temporary tmp(*this, val); + + ptr inserted_end = shift_into_uninitialized(pos, count); + + // Attempt to copy over the elements. + // If we fail we'll attempt a full roll-back. + try { + std::fill(pos, inserted_end, tmp.get()); + } catch (...) { + ptr original_end = move_left(inserted_end, end_ptr(), pos); + destroy_range(original_end, end_ptr()); + decrease_size(count); + throw; + } + } + return pos; + } + } + + template + constexpr ptr insert_range_helper(ptr pos, ForwardIt first, ForwardIt last) { + assert(!(first == last) && "The range should not be empty."); + assert(!(end_ptr() == pos) && "`pos` should not be at the end."); + + const size_ty num_insert = external_range_length(first, last); + if (num_uninitialized() < num_insert) { + // Reallocate. + if (get_max_size() - get_size() < num_insert) { + throw_allocation_size_error(); + } + + const size_ty offset = internal_range_length(begin_ptr(), pos); + const size_ty new_size = get_size() + num_insert; + + // The check is handled by the if-guard. + const size_ty new_capacity = unchecked_calculate_new_capacity(new_size); + const ptr new_data_ptr = + unchecked_allocate(new_capacity, allocation_end_ptr()); + ptr new_first = unchecked_next(new_data_ptr, offset); + ptr new_last = new_first; + + try { + uninitialized_copy(first, last, new_first); + unchecked_advance(new_last, num_insert); + + uninitialized_move(begin_ptr(), pos, new_data_ptr); + new_first = new_data_ptr; + uninitialized_move(pos, end_ptr(), new_last); + } catch (...) { + destroy_range(new_first, new_last); + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, new_size); + return unchecked_next(begin_ptr(), offset); + } else { + // if we have fewer to insert than tailing elements after + // `pos` we shift into uninitialized and then copy over + const size_ty tail_size = internal_range_length(pos, end_ptr()); + if (tail_size < num_insert) { + // Use the same method as insert_copies. + ptr original_end = end_ptr(); + ForwardIt pivot = unchecked_next(first, tail_size); + + // Place a portion of the input into the uninitialized section. + uninitialized_copy(pivot, last, end_ptr()); + increase_size(num_insert - tail_size); + + try { + // Now move the tail to the end. + uninitialized_move(pos, original_end, end_ptr()); + increase_size(tail_size); + + try { + // Finally, try to copy the rest of the elements over. + copy_range(first, pivot, pos); + } catch (...) { + // Attempt to roll back and destroy the tail if we fail. + ptr inserted_end = unchecked_prev(end_ptr(), tail_size); + move_left(inserted_end, end_ptr(), pos); + destroy_range(inserted_end, end_ptr()); + decrease_size(tail_size); + throw; + } + } catch (...) { + // If we throw, destroy the first copy we made. + destroy_range(original_end, end_ptr()); + decrease_size(internal_range_length(original_end, end_ptr())); + throw; + } + } else { + shift_into_uninitialized(pos, num_insert); + + // Attempt to copy over the elements. + // If we fail we'll attempt a full roll-back. + try { + copy_range(first, last, pos); + } catch (...) { + ptr inserted_end = unchecked_next(pos, num_insert); + ptr original_end = move_left(inserted_end, end_ptr(), pos); + destroy_range(original_end, end_ptr()); + decrease_size(num_insert); + throw; + } + } + return pos; + } + } + + template + constexpr ptr insert_range(ptr pos, InputIt first, InputIt last, + std::input_iterator_tag) { + assert(!(first == last) && "The range should not be empty."); + + // Ensure we use this specific overload to give a strong exception guarantee + // for 1 element. + if (end_ptr() == pos) { + return append_range(first, last, std::input_iterator_tag{}); + } + + using iterator_cat = + typename std::iterator_traits::iterator_category; + small_vector_base tmp(first, last, iterator_cat{}, allocator_ref()); + + return insert_range_helper(pos, std::make_move_iterator(tmp.begin_ptr()), + std::make_move_iterator(tmp.end_ptr())); + } + + template + constexpr ptr insert_range(ptr pos, ForwardIt first, ForwardIt last, + std::forward_iterator_tag) { + if (!(end_ptr() == pos)) { + return insert_range_helper(pos, first, last); + } + + if (unchecked_next(first) == last) { + return append_element(*first); + } + + using iterator_cat = + typename std::iterator_traits::iterator_category; + return append_range(first, last, iterator_cat{}); + } + + template + constexpr ptr emplace_into_current_end(Args&&... args) { + construct(end_ptr(), std::forward(args)...); + increase_size(1); + return unchecked_prev(end_ptr()); + } + + template + requires std::is_nothrow_move_constructible_v + constexpr ptr emplace_into_current(ptr pos, value_ty&& val) { + if (pos == end_ptr()) { + return emplace_into_current_end(std::move(val)); + } + + // In the special case of value_ty&& we don't make a copy because behavior + // is unspecified when it is an internal element. Hence, we'll take the + // opportunity to optimize and assume that it isn't an internal element. + shift_into_uninitialized(pos, 1); + destroy(pos); + construct(pos, std::move(val)); + return pos; + } + + template + constexpr ptr emplace_into_current(ptr pos, Args&&... args) { + if (pos == end_ptr()) { + return emplace_into_current_end(std::forward(args)...); + } + + if (std::is_constant_evaluated()) { + heap_temporary tmp(*this, std::forward(args)...); + shift_into_uninitialized(pos, 1); + *pos = tmp.release(); + return pos; + } + + // This is necessary because of possible aliasing. + stack_temporary tmp(*this, std::forward(args)...); + shift_into_uninitialized(pos, 1); + *pos = tmp.release(); + return pos; + } + + template + constexpr ptr emplace_into_reallocation_end(Args&&... args) { + // Appending; strong exception guarantee. + if (get_max_size() == get_size()) { + throw_allocation_size_error(); + } + + const size_ty new_size = get_size() + 1; + + // The check is handled by the if-guard. + const size_ty new_capacity = unchecked_calculate_new_capacity(new_size); + const ptr new_data_ptr = + unchecked_allocate(new_capacity, allocation_end_ptr()); + const ptr emplace_pos = unchecked_next(new_data_ptr, get_size()); + + try { + construct(emplace_pos, std::forward(args)...); + try { + uninitialized_move(begin_ptr(), end_ptr(), + new_data_ptr); + } catch (...) { + destroy(emplace_pos); + throw; + } + } catch (...) { + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, new_size); + return emplace_pos; + } + + template + constexpr ptr emplace_into_reallocation(ptr pos, Args&&... args) { + const size_ty offset = internal_range_length(begin_ptr(), pos); + if (offset == get_size()) { + return emplace_into_reallocation_end(std::forward(args)...); + } + + if (get_max_size() == get_size()) { + throw_allocation_size_error(); + } + + const size_ty new_size = get_size() + 1; + + // The check is handled by the if-guard. + const size_ty new_capacity = unchecked_calculate_new_capacity(new_size); + const ptr new_data_ptr = + unchecked_allocate(new_capacity, allocation_end_ptr()); + ptr new_first = unchecked_next(new_data_ptr, offset); + ptr new_last = new_first; + + try { + construct(new_first, std::forward(args)...); + unchecked_advance(new_last, 1); + + uninitialized_move(begin_ptr(), pos, new_data_ptr); + new_first = new_data_ptr; + uninitialized_move(pos, end_ptr(), new_last); + } catch (...) { + destroy_range(new_first, new_last); + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, new_size); + return unchecked_next(begin_ptr(), offset); + } + + constexpr ptr shrink_to_size() { + if (!has_allocation() || get_size() == get_capacity()) { + return begin_ptr(); + } + + // The rest runs only if allocated. + + size_ty new_capacity; + ptr new_data_ptr; + + if (InlineCapacity < get_size()) { + new_capacity = get_size(); + new_data_ptr = unchecked_allocate(new_capacity, allocation_end_ptr()); + } else { + // We move to inline storage. + new_capacity = InlineCapacity; + if (std::is_constant_evaluated()) { + new_data_ptr = alloc_interface::allocate(InlineCapacity); + } else { + new_data_ptr = storage_ptr(); + } + } + + uninitialized_move(begin_ptr(), end_ptr(), new_data_ptr); + + destroy_range(begin_ptr(), end_ptr()); + deallocate(data_ptr(), get_capacity()); + + set_data_ptr(new_data_ptr); + set_capacity(new_capacity); + + return begin_ptr(); + } + + template + constexpr void resize_with(size_ty new_size, const ValueT&... val) { + // ValueT... should either be value_ty or empty. + + if (new_size == 0) { + erase_all(); + } + + if (get_capacity() < new_size) { + // Reallocate. + + if (get_max_size() < new_size) { + throw_allocation_size_error(); + } + + const size_ty original_size = get_size(); + + // The check is handled by the if-guard. + const size_ty new_capacity = unchecked_calculate_new_capacity(new_size); + ptr new_data_ptr = unchecked_allocate(new_capacity, allocation_end_ptr()); + ptr new_last = unchecked_next(new_data_ptr, original_size); + + try { + new_last = uninitialized_fill( + new_last, unchecked_next(new_data_ptr, new_size), val...); + + // Strong exception guarantee. + uninitialized_move(begin_ptr(), end_ptr(), + new_data_ptr); + } catch (...) { + destroy_range(unchecked_next(new_data_ptr, original_size), new_last); + deallocate(new_data_ptr, new_capacity); + throw; + } + + reset_data(new_data_ptr, new_capacity, new_size); + } else if (get_size() < new_size) { + // Construct in the uninitialized section. + uninitialized_fill(end_ptr(), unchecked_next(begin_ptr(), new_size), + val...); + set_size(new_size); + } else { + erase_range(unchecked_next(begin_ptr(), new_size), end_ptr()); + } + + // Do nothing if the count is the same as the current size. + } + + constexpr void request_capacity(size_ty request) { + if (request <= get_capacity()) { + return; + } + + size_ty new_capacity = checked_calculate_new_capacity(request); + ptr new_begin = unchecked_allocate(new_capacity); + + try { + uninitialized_move(begin_ptr(), end_ptr(), + new_begin); + } catch (...) { + deallocate(new_begin, new_capacity); + throw; + } + + wipe(); + + set_data_ptr(new_begin); + set_capacity(new_capacity); + } + + constexpr ptr erase_at(ptr pos) { + move_left(unchecked_next(pos), end_ptr(), pos); + erase_last(); + return pos; + } + + constexpr void erase_last() { + decrease_size(1); + + // The element located at end_ptr is still alive since the size decreased. + destroy(end_ptr()); + } + + constexpr ptr erase_range(ptr first, ptr last) { + if (!(first == last)) { + erase_to_end(move_left(last, end_ptr(), first)); + } + return first; + } + + constexpr void erase_to_end(ptr pos) { + assert(0 <= (end_ptr() - pos) && "`pos` was in the uninitialized range"); + if (size_ty change = internal_range_length(pos, end_ptr())) { + decrease_size(change); + destroy_range(pos, unchecked_next(pos, change)); + } + } + + constexpr void erase_all() { + ptr curr_end = end_ptr(); + set_size(0); + destroy_range(begin_ptr(), curr_end); + } + + constexpr void swap_elements(small_vector_base& other) noexcept( + std::is_nothrow_move_constructible_v && + std::is_nothrow_swappable_v) { + assert(get_size() <= other.get_size()); + + const ptr other_tail = + std::swap_ranges(begin_ptr(), end_ptr(), other.begin_ptr()); + uninitialized_move(other_tail, other.end_ptr(), end_ptr()); + destroy_range(other_tail, other.end_ptr()); + + swap_size(other); + } + + constexpr void swap_default(small_vector_base& other) noexcept( + std::is_nothrow_move_constructible_v && + std::is_nothrow_swappable_v) { + // This function is used when: + // We are using the standard allocator. + // The allocators propagate and are equal. + // The allocators are always equal. + // The allocators do not propagate and are equal. + // The allocators propagate and are not equal. + + // Not handled: + // The allocators do not propagate and are not equal. + + assert(get_capacity() <= other.get_capacity()); + + if (has_allocation()) { // Implies that `other` also has an allocation. + swap_allocation(other); + } else if (other.has_allocation()) { + // Note: This will never be constant evaluated because both are always + // allocated. + uninitialized_move(begin_ptr(), end_ptr(), other.storage_ptr()); + destroy_range(begin_ptr(), end_ptr()); + + set_data_ptr(other.data_ptr()); + set_capacity(other.get_capacity()); + + other.set_data_ptr(other.storage_ptr()); + other.set_capacity(InlineCapacity); + + swap_size(other); + } else if (get_size() < other.get_size()) { + swap_elements(other); + } else { + other.swap_elements(*this); + } + + alloc_interface::swap(other); + } + + constexpr void swap_unequal_no_propagate(small_vector_base& other) { + assert(get_capacity() <= other.get_capacity()); + + if (get_capacity() < other.get_size()) { + // Reallocation required. + // We should always be able to reuse the allocation of `other`. + const size_ty new_capacity = + unchecked_calculate_new_capacity(other.get_size()); + const ptr new_data_ptr = unchecked_allocate(new_capacity, end_ptr()); + + try { + uninitialized_move(other.begin_ptr(), other.end_ptr(), new_data_ptr); + try { + destroy_range(std::move(begin_ptr(), end_ptr(), other.begin_ptr()), + other.end_ptr()); + } catch (...) { + destroy_range(new_data_ptr, + unchecked_next(new_data_ptr, other.get_size())); + throw; + } + } catch (...) { + deallocate(new_data_ptr, new_capacity); + throw; + } + + destroy_range(begin_ptr(), end_ptr()); + if (has_allocation()) { + deallocate(data_ptr(), get_capacity()); + } + + set_data_ptr(new_data_ptr); + set_capacity(new_capacity); + swap_size(other); + } else if (get_size() < other.get_size()) { + swap_elements(other); + } else { + other.swap_elements(*this); + } + + // This should have no effect. + alloc_interface::swap(other); + } + + template + requires allocations_are_swappable_v && (InlineCapacity == 0) + constexpr void swap(small_vector_base& other) noexcept { + swap_allocation(other); + alloc_interface::swap(other); + } + + template + requires allocations_are_swappable_v && (InlineCapacity != 0) + constexpr void swap(small_vector_base& other) noexcept( + std::is_nothrow_move_constructible_v && + std::is_nothrow_swappable_v) { + if (get_capacity() < other.get_capacity()) { + swap_default(other); + } else { + other.swap_default(*this); + } + } + + template + requires(!allocations_are_swappable_v) + constexpr void swap(small_vector_base& other) { + if (get_capacity() < other.get_capacity()) { + if (other.allocator_ref() == allocator_ref()) { + swap_default(other); + } else { + swap_unequal_no_propagate(other); + } + } else { + if (other.allocator_ref() == allocator_ref()) { + other.swap_default(*this); + } else { + other.swap_unequal_no_propagate(*this); + } + } + } + +#ifdef __GLIBCXX__ + + // These are compatibility fixes for libstdc++ because std::copy doesn't work + // for `move_iterator`s when constant evaluated. + + template + static constexpr InputIt unmove_iterator(InputIt it) { + return it; + } + + template + static constexpr auto unmove_iterator(std::move_iterator it) + -> decltype(unmove_iterator(it.base())) { + return unmove_iterator(it.base()); + } + + template + static constexpr auto unmove_iterator(std::reverse_iterator it) + -> std::reverse_iterator { + return std::reverse_iterator( + unmove_iterator(it.base())); + } + +#endif + + template + constexpr ptr copy_range(InputIt first, InputIt last, ptr dest) { +#ifdef __GLIBCXX__ + if (std::is_constant_evaluated()) { + if constexpr (!std::is_same_v())), + InputIt>) { + return std::move(unmove_iterator(first), unmove_iterator(last), dest); + } + } +#endif + + return std::copy(first, last, dest); + } + + template + requires is_memcpyable_iterator_v + constexpr InputIt copy_n_return_in(InputIt first, size_ty count, + ptr dest) noexcept { + if (std::is_constant_evaluated()) { + std::copy_n(first, count, dest); + return unchecked_next(first, count); + } + + if (count != 0) { + std::memcpy(to_address(dest), to_address(first), + count * sizeof(value_ty)); + } + // Note: The unsafe cast here should be proven to be safe in the caller + // function. + return unchecked_next(first, count); + } + + template + requires is_memcpyable_iterator_v + constexpr std::move_iterator copy_n_return_in( + std::move_iterator first, size_ty count, ptr dest) noexcept { + return std::move_iterator( + copy_n_return_in(first.base(), count, dest)); + } + + template + requires(!is_memcpyable_iterator_v && + std::is_base_of_v< + std::random_access_iterator_tag, + typename std::iterator_traits::iterator_category>) + constexpr RandomIt copy_n_return_in(RandomIt first, size_ty count, ptr dest) { +#ifdef __GLIBCXX__ + if (std::is_constant_evaluated()) { + if constexpr (!std::is_same_v())), + RandomIt>) { + auto bfirst = unmove_iterator(first); + auto blast = unchecked_next(bfirst, count); + std::move(bfirst, blast, dest); + return unchecked_next(first, count); + } + } +#endif + + std::copy_n(first, count, dest); + // Note: This unsafe cast should be proven safe in the caller function. + return unchecked_next(first, count); + } + + template + requires(!is_memcpyable_iterator_v && + !std::is_base_of_v< + std::random_access_iterator_tag, + typename std::iterator_traits::iterator_category>) + constexpr InputIt copy_n_return_in(InputIt first, size_ty count, ptr dest) { + for (; count != 0; + --count, static_cast(++dest), static_cast(++first)) { + *dest = *first; + } + return first; + } + + template + requires is_memcpyable_v + constexpr ptr move_left(ptr first, ptr last, ptr d_first) { + // Shift initialized elements to the left. + + if (std::is_constant_evaluated()) { + return std::move(first, last, d_first); + } + + const size_ty num_moved = internal_range_length(first, last); + if (num_moved != 0) { + std::memmove(to_address(d_first), to_address(first), + num_moved * sizeof(value_ty)); + } + return unchecked_next(d_first, num_moved); + } + + template + requires(!is_memcpyable_v) + constexpr ptr move_left(ptr first, ptr last, ptr d_first) { + // Shift initialized elements to the left. + return std::move(first, last, d_first); + } + + template + requires is_memcpyable_v + constexpr ptr move_right(ptr first, ptr last, ptr d_last) { + // Move initialized elements to the right. + + if (std::is_constant_evaluated()) { + return std::move_backward(first, last, d_last); + } + + const size_ty num_moved = internal_range_length(first, last); + const ptr dest = unchecked_prev(d_last, num_moved); + if (num_moved != 0) { + std::memmove(to_address(dest), to_address(first), + num_moved * sizeof(value_ty)); + } + return dest; + } + + template + requires(!is_memcpyable_v) + constexpr ptr move_right(ptr first, ptr last, ptr d_last) { + // move initialized elements to the right + // n should not be 0 + return std::move_backward(first, last, d_last); + } + + public: + constexpr void set_default() { + set_to_inline_storage(); + set_size(0); + } + + [[nodiscard]] + constexpr ptr data_ptr() noexcept { + return m_data.data_ptr(); + } + + [[nodiscard]] + constexpr cptr data_ptr() const noexcept { + return m_data.data_ptr(); + } + + [[nodiscard]] + constexpr size_ty get_capacity() const noexcept { + return m_data.capacity(); + } + + [[nodiscard]] + constexpr size_ty get_size() const noexcept { + return m_data.size(); + } + + [[nodiscard]] + constexpr size_ty num_uninitialized() const noexcept { + return get_capacity() - get_size(); + } + + [[nodiscard]] + constexpr ptr begin_ptr() noexcept { + return data_ptr(); + } + + [[nodiscard]] + constexpr cptr begin_ptr() const noexcept { + return data_ptr(); + } + + [[nodiscard]] + constexpr ptr end_ptr() noexcept { + return unchecked_next(begin_ptr(), get_size()); + } + + [[nodiscard]] + constexpr cptr end_ptr() const noexcept { + return unchecked_next(begin_ptr(), get_size()); + } + + [[nodiscard]] + constexpr ptr allocation_end_ptr() noexcept { + return unchecked_next(begin_ptr(), get_capacity()); + } + + [[nodiscard]] + constexpr cptr allocation_end_ptr() const noexcept { + return unchecked_next(begin_ptr(), get_capacity()); + } + + [[nodiscard]] + constexpr alloc_ty copy_allocator() const noexcept { + return alloc_ty(allocator_ref()); + } + + [[nodiscard]] + constexpr ptr storage_ptr() noexcept { + return m_data.storage(); + } + + [[nodiscard]] + constexpr cptr storage_ptr() const noexcept { + return m_data.storage(); + } + + [[nodiscard]] + constexpr bool has_allocation() const noexcept { + if (std::is_constant_evaluated()) { + return true; + } + return InlineCapacity < get_capacity(); + } + + [[nodiscard]] + constexpr bool is_inlinable() const noexcept { + return get_size() <= InlineCapacity; + } + + private: + small_vector_data m_data; +}; + +} // namespace detail + +template + requires concepts::small_vector::AllocatorFor +class small_vector + : private detail::small_vector_base { + using base = detail::small_vector_base; + + public: + static_assert(std::is_same_v, + "`Allocator::value_type` must be the same as `T`."); + + template + requires concepts::small_vector::AllocatorFor + friend class small_vector; + + using value_type = T; + using allocator_type = Allocator; + using size_type = typename base::size_type; + using difference_type = typename base::difference_type; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = typename std::allocator_traits::pointer; + using const_pointer = + typename std::allocator_traits::const_pointer; + + using iterator = small_vector_iterator; + using const_iterator = small_vector_iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + static_assert(InlineCapacity <= (std::numeric_limits::max)(), + "InlineCapacity must be less than or equal to the maximum " + "value of size_type."); + + static constexpr unsigned inline_capacity_v = InlineCapacity; + + private: + static constexpr bool Destructible = + concepts::small_vector::Destructible; + + static constexpr bool MoveAssignable = + concepts::small_vector::MoveAssignable; + + static constexpr bool CopyAssignable = + concepts::small_vector::CopyAssignable; + + static constexpr bool MoveConstructible = + concepts::small_vector::MoveConstructible; + + static constexpr bool CopyConstructible = + concepts::small_vector::CopyConstructible; + + static constexpr bool Swappable = + concepts::small_vector::Swappable; + + static constexpr bool DefaultInsertable = + concepts::small_vector::DefaultInsertable; + + static constexpr bool MoveInsertable = + concepts::small_vector::MoveInsertable; + + static constexpr bool CopyInsertable = + concepts::small_vector::CopyInsertable; + + static constexpr bool Erasable = + concepts::small_vector::Erasable; + + template + struct EmplaceConstructible { + static constexpr bool value = + concepts::small_vector::EmplaceConstructible; + }; + + public: + constexpr small_vector() noexcept(noexcept(allocator_type())) + requires concepts::DefaultConstructible + = default; + + constexpr small_vector(const small_vector& other) + requires CopyInsertable + : base(base::bypass, other) {} + + constexpr small_vector(small_vector&& other) noexcept( + std::is_nothrow_move_constructible_v || InlineCapacity == 0) + requires MoveInsertable + : base(base::bypass, std::move(other)) {} + + constexpr explicit small_vector(const allocator_type& alloc) noexcept + : base(alloc) {} + + constexpr small_vector(const small_vector& other, const allocator_type& alloc) + requires CopyInsertable + : base(base::bypass, other, alloc) {} + + constexpr small_vector(small_vector&& other, const allocator_type& alloc) + requires MoveInsertable + : base(base::bypass, std::move(other), alloc) {} + + constexpr explicit small_vector( + size_type count, const allocator_type& alloc = allocator_type()) + requires DefaultInsertable + : base(count, alloc) {} + + constexpr small_vector(size_type count, const_reference value, + const allocator_type& alloc = allocator_type()) + requires CopyInsertable + : base(count, value, alloc) {} + + template + requires std::invocable && + EmplaceConstructible>::value + constexpr small_vector(size_type count, Generator g, + const allocator_type& alloc = allocator_type()) + : base(count, g, alloc) {} + + template + requires EmplaceConstructible>::value && + (std::forward_iterator || MoveInsertable) + constexpr small_vector(InputIt first, InputIt last, + const allocator_type& alloc = allocator_type()) + : base(first, last, + typename std::iterator_traits::iterator_category{}, + alloc) {} + + constexpr small_vector(std::initializer_list init, + const allocator_type& alloc = allocator_type()) + requires EmplaceConstructible::value + : small_vector(init.begin(), init.end(), alloc) {} + + template + requires CopyInsertable + constexpr explicit small_vector(const small_vector& other) + : base(base::bypass, other) {} + + template + requires MoveInsertable + constexpr explicit small_vector( + small_vector&& + other) noexcept(std::is_nothrow_move_constructible:: + value && + I < InlineCapacity) + : base(base::bypass, std::move(other)) {} + + template + requires CopyInsertable + constexpr small_vector(const small_vector& other, + const allocator_type& alloc) + : base(base::bypass, other, alloc) {} + + template + requires MoveInsertable + constexpr small_vector(small_vector&& other, + const allocator_type& alloc) + : base(base::bypass, std::move(other), alloc) {} + + constexpr ~small_vector() + requires Erasable + = default; + + constexpr small_vector& operator=(const small_vector& other) + requires CopyInsertable && CopyAssignable + { + assign(other); + return *this; + } + + constexpr small_vector& operator=(small_vector&& other) noexcept( + (std::is_same_v, Allocator> || + std::allocator_traits< + Allocator>::propagate_on_container_move_assignment::value || + std::allocator_traits::is_always_equal::value) && + ((std::is_nothrow_move_assignable_v && + std::is_nothrow_move_constructible_v) || + InlineCapacity == 0)) + // Note: The standard says here that + // std::allocator_traits::propagate_on_container_move_assignment + // == false implies MoveInsertable && MoveAssignable, but since we have + // inline storage we must always require moves [tab:container.alloc.req]. + requires MoveInsertable && MoveAssignable + { + assign(std::move(other)); + return *this; + } + + constexpr small_vector& operator=(std::initializer_list ilist) + requires CopyInsertable && CopyAssignable + { + assign(ilist); + return *this; + } + + constexpr void assign(size_type count, const_reference value) + requires CopyInsertable && CopyAssignable + { + base::assign_with_copies(count, value); + } + + template + requires EmplaceConstructible>::value && + (std::forward_iterator || MoveInsertable) + constexpr void assign(InputIt first, InputIt last) { + using iterator_cat = + typename std::iterator_traits::iterator_category; + base::assign_with_range(first, last, iterator_cat{}); + } + + constexpr void assign(std::initializer_list ilist) + requires EmplaceConstructible::value + { + assign(ilist.begin(), ilist.end()); + } + + constexpr void assign(const small_vector& other) + requires CopyInsertable && CopyAssignable + { + if (&other != this) { + base::copy_assign(other); + } + } + + template + requires CopyInsertable && CopyAssignable + constexpr void assign(const small_vector& other) { + base::copy_assign(other); + } + + constexpr void assign(small_vector&& other) noexcept( + (std::is_same_v, Allocator> || + std::allocator_traits< + Allocator>::propagate_on_container_move_assignment::value || + std::allocator_traits::is_always_equal::value) && + ((std::is_nothrow_move_assignable_v && + std::is_nothrow_move_constructible_v) || + InlineCapacity == 0)) + requires MoveInsertable && MoveAssignable + { + if (&other != this) { + base::move_assign(std::move(other)); + } + } + + template + requires MoveInsertable && MoveAssignable + constexpr void assign(small_vector&& other) noexcept( + I <= InlineCapacity && + (std::is_same_v, Allocator> || + std::allocator_traits< + Allocator>::propagate_on_container_move_assignment::value || + std::allocator_traits::is_always_equal::value) && + std::is_nothrow_move_assignable_v && + std::is_nothrow_move_constructible_v) { + base::move_assign(std::move(other)); + } + + constexpr void swap(small_vector& other) noexcept( + (std::is_same_v, Allocator> || + std::allocator_traits::propagate_on_container_swap::value || + std::allocator_traits::is_always_equal::value) && + ((std::is_nothrow_move_constructible_v && + std::is_nothrow_move_assignable_v && + std::is_nothrow_swappable_v) || + InlineCapacity == 0)) + requires(MoveInsertable && MoveAssignable && Swappable) || + ((std::is_same_v, Allocator> || + std::allocator_traits< + Allocator>::propagate_on_container_swap::value || + std::allocator_traits::is_always_equal::value) && + InlineCapacity == 0) + { + base::swap(other); + } + + constexpr iterator begin() noexcept { return iterator{base::begin_ptr()}; } + + constexpr const_iterator begin() const noexcept { + return const_iterator{base::begin_ptr()}; + } + + constexpr const_iterator cbegin() const noexcept { return begin(); } + + constexpr iterator end() noexcept { return iterator{base::end_ptr()}; } + + constexpr const_iterator end() const noexcept { + return const_iterator{base::end_ptr()}; + } + + constexpr const_iterator cend() const noexcept { return end(); } + + constexpr reverse_iterator rbegin() noexcept { + return reverse_iterator{end()}; + } + + constexpr const_reverse_iterator rbegin() const noexcept { + return const_reverse_iterator{end()}; + } + + constexpr const_reverse_iterator crbegin() const noexcept { return rbegin(); } + + constexpr reverse_iterator rend() noexcept { + return reverse_iterator{begin()}; + } + + constexpr const_reverse_iterator rend() const noexcept { + return const_reverse_iterator{begin()}; + } + + constexpr const_reverse_iterator crend() const noexcept { return rend(); } + + constexpr reference at(size_type pos) { + if (size() <= pos) { + base::throw_index_error(); + } + return begin()[static_cast(pos)]; + } + + constexpr const_reference at(size_type pos) const { + if (size() <= pos) { + base::throw_index_error(); + } + return begin()[static_cast(pos)]; + } + + constexpr reference operator[](size_type pos) { + return begin()[static_cast(pos)]; + } + + constexpr const_reference operator[](size_type pos) const { + return begin()[static_cast(pos)]; + } + + constexpr reference front() { return (*this)[0]; } + + constexpr const_reference front() const { return (*this)[0]; } + + constexpr reference back() { return (*this)[size() - 1]; } + + constexpr const_reference back() const { return (*this)[size() - 1]; } + + constexpr pointer data() noexcept { return base::begin_ptr(); } + + constexpr const_pointer data() const noexcept { return base::begin_ptr(); } + + constexpr size_type size() const noexcept { + return static_cast(base::get_size()); + } + + [[nodiscard]] + constexpr bool empty() const noexcept { + return size() == 0; + } + + constexpr size_type max_size() const noexcept { + return static_cast(base::get_max_size()); + } + + constexpr size_type capacity() const noexcept { + return static_cast(base::get_capacity()); + } + + constexpr allocator_type get_allocator() const noexcept { + return base::copy_allocator(); + } + + constexpr iterator insert(const_iterator pos, const_reference value) + requires CopyInsertable && CopyAssignable + { + return emplace(pos, value); + } + + constexpr iterator insert(const_iterator pos, value_type&& value) + requires MoveInsertable && MoveAssignable + { + return emplace(pos, std::move(value)); + } + + constexpr iterator insert(const_iterator pos, size_type count, + const_reference value) + requires CopyInsertable && CopyAssignable + { + return iterator(base::insert_copies(base::ptr_cast(pos), count, value)); + } + + // Note: Unlike std::vector, this does not require MoveConstructible because + // we + // don't use std::rotate (as was the reason for the change in C++17). + // Relevant: https://cplusplus.github.io/LWG/issue2266). + template + requires EmplaceConstructible>::value && + MoveInsertable && MoveAssignable + constexpr iterator insert(const_iterator pos, InputIt first, InputIt last) { + if (first == last) { + return iterator(base::ptr_cast(pos)); + } + + using iterator_cat = + typename std::iterator_traits::iterator_category; + return iterator( + base::insert_range(base::ptr_cast(pos), first, last, iterator_cat{})); + } + + constexpr iterator insert(const_iterator pos, + std::initializer_list ilist) + requires EmplaceConstructible::value && MoveInsertable + && MoveAssignable + { + return insert(pos, ilist.begin(), ilist.end()); + } + + template + requires EmplaceConstructible::value && MoveInsertable && + MoveAssignable + constexpr iterator emplace(const_iterator pos, Args&&... args) { + return iterator( + base::emplace_at(base::ptr_cast(pos), std::forward(args)...)); + } + + constexpr iterator erase(const_iterator pos) + requires MoveAssignable && Erasable + { + assert(0 <= (pos - begin()) && + "`pos` is out of bounds (before `begin ()`)."); + assert(0 < (end() - pos) && + "`pos` is out of bounds (at or after `end ()`)."); + + return iterator(base::erase_at(base::ptr_cast(pos))); + } + + constexpr iterator erase(const_iterator first, const_iterator last) + requires MoveAssignable && Erasable + { + assert(0 <= (last - first) && "Invalid range."); + assert(0 <= (first - begin()) && + "`first` is out of bounds (before `begin ()`)."); + assert(0 <= (end() - last) && "`last` is out of bounds (after `end ()`)."); + + return iterator( + base::erase_range(base::ptr_cast(first), base::ptr_cast(last))); + } + + constexpr void push_back(const_reference value) + requires CopyInsertable + { + emplace_back(value); + } + + constexpr void push_back(value_type&& value) + requires MoveInsertable + { + emplace_back(std::move(value)); + } + + template + requires EmplaceConstructible::value && MoveInsertable + constexpr reference emplace_back(Args&&... args) { + return *base::append_element(std::forward(args)...); + } + + constexpr void pop_back() + requires Erasable + { + assert(!empty() && "`pop_back ()` called on an empty `small_vector`."); + base::erase_last(); + } + + constexpr void reserve(size_type new_capacity) + requires MoveInsertable + { + base::request_capacity(new_capacity); + } + + constexpr void shrink_to_fit() + requires MoveInsertable + { + base::shrink_to_size(); + } + + constexpr void clear() noexcept + requires Erasable + { + base::erase_all(); + } + + constexpr void resize(size_type count) + requires MoveInsertable && DefaultInsertable + { + base::resize_with(count); + } + + constexpr void resize(size_type count, const_reference value) + requires CopyInsertable + { + base::resize_with(count, value); + } + + [[nodiscard]] + constexpr bool inlined() const noexcept { + return !base::has_allocation(); + } + + [[nodiscard]] + constexpr bool inlinable() const noexcept { + return base::is_inlinable(); + } + + [[nodiscard]] + static consteval size_type inline_capacity() noexcept { + return static_cast(inline_capacity_v); + } + + template + requires EmplaceConstructible>::value && + MoveInsertable + constexpr small_vector& append(InputIt first, InputIt last) { + using policy = typename base::strong_exception_policy; + using iterator_cat = + typename std::iterator_traits::iterator_category; + base::template append_range(first, last, iterator_cat{}); + return *this; + } + + constexpr small_vector& append(std::initializer_list ilist) + requires EmplaceConstructible::value && MoveInsertable + { + return append(ilist.begin(), ilist.end()); + } + + template + constexpr small_vector& append(const small_vector& other) + requires CopyInsertable + { + return append(other.begin(), other.end()); + } + + template + constexpr small_vector& append(small_vector&& other) + requires MoveInsertable + { + // Provide a strong exception guarantee for `other` as well. + using move_iter_type = typename std::conditional_t< + base::template relocate_with_move_v, + std::move_iterator, iterator>; + + append(move_iter_type{other.begin()}, move_iter_type{other.end()}); + other.clear(); + return *this; + } +}; + +template +inline constexpr bool operator==( + const small_vector& lhs, + const small_vector& rhs) { + return lhs.size() == rhs.size() && + std::equal(lhs.begin(), lhs.end(), rhs.begin()); +} + +template +inline constexpr bool operator==( + const small_vector& lhs, + const small_vector& rhs) { + return lhs.size() == rhs.size() && + std::equal(lhs.begin(), lhs.end(), rhs.begin()); +} + +template + requires std::three_way_comparable +constexpr auto operator<=>( + const small_vector& lhs, + const small_vector& rhs) { + return std::lexicographical_compare_three_way( + lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), std::compare_three_way{}); +} + +template + requires std::three_way_comparable +constexpr auto operator<=>( + const small_vector& lhs, + const small_vector& rhs) { + return std::lexicographical_compare_three_way( + lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), std::compare_three_way{}); +} + +template +constexpr auto operator<=>( + const small_vector& lhs, + const small_vector& rhs) { + constexpr auto comparison = [](const T& l, const T& r) { + return (l < r) ? std::weak_ordering::less + : (r < l) ? std::weak_ordering::greater + : std::weak_ordering::equivalent; + }; + + return std::lexicographical_compare_three_way( + lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), comparison); +} + +template +constexpr auto operator<=>( + const small_vector& lhs, + const small_vector& rhs) { + constexpr auto comparison = [](const T& l, const T& r) { + return (l < r) ? std::weak_ordering::less + : (r < l) ? std::weak_ordering::greater + : std::weak_ordering::equivalent; + }; + + return std::lexicographical_compare_three_way( + lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), comparison); +} + +template +inline constexpr void swap(small_vector& lhs, + small_vector& + rhs) noexcept(noexcept(lhs.swap(rhs))) + requires concepts::MoveInsertable< + T, small_vector, Allocator> && + concepts::Swappable +{ + lhs.swap(rhs); +} + +template +inline constexpr typename small_vector::size_type +erase(small_vector& v, const U& value) { + const auto original_size = v.size(); + v.erase(std::remove(v.begin(), v.end(), value), v.end()); + return original_size - v.size(); +} + +template +inline constexpr typename small_vector::size_type +erase_if(small_vector& v, Pred pred) { + const auto original_size = v.size(); + v.erase(std::remove_if(v.begin(), v.end(), pred), v.end()); + return original_size - v.size(); +} + +template +constexpr typename small_vector::iterator begin( + small_vector& v) noexcept { + return v.begin(); +} + +template +constexpr typename small_vector::const_iterator +begin(const small_vector& v) noexcept { + return v.begin(); +} + +template +constexpr typename small_vector::const_iterator +cbegin(const small_vector& v) noexcept { + return begin(v); +} + +template +constexpr typename small_vector::iterator end( + small_vector& v) noexcept { + return v.end(); +} + +template +constexpr typename small_vector::const_iterator +end(const small_vector& v) noexcept { + return v.end(); +} + +template +constexpr typename small_vector::const_iterator +cend(const small_vector& v) noexcept { + return end(v); +} + +template +constexpr typename small_vector::reverse_iterator +rbegin(small_vector& v) noexcept { + return v.rbegin(); +} + +template +constexpr + typename small_vector::const_reverse_iterator + rbegin(const small_vector& v) noexcept { + return v.rbegin(); +} + +template +constexpr + typename small_vector::const_reverse_iterator + crbegin(const small_vector& v) noexcept { + return rbegin(v); +} + +template +constexpr typename small_vector::reverse_iterator +rend(small_vector& v) noexcept { + return v.rend(); +} + +template +constexpr + typename small_vector::const_reverse_iterator + rend(const small_vector& v) noexcept { + return v.rend(); +} + +template +constexpr + typename small_vector::const_reverse_iterator + crend(const small_vector& v) noexcept { + return rend(v); +} + +template +constexpr typename small_vector::size_type size( + const small_vector& v) noexcept { + return v.size(); +} + +template +constexpr typename std::common_type_t< + std::ptrdiff_t, typename std::make_signed_t::size_type>> +ssize(const small_vector& v) noexcept { + using ret_type = typename std::common_type_t< + std::ptrdiff_t, typename std::make_signed_t>; + return static_cast(v.size()); +} + +template +[[nodiscard]] +constexpr bool empty( + const small_vector& v) noexcept { + return v.empty(); +} + +template +constexpr typename small_vector::pointer data( + small_vector& v) noexcept { + return v.data(); +} + +template +constexpr typename small_vector::const_pointer +data(const small_vector& v) noexcept { + return v.data(); +} + +template < + typename InputIt, + unsigned InlineCapacity = default_buffer_size_v< + std::allocator::value_type>>, + typename Allocator = + std::allocator::value_type>> +small_vector(InputIt, InputIt, Allocator = Allocator()) + -> small_vector::value_type, + InlineCapacity, Allocator>; + +} // namespace sleipnir diff --git a/src/main/java/frc/robot/Superstructure.java b/src/main/java/frc/robot/Superstructure.java new file mode 100644 index 0000000..49b7916 --- /dev/null +++ b/src/main/java/frc/robot/Superstructure.java @@ -0,0 +1,46 @@ +// Copyright (c) 2024 CurtinFRC +// Open Source Software, you can modify it according to the terms +// of the MIT License at the root of this project + +package frc.robot; + +import com.ctre.phoenix6.mechanisms.swerve.SwerveRequest; +import edu.wpi.first.math.geometry.Rotation2d; +import edu.wpi.first.wpilibj2.command.Command; +import edu.wpi.first.wpilibj2.command.Commands; +import frc.robot.jni.ShooterTrajoptJNI; +import frc.robot.subsystems.Arm; +import frc.robot.subsystems.CommandSwerveDrivetrain; +import frc.robot.subsystems.Intake; +import frc.robot.subsystems.Shooter; + +@SuppressWarnings("PMD.UnusedPrivateField") +public class Superstructure { + private final Intake m_intake; + private final Shooter m_shooter; + private final Arm m_arm; + private final CommandSwerveDrivetrain m_drivetrain; + + public Superstructure( + Intake intake, Shooter shooter, Arm arm, CommandSwerveDrivetrain drivetrain) { + m_intake = intake; + m_shooter = shooter; + m_arm = arm; + m_drivetrain = drivetrain; + } + + public Command shootFromRange() { + var pose = m_drivetrain.getState().Pose; + Trajectory traj = new Trajectory(); + + ShooterTrajoptJNI.calculateTrajectory(traj, pose.getX(), pose.getY(), 0, 0); + return Commands.parallel( + m_arm.goToAngle(traj.pitch), + m_drivetrain.applyRequest( + () -> + new SwerveRequest.FieldCentricFacingAngle() + .withTargetDirection(new Rotation2d(traj.yaw)))) + .andThen(m_shooter.spinup(traj.angular_velocity)) + .andThen(m_shooter.maintain()); + } +} diff --git a/src/main/java/frc/robot/Trajectory.java b/src/main/java/frc/robot/Trajectory.java new file mode 100644 index 0000000..f81aaf5 --- /dev/null +++ b/src/main/java/frc/robot/Trajectory.java @@ -0,0 +1,17 @@ +// Copyright (c) 2024 CurtinFRC +// Open Source Software, you can modify it according to the terms +// of the MIT License at the root of this project + +package frc.robot; + +public class Trajectory { + public double angular_velocity; + public double yaw; + public double pitch; + + Trajectory() { + angular_velocity = 0; + yaw = 0; + pitch = 0; + } +} diff --git a/src/main/java/frc/robot/jni/ShooterTrajoptJNI.java b/src/main/java/frc/robot/jni/ShooterTrajoptJNI.java new file mode 100644 index 0000000..85d3ba4 --- /dev/null +++ b/src/main/java/frc/robot/jni/ShooterTrajoptJNI.java @@ -0,0 +1,30 @@ +// Copyright (c) 2024 CurtinFRC +// Open Source Software, you can modify it according to the terms +// of the MIT License at the root of this project + +package frc.robot.jni; + +import edu.wpi.first.util.RuntimeLoader; +import frc.robot.Trajectory; +import java.io.IOException; + +public class ShooterTrajoptJNI { + static RuntimeLoader loader = null; + + static { + try { + loader = + new RuntimeLoader<>( + "ShooterTrajoptJNI", + RuntimeLoader.getDefaultExtractionRoot(), + ShooterTrajoptJNI.class); + loader.loadLibrary(); + } catch (IOException e) { + e.printStackTrace(); + System.exit(1); + } + } + + public static native void calculateTrajectory( + Trajectory javatraj, double x, double y, double vel_x, double vel_y); +} diff --git a/src/main/java/frc/robot/subsystems/Arm.java b/src/main/java/frc/robot/subsystems/Arm.java index fa58077..947ab97 100644 --- a/src/main/java/frc/robot/subsystems/Arm.java +++ b/src/main/java/frc/robot/subsystems/Arm.java @@ -90,7 +90,7 @@ public Command stop() { * @param position The desired position. * @return a {@link Command} to get to the desired position. */ - private Command moveToPosition(double position) { + public Command goToAngle(double position) { return achievePosition(position) .until( () -> @@ -146,6 +146,6 @@ public Command goToSetpoint(Setpoint setpoint) { break; } - return moveToPosition(position); + return goToAngle(position); } } diff --git a/src/main/native/cpp/ShooterTrajopt.cpp b/src/main/native/cpp/ShooterTrajopt.cpp new file mode 100644 index 0000000..f6f969b --- /dev/null +++ b/src/main/native/cpp/ShooterTrajopt.cpp @@ -0,0 +1,143 @@ +// Copyright (c) 2024 CurtinFRC +// Open Source Software, you can modify it according to the terms +// of the MIT License at the root of this project + +#include "ShooterTrajopt.h" + +#include +#include + +#include +#include +#include + +// FRC 2024 shooter trajectory optimization. +// +// This program finds the initial velocity, pitch, and yaw for a game piece to +// hit the 2024 FRC game's target that minimizes z sensitivity to initial +// velocity. + +namespace slp = sleipnir; + +using Eigen::Vector3d; +using Vector6d = Eigen::Vector; + +constexpr double field_width = 8.2296; // 27 ft -> m +constexpr double field_length = 16.4592; // 54 ft -> m +[[maybe_unused]] +constexpr double target_width = 1.05; // m +constexpr double target_lower_edge = 1.98; // m +constexpr double target_upper_edge = 2.11; // m +constexpr double target_depth = 0.46; // m +constexpr Vector6d target_wrt_field{{field_length - target_depth / 2.0}, + {field_width - 2.6575}, + {(target_upper_edge + target_lower_edge) / 2.0}, + {0.0}, + {0.0}, + {0.0}}; +constexpr double g = 9.806; // m/s² + +slp::VariableMatrix f(const slp::VariableMatrix& x) { + // x' = x' + // y' = y' + // z' = z' + // x" = −a_D(v_x) + // y" = −a_D(v_y) + // z" = −g − a_D(v_z) + // + // where a_D(v) = ½ρv² C_D A / m + constexpr double rho = 1.204; // kg/m³ + constexpr double C_D = 0.5; + constexpr double A = std::numbers::pi * 0.3; + constexpr double m = 2.0; // kg + auto a_D = [](auto v) { return 0.5 * rho * v * v * C_D * A / m; }; + + auto v_x = x(3, 0); + auto v_y = x(4, 0); + auto v_z = x(5, 0); + return slp::VariableMatrix{{v_x}, {v_y}, {v_z}, {-a_D(v_x)}, {-a_D(v_y)}, {-g - a_D(v_z)}}; +} + +traj calculate_trajectory(const double x_meter, const double y_meter, const double vel_x, + const double vel_y) { + // Robot initial state + Vector6d robot_wrt_field{{x_meter}, {y_meter}, {0.0}, {vel_x}, {vel_y}, {0.0}}; + + constexpr double max_initial_velocity = 15.0; // m/s + + Vector6d shooter_wrt_robot{{0.0}, {0.0}, {0.6096}, {0.0}, {0.0}, {0.0}}; + Vector6d shooter_wrt_field = robot_wrt_field + shooter_wrt_robot; + + slp::OptimizationProblem problem; + + // Set up duration decision variables + constexpr int N = 10; + auto T = problem.DecisionVariable(); + problem.SubjectTo(T >= 0); + T.SetValue(1); + auto dt = T / N; + + // Disc state in field frame + // + // [x position] + // [y position] + // [z position] + // x = [x velocity] + // [y velocity] + // [z velocity] + auto x = problem.DecisionVariable(6); + + // Position initial guess is start position + x.Segment(0, 3).SetValue(shooter_wrt_field.segment(0, 3)); + + // Velocity initial guess is max initial velocity toward target + Vector3d uvec_shooter_to_target = + (target_wrt_field.segment(0, 3) - shooter_wrt_field.segment(0, 3)).normalized(); + x.Segment(3, 3).SetValue(robot_wrt_field.segment(3, 3) + max_initial_velocity * uvec_shooter_to_target); + + // Shooter initial position + problem.SubjectTo(x.Segment(0, 3) == shooter_wrt_field.block(0, 0, 3, 1)); + + // Require initial velocity is below max + // + // √{v_x² + v_y² + v_z²) ≤ vₘₐₓ + // v_x² + v_y² + v_z² ≤ vₘₐₓ² + problem.SubjectTo(slp::pow(x(3) - robot_wrt_field(3), 2) + slp::pow(x(4) - robot_wrt_field(4), 2) + + slp::pow(x(5) - robot_wrt_field(5), 2) <= + max_initial_velocity * max_initial_velocity); + + // Dynamics constraints - RK4 integration + auto h = dt; + auto x_k = x; + for (int k = 0; k < N - 1; ++k) { + auto k1 = f(x_k); + auto k2 = f(x_k + h / 2 * k1); + auto k3 = f(x_k + h / 2 * k2); + auto k4 = f(x_k + h * k3); + x_k += h / 6 * (k1 + 2 * k2 + 2 * k3 + k4); + } + + // Require final position is in center of target circle + problem.SubjectTo(x_k.Segment(0, 3) == target_wrt_field.block(0, 0, 3, 1)); + + // Require the final velocity is up + problem.SubjectTo(x_k(5) > 0.0); + + // Minimize sensitivity of vertical position to velocity + auto sensitivity = slp::Gradient(x_k(3), x.Segment(3, 3)).Get(); + problem.Minimize(sensitivity.T() * sensitivity); + + problem.Solve({.diagnostics = true}); + + // Initial velocity vector + Eigen::Vector3d v0 = x.Segment(3, 3).Value() - robot_wrt_field.segment(3, 3); + + double velocity = v0.norm(); + double angular_velocity = velocity * 0.0508; // TODO! check + + double pitch = std::atan2(v0(2), std::hypot(v0(0), v0(1))); + + double yaw = std::atan2(v0(1), v0(0)); + + return traj{yaw, pitch, angular_velocity}; +} diff --git a/src/main/native/cpp/ShooterTrajoptJNI.cpp b/src/main/native/cpp/ShooterTrajoptJNI.cpp new file mode 100644 index 0000000..028d9f3 --- /dev/null +++ b/src/main/native/cpp/ShooterTrajoptJNI.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2024 CurtinFRC +// Open Source Software, you can modify it according to the terms +// of the MIT License at the root of this project + +#include +#include + +#include "ShooterTrajopt.h" +#include "jni_md.h" + +extern "C" { +/* + * Class: frc_robot_jni_ShooterTrajoptJNI + * Method: calculateTrajectory + * Signature: (Ljava/lang/Object;????)V + */ +JNIEXPORT void JNICALL +Java_frc_robot_jni_ShooterTrajoptJNI_calculateTrajectory + (JNIEnv* env, jclass, jobject javatraj, double x, double y, double vel_x, + double vel_y) +{ + auto traj = calculate_trajectory(x, y, vel_x, vel_y); + jclass clazz = (*env).GetObjectClass(javatraj); + + // Get Field references + jfieldID angular_velocity = (*env).GetFieldID(clazz, "angular_velocity", "F"); + jfieldID yaw = (*env).GetFieldID(clazz, "yaw", "F"); + jfieldID pitch = (*env).GetFieldID(clazz, "pitch", "F"); + + // Set fields for object + (*env).SetFloatField(javatraj, angular_velocity, traj.angular_velocity); + (*env).SetFloatField(javatraj, yaw, traj.yaw); + (*env).SetFloatField(javatraj, pitch, traj.pitch); +} +} // extern "C" diff --git a/src/main/native/include/ShooterTrajopt.h b/src/main/native/include/ShooterTrajopt.h new file mode 100644 index 0000000..064c482 --- /dev/null +++ b/src/main/native/include/ShooterTrajopt.h @@ -0,0 +1,13 @@ +// Copyright (c) 2024 CurtinFRC +// Open Source Software, you can modify it according to the terms +// of the MIT License at the root of this project + +#pragma once + +struct traj { + double yaw; + double pitch; + double angular_velocity; +}; + +traj calculate_trajectory(const double x_meter, const double y_meter, const double vel_x, const double vel_y);