Added basic wavelet classifier blocks. GSL is now a prerequisite for
[debian/gnuradio] / gnuradio-core / src / python / gnuradio / gr / qa_classify.py
1 #!/usr/bin/env python
2
3 import numpy
4 from gnuradio import gr, gr_unittest
5 import copy
6 import pygsl.wavelet as wavelet
7 import math
8
9
10 def sqr(x):
11     return x*x
12
13 def np2(k):
14     m = 0
15     n = k - 1
16     while n > 0:
17         m += 1
18     return m
19
20
21 class qa_classify(gr_unittest.TestCase):
22
23     def setUp(self):
24         self.tb = gr.top_block()
25
26     def tearDown(self):
27         self.tb = None
28
29 #     def test_000_(self):
30 #         src_data = numpy.zeros(10)
31 #         trg_data = numpy.zeros(10)
32 #         src = gr.vector_source_f(src_data)
33 #         dst = gr.vector_sink_f()
34 #         self.tb.connect(src, dst)
35 #         self.tb.run()
36 #         rsl_data = dst.data()
37 #         sum = 0
38 #         for (u,v) in zip(trg_data, rsl_data):
39 #             w = u - v
40 #             sum += w * w
41 #         sum /= float(len(trg_data))
42 #         assert sum < 1e-6
43
44     def test_001_(self):
45         src_data = numpy.array([-1.0, 1.0, -1.0, 1.0])
46         trg_data = src_data * 0.5
47         src = gr.vector_source_f(src_data)
48         dst = gr.vector_sink_f()
49         rail = gr.rail_ff(-0.5, 0.5)
50         self.tb.connect(src, rail)
51         self.tb.connect(rail, dst)
52         self.tb.run()
53         rsl_data = dst.data()
54         sum = 0
55         for (u, v) in zip(trg_data, rsl_data):
56             w = u - v
57             sum += w * w
58         sum /= float(len(trg_data))
59         assert sum < 1e-6
60
61     def test_002_(self):
62         src_data = numpy.array([-1.0,
63                                 -1.0/2.0,
64                                 -1.0/3.0,
65                                 -1.0/4.0,
66                                 -1.0/5.0])
67         trg_data = copy.deepcopy(src_data)
68
69         src = gr.vector_source_f(src_data, False, len(src_data))
70         st = gr.stretch_ff(-1.0/5.0, len(src_data))
71         dst = gr.vector_sink_f(len(src_data))
72         self.tb.connect(src, st)
73         self.tb.connect(st, dst)
74         self.tb.run()
75         rsl_data = dst.data()
76         sum = 0
77         for (u, v) in zip(trg_data, rsl_data):
78             w = u - v
79             sum += w * w
80         sum /= float(len(trg_data))
81         assert sum < 1e-6
82         
83     def test_003_(self):
84         src_grid = (0.0, 1.0, 2.0, 3.0, 4.0)
85         trg_grid = copy.deepcopy(src_grid)
86         src_data = (0.0, 1.0, 0.0, 1.0, 0.0)
87
88         src = gr.vector_source_f(src_data, False, len(src_grid))
89         sq = gr.squash_ff(src_grid, trg_grid)
90         dst = gr.vector_sink_f(len(trg_grid))
91         self.tb.connect(src, sq)
92         self.tb.connect(sq, dst)
93         self.tb.run()
94         rsl_data = dst.data()
95         sum = 0
96         for (u, v) in zip(src_data, rsl_data):
97             w = u - v
98             sum += w * w
99         sum /= float(len(src_data))
100         assert sum < 1e-6
101
102     def test_004_(self):
103
104         n = 256
105         o = 4
106         ws = wavelet.workspace(n)
107         w = wavelet.daubechies(o)
108
109         a = numpy.arange(n)
110         b = numpy.sin(a*numpy.pi/16.0)
111         c = w.transform_forward(b, ws)
112         d = w.transform_inverse(c, ws)
113
114         src = gr.vector_source_f(b, False, n)
115         wv = gr.wavelet_ff(n, o, True)
116
117         dst = gr.vector_sink_f(n)
118         self.tb.connect(src, wv)
119         self.tb.connect(wv, dst)
120         self.tb.run()
121         e = dst.data()
122
123         sum = 0
124         for (u, v) in zip(c, e):
125             w = u - v
126             sum += w * w
127         sum /= float(len(c))
128         assert sum < 1e-6
129
130     def test_005_(self):
131
132         src_data = (1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
133
134         dwav = numpy.array(src_data)
135         wvps = numpy.zeros(3)
136         # wavelet power spectrum
137         scl = 1.0/sqr(dwav[0])
138         k = 1
139         for e in range(len(wvps)):
140             wvps[e] = scl*sqr(dwav[k:k+(01<<e)]).sum()
141             k += 01<<e
142
143         src = gr.vector_source_f(src_data, False, len(src_data))
144         kon = gr.wvps_ff(len(src_data))
145         dst = gr.vector_sink_f(int(math.ceil(math.log(len(src_data), 2))))
146
147         self.tb.connect(src, kon)
148         self.tb.connect(kon, dst)
149
150         self.tb.run()
151         snk_data = dst.data()
152
153         sum = 0
154         for (u,v) in zip(snk_data, wvps):
155             w = u - v
156             sum += w * w
157         sum /= float(len(snk_data))
158         assert sum < 1e-6
159
160 if __name__ == '__main__':
161     gr_unittest.main()
162