I am using scikit-learn DecissionTreeClassifier in a class 3 dataset. After I adjust the classifier, I refer to all leaf nodes of the tree_ attribute to get the number of instances that fall into this node for each class.
clf = tree.DecisionTreeClassifier(max_depth=5) clf.fit(X, y)
This will print:
>>> array([[ 0., 1., 68.]])
but ... how do you know which position in this array belongs to the class? The classifier has the classes_ attribute, which is also a list.
>>> clf.classes_ array(['CLASS_1', 'CLASS_2', 'CLASS_3'], dtype=object)
Maybe index 1 in the array of values corresponds to the class in index 1 of the array of classes, etc.?
source share