Imported Upstream version 3.2.2
[debian/gnuradio] / gnuradio-core / src / python / gnuradio / gr / qa_classify.py
1 #!/usr/bin/env python
2 #
3 # Copyright 2008 Free Software Foundation, Inc.
4
5 # This file is part of GNU Radio
6
7 # GNU Radio is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation; either version 3, or (at your option)
10 # any later version.
11
12 # GNU Radio is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with GNU Radio; see the file COPYING.  If not, write to
19 # the Free Software Foundation, Inc., 51 Franklin Street,
20 # Boston, MA 02110-1301, USA.
21
22
23 import numpy
24 from gnuradio import gr, gr_unittest
25 import copy
26 #import pygsl.wavelet as wavelet # FIXME: pygsl not checked for in config
27 import math
28
29
30 def sqr(x):
31     return x*x
32
33 def np2(k):
34     m = 0
35     n = k - 1
36     while n > 0:
37         m += 1
38     return m
39
40
41 class qa_classify(gr_unittest.TestCase):
42
43     def setUp(self):
44         self.tb = gr.top_block()
45
46     def tearDown(self):
47         self.tb = None
48
49 #     def test_000_(self):
50 #         src_data = numpy.zeros(10)
51 #         trg_data = numpy.zeros(10)
52 #         src = gr.vector_source_f(src_data)
53 #         dst = gr.vector_sink_f()
54 #         self.tb.connect(src, dst)
55 #         self.tb.run()
56 #         rsl_data = dst.data()
57 #         sum = 0
58 #         for (u,v) in zip(trg_data, rsl_data):
59 #             w = u - v
60 #             sum += w * w
61 #         sum /= float(len(trg_data))
62 #         assert sum < 1e-6
63
64     def test_001_(self):
65         src_data = numpy.array([-1.0, 1.0, -1.0, 1.0])
66         trg_data = src_data * 0.5
67         src = gr.vector_source_f(src_data)
68         dst = gr.vector_sink_f()
69         rail = gr.rail_ff(-0.5, 0.5)
70         self.tb.connect(src, rail)
71         self.tb.connect(rail, dst)
72         self.tb.run()
73         rsl_data = dst.data()
74         sum = 0
75         for (u, v) in zip(trg_data, rsl_data):
76             w = u - v
77             sum += w * w
78         sum /= float(len(trg_data))
79         assert sum < 1e-6
80
81     def test_002_(self):
82         src_data = numpy.array([-1.0,
83                                 -1.0/2.0,
84                                 -1.0/3.0,
85                                 -1.0/4.0,
86                                 -1.0/5.0])
87         trg_data = copy.deepcopy(src_data)
88
89         src = gr.vector_source_f(src_data, False, len(src_data))
90         st = gr.stretch_ff(-1.0/5.0, len(src_data))
91         dst = gr.vector_sink_f(len(src_data))
92         self.tb.connect(src, st)
93         self.tb.connect(st, dst)
94         self.tb.run()
95         rsl_data = dst.data()
96         sum = 0
97         for (u, v) in zip(trg_data, rsl_data):
98             w = u - v
99             sum += w * w
100         sum /= float(len(trg_data))
101         assert sum < 1e-6
102         
103     def test_003_(self):
104         src_grid = (0.0, 1.0, 2.0, 3.0, 4.0)
105         trg_grid = copy.deepcopy(src_grid)
106         src_data = (0.0, 1.0, 0.0, 1.0, 0.0)
107
108         src = gr.vector_source_f(src_data, False, len(src_grid))
109         sq = gr.squash_ff(src_grid, trg_grid)
110         dst = gr.vector_sink_f(len(trg_grid))
111         self.tb.connect(src, sq)
112         self.tb.connect(sq, dst)
113         self.tb.run()
114         rsl_data = dst.data()
115         sum = 0
116         for (u, v) in zip(src_data, rsl_data):
117             w = u - v
118             sum += w * w
119         sum /= float(len(src_data))
120         assert sum < 1e-6
121
122 #    def test_004_(self): # FIXME: requires pygsl
123 #
124 #        n = 256
125 #        o = 4
126 #        ws = wavelet.workspace(n)
127 #        w = wavelet.daubechies(o)
128 #
129 #        a = numpy.arange(n)
130 #        b = numpy.sin(a*numpy.pi/16.0)
131 #        c = w.transform_forward(b, ws)
132 #        d = w.transform_inverse(c, ws)
133 #
134 #        src = gr.vector_source_f(b, False, n)
135 #        wv = gr.wavelet_ff(n, o, True)
136 #
137 #        dst = gr.vector_sink_f(n)
138 #        self.tb.connect(src, wv)
139 #        self.tb.connect(wv, dst)
140 #        self.tb.run()
141 #        e = dst.data()
142 #
143 #        sum = 0
144 #        for (u, v) in zip(c, e):
145 #            w = u - v
146 #            sum += w * w
147 #        sum /= float(len(c))
148 #        assert sum < 1e-6
149
150     def test_005_(self):
151
152         src_data = (1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
153
154         dwav = numpy.array(src_data)
155         wvps = numpy.zeros(3)
156         # wavelet power spectrum
157         scl = 1.0/sqr(dwav[0])
158         k = 1
159         for e in range(len(wvps)):
160             wvps[e] = scl*sqr(dwav[k:k+(01<<e)]).sum()
161             k += 01<<e
162
163         src = gr.vector_source_f(src_data, False, len(src_data))
164         kon = gr.wvps_ff(len(src_data))
165         dst = gr.vector_sink_f(int(math.ceil(math.log(len(src_data), 2))))
166
167         self.tb.connect(src, kon)
168         self.tb.connect(kon, dst)
169
170         self.tb.run()
171         snk_data = dst.data()
172
173         sum = 0
174         for (u,v) in zip(snk_data, wvps):
175             w = u - v
176             sum += w * w
177         sum /= float(len(snk_data))
178         assert sum < 1e-6
179
180 if __name__ == '__main__':
181     gr_unittest.main()
182