4 * Copyright © 2011 Keith Packard <keithp@keithp.com>
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; version 2 of the License.
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13 * General Public License for more details.
15 * You should have received a copy of the GNU General Public License along
16 * with this program; if not, write to the Free Software Foundation, Inc.,
17 * 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
24 public typedef struct {
26 mat_t p; /* error estimate */
27 mat_t k; /* kalman factor */
30 public typedef struct {
32 mat_t q; /* model error covariance */
33 mat_t r; /* measurement error covariance */
34 mat_t h; /* measurement from model */
37 public typedef struct {
39 mat_t k; /* kalman coefficient */
40 mat_t h; /* measurement from model */
43 vec_t measurement_from_state(vec_t x, mat_t h) {
44 return multiply_mat_vec(h, x);
48 void print_state(string name, state_t s) {
49 print_vec(sprintf("%s state", name), s.x);
50 print_mat(sprintf("%s error", name), s.p);
53 public bool debug = false;
55 public state_t predict (state_t s, parameters_t p) {
59 printf ("--------PREDICT--------\n");
60 print_state("current", s);
72 n.x = multiply_mat_vec(p.a, s.x);
75 mat_t t0 = multiply (p.a, s.p);
79 /* t1 = a * p * transpose(a) */
81 mat_t t1 = multiply (t0, transpose(p.a));
90 * p' = a * p * transpose(a) + q
95 print_state("predict", n);
99 public vec_t predict_fast(vec_t x, parameters_fast_t p) {
101 printf ("--------FAST PREDICT--------\n");
102 print_vec("current", x);
104 vec_t new = multiply_mat_vec(p.a, x);
106 print_vec("predict", new);
110 public vec_t correct_fast(vec_t x, vec_t z, parameters_fast_t p) {
112 printf ("--------FAST CORRECT--------\n");
113 print_vec("measure", z);
114 print_vec("current", x);
116 vec_t model = multiply_mat_vec(p.h, x);
118 print_vec("extract model", model);
119 vec_t diff = vec_subtract(z, model);
121 print_vec("difference", diff);
122 vec_t adjust = multiply_mat_vec(p.k, diff);
124 print_vec("adjust", adjust);
126 vec_t new = vec_add(x,
127 multiply_mat_vec(p.k,
129 multiply_mat_vec(p.h, x))));
131 print_vec("correct", new);
135 public state_t correct(state_t s, vec_t z, parameters_t p) {
139 printf ("--------CORRECT--------\n");
140 print_vec("measure", z);
141 print_state("current", s);
146 /* 3x2 = 3x3 * 3x2 */
147 mat_t t0 = multiply(s.p, transpose(p.h));
153 /* 2x3 = 2x3 * 3x3 */
154 mat_t t1 = multiply(p.h, s.p);
158 /* t2 = h * p * transpose(h) */
160 /* 2x2 = 2x3 * 3x2 */
161 mat_t t2 = multiply(t1, transpose(p.h));
165 /* t3 = h * p * transpose(h) + r */
167 /* 2x2 = 2x2 + 2x2 */
168 mat_t t3 = add(t2, p.r);
172 /* t4 = inverse(h * p * transpose(h) + r) */
175 mat_t t4 = inverse(t3);
183 * h: state to measurement matrix
184 * r: measurement error covariance
186 * k = p * transpose(h) * inverse(h * p * transpose(h) + r)
191 /* 3x2 = 3x2 * 2x2 */
192 mat_t k = multiply(t0, t4);
200 vec_t t5 = multiply_mat_vec(p.h, s.x);
207 vec_t t6 = vec_subtract(z, t5);
211 /* t7 = k * (z - h * x) */
214 vec_t t7 = multiply_mat_vec(k, t6);
223 * h: state to measurement matrix
224 * x': corrected state
226 * x' = x + k * (z - h * x)
229 n.x = vec_add(s.x, t7);
231 print_vec("n->x", n.x);
235 /* 3x3 = 3x2 * 2x3 */
236 mat_t t8 = multiply(k, p.h);
242 /* 3x3 = 3x3 - 3x3 */
243 mat_t t9 = subtract(identity(dim(s.x)), t8);
251 * h: state to measurement matrix
252 * p': corrected error
254 * p' = (1 - k * h) * p
259 /* 3x3 = 3x3 * 3x3 */
260 n.p = multiply(t9, s.p);
262 print_mat("n->p", n.p);
263 # print_state("correct", n);
268 real distance(mat_t a, mat_t b) {
274 for (int i = 0; i < i_max; i++)
275 for (int j = 0; j < j_max; j++)
276 s += (a[i,j] - b[i,j]) ** 2;
280 public mat_t converge(parameters_t p) {
281 int model = dims(p.a)[0];
282 int measure = dims(p.r)[0];
285 .x = (real[model]) { 0 ... },
286 .p = (real[model,model]) { { 0 ... } ... },
287 .k = (real[model,measure]) { { 0 ... } ... }
290 vec_t z = (real [measure]) { 0 ... };
292 state_t s_pre = predict(s, p);
293 state_t s_post = correct(s_pre, z, p);
294 real d = distance(s.k, s_post.k);
297 if (d < 1e-10 && reps > 10)
303 public parameters_fast_t convert_to_fast(parameters_t p) {
304 return (parameters_fast_t) {
305 .a = p.a, .k = converge(p), .h = p.h