4 * Copyright © 2011 Keith Packard <keithp@keithp.com>
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
11 * This program is distributed in the hope that it will be useful, but
12 * WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * General Public License for more details.
16 * You should have received a copy of the GNU General Public License along
17 * with this program; if not, write to the Free Software Foundation, Inc.,
18 * 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
25 public typedef struct {
27 mat_t p; /* error estimate */
28 mat_t k; /* kalman factor */
31 public typedef struct {
33 mat_t q; /* model error covariance */
34 mat_t r; /* measurement error covariance */
35 mat_t h; /* measurement from model */
38 public typedef struct {
40 mat_t k; /* kalman coefficient */
41 mat_t h; /* measurement from model */
44 vec_t measurement_from_state(vec_t x, mat_t h) {
45 return multiply_mat_vec(h, x);
49 void print_state(string name, state_t s) {
50 print_vec(sprintf("%s state", name), s.x);
51 print_mat(sprintf("%s error", name), s.p);
54 public bool debug = false;
56 public state_t predict (state_t s, parameters_t p) {
60 printf ("--------PREDICT--------\n");
61 print_state("current", s);
73 n.x = multiply_mat_vec(p.a, s.x);
76 mat_t t0 = multiply (p.a, s.p);
80 /* t1 = a * p * transpose(a) */
82 mat_t t1 = multiply (t0, transpose(p.a));
91 * p' = a * p * transpose(a) + q
96 print_state("predict", n);
100 public vec_t predict_fast(vec_t x, parameters_fast_t p) {
102 printf ("--------FAST PREDICT--------\n");
103 print_vec("current", x);
105 vec_t new = multiply_mat_vec(p.a, x);
107 print_vec("predict", new);
111 public vec_t correct_fast(vec_t x, vec_t z, parameters_fast_t p) {
113 printf ("--------FAST CORRECT--------\n");
114 print_vec("measure", z);
115 print_vec("current", x);
117 vec_t model = multiply_mat_vec(p.h, x);
119 print_vec("extract model", model);
120 vec_t diff = vec_subtract(z, model);
122 print_vec("difference", diff);
123 vec_t adjust = multiply_mat_vec(p.k, diff);
125 print_vec("adjust", adjust);
127 vec_t new = vec_add(x,
128 multiply_mat_vec(p.k,
130 multiply_mat_vec(p.h, x))));
132 print_vec("correct", new);
136 public state_t correct(state_t s, vec_t z, parameters_t p) {
140 printf ("--------CORRECT--------\n");
141 print_vec("measure", z);
142 print_state("current", s);
147 /* 3x2 = 3x3 * 3x2 */
148 mat_t t0 = multiply(s.p, transpose(p.h));
154 /* 2x3 = 2x3 * 3x3 */
155 mat_t t1 = multiply(p.h, s.p);
159 /* t2 = h * p * transpose(h) */
161 /* 2x2 = 2x3 * 3x2 */
162 mat_t t2 = multiply(t1, transpose(p.h));
166 /* t3 = h * p * transpose(h) + r */
168 /* 2x2 = 2x2 + 2x2 */
169 mat_t t3 = add(t2, p.r);
173 /* t4 = inverse(h * p * transpose(h) + r) */
176 mat_t t4 = inverse(t3);
184 * h: state to measurement matrix
185 * r: measurement error covariance
187 * k = p * transpose(h) * inverse(h * p * transpose(h) + r)
192 /* 3x2 = 3x2 * 2x2 */
193 mat_t k = multiply(t0, t4);
201 vec_t t5 = multiply_mat_vec(p.h, s.x);
208 vec_t t6 = vec_subtract(z, t5);
212 /* t7 = k * (z - h * x) */
215 vec_t t7 = multiply_mat_vec(k, t6);
224 * h: state to measurement matrix
225 * x': corrected state
227 * x' = x + k * (z - h * x)
230 n.x = vec_add(s.x, t7);
232 print_vec("n->x", n.x);
236 /* 3x3 = 3x2 * 2x3 */
237 mat_t t8 = multiply(k, p.h);
243 /* 3x3 = 3x3 - 3x3 */
244 mat_t t9 = subtract(identity(dim(s.x)), t8);
252 * h: state to measurement matrix
253 * p': corrected error
255 * p' = (1 - k * h) * p
260 /* 3x3 = 3x3 * 3x3 */
261 n.p = multiply(t9, s.p);
263 print_mat("n->p", n.p);
264 # print_state("correct", n);
269 real distance(mat_t a, mat_t b) {
275 for (int i = 0; i < i_max; i++)
276 for (int j = 0; j < j_max; j++)
277 s += (a[i,j] - b[i,j]) ** 2;
281 public mat_t converge(parameters_t p) {
282 int model = dims(p.a)[0];
283 int measure = dims(p.r)[0];
286 .x = (real[model]) { 0 ... },
287 .p = (real[model,model]) { { 0 ... } ... },
288 .k = (real[model,measure]) { { 0 ... } ... }
291 vec_t z = (real [measure]) { 0 ... };
293 state_t s_pre = predict(s, p);
294 state_t s_post = correct(s_pre, z, p);
295 real d = distance(s.k, s_post.k);
298 if (d < 1e-10 && reps > 10)
304 public parameters_fast_t convert_to_fast(parameters_t p) {
305 return (parameters_fast_t) {
306 .a = p.a, .k = converge(p), .h = p.h