2 * Portions Copyright 2001 Sun Microsystems, Inc.
3 * Portions Copyright 1999-2001 Language Technologies Institute,
4 * Carnegie Mellon University.
5 * All Rights Reserved. Use is subject to license terms.
7 * See the file "license.terms" for information on usage and
8 * redistribution of this file, and for a DISCLAIMER OF ALL
11 package com.sun.speech.freetts.cart;
13 import java.io.BufferedReader;
14 import java.io.DataInputStream;
15 import java.io.DataOutputStream;
16 import java.io.IOException;
17 import java.io.InputStreamReader;
19 import java.nio.ByteBuffer;
20 import java.util.StringTokenizer;
21 import java.util.logging.Level;
22 import java.util.logging.Logger;
23 import java.util.regex.Pattern;
25 import com.sun.speech.freetts.Item;
26 import com.sun.speech.freetts.PathExtractor;
27 import com.sun.speech.freetts.PathExtractorImpl;
28 import com.sun.speech.freetts.util.Utilities;
31 * Implementation of a Classification and Regression Tree (CART) that is
32 * used more like a binary decision tree, with each node containing a
33 * decision or a final value. The decision nodes in the CART trees
34 * operate on an Item and have the following format:
37 * NODE feat operand value qfalse
40 * <p>Where <code>feat</code> is an string that represents a feature
41 * to pass to the <code>findFeature</code> method of an item.
43 * <p>The <code>value</code> represents the value to be compared against
44 * the feature obtained from the item via the <code>feat</code> string.
45 * The <code>operand</code> is the operation to do the comparison. The
46 * available operands are as follows:
49 * <li>< - the feature is less than value
50 * <li>= - the feature is equal to the value
51 * <li>> - the feature is greater than the value
52 * <li>MATCHES - the feature matches the regular expression stored in value
53 * <li>IN - [[[TODO: still guessing because none of the CART's in
54 * Flite seem to use IN]]] the value is in the list defined by the
58 * <p>[[[TODO: provide support for the IN operator.]]]
60 * <p>For < and >, this CART coerces the value and feature to
61 * float's. For =, this CART coerces the value and feature to string and
62 * checks for string equality. For MATCHES, this CART uses the value as a
63 * regular expression and compares the obtained feature to that.
65 * <p>A CART is represented by an array in this implementation. The
66 * <code>qfalse</code> value represents the index of the array to go to if
67 * the comparison does not match. In this implementation, qtrue index
68 * is always implied, and represents the next element in the
69 * array. The root node of the CART is the first element in the array.
71 * <p>The interpretations always start at the root node of the CART
72 * and continue until a final node is found. The final nodes have the
79 * <p>Where <code>value</code> represents the value of the node.
80 * Reaching a final node indicates the interpretation is over and the
81 * value of the node is the interpretation result.
83 public class CARTImpl implements CART {
84 /** Logger instance. */
85 private static final Logger LOGGER =
86 Logger.getLogger(CARTImpl.class.getName());
88 * Entry in file represents the total number of nodes in the
89 * file. This should be at the top of the file. The format
90 * should be "TOTAL n" where n is an integer value.
92 final static String TOTAL = "TOTAL";
95 * Entry in file represents a node. The format should be
96 * "NODE feat op val f" where 'feat' represents a feature, op
97 * represents an operand, val is the value, and f is the index
98 * of the node to go to is there isn't a match.
100 final static String NODE = "NODE";
103 * Entry in file represents a final node. The format should be
104 * "LEAF val" where val represents the value.
106 final static String LEAF = "LEAF";
111 final static String OPERAND_MATCHES = "MATCHES";
115 * The CART. Entries can be DecisionNode or LeafNode. An
116 * ArrayList could be used here -- I chose not to because I
117 * thought it might be quicker to avoid dealing with the dynamic
123 * The number of nodes in the CART.
125 transient int curNode = 0;
128 * Creates a new CART by reading from the given URL.
130 * @param url the location of the CART data
132 * @throws IOException if errors occur while reading the data
134 public CARTImpl(URL url) throws IOException {
135 BufferedReader reader;
138 reader = new BufferedReader(new InputStreamReader(url.openStream()));
139 line = reader.readLine();
140 while (line != null) {
141 if (!line.startsWith("***")) {
144 line = reader.readLine();
150 * Creates a new CART by reading from the given reader.
152 * @param reader the source of the CART data
153 * @param nodes the number of nodes to read for this cart
155 * @throws IOException if errors occur while reading the data
157 public CARTImpl(BufferedReader reader, int nodes) throws IOException {
160 for (int i = 0; i < nodes; i++) {
161 line = reader.readLine();
162 if (!line.startsWith("***")) {
169 * Creates a new CART that will be populated with nodes later.
171 * @param numNodes the number of nodes
173 private CARTImpl(int numNodes) {
174 cart = new Node[numNodes];
178 * Dumps this CART to the output stream.
180 * @param os the output stream
182 * @throws IOException if an error occurs during output
184 public void dumpBinary(DataOutputStream os) throws IOException {
185 os.writeInt(cart.length);
186 for (int i = 0; i < cart.length; i++) {
187 cart[i].dumpBinary(os);
192 * Loads a CART from the input byte buffer.
194 * @param bb the byte buffer
198 * @throws IOException if an error occurs during output
200 * Note that cart nodes are really saved as strings that
203 public static CART loadBinary(ByteBuffer bb) throws IOException {
204 int numNodes = bb.getInt();
205 CARTImpl cart = new CARTImpl(numNodes);
207 for (int i = 0; i < numNodes; i++) {
208 String nodeCreationLine = Utilities.getString(bb);
209 cart.parseAndAdd(nodeCreationLine);
215 * Loads a CART from the input stream.
217 * @param is the input stream
221 * @throws IOException if an error occurs during output
223 * Note that cart nodes are really saved as strings that
226 public static CART loadBinary(DataInputStream is) throws IOException {
227 int numNodes = is.readInt();
228 CARTImpl cart = new CARTImpl(numNodes);
230 for (int i = 0; i < numNodes; i++) {
231 String nodeCreationLine = Utilities.getString(is);
232 cart.parseAndAdd(nodeCreationLine);
238 * Creates a node from the given input line and add it to the CART.
239 * It expects the TOTAL line to come before any of the nodes.
241 * @param line a line of input to parse
243 protected void parseAndAdd(String line) {
244 StringTokenizer tokenizer = new StringTokenizer(line," ");
245 String type = tokenizer.nextToken();
246 if (type.equals(LEAF) || type.equals(NODE)) {
247 cart[curNode] = getNode(type, tokenizer, curNode);
248 cart[curNode].setCreationLine(line);
250 } else if (type.equals(TOTAL)) {
251 cart = new Node[Integer.parseInt(tokenizer.nextToken())];
254 throw new Error("Invalid CART type: " + type);
259 * Gets the node based upon the type and tokenizer.
261 * @param type <code>NODE</code> or <code>LEAF</code>
262 * @param tokenizer the StringTokenizer containing the data to get
263 * @param currentNode the index of the current node we're looking at
267 protected Node getNode(String type,
268 StringTokenizer tokenizer,
270 if (type.equals(NODE)) {
271 String feature = tokenizer.nextToken();
272 String operand = tokenizer.nextToken();
273 Object value = parseValue(tokenizer.nextToken());
274 int qfalse = Integer.parseInt(tokenizer.nextToken());
275 if (operand.equals(OPERAND_MATCHES)) {
276 return new MatchingNode(feature,
281 return new ComparisonNode(feature,
287 } else if (type.equals(LEAF)) {
288 return new LeafNode(parseValue(tokenizer.nextToken()));
295 * Coerces a string into a value.
297 * @param string of the form "type(value)"; for example, "Float(2.3)"
301 protected Object parseValue(String string) {
302 int openParen = string.indexOf("(");
303 String type = string.substring(0,openParen);
304 String value = string.substring(openParen + 1, string.length() - 1);
305 if (type.equals("String")) {
307 } else if (type.equals("Float")) {
308 return new Float(Float.parseFloat(value));
309 } else if (type.equals("Integer")) {
310 return new Integer(Integer.parseInt(value));
311 } else if (type.equals("List")) {
312 StringTokenizer tok = new StringTokenizer(value, ",");
313 int size = tok.countTokens();
315 int[] values = new int[size];
316 for (int i = 0; i < size; i++) {
317 float fval = Float.parseFloat(tok.nextToken());
318 values[i] = Math.round(fval);
322 throw new Error("Unknown type: " + type);
327 * Passes the given item through this CART and returns the
330 * @param item the item to analyze
332 * @return the interpretation
334 public Object interpret(Item item) {
336 DecisionNode decision;
338 while (!(cart[nodeIndex] instanceof LeafNode)) {
339 decision = (DecisionNode) cart[nodeIndex];
340 nodeIndex = decision.getNextNode(item);
342 if (LOGGER.isLoggable(Level.FINER)) {
343 LOGGER.finer("LEAF " + cart[nodeIndex].getValue());
345 return ((LeafNode) cart[nodeIndex]).getValue();
349 * A node for the CART.
351 static abstract class Node {
353 * The value of this node.
355 protected Object value;
356 private String creationLine;
359 * Create a new Node with the given value.
361 public Node(Object value) {
368 public Object getValue() {
373 * Return a string representation of the type of the value.
375 public String getValueString() {
378 } else if (value instanceof String) {
379 return "String(" + value.toString() + ")";
380 } else if (value instanceof Float) {
381 return "Float(" + value.toString() + ")";
382 } else if (value instanceof Integer) {
383 return "Integer(" + value.toString() + ")";
385 return value.getClass().toString() + "(" + value.toString() + ")";
390 * sets the line of text used to create this node.
391 * @param line the creation line
393 public void setCreationLine(String line) {
398 * Dumps the binary form of this node.
399 * @param os the output stream to output the node on
400 * @throws IOException if an IO error occurs
402 final public void dumpBinary(DataOutputStream os) throws IOException {
403 Utilities.outString(os, creationLine);
408 * A decision node that determines the next Node to go to in the CART.
410 abstract static class DecisionNode extends Node {
412 * The feature used to find a value from an Item.
414 private PathExtractor path;
417 * Index of Node to go to if the comparison doesn't match.
419 protected int qfalse;
422 * Index of Node to go to if the comparison matches.
427 * The feature used to find a value from an Item.
429 public String getFeature() {
430 return path.toString();
435 * Find the feature associated with this DecisionNode
437 * @param item the item to start from
438 * @return the object representing the feature
440 public Object findFeature(Item item) {
441 return path.findFeature(item);
446 * Returns the next node based upon the
447 * descision determined at this node
448 * @param item the current item.
449 * @return the index of the next node
451 public final int getNextNode(Item item) {
452 return getNextNode(findFeature(item));
456 * Create a new DecisionNode.
457 * @param feature the string used to get a value from an Item
458 * @param value the value to compare to
459 * @param qtrue the Node index to go to if the comparison matches
460 * @param qfalse the Node machine index to go to upon no match
462 public DecisionNode(String feature,
467 this.path = new PathExtractorImpl(feature, true);
469 this.qfalse = qfalse;
473 * Get the next Node to go to in the CART. The return
474 * value is an index in the CART.
476 abstract public int getNextNode(Object val);
480 * A decision Node that compares two values.
482 static class ComparisonNode extends DecisionNode {
486 final static String LESS_THAN = "<";
491 final static String EQUALS = "=";
496 final static String GREATER_THAN = ">";
499 * The comparison type. One of LESS_THAN, GREATER_THAN, or
502 String comparisonType;
505 * Create a new ComparisonNode with the given values.
506 * @param feature the string used to get a value from an Item
507 * @param value the value to compare to
508 * @param comparisonType one of LESS_THAN, EQUAL_TO, or GREATER_THAN
509 * @param qtrue the Node index to go to if the comparison matches
510 * @param qfalse the Node index to go to upon no match
512 public ComparisonNode(String feature,
514 String comparisonType,
517 super(feature, value, qtrue, qfalse);
518 if (!comparisonType.equals(LESS_THAN)
519 && !comparisonType.equals(EQUALS)
520 && !comparisonType.equals(GREATER_THAN)) {
521 throw new Error("Invalid comparison type: " + comparisonType);
523 this.comparisonType = comparisonType;
528 * Compare the given value and return the appropriate Node index.
529 * IMPLEMENTATION NOTE: LESS_THAN and GREATER_THAN, the Node's
530 * value and the value passed in are converted to floating point
531 * values. For EQUAL, the Node's value and the value passed in
532 * are treated as String compares. This is the way of Flite, so
534 * @param val the value to compare
536 public int getNextNode(Object val) {
540 if (comparisonType.equals(LESS_THAN)
541 || comparisonType.equals(GREATER_THAN)) {
544 if (value instanceof Float) {
545 cart_fval = ((Float) value).floatValue();
547 cart_fval = Float.parseFloat(value.toString());
549 if (val instanceof Float) {
550 fval = ((Float) val).floatValue();
552 fval = Float.parseFloat(val.toString());
554 if (comparisonType.equals(LESS_THAN)) {
555 yes = (fval < cart_fval);
557 yes = (fval > cart_fval);
559 } else { // comparisonType = "="
560 String sval = val.toString();
561 String cart_sval = value.toString();
562 yes = sval.equals(cart_sval);
570 if (LOGGER.isLoggable(Level.FINER)) {
571 LOGGER.finer(trace(val, yes, ret));
577 private String trace(Object value, boolean match, int next) {
579 "NODE " + getFeature() + " ["
581 + comparisonType + " ["
583 + (match ? "Yes" : "No") + " next " +
588 * Get a string representation of this Node.
590 public String toString() {
592 "NODE " + getFeature() + " "
593 + comparisonType + " "
594 + getValueString() + " "
595 + Integer.toString(qtrue) + " "
596 + Integer.toString(qfalse);
601 * A Node that checks for a regular expression match.
603 static class MatchingNode extends DecisionNode {
607 * Create a new MatchingNode with the given values.
608 * @param feature the string used to get a value from an Item
609 * @param regex the regular expression
610 * @param qtrue the Node index to go to if the comparison matches
611 * @param qfalse the Node index to go to upon no match
613 public MatchingNode(String feature,
617 super(feature, regex, qtrue, qfalse);
618 this.pattern = Pattern.compile(regex);
622 * Compare the given value and return the appropriate CART index.
623 * @param val the value to compare -- this must be a String
625 public int getNextNode(Object val) {
626 return pattern.matcher((String) val).matches()
632 * Get a string representation of this Node.
634 public String toString() {
635 StringBuffer buf = new StringBuffer(
636 NODE + " " + getFeature() + " " + OPERAND_MATCHES);
637 buf.append(getValueString() + " ");
638 buf.append(Integer.toString(qtrue) + " ");
639 buf.append(Integer.toString(qfalse));
640 return buf.toString();
645 * The final Node of a CART. This just a marker class.
647 static class LeafNode extends Node {
649 * Create a new LeafNode with the given value.
650 * @param the value of this LeafNode
652 public LeafNode(Object value) {
657 * Get a string representation of this Node.
659 public String toString() {
660 return "LEAF " + getValueString();