upstream version 1.2.2
[debian/freetts] / com / sun / speech / freetts / cart / CARTImpl.java
1 /**
2  * Portions Copyright 2001 Sun Microsystems, Inc.
3  * Portions Copyright 1999-2001 Language Technologies Institute, 
4  * Carnegie Mellon University.
5  * All Rights Reserved.  Use is subject to license terms.
6  * 
7  * See the file "license.terms" for information on usage and
8  * redistribution of this file, and for a DISCLAIMER OF ALL 
9  * WARRANTIES.
10  */
11 package com.sun.speech.freetts.cart;
12
13 import java.io.BufferedReader;
14 import java.io.DataInputStream;
15 import java.io.DataOutputStream;
16 import java.io.IOException;
17 import java.io.InputStreamReader;
18 import java.net.URL;
19 import java.nio.ByteBuffer;
20 import java.util.StringTokenizer;
21 import java.util.logging.Level;
22 import java.util.logging.Logger;
23 import java.util.regex.Pattern;
24
25 import com.sun.speech.freetts.Item;
26 import com.sun.speech.freetts.PathExtractor;
27 import com.sun.speech.freetts.PathExtractorImpl;
28 import com.sun.speech.freetts.util.Utilities;
29
30 /**
31  * Implementation of a Classification and Regression Tree (CART) that is
32  * used more like a binary decision tree, with each node containing a
33  * decision or a final value.  The decision nodes in the CART trees
34  * operate on an Item and have the following format:
35  *
36  * <pre>
37  *   NODE feat operand value qfalse 
38  * </pre>
39  *
40  * <p>Where <code>feat</code> is an string that represents a feature
41  * to pass to the <code>findFeature</code> method of an item.
42  *
43  * <p>The <code>value</code> represents the value to be compared against
44  * the feature obtained from the item via the <code>feat</code> string.
45  * The <code>operand</code> is the operation to do the comparison.  The
46  * available operands are as follows:
47  *
48  * <ul>
49  *   <li>&lt; - the feature is less than value 
50  *   <li>= - the feature is equal to the value 
51  *   <li>> - the feature is greater than the value 
52  *   <li>MATCHES - the feature matches the regular expression stored in value 
53  *   <li>IN - [[[TODO: still guessing because none of the CART's in
54  *     Flite seem to use IN]]] the value is in the list defined by the
55  *     feature.
56  * </ul>
57  *
58  * <p>[[[TODO: provide support for the IN operator.]]]
59  *
60  * <p>For &lt; and >, this CART coerces the value and feature to
61  * float's. For =, this CART coerces the value and feature to string and
62  * checks for string equality. For MATCHES, this CART uses the value as a
63  * regular expression and compares the obtained feature to that.
64  *
65  * <p>A CART is represented by an array in this implementation. The
66  * <code>qfalse</code> value represents the index of the array to go to if
67  * the comparison does not match. In this implementation, qtrue index
68  * is always implied, and represents the next element in the
69  * array. The root node of the CART is the first element in the array.
70  *
71  * <p>The interpretations always start at the root node of the CART
72  * and continue until a final node is found.  The final nodes have the
73  * following form:
74  *
75  * <pre>
76  *   LEAF value
77  * </pre>
78  *
79  * <p>Where <code>value</code> represents the value of the node.
80  * Reaching a final node indicates the interpretation is over and the
81  * value of the node is the interpretation result.
82  */
83 public class CARTImpl implements CART {
84     /** Logger instance. */
85     private static final Logger LOGGER =
86         Logger.getLogger(CARTImpl.class.getName());
87     /**
88      * Entry in file represents the total number of nodes in the
89      * file.  This should be at the top of the file.  The format
90      * should be "TOTAL n" where n is an integer value.
91      */
92     final static String TOTAL = "TOTAL";
93
94     /**
95      * Entry in file represents a node.  The format should be
96      * "NODE feat op val f" where 'feat' represents a feature, op
97      * represents an operand, val is the value, and f is the index
98      * of the node to go to is there isn't a match.
99      */
100     final static String NODE = "NODE";
101
102     /**
103      * Entry in file represents a final node.  The format should be
104      * "LEAF val" where val represents the value.
105      */
106     final static String LEAF = "LEAF";
107
108     /**
109      * OPERAND_MATCHES
110      */
111     final static String OPERAND_MATCHES = "MATCHES";
112
113
114     /**
115      * The CART. Entries can be DecisionNode or LeafNode.  An
116      * ArrayList could be used here -- I chose not to because I
117      * thought it might be quicker to avoid dealing with the dynamic
118      * resizing.
119      */
120     Node[] cart = null;
121
122     /**
123      * The number of nodes in the CART.
124      */
125     transient int curNode = 0;
126
127     /**
128      * Creates a new CART by reading from the given URL.
129      *
130      * @param url the location of the CART data
131      *
132      * @throws IOException if errors occur while reading the data
133      */ 
134     public CARTImpl(URL url) throws IOException {
135         BufferedReader reader;
136         String line;
137
138         reader = new BufferedReader(new InputStreamReader(url.openStream()));
139         line = reader.readLine();
140         while (line != null) {
141             if (!line.startsWith("***")) {
142                 parseAndAdd(line);
143             }
144             line = reader.readLine();
145         }
146         reader.close();
147     }
148
149     /**
150      * Creates a new CART by reading from the given reader.
151      *
152      * @param reader the source of the CART data
153      * @param nodes the number of nodes to read for this cart
154      *
155      * @throws IOException if errors occur while reading the data
156      */ 
157     public CARTImpl(BufferedReader reader, int nodes) throws IOException {
158         this(nodes);
159         String line;
160         for (int i = 0; i < nodes; i++) {
161             line = reader.readLine();
162             if (!line.startsWith("***")) {
163                 parseAndAdd(line);
164             }
165         }
166     }
167     
168     /**
169      * Creates a new CART that will be populated with nodes later.
170      *
171      * @param numNodes the number of nodes
172      */
173     private CARTImpl(int numNodes) {
174         cart = new Node[numNodes];
175     }
176
177     /**
178      * Dumps this CART to the output stream.
179      *
180      * @param os the output stream
181      *
182      * @throws IOException if an error occurs during output
183      */
184     public void dumpBinary(DataOutputStream os) throws IOException {
185         os.writeInt(cart.length);
186         for (int i = 0; i < cart.length; i++) {
187             cart[i].dumpBinary(os);
188         }
189     }
190
191     /**
192      * Loads a CART from the input byte buffer.
193      *
194      * @param bb the byte buffer
195      *
196      * @return the CART
197      *
198      * @throws IOException if an error occurs during output
199      *
200      * Note that cart nodes are really saved as strings that
201      * have to be parsed.
202      */
203     public static CART loadBinary(ByteBuffer bb) throws IOException {
204         int numNodes = bb.getInt();
205         CARTImpl cart = new CARTImpl(numNodes);
206
207         for (int i = 0; i < numNodes; i++) {
208             String nodeCreationLine = Utilities.getString(bb);
209             cart.parseAndAdd(nodeCreationLine);
210         }
211         return cart;
212     }
213
214     /**
215      * Loads a CART from the input stream.
216      *
217      * @param is the input stream
218      *
219      * @return the CART
220      *
221      * @throws IOException if an error occurs during output
222      *
223      * Note that cart nodes are really saved as strings that
224      * have to be parsed.
225      */
226     public static CART loadBinary(DataInputStream is) throws IOException {
227         int numNodes = is.readInt();
228         CARTImpl cart = new CARTImpl(numNodes);
229
230         for (int i = 0; i < numNodes; i++) {
231             String nodeCreationLine = Utilities.getString(is);
232             cart.parseAndAdd(nodeCreationLine);
233         }
234         return cart;
235     }
236     
237     /**
238      * Creates a node from the given input line and add it to the CART.
239      * It expects the TOTAL line to come before any of the nodes.
240      *
241      * @param line a line of input to parse
242      */
243     protected void parseAndAdd(String line) {
244         StringTokenizer tokenizer = new StringTokenizer(line," ");
245         String type = tokenizer.nextToken();        
246         if (type.equals(LEAF) || type.equals(NODE)) {
247             cart[curNode] = getNode(type, tokenizer, curNode);
248             cart[curNode].setCreationLine(line);
249             curNode++;
250         } else if (type.equals(TOTAL)) {
251             cart = new Node[Integer.parseInt(tokenizer.nextToken())];
252             curNode = 0;
253         } else {
254             throw new Error("Invalid CART type: " + type);
255         }
256     }
257
258     /**
259      * Gets the node based upon the type and tokenizer.
260      *
261      * @param type <code>NODE</code> or <code>LEAF</code>
262      * @param tokenizer the StringTokenizer containing the data to get
263      * @param currentNode the index of the current node we're looking at
264      *
265      * @return the node
266      */
267     protected Node getNode(String type,
268                            StringTokenizer tokenizer,
269                            int currentNode) {
270         if (type.equals(NODE)) {
271             String feature = tokenizer.nextToken();
272             String operand = tokenizer.nextToken();
273             Object value = parseValue(tokenizer.nextToken());
274             int qfalse = Integer.parseInt(tokenizer.nextToken());
275             if (operand.equals(OPERAND_MATCHES)) {
276                 return new MatchingNode(feature,
277                                         value.toString(),
278                                         currentNode + 1,
279                                         qfalse);
280             } else {
281                 return new ComparisonNode(feature,
282                                           value,
283                                           operand,
284                                           currentNode + 1,
285                                           qfalse);
286             }
287         } else if (type.equals(LEAF)) {
288             return new LeafNode(parseValue(tokenizer.nextToken()));
289         }
290
291         return null;
292     }
293
294     /**
295      * Coerces a string into a value.
296      *
297      * @param string of the form "type(value)"; for example, "Float(2.3)"
298      *
299      * @return the value
300      */
301     protected Object parseValue(String string) {
302         int openParen = string.indexOf("(");
303         String type = string.substring(0,openParen);
304         String value = string.substring(openParen + 1, string.length() - 1);
305         if (type.equals("String")) {
306             return value;
307         } else if (type.equals("Float")) {
308             return new Float(Float.parseFloat(value));
309         } else if (type.equals("Integer")) {
310             return new Integer(Integer.parseInt(value));
311         } else if (type.equals("List")) {
312             StringTokenizer tok = new StringTokenizer(value, ",");
313             int size = tok.countTokens();
314
315             int[] values = new int[size];
316             for (int i = 0; i < size; i++) {
317                 float fval = Float.parseFloat(tok.nextToken());
318                 values[i] = Math.round(fval);
319             }
320             return values;
321         } else {
322             throw new Error("Unknown type: " + type);
323         }
324     }
325     
326     /**
327      * Passes the given item through this CART and returns the
328      * interpretation.
329      *
330      * @param item the item to analyze
331      *
332      * @return the interpretation
333      */
334     public Object interpret(Item item) {
335         int nodeIndex = 0;
336         DecisionNode decision;
337
338         while (!(cart[nodeIndex] instanceof LeafNode)) {
339             decision = (DecisionNode) cart[nodeIndex];
340             nodeIndex = decision.getNextNode(item);
341         }
342         if (LOGGER.isLoggable(Level.FINER)) {
343             LOGGER.finer("LEAF " + cart[nodeIndex].getValue());
344         }
345         return ((LeafNode) cart[nodeIndex]).getValue();
346     }
347
348     /**
349      * A node for the CART.
350      */
351     static abstract class Node {
352         /**
353          * The value of this node.
354          */
355         protected Object value;
356         private String creationLine;
357
358         /**
359          * Create a new Node with the given value.
360          */
361         public Node(Object value) {
362             this.value = value;
363         }
364     
365         /**
366          * Get the value.
367          */
368         public Object getValue() {
369             return value;
370         }
371
372         /**
373          * Return a string representation of the type of the value.
374          */
375         public String getValueString() {
376             if (value == null) {
377                 return "NULL()";
378             } else if (value instanceof String) {
379                 return "String(" + value.toString() + ")";
380             } else if (value instanceof Float) {
381                 return "Float(" + value.toString() + ")";
382             } else if (value instanceof Integer) {
383                 return "Integer(" + value.toString() + ")";
384             } else {
385                 return value.getClass().toString() + "(" + value.toString() + ")";
386             }
387         }    
388
389         /**
390          * sets the line of text used to create this node.
391          * @param line the creation line
392          */
393         public void setCreationLine(String line) {
394             creationLine = line;
395         }
396
397         /**
398          * Dumps the binary form of this node.
399          * @param os the output stream to output the node on
400          * @throws IOException if an IO error occurs
401          */
402         final public void dumpBinary(DataOutputStream os) throws IOException {
403             Utilities.outString(os, creationLine);
404         }
405     }
406
407     /**
408      * A decision node that determines the next Node to go to in the CART.
409      */
410     abstract static class DecisionNode extends Node {
411         /**
412          * The feature used to find a value from an Item.
413          */
414         private PathExtractor path;
415
416         /**
417          * Index of Node to go to if the comparison doesn't match.
418          */
419         protected int qfalse;
420
421         /**
422          * Index of Node to go to if the comparison matches.
423          */
424         protected int qtrue;
425
426         /**
427          * The feature used to find a value from an Item.
428          */
429         public String getFeature() {
430             return path.toString();
431         }
432
433
434         /**
435          * Find the feature associated with this DecisionNode
436          * and the given item
437          * @param item the item to start from
438          * @return the object representing the feature
439          */
440         public Object findFeature(Item item) {
441             return path.findFeature(item);
442         }
443
444
445         /**
446          * Returns the next node based upon the
447          * descision determined at this node
448          * @param item the current item.
449          * @return the index of the next node
450          */
451         public final int getNextNode(Item item) {
452             return getNextNode(findFeature(item));
453         }
454
455         /**
456          * Create a new DecisionNode.
457          * @param feature the string used to get a value from an Item
458          * @param value the value to compare to
459          * @param qtrue the Node index to go to if the comparison matches
460          * @param qfalse the Node machine index to go to upon no match
461          */
462         public DecisionNode(String feature,
463                             Object value,
464                             int qtrue,
465                             int qfalse) {
466             super(value);
467             this.path = new PathExtractorImpl(feature, true);
468             this.qtrue = qtrue;
469             this.qfalse = qfalse;
470         }
471     
472         /**
473          * Get the next Node to go to in the CART.  The return
474          * value is an index in the CART.
475          */
476         abstract public int getNextNode(Object val);
477     }
478
479     /**
480      * A decision Node that compares two values.
481      */
482     static class ComparisonNode extends DecisionNode {
483         /**
484          * LESS_THAN
485          */
486         final static String LESS_THAN = "<";
487     
488         /**
489          * EQUALS
490          */
491         final static String EQUALS = "=";
492     
493         /**
494          * GREATER_THAN
495          */
496         final static String GREATER_THAN = ">";
497     
498         /**
499          * The comparison type.  One of LESS_THAN, GREATER_THAN, or
500          *  EQUAL_TO.
501          */
502         String comparisonType;
503
504         /**
505          * Create a new ComparisonNode with the given values.
506          * @param feature the string used to get a value from an Item
507          * @param value the value to compare to
508          * @param comparisonType one of LESS_THAN, EQUAL_TO, or GREATER_THAN
509          * @param qtrue the Node index to go to if the comparison matches
510          * @param qfalse the Node index to go to upon no match
511          */
512         public ComparisonNode(String feature,
513                               Object value,
514                               String comparisonType,
515                               int qtrue,
516                               int qfalse) {
517             super(feature, value, qtrue, qfalse);
518             if (!comparisonType.equals(LESS_THAN)
519                 && !comparisonType.equals(EQUALS)
520                 && !comparisonType.equals(GREATER_THAN)) {
521                 throw new Error("Invalid comparison type: " + comparisonType);
522             } else {
523                 this.comparisonType = comparisonType;
524             }
525         }
526
527         /**
528          * Compare the given value and return the appropriate Node index.
529          * IMPLEMENTATION NOTE:  LESS_THAN and GREATER_THAN, the Node's
530          * value and the value passed in are converted to floating point
531          * values.  For EQUAL, the Node's value and the value passed in
532          * are treated as String compares.  This is the way of Flite, so
533          * be it Flite.
534          * @param val the value to compare
535          */
536         public int getNextNode(Object val) {
537             boolean yes = false;
538             int ret;
539
540             if (comparisonType.equals(LESS_THAN)
541                 || comparisonType.equals(GREATER_THAN)) {
542                 float cart_fval;
543                 float fval;
544                 if (value instanceof Float) {
545                     cart_fval = ((Float) value).floatValue();
546                 } else {
547                     cart_fval = Float.parseFloat(value.toString());
548                 }
549                 if (val instanceof Float) {
550                     fval = ((Float) val).floatValue();
551                 } else {
552                     fval = Float.parseFloat(val.toString());
553                 }
554                 if (comparisonType.equals(LESS_THAN)) {
555                     yes = (fval < cart_fval);
556                 } else {
557                     yes =  (fval > cart_fval);
558                 }
559             } else { // comparisonType = "="
560                 String sval = val.toString();
561                 String cart_sval = value.toString();
562                 yes = sval.equals(cart_sval);
563             }
564             if (yes) {
565                 ret = qtrue;
566             } else {
567                 ret = qfalse;
568             }
569
570             if (LOGGER.isLoggable(Level.FINER)) {
571                 LOGGER.finer(trace(val, yes, ret));
572             }
573
574             return ret;
575         }
576
577         private String trace(Object value, boolean match, int next) {
578             return
579                 "NODE " + getFeature() + " ["
580                 + value + "] " 
581                 + comparisonType + " [" 
582                 + getValue() + "] "
583                 + (match ? "Yes" : "No") + " next " +
584                     next;
585         }
586
587         /**
588          * Get a string representation of this Node.
589          */
590         public String toString() {
591             return
592                 "NODE " + getFeature() + " "
593                 + comparisonType + " "
594                 + getValueString() + " "
595                 + Integer.toString(qtrue) + " "
596                 + Integer.toString(qfalse);
597         }
598     }
599
600     /**
601      * A Node that checks for a regular expression match.
602      */
603     static class MatchingNode extends DecisionNode {
604         Pattern pattern;
605     
606         /**
607          * Create a new MatchingNode with the given values.
608          * @param feature the string used to get a value from an Item
609          * @param regex the regular expression
610          * @param qtrue the Node index to go to if the comparison matches
611          * @param qfalse the Node index to go to upon no match
612          */
613         public MatchingNode(String feature,
614                             String regex,
615                             int qtrue,
616                             int qfalse) {
617             super(feature, regex, qtrue, qfalse);
618             this.pattern = Pattern.compile(regex);
619         }
620
621         /**
622          * Compare the given value and return the appropriate CART index.
623          * @param val the value to compare -- this must be a String
624          */
625         public int getNextNode(Object val) {
626             return pattern.matcher((String) val).matches()
627                 ? qtrue
628                 : qfalse;
629         }
630
631         /**
632          * Get a string representation of this Node.
633          */
634         public String toString() {
635             StringBuffer buf = new StringBuffer(
636                 NODE + " " + getFeature() + " " + OPERAND_MATCHES);
637             buf.append(getValueString() + " ");
638             buf.append(Integer.toString(qtrue) + " ");
639             buf.append(Integer.toString(qfalse));
640             return buf.toString();
641         }
642     }
643
644     /**
645      * The final Node of a CART.  This just a marker class.
646      */
647     static class LeafNode extends Node {
648         /**
649          * Create a new LeafNode with the given value.
650          * @param the value of this LeafNode
651          */
652         public LeafNode(Object value) {
653             super(value);
654         }
655
656         /**
657          * Get a string representation of this Node.
658          */
659         public String toString() {
660             return "LEAF " + getValueString();
661         }
662     }
663 }
664