Several enhancements to gr-trellis and gnuradio-examples/python/channel-coding:
[debian/gnuradio] / gr-trellis / src / lib / trellis_siso_f.cc
1 /* -*- c++ -*- */
2 /*
3  * Copyright 2004 Free Software Foundation, Inc.
4  * 
5  * This file is part of GNU Radio
6  * 
7  * GNU Radio is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2, or (at your option)
10  * any later version.
11  * 
12  * GNU Radio is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with GNU Radio; see the file COPYING.  If not, write to
19  * the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
20  * Boston, MA 02111-1307, USA.
21  */
22
23 #ifdef HAVE_CONFIG_H
24 #include "config.h"
25 #endif
26
27 #include <trellis_siso_f.h>
28 #include <gr_io_signature.h>
29 #include <stdexcept>
30 #include <assert.h>
31 #include <iostream>
32   
33 static const float INF = 1.0e9;
34
35 trellis_siso_f_sptr 
36 trellis_make_siso_f (
37     const fsm &FSM,
38     int K,
39     int S0,
40     int SK,
41     bool POSTI,
42     bool POSTO,
43     trellis_siso_type_t SISO_TYPE)
44 {
45   return trellis_siso_f_sptr (new trellis_siso_f (FSM,K,S0,SK,POSTI,POSTO,SISO_TYPE));
46 }
47
48 trellis_siso_f::trellis_siso_f (
49     const fsm &FSM,
50     int K,
51     int S0,
52     int SK,
53     bool POSTI,
54     bool POSTO,
55     trellis_siso_type_t SISO_TYPE)
56   : gr_block ("siso_f",
57                           gr_make_io_signature (1, -1, sizeof (float)),
58                           gr_make_io_signature (1, -1, sizeof (float))),  
59   d_FSM (FSM),
60   d_K (K),
61   d_S0 (S0),
62   d_SK (SK),
63   d_POSTI (POSTI),
64   d_POSTO (POSTO),
65   d_SISO_TYPE (SISO_TYPE),
66   d_alpha(FSM.S()*(K+1)),
67   d_beta(FSM.S()*(K+1))
68 {
69     int multiple;
70     if (d_POSTI && d_POSTO) 
71         multiple = d_FSM.I()+d_FSM.O();
72     else if(d_POSTI)
73         multiple = d_FSM.I();
74     else if(d_POSTO)
75         multiple = d_FSM.O();
76     else
77         throw std::runtime_error ("Not both POSTI and POSTO can be false.");
78     //printf("constructor: Multiple = %d\n",multiple);
79     set_output_multiple (d_K*multiple);
80     //what is the meaning of relative rate for this?
81     // it was suggested to use the one furthest from 1.0
82     // let's do it.
83     set_relative_rate ( multiple / ((double) d_FSM.I()) );
84 }
85
86
87 void
88 trellis_siso_f::forecast (int noutput_items, gr_vector_int &ninput_items_required)
89 {
90   int multiple;
91   if (d_POSTI && d_POSTO)
92       multiple = d_FSM.I()+d_FSM.O();
93   else if(d_POSTI)
94       multiple = d_FSM.I();
95   else if(d_POSTO)
96       multiple = d_FSM.O();
97   else
98       throw std::runtime_error ("Not both POSTI and POSTO can be false.");
99   //printf("forecast: Multiple = %d\n",multiple); 
100   assert (noutput_items % (d_K*multiple) == 0);
101   int input_required1 =  d_FSM.I() * (noutput_items/multiple) ;
102   int input_required2 =  d_FSM.O() * (noutput_items/multiple) ;
103   //printf("forecast: Output requirements:  %d\n",noutput_items);
104   //printf("forecast: Input requirements:  %d   %d\n",input_required1,input_required2);
105   unsigned ninputs = ninput_items_required.size();
106   assert(ninputs % 2 == 0);
107   for (unsigned int i = 0; i < ninputs/2; i++) {
108     ninput_items_required[2*i] = input_required1;
109     ninput_items_required[2*i+1] = input_required2;
110   }
111 }
112
113 inline float min(float a, float b)
114 {
115   return a <= b ? a : b;
116 }
117
118 inline float min_star(float a, float b)
119 {
120   return (a <= b ? a : b)-log(1+exp(a <= b ? a-b : b-a));
121 }
122
123 void siso_algorithm(int I, int S, int O, 
124              const std::vector<int> &NS,
125              const std::vector<int> &OS,
126              const std::vector<int> &PS,
127              const std::vector<int> &PI,
128              int K,
129              int S0,int SK,
130              bool POSTI, bool POSTO,
131              float (*p2mymin)(float,float),
132              const float *priori, const float *prioro, float *post,
133              std::vector<float> &alpha,
134              std::vector<float> &beta) 
135 {
136   float norm,mm,minm;
137
138
139   if(S0<0) { // initial state not specified
140       for(int i=0;i<S;i++) alpha[0*S+i]=0;
141   }
142   else {
143       for(int i=0;i<S;i++) alpha[0*S+i]=INF;
144       alpha[0*S+S0]=0.0;
145   }
146
147   for(int k=0;k<K;k++) { // forward recursion
148       norm=INF;
149       for(int j=0;j<S;j++) {
150           minm=INF;
151           for(int i=0;i<I;i++) {
152               int i0 = j*I+i;
153               mm=alpha[k*S+PS[i0]]+priori[k*I+PI[i0]]+prioro[k*O+OS[PS[i0]*I+PI[i0]]];
154               minm=(*p2mymin)(minm,mm);
155           }
156           alpha[(k+1)*S+j]=minm;
157           if(minm<norm) norm=minm;
158       }
159       for(int j=0;j<S;j++) 
160           alpha[(k+1)*S+j]-=norm; // normalize total metrics so they do not explode
161   }
162
163   if(SK<0) { // final state not specified
164       for(int i=0;i<S;i++) beta[K*S+i]=0;
165   }
166   else {
167       for(int i=0;i<S;i++) beta[K*S+i]=INF;
168       beta[K*S+SK]=0.0;
169   }
170
171   for(int k=K-1;k>=0;k--) { // backward recursion
172       norm=INF;
173       for(int j=0;j<S;j++) { 
174           minm=INF;
175           for(int i=0;i<I;i++) {
176               int i0 = j*I+i;
177               mm=beta[(k+1)*S+NS[i0]]+priori[k*I+i]+prioro[k*O+OS[i0]];
178               minm=(*p2mymin)(minm,mm);
179           }
180           beta[k*S+j]=minm;
181           if(minm<norm) norm=minm;
182       }
183       for(int j=0;j<S;j++)
184           beta[k*S+j]-=norm; // normalize total metrics so they do not explode
185   }
186
187
188 if (POSTI && POSTO)
189 {
190   for(int k=0;k<K;k++) { // input combining
191       norm=INF;
192       for(int i=0;i<I;i++) {
193           minm=INF;
194           for(int j=0;j<S;j++) {
195               mm=alpha[k*S+j]+prioro[k*O+OS[j*I+i]]+beta[(k+1)*S+NS[j*I+i]];
196               minm=(*p2mymin)(minm,mm);
197           }
198           post[k*(I+O)+i]=minm;
199           if(minm<norm) norm=minm;
200       }
201       for(int i=0;i<I;i++)
202           post[k*(I+O)+i]-=norm; // normalize metrics
203   }
204
205
206   for(int k=0;k<K;k++) { // output combining
207       norm=INF;
208       for(int n=0;n<O;n++) {
209           minm=INF;
210           for(int j=0;j<S;j++) {
211               for(int i=0;i<I;i++) {
212                   mm= (n==OS[j*I+i] ? alpha[k*S+j]+priori[k*I+i]+beta[(k+1)*S+NS[j*I+i]] : INF);
213                   minm=(*p2mymin)(minm,mm);
214               }
215           }
216           post[k*(I+O)+I+n]=minm;
217           if(minm<norm) norm=minm;
218       }
219       for(int n=0;n<O;n++)
220           post[k*(I+O)+I+n]-=norm; // normalize metrics
221   }
222
223 else if(POSTI) 
224 {
225   for(int k=0;k<K;k++) { // input combining
226       norm=INF;
227       for(int i=0;i<I;i++) {
228           minm=INF;
229           for(int j=0;j<S;j++) {
230               mm=alpha[k*S+j]+prioro[k*O+OS[j*I+i]]+beta[(k+1)*S+NS[j*I+i]];
231               minm=(*p2mymin)(minm,mm);
232           }
233           post[k*I+i]=minm;
234           if(minm<norm) norm=minm;
235       }
236       for(int i=0;i<I;i++)
237           post[k*I+i]-=norm; // normalize metrics
238   }
239 }
240 else if(POSTO)
241 {
242   for(int k=0;k<K;k++) { // output combining
243       norm=INF;
244       for(int n=0;n<O;n++) {
245           minm=INF;
246           for(int j=0;j<S;j++) {
247               for(int i=0;i<I;i++) {
248                   mm= (n==OS[j*I+i] ? alpha[k*S+j]+priori[k*I+i]+beta[(k+1)*S+NS[j*I+i]] : INF);
249                   minm=(*p2mymin)(minm,mm);
250               }
251           }
252           post[k*O+n]=minm;
253           if(minm<norm) norm=minm;
254       }
255       for(int n=0;n<O;n++)
256           post[k*O+n]-=norm; // normalize metrics
257   }
258 }
259 else
260     throw std::runtime_error ("Not both POSTI and POSTO can be false.");
261
262 }
263
264
265
266
267
268
269 int
270 trellis_siso_f::general_work (int noutput_items,
271                         gr_vector_int &ninput_items,
272                         gr_vector_const_void_star &input_items,
273                         gr_vector_void_star &output_items)
274 {
275   assert (input_items.size() == 2*output_items.size());
276   int nstreams = output_items.size();
277   //printf("general_work:Streams:  %d\n",nstreams); 
278   int multiple;
279   if (d_POSTI && d_POSTO)
280       multiple = d_FSM.I()+d_FSM.O();
281   else if(d_POSTI)
282       multiple = d_FSM.I();
283   else if(d_POSTO)
284       multiple = d_FSM.O();
285   else
286       throw std::runtime_error ("Not both POSTI and POSTO can be false.");
287
288   assert (noutput_items % (d_K*multiple) == 0);
289   int nblocks = noutput_items / (d_K*multiple);
290   //printf("general_work:Blocks:  %d\n",nblocks); 
291   //for(int i=0;i<ninput_items.size();i++)
292       //printf("general_work:Input items available:  %d\n",ninput_items[i]);
293
294   float (*p2min)(float, float) = NULL; 
295   if(d_SISO_TYPE == TRELLIS_MIN_SUM)
296     p2min = &min;
297   else if(d_SISO_TYPE == TRELLIS_SUM_PRODUCT)
298     p2min = &min_star;
299
300
301   for (int m=0;m<nstreams;m++) {
302     const float *in1 = (const float *) input_items[2*m];
303     const float *in2 = (const float *) input_items[2*m+1];
304     float *out = (float *) output_items[m];
305     for (int n=0;n<nblocks;n++) {
306       siso_algorithm(d_FSM.I(),d_FSM.S(),d_FSM.O(),
307         d_FSM.NS(),d_FSM.OS(),d_FSM.PS(),d_FSM.PI(),
308         d_K,d_S0,d_SK,
309         d_POSTI,d_POSTO,
310         p2min,
311         &(in1[n*d_K*d_FSM.I()]),&(in2[n*d_K*d_FSM.O()]),
312         &(out[n*d_K*multiple]),
313         d_alpha,d_beta);
314     }
315   }
316
317   for (unsigned int i = 0; i < input_items.size()/2; i++) {
318     consume(2*i,d_FSM.I() * noutput_items / multiple );
319     consume(2*i+1,d_FSM.O() * noutput_items / multiple );
320   }
321
322   return noutput_items;
323 }