--- /dev/null
+load "matrix.5c"
+
+/*
+ * Copyright © 2011 Keith Packard <keithp@keithp.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; version 2 of the License.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program; if not, write to the Free Software Foundation, Inc.,
+ * 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+ */
+
+namespace kalman {
+
+ import matrix;
+
+ public typedef struct {
+ vec_t x; /* state */
+ mat_t p; /* error estimate */
+ mat_t k; /* kalman factor */
+ } state_t;
+
+ public typedef struct {
+ mat_t a; /* model */
+ mat_t q; /* model error covariance */
+ mat_t r; /* measurement error covariance */
+ mat_t h; /* measurement from model */
+ } parameters_t;
+
+ public typedef struct {
+ mat_t a; /* model */
+ mat_t k; /* kalman coefficient */
+ mat_t h; /* measurement from model */
+ } parameters_fast_t;
+
+ vec_t measurement_from_state(vec_t x, mat_t h) {
+ return multiply_mat_vec(h, x);
+ }
+
+
+ void print_state(string name, state_t s) {
+ print_vec(sprintf("%s state", name), s.x);
+ print_mat(sprintf("%s error", name), s.p);
+ }
+
+ public bool debug = false;
+
+ public state_t predict (state_t s, parameters_t p) {
+ state_t n;
+
+ if (debug) {
+ printf ("--------PREDICT--------\n");
+ print_state("current", s);
+ }
+
+ /* Predict state
+ *
+ * x': predicted state
+ * a: model
+ * x: previous state
+ *
+ * x' = a * x;
+ */
+
+ n.x = multiply_mat_vec(p.a, s.x);
+
+ /* t0 = a * p */
+ mat_t t0 = multiply (p.a, s.p);
+ if (debug)
+ print_mat("t0", t0);
+
+ /* t1 = a * p * transpose(a) */
+
+ mat_t t1 = multiply (t0, transpose(p.a));
+
+ /* Predict error
+ *
+ * p': predicted error
+ * a: model
+ * p: previous error
+ * q: model error
+ *
+ * p' = a * p * transpose(a) + q
+ */
+
+ n.p = add(t1, p.q);
+ if (debug)
+ print_state("predict", n);
+ return n;
+ }
+
+ public vec_t predict_fast(vec_t x, parameters_fast_t p) {
+ if (debug) {
+ printf ("--------FAST PREDICT--------\n");
+ print_vec("current", x);
+ }
+ vec_t new = multiply_mat_vec(p.a, x);
+ if (debug)
+ print_vec("predict", new);
+ return new;
+ }
+
+ public vec_t correct_fast(vec_t x, vec_t z, parameters_fast_t p) {
+ if (debug) {
+ printf ("--------FAST CORRECT--------\n");
+ print_vec("measure", z);
+ print_vec("current", x);
+ }
+ vec_t model = multiply_mat_vec(p.h, x);
+ if (debug)
+ print_vec("extract model", model);
+ vec_t diff = vec_subtract(z, model);
+ if (debug)
+ print_vec("difference", diff);
+ vec_t adjust = multiply_mat_vec(p.k, diff);
+ if (debug)
+ print_vec("adjust", adjust);
+
+ vec_t new = vec_add(x,
+ multiply_mat_vec(p.k,
+ vec_subtract(z,
+ multiply_mat_vec(p.h, x))));
+ if (debug)
+ print_vec("correct", new);
+ return new;
+ }
+
+ public state_t correct(state_t s, vec_t z, parameters_t p) {
+ state_t n;
+
+ if (debug) {
+ printf ("--------CORRECT--------\n");
+ print_vec("measure", z);
+ print_state("current", s);
+ }
+
+ /* t0 = p * T(h) */
+
+ /* 3x2 = 3x3 * 3x2 */
+ mat_t t0 = multiply(s.p, transpose(p.h));
+ if (debug)
+ print_mat("t0", t0);
+
+ /* t1 = h * p */
+
+ /* 2x3 = 2x3 * 3x3 */
+ mat_t t1 = multiply(p.h, s.p);
+ if (debug)
+ print_mat("t1", t1);
+
+ /* t2 = h * p * transpose(h) */
+
+ /* 2x2 = 2x3 * 3x2 */
+ mat_t t2 = multiply(t1, transpose(p.h));
+ if (debug)
+ print_mat("t2", t2);
+
+ /* t3 = h * p * transpose(h) + r */
+
+ /* 2x2 = 2x2 + 2x2 */
+ mat_t t3 = add(t2, p.r);
+ if (debug)
+ print_mat("t3", t3);
+
+ /* t4 = inverse(h * p * transpose(h) + r) */
+
+ /* 2x2 = 2x2 */
+ mat_t t4 = inverse(t3);
+ if (debug)
+ print_mat("t4", t4);
+
+ /* Kalman value */
+
+ /* k: Kalman value
+ * p: error estimate
+ * h: state to measurement matrix
+ * r: measurement error covariance
+ *
+ * k = p * transpose(h) * inverse(h * p * transpose(h) + r)
+ *
+ * k = K(p)
+ */
+
+ /* 3x2 = 3x2 * 2x2 */
+ mat_t k = multiply(t0, t4);
+ if (debug)
+ print_mat("k", k);
+ n.k = k;
+
+ /* t5 = h * x */
+
+ /* 2 = 2x3 * 3 */
+ vec_t t5 = multiply_mat_vec(p.h, s.x);
+ if (debug)
+ print_vec("t5", t5);
+
+ /* t6 = z - h * x */
+
+ /* 2 = 2 - 2 */
+ vec_t t6 = vec_subtract(z, t5);
+ if (debug)
+ print_vec("t6", t6);
+
+ /* t7 = k * (z - h * x) */
+
+ /* 3 = 3x2 * 2 */
+ vec_t t7 = multiply_mat_vec(k, t6);
+ if (debug)
+ print_vec("t7", t7);
+
+ /* Correct state
+ *
+ * x: predicted state
+ * k: kalman value
+ * z: measurement
+ * h: state to measurement matrix
+ * x': corrected state
+ *
+ * x' = x + k * (z - h * x)
+ */
+
+ n.x = vec_add(s.x, t7);
+ if (debug)
+ print_vec("n->x", n.x);
+
+ /* t8 = k * h */
+
+ /* 3x3 = 3x2 * 2x3 */
+ mat_t t8 = multiply(k, p.h);
+ if (debug)
+ print_mat("t8", t8);
+
+ /* t9 = 1 - k * h */
+
+ /* 3x3 = 3x3 - 3x3 */
+ mat_t t9 = subtract(identity(dim(s.x)), t8);
+ if (debug)
+ print_mat("t9", t9);
+
+ /* Correct error
+ *
+ * p: predicted error
+ * k: kalman value
+ * h: state to measurement matrix
+ * p': corrected error
+ *
+ * p' = (1 - k * h) * p
+ *
+ * p' = P(k,p)
+ */
+
+ /* 3x3 = 3x3 * 3x3 */
+ n.p = multiply(t9, s.p);
+ if (debug) {
+ print_mat("n->p", n.p);
+# print_state("correct", n);
+ }
+ return n;
+ }
+
+ real distance(mat_t a, mat_t b) {
+ int[2] d = dims(a);
+ int i_max = d[0];
+ int j_max = d[1];
+ real s = 0;
+
+ for (int i = 0; i < i_max; i++)
+ for (int j = 0; j < j_max; j++)
+ s += (a[i,j] - b[i,j]) ** 2;
+ return sqrt(s);
+ }
+
+ public mat_t converge(parameters_t p) {
+ int model = dims(p.a)[0];
+ int measure = dims(p.r)[0];
+ int reps = 0;
+ state_t s = {
+ .x = (real[model]) { 0 ... },
+ .p = (real[model,model]) { { 0 ... } ... },
+ .k = (real[model,measure]) { { 0 ... } ... }
+ };
+
+ vec_t z = (real [measure]) { 0 ... };
+ for (;;) {
+ state_t s_pre = predict(s, p);
+ state_t s_post = correct(s_pre, z, p);
+ real d = distance(s.k, s_post.k);
+ s = s_post;
+ reps++;
+ if (d < 1e-10 && reps > 10)
+ break;
+ }
+ return s.k;
+ }
+
+ public parameters_fast_t convert_to_fast(parameters_t p) {
+ return (parameters_fast_t) {
+ .a = p.a, .k = converge(p), .h = p.h
+ };
+ }
+}