altos: Share log code between telescience and telebt. Add telebt log
[fw/altos] / src / kalman / kalman_filter.5c
1 load "matrix.5c"
2
3 /*
4  * Copyright © 2011 Keith Packard <keithp@keithp.com>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; version 2 of the License.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License along
16  * with this program; if not, write to the Free Software Foundation, Inc.,
17  * 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
18  */
19
20 namespace kalman {
21
22         import matrix;
23
24         public typedef struct {
25                 vec_t   x;              /* state */
26                 mat_t   p;              /* error estimate */
27                 mat_t   k;              /* kalman factor */
28         } state_t;
29
30         public typedef struct {
31                 mat_t   a;              /* model */
32                 mat_t   q;              /* model error covariance */
33                 mat_t   r;              /* measurement error covariance */
34                 mat_t   h;              /* measurement from model */
35         } parameters_t;
36
37         public typedef struct {
38                 mat_t   a;              /* model */
39                 mat_t   k;              /* kalman coefficient */
40                 mat_t   h;              /* measurement from model */
41         } parameters_fast_t;
42
43         vec_t measurement_from_state(vec_t x, mat_t h) {
44                 return multiply_mat_vec(h, x);
45         }
46
47
48         void print_state(string name, state_t s) {
49                 print_vec(sprintf("%s state", name), s.x);
50                 print_mat(sprintf("%s error", name), s.p);
51         }
52
53         public bool debug = false;
54
55         public state_t predict (state_t s, parameters_t p) {
56                 state_t n;
57
58                 if (debug) {
59                         printf ("--------PREDICT--------\n");
60                         print_state("current", s);
61                 }
62
63                 /* Predict state
64                  *
65                  * x': predicted state
66                  * a:  model
67                  * x:  previous state
68                  *
69                  * x' = a * x;
70                  */
71
72                 n.x = multiply_mat_vec(p.a, s.x);
73
74                 /* t0 = a * p */
75                 mat_t t0 = multiply (p.a, s.p);
76                 if (debug)
77                         print_mat("t0", t0);
78
79                 /* t1 = a * p * transpose(a) */
80
81                 mat_t t1 = multiply (t0, transpose(p.a));
82
83                 /* Predict error
84                  *
85                  * p': predicted error
86                  * a:  model
87                  * p:  previous error
88                  * q:  model error
89                  *
90                  * p' = a * p * transpose(a) + q
91                  */
92
93                 n.p = add(t1, p.q);
94                 if (debug)
95                         print_state("predict", n);
96                 return n;
97         }
98
99         public vec_t predict_fast(vec_t x, parameters_fast_t p) {
100                 if (debug) {
101                         printf ("--------FAST PREDICT--------\n");
102                         print_vec("current", x);
103                 }
104                 vec_t new = multiply_mat_vec(p.a, x);
105                 if (debug)
106                         print_vec("predict", new);
107                 return new;
108         }
109
110         public vec_t correct_fast(vec_t x, vec_t z, parameters_fast_t p) {
111                 if (debug) {
112                         printf ("--------FAST CORRECT--------\n");
113                         print_vec("measure", z);
114                         print_vec("current", x);
115                 }
116                 vec_t   model = multiply_mat_vec(p.h, x);
117                 if (debug)
118                         print_vec("extract model", model);
119                 vec_t   diff = vec_subtract(z, model);
120                 if (debug)
121                         print_vec("difference", diff);
122                 vec_t   adjust = multiply_mat_vec(p.k, diff);
123                 if (debug)
124                         print_vec("adjust", adjust);
125
126                 vec_t new = vec_add(x,
127                                multiply_mat_vec(p.k,
128                                                 vec_subtract(z,
129                                                              multiply_mat_vec(p.h, x))));
130                 if (debug)
131                         print_vec("correct", new);
132                 return new;
133         }
134
135         public state_t correct(state_t s, vec_t z, parameters_t p) {
136                 state_t n;
137
138                 if (debug) {
139                         printf ("--------CORRECT--------\n");
140                         print_vec("measure", z);
141                         print_state("current", s);
142                 }
143
144                 /* t0 = p * T(h) */
145
146                 /* 3x2 = 3x3 * 3x2 */
147                 mat_t t0 = multiply(s.p, transpose(p.h));
148                 if (debug)
149                         print_mat("t0", t0);
150
151                 /* t1 = h * p */
152
153                 /* 2x3 = 2x3 * 3x3 */
154                 mat_t t1 = multiply(p.h, s.p);
155                 if (debug)
156                         print_mat("t1", t1);
157
158                 /* t2 = h * p * transpose(h) */
159
160                 /* 2x2 = 2x3 * 3x2 */
161                 mat_t t2 = multiply(t1, transpose(p.h));
162                 if (debug)
163                         print_mat("t2", t2);
164
165                 /* t3 = h * p * transpose(h) + r */
166
167                 /* 2x2 = 2x2 + 2x2 */
168                 mat_t t3 = add(t2, p.r);
169                 if (debug)
170                         print_mat("t3", t3);
171
172                 /* t4 = inverse(h * p * transpose(h) + r) */
173
174                 /* 2x2 = 2x2 */
175                 mat_t t4 = inverse(t3);
176                 if (debug)
177                         print_mat("t4", t4);
178
179                 /* Kalman value */
180
181                 /* k: Kalman value
182                  * p: error estimate
183                  * h: state to measurement matrix
184                  * r: measurement error covariance
185                  *
186                  * k = p * transpose(h) * inverse(h * p * transpose(h) + r)
187                  *
188                  * k = K(p)
189                  */
190
191                 /* 3x2 = 3x2 * 2x2 */
192                 mat_t k = multiply(t0, t4);
193                 if (debug)
194                         print_mat("k", k);
195                 n.k = k;
196
197                 /* t5 = h * x */
198
199                 /* 2 = 2x3 * 3 */
200                 vec_t t5 = multiply_mat_vec(p.h, s.x);
201                 if (debug)
202                         print_vec("t5", t5);
203
204                 /* t6 = z - h * x */
205
206                 /* 2 = 2 - 2 */
207                 vec_t t6 = vec_subtract(z, t5);
208                 if (debug)
209                         print_vec("t6", t6);
210
211                 /* t7 = k * (z - h * x) */
212
213                 /* 3 = 3x2 * 2 */
214                 vec_t t7 = multiply_mat_vec(k, t6);
215                 if (debug)
216                         print_vec("t7", t7);
217
218                 /* Correct state
219                  *
220                  * x:  predicted state
221                  * k:  kalman value
222                  * z:  measurement
223                  * h:  state to measurement matrix
224                  * x': corrected state
225                  *
226                  * x' = x + k * (z - h * x)
227                  */
228
229                 n.x = vec_add(s.x, t7);
230                 if (debug)
231                         print_vec("n->x", n.x);
232
233                 /* t8 = k * h */
234
235                 /* 3x3 = 3x2 * 2x3 */
236                 mat_t t8 = multiply(k, p.h);
237                 if (debug)
238                         print_mat("t8", t8);
239
240                 /* t9 = 1 - k * h */
241
242                 /* 3x3 = 3x3 - 3x3 */
243                 mat_t t9 = subtract(identity(dim(s.x)), t8);
244                 if (debug)
245                         print_mat("t9", t9);
246
247                 /* Correct error
248                  *
249                  * p:  predicted error
250                  * k:  kalman value
251                  * h:  state to measurement matrix
252                  * p': corrected error
253                  *
254                  * p' = (1 - k * h) * p
255                  *
256                  * p' = P(k,p)
257                  */
258
259                 /* 3x3 = 3x3 * 3x3 */
260                 n.p = multiply(t9, s.p);
261                 if (debug) {
262                         print_mat("n->p", n.p);
263 #                       print_state("correct", n);
264                 }
265                 return n;
266         }
267
268         real distance(mat_t a, mat_t b) {
269                 int[2]  d = dims(a);
270                 int     i_max = d[0];
271                 int     j_max = d[1];
272                 real    s = 0;
273
274                 for (int i = 0; i < i_max; i++)
275                         for (int j = 0; j < j_max; j++)
276                                 s += (a[i,j] - b[i,j]) ** 2;
277                 return sqrt(s);
278         }
279
280         public mat_t converge(parameters_t p) {
281                 int     model = dims(p.a)[0];
282                 int     measure = dims(p.r)[0];
283                 int     reps = 0;
284                 state_t s = {
285                         .x = (real[model]) { 0 ... },
286                         .p = (real[model,model]) { { 0 ... } ... },
287                         .k = (real[model,measure]) { { 0 ... } ... }
288                 };
289
290                 vec_t   z = (real [measure]) { 0 ... };
291                 for (;;) {
292                         state_t s_pre = predict(s, p);
293                         state_t s_post = correct(s_pre, z, p);
294                         real    d = distance(s.k, s_post.k);
295                         s = s_post;
296                         reps++;
297                         if (d < 1e-10 && reps > 10)
298                                 break;
299                 }
300                 return s.k;
301         }
302
303         public parameters_fast_t convert_to_fast(parameters_t p) {
304                 return (parameters_fast_t) {
305                         .a = p.a, .k = converge(p), .h = p.h
306                 };
307         }
308 }