Skip to content

Commit fc33212

Browse files
Use fully vectorized code in c3plus projection step
1 parent 5297841 commit fc33212

File tree

9 files changed

+54
-58
lines changed

9 files changed

+54
-58
lines changed

examples/sampling_c3/anything/franka_hardware_anything.pmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ group "operator" {
88
host = "sampling_c3_localhost";
99
}
1010
cmd "logger" {
11-
exec = "python3 examples/sampling_c3/start_logging.py hw anything";
11+
exec = "python3 examples/sampling_c3/start_logging.py hw anything /mnt/data2/anything/logs/";
1212
host = "sampling_c3_localhost";
1313
}
1414
cmd "record_video" {

examples/sampling_c3/anything/franka_sim_anything.pmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ group "operator" {
1919
host = "localhost";
2020
}
2121
cmd "start_logging" {
22-
exec = "python3 examples/sampling_c3/start_logging.py sim anything";
22+
exec = "python3 examples/sampling_c3/start_logging.py sim anything /mnt/data2/anything/logs/";
2323
host = "localhost";
2424
}
2525
cmd "generate_files" {

examples/sampling_c3/jacktoy/franka_hardware_jack.pmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ group "operator" {
88
host = "sampling_c3_localhost";
99
}
1010
cmd "logger" {
11-
exec = "python3 examples/sampling_c3/start_logging.py hw jacktoy";
11+
exec = "python3 examples/sampling_c3/start_logging.py hw jacktoy /mnt/data2/jacktoy/logs/";
1212
host = "sampling_c3_localhost";
1313
}
1414
cmd "record_video" {

examples/sampling_c3/jacktoy/franka_sim_jack.pmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ group "operator" {
1919
host = "localhost";
2020
}
2121
cmd "start_logging" {
22-
exec = "python3 examples/sampling_c3/start_logging.py sim jacktoy";
22+
exec = "python3 examples/sampling_c3/start_logging.py sim jacktoy /mnt/data2/jacktoy/logs/";
2323
host = "localhost";
2424
}
2525
}

examples/sampling_c3/push_t/franka_hardware_t.pmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ group "operator" {
88
host = "sampling_c3_localhost";
99
}
1010
cmd "logger" {
11-
exec = "python3 examples/sampling_c3/start_logging.py hw push_t";
11+
exec = "python3 examples/sampling_c3/start_logging.py hw push_t /mnt/data2/push_t/logs/";
1212
host = "sampling_c3_localhost";
1313
}
1414
cmd "record_video" {

examples/sampling_c3/push_t/franka_sim_t.pmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ group "operator" {
1919
host = "localhost";
2020
}
2121
cmd "start_logging" {
22-
exec = "python3 examples/sampling_c3/start_logging.py sim push_t logs";
22+
exec = "python3 examples/sampling_c3/start_logging.py sim push_t /mnt/data2/push_t/logs/";
2323
host = "localhost";
2424
}
2525
}

solvers/base_c3.h

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,13 @@ class C3Base {
2929
std::vector<Eigen::MatrixXd> U;
3030
};
3131

32-
/// @param lcs Parameters defining the LCS (Linear Complementarity System).
32+
/// @param lcs Parameters defining the LCS (Linear Complementarity
33+
/// System).
3334
/// @param costs Cost matrices used in the optimization.
3435
/// @param x_des Desired goal state.
3536
/// @param options Options specific to the C3 formulation.
36-
/// @param z_size The size of the z vector, which varies depending on the C3 variant.
37-
/// For example:
38-
/// - C3MIQP / C3QP: z = [x, u, lambda]
39-
/// - C3Plus: z = [x, u, lambda, eta]
40-
C3Base(const LCS& LCS, const CostMatrices& costs,
41-
const std::vector<Eigen::VectorXd>& x_des, const C3Options& options,
42-
const int z_size);
43-
44-
/// @param lcs Parameters defining the LCS (Linear Complementarity System).
45-
/// @param costs Cost matrices used in the optimization.
46-
/// @param x_des Desired goal state.
47-
/// @param options Options specific to the C3 formulation.
48-
/// @note Using this constructor will set z_size to the default value, which is size_x + size_u + size_lambda
37+
/// @note Using this constructor will set z_size to the default value, which
38+
/// is size_x + size_u + size_lambda
4939
C3Base(const LCS& LCS, const CostMatrices& costs,
5040
const std::vector<Eigen::VectorXd>& x_des, const C3Options& options);
5141

@@ -171,6 +161,22 @@ class C3Base {
171161
void UpdateTarget(const std::vector<Eigen::VectorXd>& x_des);
172162

173163
protected:
164+
/// @param lcs Parameters defining the LCS.
165+
/// @param costs Cost matrices used in the optimization.
166+
/// @param x_des Desired goal state trajectory.
167+
/// @param options Options specific to the C3 formulation.
168+
/// @param z_size Size of the z vector, which depends on the specific C3
169+
/// variant.
170+
/// For example:
171+
/// - C3MIQP / C3QP: z = [x, u, lambda]
172+
/// - C3Plus: z = [x, u, lambda, eta]
173+
///
174+
/// This constructor is intended for internal use only. The public constructor
175+
/// delegates to this one, passing in an explicitly computed z vector size.
176+
C3Base(const LCS& lcs, const CostMatrices& costs,
177+
const std::vector<Eigen::VectorXd>& x_des, const C3Options& options,
178+
int z_size);
179+
174180
// Helper functions for C3Base constructor
175181
void ScaleLCS();
176182
void InitializeWarmStarts();

solvers/c3_plus.cc

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -103,35 +103,30 @@ VectorXd C3Plus::SolveSingleProjection(const MatrixXd& U,
103103
const int& warm_start_index) {
104104
VectorXd delta_proj = delta_c;
105105

106-
// Handle complementarity constraints for each lambda-eta pair
107-
for (int i = 0; i < m_; ++i) {
108-
double w_eta = std::abs(U(n_ + m_ + k_ + i, n_ + m_ + k_ + i));
109-
double w_lambda = std::abs(U(n_ + i, n_ + i));
110-
111-
double lambda_val = delta_c(n_ + i);
112-
double eta_val = delta_c(n_ + m_ + k_ + i);
113-
114-
if (lambda_val <= 0) {
115-
delta_proj(n_ + i) = 0;
116-
delta_proj(n_ + m_ + k_ + i) = std::max(0.0, eta_val);
117-
} else {
118-
if (eta_val <= 0) {
119-
delta_proj(n_ + i) = lambda_val;
120-
delta_proj(n_ + m_ + k_ + i) = 0;
121-
} else {
122-
// If point (lambda, eta) is above the slope sqrt(w_lambda/w_eta), set
123-
// lambda to 0 and keep eta Otherwise, set lambda to lambda and set eta
124-
// to 0
125-
if (eta_val * std::sqrt(w_eta) > lambda_val * std::sqrt(w_lambda)) {
126-
delta_proj(n_ + i) = 0;
127-
delta_proj(n_ + m_ + k_ + i) = eta_val;
128-
} else {
129-
delta_proj(n_ + i) = lambda_val;
130-
delta_proj(n_ + m_ + k_ + i) = 0;
131-
}
132-
}
133-
}
134-
}
106+
// Extract the weight vectors for lambda and eta from the diagonal of the cost
107+
// matrix U.
108+
// Use absolute values to ensure numerical safety when taking square roots,
109+
// in case the user inadvertently supplies negative weights.
110+
VectorXd w_eta_vec =
111+
U.block(n_ + m_ + k_, n_ + m_ + k_, m_, m_).diagonal().cwiseAbs();
112+
VectorXd w_lambda_vec = U.block(n_, n_, m_, m_).diagonal().cwiseAbs();
113+
114+
VectorXd lambda_c = delta_c.segment(n_, m_);
115+
VectorXd eta_c = delta_c.segment(n_ + m_ + k_, m_);
116+
117+
// Set the smaller of lambda and eta to zero
118+
Eigen::Array<bool, Eigen::Dynamic, 1> eta_larger =
119+
eta_c.array() * w_eta_vec.array().sqrt() >
120+
lambda_c.array() * w_lambda_vec.array().sqrt();
121+
122+
delta_proj.segment(n_, m_) = eta_larger.select(VectorXd::Zero(m_), lambda_c);
123+
delta_proj.segment(n_ + m_ + k_, m_) =
124+
eta_larger.select(eta_c, VectorXd::Zero(m_));
125+
126+
// Clip lambda and eta at 0
127+
delta_proj.segment(n_, m_) = delta_proj.segment(n_, m_).cwiseMax(0);
128+
delta_proj.segment(n_ + m_ + k_, m_) =
129+
delta_proj.segment(n_ + m_ + k_, m_).cwiseMax(0);
135130

136131
return delta_proj;
137132
}

solvers/c3_plus.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,9 @@ namespace solvers {
3636
// onto the feasible set defined by the complementarity condition (i.e., λᵢ ηᵢ
3737
// = 0 for all i, with λ ≥ 0 and η ≥ 0).
3838
//
39-
// To get the solution, we can simply perform if-else to handle the following
40-
// cases:
41-
//
42-
// 1. λ₀ <= 0 and η₀ > 0, then λ = 0 and η = η₀
43-
// 2. λ₀ <= 0 and η₀ <= 0 then λ = 0 and η = 0
44-
// 3. λ₀ > 0 and η₀ <= 0, then λ = λ₀ and η = 0
45-
// 4. λ₀ > 0, η₀ > 0, and η₀ > sqrt(w_λ/w_η) * λ₀, then λ = 0 and η = η₀
46-
// 5. λ₀ > 0, η₀ > 0, and η₀ <= sqrt(w_λ/w_η) * λ₀, then λ = λ₀ and η = 0
39+
// To get the solution, we can simply do the following steps:
40+
// 1. If η₀ > sqrt(w_λ/w_η) * λ₀, then λ = 0, else η = 0
41+
// 2. [λ, η] = max(0, [λ, η])
4742
class C3Plus final : public C3Base {
4843
public:
4944
C3Plus(const LCS& LCS, const CostMatrices& costs,
@@ -74,4 +69,4 @@ class C3Plus final : public C3Base {
7469
};
7570

7671
} // namespace solvers
77-
} // namespace dairlib
72+
} // namespace dairlib

0 commit comments

Comments
 (0)