Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/factor removal #2

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
26 changes: 20 additions & 6 deletions include/dcsam/DCSAM.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,24 @@ class DCSAM {
* factors to add.
* @param dcfg - a DCFactorGraph containing any joint discrete-continuous
* factors to add.
* @param initialGuess - an initial guess for any new continuous keys that.
* appear in the updated factors (or if one wants to force override previously
* obtained continuous values).
* @param initialGuessContinuous - an initial guess for any new continuous
* keys that appear in the updated factors (or if one wants to force override
* previously obtained continuous values).
* @param initialGuessDiscrete - an initial guess for any new discrete keys
* that appear in the updated factors (or if one wants to force override
* previously obtained discrete values).
* @param removeFactorIndices - indices of continuous factors to remove
* @param removeDiscreteFactorIndices - indices of discrete factors to remove

*/
void update(const gtsam::NonlinearFactorGraph &graph,
const gtsam::DiscreteFactorGraph &dfg, const DCFactorGraph &dcfg,
const gtsam::Values &initialGuessContinuous = gtsam::Values(),
const DiscreteValues &initialGuessDiscrete = DiscreteValues());
const DiscreteValues &initialGuessDiscrete = DiscreteValues(),
const gtsam::FactorIndices &removeFactorIndices =
gtsam::FactorIndices(),
const std::vector<size_t> &removeDiscreteFactorIndices =
std::vector<size_t>());
kurransingh marked this conversation as resolved.
Show resolved Hide resolved

/**
* A HybridFactorGraph is a container holding a NonlinearFactorGraph, a
Expand All @@ -86,11 +96,15 @@ class DCSAM {
* parameters: that is:
*
* update(hfg.nonlinearGraph(), hfg.discreteGraph(), hfg.dcGraph(),
* initialGuess);
* initialGuess, removeFactorIndices, removeDiscreteFactorIndices);
*/
void update(const HybridFactorGraph &hfg,
const gtsam::Values &initialGuessContinuous = gtsam::Values(),
const DiscreteValues &initialGuessDiscrete = DiscreteValues());
const DiscreteValues &initialGuessDiscrete = DiscreteValues(),
const gtsam::FactorIndices &removeFactorIndices =
gtsam::FactorIndices(),
const std::vector<size_t> &removeDiscreteFactorIndices =
std::vector<size_t>());

/**
* Inline convenience function to allow "skipping" the initial guess for
Expand Down
23 changes: 18 additions & 5 deletions src/DCSAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,18 @@ void DCSAM::update(const gtsam::NonlinearFactorGraph &graph,
const gtsam::DiscreteFactorGraph &dfg,
const DCFactorGraph &dcfg,
const gtsam::Values &initialGuessContinuous,
const DiscreteValues &initialGuessDiscrete) {
// First things first: combine currContinuous_ estimate with the new values
const DiscreteValues &initialGuessDiscrete,
const gtsam::FactorIndices &removeFactorIndices,
kurransingh marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<size_t> &removeDiscreteFactorIndices) {

// First things first: get rid of factors that are to be removed so updates
// to follow take the removals into account
isam_.update(gtsam::NonlinearFactorGraph(), gtsam::Values(), removeFactorIndices);
kurransingh marked this conversation as resolved.
Show resolved Hide resolved
for (auto& i : removeDiscreteFactorIndices) {
dfg_.remove(i);
}

// Next: combine currContinuous_ estimate with the new values
// from initialGuessContinuous to produce the full continuous variable state.
kurransingh marked this conversation as resolved.
Show resolved Hide resolved
for (const gtsam::Key k : initialGuessContinuous.keys()) {
if (currContinuous_.exists(k))
Expand Down Expand Up @@ -90,9 +100,12 @@ void DCSAM::update(const gtsam::NonlinearFactorGraph &graph,

void DCSAM::update(const HybridFactorGraph &hfg,
const gtsam::Values &initialGuessContinuous,
const DiscreteValues &initialGuessDiscrete) {
update(hfg.nonlinearGraph(), hfg.discreteGraph(), hfg.dcGraph(),
initialGuessContinuous, initialGuessDiscrete);
const DiscreteValues &initialGuessDiscrete,
const gtsam::FactorIndices &removeFactorIndices,
const std::vector<size_t> &removeDiscreteFactorIndices) {
update(hfg.nonlinearGraph(), hfg.discreteGraph(), hfg.dcGraph(),
initialGuessContinuous, initialGuessDiscrete,
removeFactorIndices, removeDiscreteFactorIndices);
}

void DCSAM::update() {
Expand Down
206 changes: 206 additions & 0 deletions tests/testDCSAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,212 @@ TEST(TestSuite, dcMaxMixture_semantic_slam) {
EXPECT_EQ(mpeClassL1, 1);
}

/**
* This is for testing the behavior of factor removal
*/
TEST(TestSuite, factor_removal) {
kurransingh marked this conversation as resolved.
Show resolved Hide resolved
// Make a factor graph
HybridFactorGraph hfg;

// Values for initial guess
gtsam::Values initialGuess;
DiscreteValues initialGuessDiscrete;

gtsam::Symbol x0('x', 0);
gtsam::Symbol l1('l', 1);
gtsam::Symbol lc1('c', 1);
// Create a discrete key for landmark 1 class with cardinality 2.
gtsam::DiscreteKey lm1_class(lc1, 2);
gtsam::Pose2 pose0(0, 0, 0);
gtsam::Pose2 dx(1, 0, 0.78539816);
double prior_sigma = 0.1;
double meas_sigma = 1.0;
double circumradius = (std::sqrt(4 + 2 * std::sqrt(2))) / 2.0;
gtsam::Point2 landmark1(circumradius, circumradius);

gtsam::noiseModel::Isotropic::shared_ptr prior_noise =
gtsam::noiseModel::Isotropic::Sigma(3, prior_sigma);
gtsam::noiseModel::Isotropic::shared_ptr prior_lm_noise =
gtsam::noiseModel::Isotropic::Sigma(2, prior_sigma);
gtsam::noiseModel::Isotropic::shared_ptr meas_noise =
gtsam::noiseModel::Isotropic::Sigma(3, meas_sigma);

// 0.1 rad std on bearing, 10cm on range
gtsam::noiseModel::Isotropic::shared_ptr br_noise =
gtsam::noiseModel::Isotropic::Sigma(2, 0.1);

std::vector<double> prior_lm1_class;
prior_lm1_class.push_back(0.9);
prior_lm1_class.push_back(0.1);

gtsam::PriorFactor<gtsam::Pose2> p0(x0, pose0, prior_noise);
gtsam::PriorFactor<gtsam::Point2> pl1(l1, landmark1, prior_lm_noise);
DiscretePriorFactor plc1(lm1_class, prior_lm1_class);

initialGuess.insert(x0, pose0);
initialGuess.insert(l1, landmark1);
initialGuessDiscrete[lm1_class.first] = 0;

hfg.push_nonlinear(p0);
hfg.push_nonlinear(pl1);
hfg.push_discrete(plc1);

// set up for landmark 2
gtsam::Symbol l2('l', 2);
gtsam::Symbol lc2('c', 2);
// Create a discrete key for landmark 2 class with cardinality 2.
gtsam::DiscreteKey lm2_class(lc2, 2);
gtsam::Point2 landmark2(circumradius + .5, circumradius + 5);

std::vector<double> prior_lm2_class;
prior_lm2_class.push_back(0.1);
prior_lm2_class.push_back(0.9);

gtsam::PriorFactor<gtsam::Point2> pl2(l2, landmark2, prior_lm_noise);
DiscretePriorFactor plc2(lm2_class, prior_lm2_class);

initialGuess.insert(l2, landmark2);
initialGuessDiscrete[lm2_class.first] = 1;

hfg.push_nonlinear(pl2);
hfg.push_discrete(plc2);

// Setup dcsam
DCSAM dcsam;
dcsam.update(hfg, initialGuess, initialGuessDiscrete);

DCValues dcval_start = dcsam.calculateEstimate();
std::cout << "Printing first values" << std::endl;
dcval_start.discrete.print();

hfg.clear();
initialGuess.clear();
initialGuessDiscrete.clear();

gtsam::Pose2 odom(pose0);
gtsam::Pose2 noise(0.01, 0.01, 0.01);
for (size_t i = 0; i < 7; i++) {
gtsam::Symbol xi('x', i);
gtsam::Symbol xj('x', i + 1);

gtsam::Pose2 meas = dx * noise;

gtsam::BetweenFactor<gtsam::Pose2> bw(xi, xj, meas, meas_noise);
hfg.push_nonlinear(bw);

// Add semantic bearing-range measurement to landmark in center
gtsam::Rot2 bearing1 = gtsam::Rot2::fromDegrees(67.5);
double range1 = circumradius;

// For the first couple measurements, pick class=0, later pick class=1
std::vector<double> semantic_meas;
if (i < 2) {
semantic_meas.push_back(0.9);
semantic_meas.push_back(0.1);
} else {
semantic_meas.push_back(0.1);
semantic_meas.push_back(0.9);
}

gtsam::DiscreteKeys dks({lm1_class, lm2_class});

// build mixture: dcmaxmixture should be picking the component for lm1
SemanticBearingRangeFactor<gtsam::Pose2, gtsam::Point2> sbr1(
xi, l1, lm1_class, semantic_meas, bearing1, range1, br_noise);
SemanticBearingRangeFactor<gtsam::Pose2, gtsam::Point2> sbr2(
xi, l2, lm2_class, semantic_meas, bearing1, range1, br_noise);
DCMaxMixtureFactor<SemanticBearingRangeFactor<gtsam::Pose2,
gtsam::Point2>> dcmmf(
{xi, l1, l2}, dks, {sbr1, sbr2}, {.5, .5}, false);

hfg.push_dc(dcmmf);
odom = odom * meas;
initialGuess.insert(xj, odom);
dcsam.update(hfg, initialGuess);
DCValues dcvals = dcsam.calculateEstimate();

size_t mpeClassL1 = dcvals.discrete.at(lc1);

// Plot poses and landmarks
#ifdef ENABLE_PLOTTING
std::vector<double> xs, ys;
for (size_t j = 0; j < i + 2; j++) {
xs.push_back(
dcvals.continuous.at<gtsam::Pose2>(gtsam::Symbol('x', j)).x());
ys.push_back(
dcvals.continuous.at<gtsam::Pose2>(gtsam::Symbol('x', j)).y());
}

std::vector<double> lmxs, lmys;
lmxs.push_back(
dcvals.continuous.at<gtsam::Point2>(gtsam::Symbol('l', 1)).x());
lmys.push_back(
dcvals.continuous.at<gtsam::Point2>(gtsam::Symbol('l', 1)).y());

string color = (mpeClassL1 == 0) ? "b" : "orange";

plt::plot(xs, ys);
plt::scatter(lmxs, lmys, {{"color", color}});
plt::show();
#endif

hfg.clear();
initialGuess.clear();
}

gtsam::Symbol x7('x', 7);
gtsam::BetweenFactor<gtsam::Pose2> bw(x0, x7, dx * noise, meas_noise);

hfg.push_nonlinear(bw);
dcsam.update(hfg, initialGuess);

DCValues dcvals = dcsam.calculateEstimate();

size_t mpeClassL1 = dcvals.discrete.at(lc1);

// Plot the poses and landmarks
#ifdef ENABLE_PLOTTING
std::vector<double> xs, ys;
for (size_t i = 0; i < 8; i++) {
xs.push_back(dcvals.continuous.at<gtsam::Pose2>(gtsam::Symbol('x', i)).x());
ys.push_back(dcvals.continuous.at<gtsam::Pose2>(gtsam::Symbol('x', i)).y());
}

std::vector<double> lmxs, lmys;
lmxs.push_back(
dcvals.continuous.at<gtsam::Point2>(gtsam::Symbol('l', 1)).x());
lmys.push_back(
dcvals.continuous.at<gtsam::Point2>(gtsam::Symbol('l', 1)).y());

string color = (mpeClassL1 == 0) ? "b" : "orange";

plt::plot(xs, ys);
plt::scatter(lmxs, lmys, {{"color", color}});
plt::show();
#endif

EXPECT_EQ(mpeClassL1, 1);


// So far, this same as earlier example. Now let's start removing factors
// and see what happens.
EXPECT_EQ(dcsam.getDiscreteFactorGraph().size(), 18);
EXPECT_EQ(dcsam.getNonlinearFactorGraph().size(), 18);

hfg.clear();
initialGuess.clear();
initialGuessDiscrete.clear();
std::vector<size_t> discreteRemovals{17};
gtsam::FactorIndices removals{17};

dcsam.update(hfg, initialGuess, initialGuessDiscrete, removals, discreteRemovals);

EXPECT_EQ(dcsam.getDiscreteFactorGraph().at(17), nullptr);
EXPECT_EQ(dcsam.getNonlinearFactorGraph().at(17), nullptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't validate the new estimate after the factor removal.

Suggestion: Try to solve a (simple) factor graph with an known solution (that we can verify to within tol). Then, remove one or more factors, verify with EXPECT_EQ that they've been removed, retrieve the new estimate via dcsam_.calculateEstimate() and validate that the new estimate is correct (to within tol).

There are also a couple "edge cases" we might want to check on, e.g. what happens if we have a discrete variable with a single discrete factor attached to it, and then we remove that factor? I'm not actually sure what the behavior would be in this case, but it seems like it would be good to know.

Copy link
Contributor Author

@kurransingh kurransingh Sep 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if we have a discrete variable with a single discrete factor attached to it, and then we remove that factor?

The variable will be removed (I believe using the gtsam::VariableIndex that is computed), and any subsequent attempts to retrieve information about the variable will throw errors.

}



int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down