altos: fix comment about decoding last byte of FEC data
[fw/altos] / src / core / ao_viterbi.c
index 2d441f4bba7fa1cbed988fa57edb9a50a24572c9..17464cd1df3fd94cbf3a80638e2d17890a9cd901 100644 (file)
 #include <ao_fec.h>
 #include <stdio.h>
 
-/*
- * 'input' is 8-bits per symbol soft decision data
- * 'len' is output byte length
- */
-
-static const uint8_t ao_fec_encode_table[16] = {
-/* next 0  1     state */
-       0, 3,   /* 000 */
-       1, 2,   /* 001 */
-       3, 0,   /* 010 */
-       2, 1,   /* 011 */
-       3, 0,   /* 100 */
-       2, 1,   /* 101 */
-       0, 3,   /* 110 */
-       1, 2    /* 111 */
-};
-
 struct ao_soft_sym {
        uint8_t a, b;
 };
 
-struct ao_soft_sym
-ao_soft_sym(uint8_t bits)
-{
-       struct ao_soft_sym      s;
-
-       s.a = ((bits & 2) >> 1) * 0xff;
-       s.b = (bits & 1) * 0xff;
-       return s;
-}
+#define NUM_STATE      8
+#define NUM_HIST       8
+#define MOD_HIST(b)    ((b) & 7)
+
+static const struct ao_soft_sym ao_fec_decode_table[NUM_STATE][2] = {
+/* next        0              1                 state */
+       { { 0x00, 0x00 }, { 0xff, 0xff } } ,    /* 000 */
+       { { 0x00, 0xff }, { 0xff, 0x00 } },     /* 001 */
+       { { 0xff, 0xff }, { 0x00, 0x00 } },     /* 010 */
+       { { 0xff, 0x00 }, { 0x00, 0xff } },     /* 011 */
+       { { 0xff, 0xff }, { 0x00, 0x00 } },     /* 100 */
+       { { 0xff, 0x00 }, { 0x00, 0xff } },     /* 101 */
+       { { 0x00, 0x00 }, { 0xff, 0xff } },     /* 110 */
+       { { 0x00, 0xff }, { 0xff, 0x00 } }      /* 111 */
+};
 
-uint8_t
+static inline uint8_t
 ao_next_state(uint8_t state, uint8_t bit)
 {
        return ((state << 1) | bit) & 0x7;
 }
 
-static inline abs(int x) { return x < 0 ? -x : x; }
+static inline uint16_t ao_abs(int16_t x) { return x < 0 ? -x : x; }
 
-int
+static inline uint16_t
 ao_cost(struct ao_soft_sym a, struct ao_soft_sym b)
 {
-       return abs(a.a - b.a) + abs(a.b - b.b);
+       return ao_abs(a.a - b.a) + ao_abs(a.b - b.b);
 }
 
+/*
+ * 'in' is 8-bits per symbol soft decision data
+ * 'len' is input byte length. 'out' must be
+ * 'len'/16 bytes long
+ */
+
 uint8_t
-ao_fec_decode(uint8_t *in, int len, uint8_t *out)
+ao_fec_decode(uint8_t *in, uint16_t len, uint8_t *out)
 {
-       int     cost[len/2 + 1][8];
-       uint8_t prev[len/2 + 1][8];
-       int     c;
-       int     i, b;
-       uint8_t state = 0, min_state;
-       uint8_t bits[len/2];
-
-       for (c = 0; c < 8; c++)
-               cost[0][c] = 10000;
+       static uint16_t cost[2][NUM_STATE];             /* path cost */
+       static uint16_t bits[2][NUM_STATE];             /* save bits to quickly output them */
+       uint16_t        i;                              /* input byte index */
+       uint16_t        b;                              /* encoded symbol index (bytes/2) */
+       uint16_t        o;                              /* output bit index */
+       uint8_t         p;                              /* previous cost/bits index */
+       uint8_t         n;                              /* next cost/bits index */
+       uint8_t         state;                          /* state index */
+       uint8_t         bit;                            /* original encoded bit index */
+
+       p = 0;
+       for (state = 0; state < NUM_STATE; state++) {
+               cost[0][state] = 0xffff;
+               bits[0][state] = 0;
+       }
        cost[0][0] = 0;
 
+       o = 0;
        for (i = 0; i < len; i += 2) {
                b = i/2;
+               n = p ^ 1;
                struct ao_soft_sym s = { .a = in[i], .b = in[i+1] };
 
-               for (state = 0; state < 8; state++)
-                       cost[b+1][state] = 10000;
-
-               for (state = 0; state < 8; state++) {
-                       struct ao_soft_sym zero = ao_soft_sym(ao_fec_encode_table[state * 2 + 0]);
-                       struct ao_soft_sym one = ao_soft_sym(ao_fec_encode_table[state * 2 + 1]);
-                       uint8_t zero_state = ao_next_state(state, 0);
-                       uint8_t one_state = ao_next_state(state, 1);
-                       int     zero_cost = ao_cost(s, zero);
-                       int     one_cost = ao_cost(s, one);
-
-#if 0
-                       printf ("saw %02x %02x expected %02x %02x (%d) or %02x %02x (%d)\n",
-                               s.a, s.b, zero.a, zero.b, zero_cost, one.a, one.b, one_cost);
-#endif
-                       zero_cost += cost[b][state];
-                       one_cost += cost[b][state];
-                       if (zero_cost < cost[b+1][zero_state]) {
-                               prev[b+1][zero_state] = state;
-                               cost[b+1][zero_state] = zero_cost;
-                       }
-
-                       if (one_cost < cost[b+1][one_state]) {
-                               prev[b+1][one_state] = state;
-                               cost[b+1][one_state] = one_cost;
+               /* Reset next costs to 'impossibly high' values so that
+                * the first path through this state is cheaper than this
+                */
+               for (state = 0; state < NUM_STATE; state++)
+                       cost[n][state] = 0xffff;
+
+               /* Compute path costs and accumulate output bit path
+                * for each state and encoded bit value
+                */
+               for (state = 0; state < NUM_STATE; state++) {
+                       for (bit = 0; bit < 2; bit++) {
+                               int     bit_cost = cost[p][state] + ao_cost(s, ao_fec_decode_table[state][bit]);
+                               uint8_t bit_state = ao_next_state(state, bit);
+
+                               /* Only track the minimal cost to reach
+                                * this state; the best path can never
+                                * go through the higher cost paths as
+                                * total path cost is cumulative
+                                */
+                               if (bit_cost < cost[n][bit_state]) {
+                                       cost[n][bit_state] = bit_cost;
+                                       bits[n][bit_state] = (bits[p][state] << 1) | (state & 1);
+                               }
                        }
                }
 
+#if 0
                printf ("bit %3d symbol %2x %2x:", i/2, s.a, s.b);
-               for (state = 0; state < 8; state++) {
-                       printf (" %5d", cost[b+1][state]);
+               for (state = 0; state < NUM_STATE; state++) {
+                       printf (" %5d(%04x)", cost[n][state], bits[n][state]);
                }
                printf ("\n");
-       }
-
-       b = len / 2;
-       c = cost[b][0];
-       min_state = 0;
-       for (state = 1; state < 8; state++) {
-               if (cost[b][state] < c) {
-                       c = cost[b][state];
-                       min_state = state;
-               }
-       }
-
-       for (b = len/2; b > 0; b--) {
-               bits[b-1] = min_state & 1;
-               min_state = prev[b][min_state];
-       }
+#endif
+               p = n;
+
+               /* A loop is needed to handle the last output byte. It
+                * won't have any bits of future data to perform full
+                * error correction, but we might as well give the
+                * best possible answer anyways.
+                */
+               while ((b - o) >= (8 + NUM_HIST) || (i + 2 >= len && b > o)) {
+
+                       /* Compute number of bits to the end of the
+                        * last full byte of data. This is generally
+                        * NUM_HIST, unless we've reached
+                        * the end of the input, in which case
+                        * it will be seven.
+                        */
+                       int8_t          dist = b - (o + 8);     /* distance to last ready-for-writing bit */
+                       uint16_t        min_cost;               /* lowest cost */
+                       uint8_t         min_state;              /* lowest cost state */
+
+                       /* Find the best fit at the current point
+                        * of the decode.
+                        */
+                       min_cost = cost[p][0];
+                       min_state = 0;
+                       for (state = 1; state < NUM_STATE; state++) {
+                               if (cost[p][state] < min_cost) {
+                                       min_cost = cost[p][state];
+                                       min_state = state;
+                               }
+                       }
 
-       for (i = 0; i < len/2; i += 8) {
-               uint8_t byte;
+                       /* The very last byte of data has the very last bit
+                        * of data left in the state value; just smash the
+                        * bits value in place and reset the 'dist' from
+                        * -1 to 0 so that the full byte is read out
+                        */
+                       if (dist < 0) {
+                               bits[p][min_state] = (bits[p][min_state] << 1) | (min_state & 1);
+                               dist = 0;
+                       }
 
-               byte = 0;
-               for (b = 0; b < 8; b++)
-                       byte = (byte << 1) | bits[i + b];
-               out[i/8] = byte;
+#if 0
+                       printf ("\tbit %3d min_cost %5d old bit %3d old_state %x bits %02x\n",
+                               i/2, min_cost, o + 8, min_state, (bits[p][min_state] >> dist) & 0xff);
+#endif
+                       out[o >> 3] = bits[p][min_state] >> dist;
+                       o += 8;
+               }
        }
        return len/16;
 }