Imported Upstream version 3.0.2
[debian/gnuradio] / gnuradio-core / src / python / gnuradio / gr / basic_flow_graph.py
1 #
2 # Copyright 2004 Free Software Foundation, Inc.
3
4 # This file is part of GNU Radio
5
6 # GNU Radio is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2, or (at your option)
9 # any later version.
10
11 # GNU Radio is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # along with GNU Radio; see the file COPYING.  If not, write to
18 # the Free Software Foundation, Inc., 51 Franklin Street,
19 # Boston, MA 02110-1301, USA.
20
21
22 from gnuradio_swig_python import gr_block_sptr
23 import types
24 import hier_block
25
26 def remove_duplicates (seq):
27     new = []
28     for x in seq:
29         if not x in new:
30             new.append (x)
31     return new
32
33
34 class endpoint (object):
35     __slots__ = ['block', 'port']
36     def __init__ (self, block, port):
37         self.block = block
38         self.port = port
39
40     def __cmp__ (self, other):
41         if self.block == other.block and self.port == other.port:
42             return 0
43         return 1
44
45     def __str__ (self):
46         return '<endpoint (%s, %s)>' % (self.block, self.port)
47         
48 def expand_src_endpoint (src_endpoint):
49     # A src_endpoint is an output of a block
50     src_endpoint = coerce_endpoint (src_endpoint)
51     if isinstance (src_endpoint.block, hier_block.hier_block_base):
52         return expand_src_endpoint (
53             coerce_endpoint (src_endpoint.block.resolve_output_port(src_endpoint.port)))
54     else:
55         return src_endpoint
56
57 def expand_dst_endpoint (dst_endpoint):
58     # a dst_endpoint is the input to a block
59     dst_endpoint = coerce_endpoint (dst_endpoint)
60     if isinstance (dst_endpoint.block, hier_block.hier_block_base):
61         exp = [coerce_endpoint(x) for x in
62                dst_endpoint.block.resolve_input_port(dst_endpoint.port)]
63         return expand_dst_endpoints (exp)
64     else:
65         return [dst_endpoint]
66
67 def expand_dst_endpoints (endpoint_list):
68     r = []
69     for e in endpoint_list:
70         r.extend (expand_dst_endpoint (e))
71     return r
72            
73
74 def coerce_endpoint (x):
75     if isinstance (x, endpoint):
76         return x
77     elif isinstance (x, types.TupleType) and len (x) == 2:
78         return endpoint (x[0], x[1])
79     elif hasattr (x, 'block'):          # assume it's a block
80         return endpoint (x, 0)
81     elif isinstance(x, hier_block.hier_block_base):
82         return endpoint (x, 0)
83     else:
84         raise ValueError, "Not coercible to endpoint: %s" % (x,)
85     
86
87 class edge (object):
88     __slots__ = ['src', 'dst']
89     def __init__ (self, src_endpoint, dst_endpoint):
90         self.src = src_endpoint
91         self.dst = dst_endpoint
92
93     def __cmp__ (self, other):
94         if self.src == other.src and self.dst == other.dst:
95             return 0
96         return 1
97
98     def __repr__ (self):
99         return '<edge (%s, %s)>' % (self.src, self.dst)
100
101 class basic_flow_graph (object):
102     '''basic_flow_graph -- describe connections between blocks'''
103     # __slots__ is incompatible with weakrefs (for some reason!)
104     # __slots__ = ['edge_list']
105     def __init__ (self):
106         self.edge_list = []
107
108     def connect (self, *points):
109         '''connect requires two or more arguments that can be coerced to endpoints.
110         If more than two arguments are provided, they are connected together successively.
111         '''
112         if len (points) < 2:
113             raise ValueError, ("connect requires at least two endpoints; %d provided." % (len (points),))
114         for i in range (1, len (points)):
115             self._connect (points[i-1], points[i])
116
117     def _connect (self, src_endpoint, dst_endpoint):
118         s = expand_src_endpoint (src_endpoint)
119         for d in expand_dst_endpoint (dst_endpoint):
120             self._connect_prim (s, d)
121
122     def _connect_prim (self, src_endpoint, dst_endpoint):
123         src_endpoint = coerce_endpoint (src_endpoint)
124         dst_endpoint = coerce_endpoint (dst_endpoint)
125         self._check_valid_src_port (src_endpoint)
126         self._check_valid_dst_port (dst_endpoint)
127         self._check_dst_in_use (dst_endpoint)
128         self._check_type_match (src_endpoint, dst_endpoint)
129         self.edge_list.append (edge (src_endpoint, dst_endpoint))
130
131     def disconnect (self, src_endpoint, dst_endpoint):
132         s = expand_src_endpoint (src_endpoint)
133         for d in expand_dst_endpoint (dst_endpoint):
134             self._disconnect_prim (s, d)
135
136     def _disconnect_prim (self, src_endpoint, dst_endpoint):
137         src_endpoint = coerce_endpoint (src_endpoint)
138         dst_endpoint = coerce_endpoint (dst_endpoint)
139         e = edge (src_endpoint, dst_endpoint)
140         self.edge_list.remove (e)
141
142     def disconnect_all (self):
143         self.edge_list = []
144         
145     def validate (self):
146         # check all blocks to ensure:
147         #  (1a) their input ports are contiguously assigned
148         #  (1b) the number of input ports is between min and max
149         #  (2a) their output ports are contiguously assigned
150         #  (2b) the number of output ports is between min and max
151         #  (3)  check_topology returns true
152
153         for m in self.all_blocks ():
154             # print m
155
156             edges = self.in_edges (m)
157             used_ports = [e.dst.port for e in edges]
158             ninputs = self._check_contiguity (m, m.input_signature (), used_ports, "input")
159
160             edges = self.out_edges (m)
161             used_ports = [e.src.port for e in edges]
162             noutputs = self._check_contiguity (m, m.output_signature (), used_ports, "output")
163
164             if not m.check_topology (ninputs, noutputs):
165                 raise ValueError, ("%s::check_topology (%d, %d) failed" % (m, ninputs, noutputs))
166             
167         
168     # --- public utilities ---
169                 
170     def all_blocks (self):
171         '''return list of all blocks in the graph'''
172         all_blocks = []
173         for edge in self.edge_list:
174             m = edge.src.block
175             if not m in all_blocks:
176                 all_blocks.append (m)
177             m = edge.dst.block
178             if not m in all_blocks:
179                 all_blocks.append (m)
180         return all_blocks
181         
182     def in_edges (self, m):
183         '''return list of all edges that have M as a destination'''
184         return [e for e in self.edge_list if e.dst.block == m]
185     
186     def out_edges (self, m):
187         '''return list of all edges that have M as a source'''
188         return [e for e in self.edge_list if e.src.block == m]
189     
190     def downstream_verticies (self, m):
191         return [e.dst.block for e in self.out_edges (m)]
192
193     def downstream_verticies_port (self, m, port):
194         return [e.dst.block for e in self.out_edges(m) if e.src.port == port]
195
196     def upstream_verticies (self, m):
197         return [e.src.block for e in self.in_edges (m)]
198
199     def adjacent_verticies (self, m):
200         '''return list of all verticies adjacent to M'''
201         return self.downstream_verticies (m) + self.upstream_verticies (m)
202
203     def sink_p (self, m):
204         '''return True iff this block is a sink'''
205         e = self.out_edges (m)
206         return len (e) == 0
207
208     def source_p (self, m):
209         '''return True iff this block is a source'''
210         e = self.in_edges (m)
211         return len (e) == 0
212         
213     # --- internal methods ---
214     
215     def _check_dst_in_use (self, dst_endpoint):
216         '''Ensure that there is not already an endpoint that terminates at dst_endpoint.'''
217         x = [ep for ep in self.edge_list if ep.dst == dst_endpoint]
218         if x:    # already in use
219             raise ValueError, ("destination endpoint already in use: %s" % (dst_endpoint))
220
221     def _check_valid_src_port (self, src_endpoint):
222         self._check_port (src_endpoint.block.output_signature(), src_endpoint.port)
223         
224     def _check_valid_dst_port (self, dst_endpoint):
225         self._check_port (dst_endpoint.block.input_signature(), dst_endpoint.port)
226         
227     def _check_port (self, signature, port):
228         if port < 0:
229             raise ValueError, 'port number out of range.'
230         if signature.max_streams () == -1: # infinite
231             return                         # OK
232         if port >= signature.max_streams ():
233             raise ValueError, 'port number out of range.'
234
235     def _check_type_match (self, src_endpoint, dst_endpoint):
236         # for now, we just ensure that the stream item sizes match
237         src_sig = src_endpoint.block.output_signature ()
238         dst_sig = dst_endpoint.block.input_signature ()
239         src_size = src_sig.sizeof_stream_item (src_endpoint.port)
240         dst_size = dst_sig.sizeof_stream_item (dst_endpoint.port)
241         if src_size != dst_size:
242             raise ValueError, (
243 ' '.join(('source and destination data sizes are different:',
244 src_endpoint.block.name(),
245 dst_endpoint.block.name())))
246
247     def _check_contiguity (self, m, sig, used_ports, dir):
248         used_ports.sort ()
249         used_ports = remove_duplicates (used_ports)
250         min_s = sig.min_streams ()
251
252         l = len (used_ports)
253         if l == 0:
254             if min_s == 0:
255                 return l
256             raise ValueError, ("%s requires %d %s connections.  It has none." %
257                                (m, min_s, dir))
258         
259         if used_ports[-1] + 1 < min_s:
260             raise ValueError, ("%s requires %d %s connections.  It has %d." %
261                                (m, min_s, dir, used_ports[-1] + 1))
262             
263         if used_ports[-1] + 1 != l:
264             for i in range (l):
265                 if used_ports[i] != i:
266                     raise ValueError, ("%s %s port %d is not connected" %
267                                        (m, dir, i))
268         
269         # print "%s ports: %s" % (dir, used_ports)
270         return l