-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #330 from hitonanode/kalman-filter
Multivariate Gaussian / Kalman filter
- Loading branch information
Showing
2 changed files
with
126 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#ifndef MULTIVARIATE_GAUSSIAN_HPP | ||
#define MULTIVARIATE_GAUSSIAN_HPP | ||
|
||
#include <cassert> | ||
#include <vector> | ||
|
||
// #include "linear_algebra_matrix/matrix.hpp" | ||
|
||
// Multivariate Gausssian distribution / Kalman filter | ||
// 多変量正規分布の数値計算・カルマンフィルタ | ||
template <class Matrix> struct MultivariateGaussian { | ||
|
||
// 正規分布 N(x, P) | ||
std::vector<double> x; // 期待値 | ||
Matrix P; // 分散共分散行列 | ||
|
||
void set(const std::vector<double> &x0, const Matrix &P0) { | ||
const int dim = x0.size(); | ||
assert(P0.height() == dim and P0.width() == dim); | ||
|
||
x = x0; | ||
P = P0; | ||
} | ||
|
||
// 加算 | ||
// すなわち x <- x + dx | ||
void shift(const std::vector<double> &dx) { | ||
const int n = x.size(); | ||
assert(dx.size() == n); | ||
|
||
for (int i = 0; i < n; ++i) x.at(i) += dx.at(i); | ||
} | ||
|
||
// F: 状態遷移行列 正方行列を想定 | ||
// すなわち x <- Fx | ||
void linear_transform(const Matrix &F) { | ||
x = F * x; | ||
P = F * P * F.transpose(); | ||
} | ||
|
||
// Q: ゼロ平均ガウシアンノイズの分散共分散行列 | ||
// すなわち x <- x + w, w ~ N(0, Q) | ||
void add_noise(const Matrix &Q) { P = P + Q; } | ||
|
||
// 現在の x の分布を P(x | *) として、条件付き確率 P(x | *, z) で更新する | ||
// H: 観測行列, R: 観測ノイズの分散共分散行列, z: 観測値 | ||
// すなわち z = Hx + v, v ~ N(0, R) | ||
void measure(const Matrix &H, const Matrix &R, const std::vector<double> &z, | ||
double regularlize = 1e-9) { | ||
const int nobs = z.size(); | ||
|
||
// 残差 e = z - Hx | ||
const auto Hx = H * x; | ||
std::vector<double> e(nobs); | ||
for (int i = 0; i < nobs; ++i) e.at(i) = z.at(i) - Hx.at(i); | ||
|
||
// 残差共分散 S = R + H P H^T | ||
Matrix Sinv = R + H * P * H.transpose(); | ||
Sinv = Sinv + Matrix::Identity(nobs) * regularlize; // 不安定かも? | ||
Sinv.inverse(); | ||
|
||
// カルマンゲイン K = P H^T S^-1 | ||
Matrix K = P * H.transpose() * Sinv; | ||
|
||
// Update x | ||
const auto dx = K * e; | ||
for (int i = 0; i < (int)x.size(); ++i) x.at(i) += dx.at(i); | ||
|
||
P = P - K * H * P; | ||
} | ||
}; | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
--- | ||
title: Multivariate Gaussian Distribution, Kalman filter / 多変量正規分布・カルマンフィルタ | ||
documentation_of: ./multivariate_gaussian.hpp | ||
--- | ||
|
||
多変量正規分布のパラメータを管理するクラス.線形変換・ノイズの加算・観測による事後確率の更新が行える.カルマンフィルタの実装に利用可能. | ||
|
||
## 使用方法 | ||
|
||
線形システムのカルマンフィルタの実装例を以下に示す. | ||
|
||
```cpp | ||
#include "linear_algebra_matrix/matrix.hpp" | ||
|
||
// 初期化 | ||
MultivariateGaussian<matrix<double>> kf; | ||
vector<double> mu(dim); | ||
matrix<double> Sigma(dim, dim); | ||
kf.set(mu, Sigma); // N(mu, Sigma) で初期化 | ||
|
||
// 以下の「時間発展」「雑音の付与」「制御信号の注入」「推定」を任意の順序で任意の回数行ってよい。 | ||
|
||
// 時間発展 | ||
matrix<double> F(dim, dim); // 時間発展行列 | ||
kf.linear_transform(F); | ||
|
||
// 雑音の付与 | ||
matrix<double> Q(dim, dim); // 正規雑音の分散・共分散行列 | ||
kf.add_noise(Q); | ||
|
||
// 制御信号の注入 | ||
vector<double> u(dim); // 制御入力 | ||
kf.shift(u); | ||
|
||
// 観測 | ||
matrix<double> H(o, dim); // 観測行列 | ||
matrix<double> R(o, o); // 観測に重畳される正規雑音の分散・共分散行列 | ||
vector<double> z(o); // 観測行列による観測結果 | ||
double regularize = 1e-9; // 逆行列数値計算の安定のためのパラメータ | ||
kf.measure(H, R, z, regularize); | ||
|
||
// 推定 | ||
vector<double> est = kf.x; | ||
``` | ||
- 現在の MAP 推定量が知りたい -> mu を見ればよい | ||
- 周辺分布が欲しい -> mu と Sigma のうち特定の次元だけ取り出せばよい | ||
- 一部の次元だけ観測できた場合の残りの次元の条件付き分布が欲しい → 未実装です | ||
- サンプリングしたい → 未実装です | ||
## 問題例 | ||
- [第一回マスターズ選手権 -決勝- A - Windy Drone Control (A)](https://atcoder.jp/contests/masters2024-final/tasks/masters2024_final_a) |