altos/test: Adjust CRC error rate after FEC fix
[fw/altos] / src / kalman / kalman_filter.5c
1 load "matrix.5c"
2
3 /*
4  * Copyright © 2011 Keith Packard <keithp@keithp.com>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful, but
12  * WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License along
17  * with this program; if not, write to the Free Software Foundation, Inc.,
18  * 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
19  */
20
21 namespace kalman {
22
23         import matrix;
24
25         public typedef struct {
26                 vec_t   x;              /* state */
27                 mat_t   p;              /* error estimate */
28                 mat_t   k;              /* kalman factor */
29         } state_t;
30
31         public typedef struct {
32                 mat_t   a;              /* model */
33                 mat_t   q;              /* model error covariance */
34                 mat_t   r;              /* measurement error covariance */
35                 mat_t   h;              /* measurement from model */
36         } parameters_t;
37
38         public typedef struct {
39                 mat_t   a;              /* model */
40                 mat_t   k;              /* kalman coefficient */
41                 mat_t   h;              /* measurement from model */
42         } parameters_fast_t;
43
44         vec_t measurement_from_state(vec_t x, mat_t h) {
45                 return multiply_mat_vec(h, x);
46         }
47
48
49         void print_state(string name, state_t s) {
50                 print_vec(sprintf("%s state", name), s.x);
51                 print_mat(sprintf("%s error", name), s.p);
52         }
53
54         public bool debug = false;
55
56         public state_t predict (state_t s, parameters_t p) {
57                 state_t n;
58
59                 if (debug) {
60                         printf ("--------PREDICT--------\n");
61                         print_state("current", s);
62                 }
63
64                 /* Predict state
65                  *
66                  * x': predicted state
67                  * a:  model
68                  * x:  previous state
69                  *
70                  * x' = a * x;
71                  */
72
73                 n.x = multiply_mat_vec(p.a, s.x);
74
75                 /* t0 = a * p */
76                 mat_t t0 = multiply (p.a, s.p);
77                 if (debug)
78                         print_mat("t0", t0);
79
80                 /* t1 = a * p * transpose(a) */
81
82                 mat_t t1 = multiply (t0, transpose(p.a));
83
84                 /* Predict error
85                  *
86                  * p': predicted error
87                  * a:  model
88                  * p:  previous error
89                  * q:  model error
90                  *
91                  * p' = a * p * transpose(a) + q
92                  */
93
94                 n.p = add(t1, p.q);
95                 if (debug)
96                         print_state("predict", n);
97                 return n;
98         }
99
100         public vec_t predict_fast(vec_t x, parameters_fast_t p) {
101                 if (debug) {
102                         printf ("--------FAST PREDICT--------\n");
103                         print_vec("current", x);
104                 }
105                 vec_t new = multiply_mat_vec(p.a, x);
106                 if (debug)
107                         print_vec("predict", new);
108                 return new;
109         }
110
111         public vec_t correct_fast(vec_t x, vec_t z, parameters_fast_t p) {
112                 if (debug) {
113                         printf ("--------FAST CORRECT--------\n");
114                         print_vec("measure", z);
115                         print_vec("current", x);
116                 }
117                 vec_t   model = multiply_mat_vec(p.h, x);
118                 if (debug)
119                         print_vec("extract model", model);
120                 vec_t   diff = vec_subtract(z, model);
121                 if (debug)
122                         print_vec("difference", diff);
123                 vec_t   adjust = multiply_mat_vec(p.k, diff);
124                 if (debug)
125                         print_vec("adjust", adjust);
126
127                 vec_t new = vec_add(x,
128                                multiply_mat_vec(p.k,
129                                                 vec_subtract(z,
130                                                              multiply_mat_vec(p.h, x))));
131                 if (debug)
132                         print_vec("correct", new);
133                 return new;
134         }
135
136         public state_t correct(state_t s, vec_t z, parameters_t p) {
137                 state_t n;
138
139                 if (debug) {
140                         printf ("--------CORRECT--------\n");
141                         print_vec("measure", z);
142                         print_state("current", s);
143                 }
144
145                 /* t0 = p * T(h) */
146
147                 /* 3x2 = 3x3 * 3x2 */
148                 mat_t t0 = multiply(s.p, transpose(p.h));
149                 if (debug)
150                         print_mat("t0", t0);
151
152                 /* t1 = h * p */
153
154                 /* 2x3 = 2x3 * 3x3 */
155                 mat_t t1 = multiply(p.h, s.p);
156                 if (debug)
157                         print_mat("t1", t1);
158
159                 /* t2 = h * p * transpose(h) */
160
161                 /* 2x2 = 2x3 * 3x2 */
162                 mat_t t2 = multiply(t1, transpose(p.h));
163                 if (debug)
164                         print_mat("t2", t2);
165
166                 /* t3 = h * p * transpose(h) + r */
167
168                 /* 2x2 = 2x2 + 2x2 */
169                 mat_t t3 = add(t2, p.r);
170                 if (debug)
171                         print_mat("t3", t3);
172
173                 /* t4 = inverse(h * p * transpose(h) + r) */
174
175                 /* 2x2 = 2x2 */
176                 mat_t t4 = inverse(t3);
177                 if (debug)
178                         print_mat("t4", t4);
179
180                 /* Kalman value */
181
182                 /* k: Kalman value
183                  * p: error estimate
184                  * h: state to measurement matrix
185                  * r: measurement error covariance
186                  *
187                  * k = p * transpose(h) * inverse(h * p * transpose(h) + r)
188                  *
189                  * k = K(p)
190                  */
191
192                 /* 3x2 = 3x2 * 2x2 */
193                 mat_t k = multiply(t0, t4);
194                 if (debug)
195                         print_mat("k", k);
196                 n.k = k;
197
198                 /* t5 = h * x */
199
200                 /* 2 = 2x3 * 3 */
201                 vec_t t5 = multiply_mat_vec(p.h, s.x);
202                 if (debug)
203                         print_vec("t5", t5);
204
205                 /* t6 = z - h * x */
206
207                 /* 2 = 2 - 2 */
208                 vec_t t6 = vec_subtract(z, t5);
209                 if (debug)
210                         print_vec("t6", t6);
211
212                 /* t7 = k * (z - h * x) */
213
214                 /* 3 = 3x2 * 2 */
215                 vec_t t7 = multiply_mat_vec(k, t6);
216                 if (debug)
217                         print_vec("t7", t7);
218
219                 /* Correct state
220                  *
221                  * x:  predicted state
222                  * k:  kalman value
223                  * z:  measurement
224                  * h:  state to measurement matrix
225                  * x': corrected state
226                  *
227                  * x' = x + k * (z - h * x)
228                  */
229
230                 n.x = vec_add(s.x, t7);
231                 if (debug)
232                         print_vec("n->x", n.x);
233
234                 /* t8 = k * h */
235
236                 /* 3x3 = 3x2 * 2x3 */
237                 mat_t t8 = multiply(k, p.h);
238                 if (debug)
239                         print_mat("t8", t8);
240
241                 /* t9 = 1 - k * h */
242
243                 /* 3x3 = 3x3 - 3x3 */
244                 mat_t t9 = subtract(identity(dim(s.x)), t8);
245                 if (debug)
246                         print_mat("t9", t9);
247
248                 /* Correct error
249                  *
250                  * p:  predicted error
251                  * k:  kalman value
252                  * h:  state to measurement matrix
253                  * p': corrected error
254                  *
255                  * p' = (1 - k * h) * p
256                  *
257                  * p' = P(k,p)
258                  */
259
260                 /* 3x3 = 3x3 * 3x3 */
261                 n.p = multiply(t9, s.p);
262                 if (debug) {
263                         print_mat("n->p", n.p);
264 #                       print_state("correct", n);
265                 }
266                 return n;
267         }
268
269         real distance(mat_t a, mat_t b) {
270                 int[2]  d = dims(a);
271                 int     i_max = d[0];
272                 int     j_max = d[1];
273                 real    s = 0;
274
275                 for (int i = 0; i < i_max; i++)
276                         for (int j = 0; j < j_max; j++)
277                                 s += (a[i,j] - b[i,j]) ** 2;
278                 return sqrt(s);
279         }
280
281         public mat_t converge(parameters_t p) {
282                 int     model = dims(p.a)[0];
283                 int     measure = dims(p.r)[0];
284                 int     reps = 0;
285                 state_t s = {
286                         .x = (real[model]) { 0 ... },
287                         .p = (real[model,model]) { { 0 ... } ... },
288                         .k = (real[model,measure]) { { 0 ... } ... }
289                 };
290
291                 vec_t   z = (real [measure]) { 0 ... };
292                 for (;;) {
293                         state_t s_pre = predict(s, p);
294                         state_t s_post = correct(s_pre, z, p);
295                         real    d = distance(s.k, s_post.k);
296                         s = s_post;
297                         reps++;
298                         if (d < 1e-10 && reps > 10)
299                                 break;
300                 }
301                 return s.k;
302         }
303
304         public parameters_fast_t convert_to_fast(parameters_t p) {
305                 return (parameters_fast_t) {
306                         .a = p.a, .k = converge(p), .h = p.h
307                 };
308         }
309 }