//:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** @author John Miller, Susan George * @version 1.6 * @date Wed May 22 14:17:49 EDT 2019 * @see LICENSE (MIT style license file). * * @title Model: ID3 Decision/Classification Tree with Pruning */ package scalation.analytics package classifier import scala.collection.mutable.Set import scalation.linalgebra.{MatriI, MatrixI, VectoI} import scalation.random.RandomVecI import scalation.util.banner import DecisionTree.hp //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** The `DecisionTreeID3wp` class extends `DecisionTreeID3` with pruning capabilities. * The base class uses the ID3 algorithm to construct a decision tree for classifying * instance vectors. * @param x the input/data matrix with instances stored in rows * @param y the response/classification vector, where y_i = class for row i of matrix x * @param fn_ the names for all features/variables * @param k the number of classes * @param cn_ the names for all classes * @param hparam the hyper-parameters for the decision tree */ class DecisionTreeID3wp (x: MatriI, y: VectoI, fn_ : Strings = null, k: Int = 2, cn_ : Strings = null, hparam: HyperParameter = hp) extends DecisionTreeID3 (x, y, fn_, k, cn_, hparam) { private val DEBUG = true // debug flag //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Find candidate nodes that may be pruned, i.e., those that are parents * of leaf nodes, restricted to those that don't have any children that * are themselves internal nodes. */ def candidates: Set [Node] = { val can = Set [Node] () for (n <- leaves) { val p = n.parent if (leafChildren (p)) can += p } // for can } // candidates //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Determine whether all the children of node 'n' are leaf nodes. * @param n the node in question */ def leafChildren (n: Node): Boolean = { for (c <- n.branch.values if ! c.isLeaf) return false true } // leafChildren //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Of all the pruning candidates, find the one with the least gain. * @param can the nodes that are canidates for pruning */ def bestCandidate (can: Set [Node]): (Node, Double) = { var min = Double.MaxValue var best: Node = null for (n <- can if n.gn < min) { min = n.gn; best = n } (best, min) } // bestCandidate //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Prune 'nPrune' nodes from the tree, the ones providing the least gain. * @param nPrune the number of nodes to be pruned. * @param threshold cut-off for pruning (IG < threshold, then prune) */ def prune (nPrune: Int = 1, threshold: Double = 0.98) { for (i <- 0 until nPrune) { val can = candidates if (DEBUG) println (s"can = $can") val (best, gn) = bestCandidate (can) println (s"prune: node $best with gain $gn identfied as bestCandidate") if (gn < threshold) { println (s"prune: make node $best with gain $gn into a leaf") makeLeaf (best) } // if } // for } // prune } // DecisionTreeID3wp class //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** The `DecisionTreeID3wp` companion object provides a factory function. */ object DecisionTreeID3wp extends App { import ClassifierInt.pullResponse //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Create a decision tree for the given combined matrix where the last column * is the response/classification vector. * @param xy the combined data matrix (features and response) * @param fn the names for all features/variables * @param k the number of classes * @param cn the names for all classes * @param hparam the hyper-parameters for the decision tree */ def apply (xy: MatriI, fn: Strings, k: Int, cn: Strings, hparam: HyperParameter = hp): DecisionTreeID3wp = { val (x, y) = pullResponse (xy) new DecisionTreeID3wp (x, y, fn, k, cn, hparam) } // apply } // DecisionTreeID3wp object //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** The `DecisionTreeID3wpTest` object is used to test the `DecisionTreeID3wp` class. * > runMain scalation.analytics.classifier.DecisionTreeID3wpTest */ object DecisionTreeID3wpTest extends App { import ExampleTennis.{xy, fn, k, cn} val tree = DecisionTreeID3wp (xy, fn, k, cn) tree.train () banner ("Orignal Tree: entropy = " + tree.calcEntropy ()) tree.printTree () tree.prune (2) banner ("Pruned Tree: entropy = " + tree.calcEntropy ()) tree.printTree () } // DecisionTreeID3wpTest object //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** The `DecisionTreeID3wpTest2` object is used to test the `DecisionTreeID3wp` class. * > runMain scalation.analytics.classifier.DecisionTreeID3wpTest2 */ object DecisionTreeID3wpTest2 extends App { import ClassifierInt.pullResponse val fname = BASE_DIR + "breast_cancer.csv" val xy = MatrixI (fname) val fn = Array ("Clump Thickness", "Uniformity of Cell Size", "Uniformity of Cell Shape", "Marginal Adhesion", "Single Epithelial Cell Size", "Bare Nuclei", "Bland Chromatin", "Normal Nucleoli", "Mitoses") val cn = Array ("benign", "malignant") val k = cn.size banner ("create, train and print a ID3 decision tree") println (s"dataset xy: ${xy.dim1}-by-${xy.dim2} matrix") val (x, y) = pullResponse (xy) val ymin = y.min () println (s"unadjusted ymin = $ymin") if (ymin != 0) y -= ymin val tree = new DecisionTreeID3wp (x, y, fn, k, cn) tree.train () tree.printTree () tree.prune () tree.printTree () } // DecisionTreeID3wpTest2 object //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** The `DecisionTreeID3wpTest3` object is used to test the `DecisionTreeID3wp` class. * > runMain scalation.analytics.classifier.DecisionTreeID3wpTest3 */ object DecisionTreeID3wpTest3 extends App { val fname = BASE_DIR + "breast_cancer.csv" val xy = MatrixI (fname) val fn = Array ("Clump Thickness", "Uniformity of Cell Size", "Uniformity of Cell Shape", "Marginal Adhesion", "Single Epithelial Cell Size", "Bare Nuclei", "Bland Chromatin", "Normal Nucleoli", "Mitoses") val cn = Array ("benign", "malignant") val k = cn.size val (x, y) = ClassifierInt.pullResponse (xy) val ymin = y.min () println (s"unadjusted ymin = $ymin") if (ymin != 0) y -= ymin // Divide samples into training and testing dataset val trainSize = (y.dim * 0.7).toInt val rvv = RandomVecI (min = 0, max = y.dim-1, dim = trainSize, unique = true, stream = 223) val trainData = new MatrixI (trainSize, xy.dim2) val testData = new MatrixI (xy.dim1-trainSize, xy.dim2) val index = rvv.igen var trainCount = 0 var testCount = 0 for ( i <- y.range) { if (index contains i) { trainData.set (trainCount, xy(i)) trainCount += 1 } else { testData.set (testCount, xy(i)) testCount += 1 } // if } // for val testFeature = testData.selectCols (Range (0, testData.dim2-1).toArray) val testTarget = testData.col (testData.dim2-1) val tree = new DecisionTreeID3wp (trainData.selectCols (Range(0, xy.dim2 - 1).toArray), trainData.col (trainData.dim2 - 1).toInt, fn, k, cn) tree.train () // Print the accuracy for unseen data var accurateCount = 0.0 for (i <- 0 until testFeature.dim1) { val d = tree.classify (testFeature(i))._1 if (tree.classify (testFeature(i))._1 == testTarget(i)) accurateCount += 1 } // for var accuracy = accurateCount / testFeature.dim1 println (s"Testing Accuracy = $accuracy") tree.prune (5) accurateCount = 0.0 for (i <- 0 until testFeature.dim1) { val d = tree.classify (testFeature(i))._1 if (tree.classify (testFeature(i))._1 == testTarget(i)) accurateCount += 1 } // for accuracy = accurateCount / testFeature.dim1 println (s"Testing Accuracy = $accuracy") } // DecisionTreeID3wpTest3 object