Test Rules Created by Rpart

I want to programmatically test one rule generated from a tree. In trees, the path between the root and the leaf (terminal node) can be interpreted as a rule.

In R, we can use the rpart package and do the following: (In this post I will use the iris dataset, for example, for purposes only)

 library(rpart) model <- rpart(Species ~ ., data=iris) 

With these two lines, I got a tree named model , whose class is rpart.object ( rpart documentation, p. 21). This object has a lot of information and supports many methods. In particular, the object has a frame variable (which can be accessed in the standard way: model$frame ) (idem) and the path.rpath method ( rpart documentation, page 7), which gives you the path from the root of the node to the node of interest ( node argument in functions)

row.names frame variable contains the node number of the tree. The var column provides the split variable in node, yval set value, and yval2 is the class probability and other information.

 > model$frame var n wt dev yval complexity ncompete nsurrogate yval2.1 yval2.2 yval2.3 yval2.4 yval2.5 yval2.6 yval2.7 1 Petal.Length 150 150 100 1 0.50 3 3 1.00000000 50.00000000 50.00000000 50.00000000 0.33333333 0.33333333 0.33333333 2 <leaf> 50 50 0 1 0.01 0 0 1.00000000 50.00000000 0.00000000 0.00000000 1.00000000 0.00000000 0.00000000 3 Petal.Width 100 100 50 2 0.44 3 3 2.00000000 0.00000000 50.00000000 50.00000000 0.00000000 0.50000000 0.50000000 6 <leaf> 54 54 5 2 0.00 0 0 2.00000000 0.00000000 49.00000000 5.00000000 0.00000000 0.90740741 0.09259259 7 <leaf> 46 46 1 3 0.01 0 0 3.00000000 0.00000000 1.00000000 45.00000000 0.00000000 0.02173913 0.97826087 

But only those marked as <leaf> in the var column are terminal nodes (sheets). In this case, nodes 2, 6, and 7.

As mentioned above, you can use the path.rpart method to extract the rule (this method is used in the rattle package and in the Sharma Credit Score article, as follows:

In addition, the model stores the values โ€‹โ€‹of the predicted value in

 predicted.levels <- attr(model, "ylevels") 

This value corresponds to the yval column in the model$frame dataset.

For a sheet with node number 7 (line No. 5), the predicted value

 > ylevels[model$frame[5, ]$yval] [1] "virginica" 

and rule

 > rule <- path.rpart(model, nodes = 7) node number: 7 root Petal.Length>=2.45 Petal.Width>=1.75 

So the rule can be read as

 If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica 

I know that I can check (in the test dataset I use the diaphragm dataset again) how many true positive results I have for this rule, a subset of the new dataset as follows

 > hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75) 

and then computing the confusion matrix

 > table(hits$Species, hits$Species == "virginica") FALSE TRUE setosa 0 0 versicolor 1 0 virginica 0 45 

(Note: I used the same aperture dataset as the test)

How can I correctly evaluate the rule? I could extract the conditions from the rule as follows

 > unlist(rule, use.names = FALSE)[-1] [1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

But how can I continue here? I can not use the subset function

Thanks in advance

NOTE. This question has been heavily edited for clarity.

+6
source share
3 answers

I could solve it as follows

DISCLAIMER: Obviously, there should be better ways to resolve this issue, but this hack works and does what I want ... (I'm not very proud of it ... hacker, but it works)

Ok, let's get started. In principle, the idea uses the sqldf package

If you check the question, the last piece of code puts in the list every part of the tree path. So i will start from there

  library(sqldf) library(stringr) # Transform to a character vector rule.v <- unlist(rule, use.names=FALSE)[-1] # Remove all the dots, sqldf doesn't handles dots in names rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2") # We have to remove all the equal signs to 'in (' rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('") # Embrace all the elements in the lists of values with " ' " # The last element couldn't be modified in this way (Any ideas?) rule.v <- str_replace_all(rule.v, pattern=",", replacement="','") # Close the last element with apostrophe and a ")" for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) { rule.v[i] <- paste(append(rule.v[i], "')"), collapse="") } # Collapse all the list in one string joined by " AND " rule.v <- paste(rule.v, collapse = " AND ") # Generate the query # Use any metric that you can get from the data frame query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="") # For debug only... print(query) # Execute and print the results print(sqldf(query)) 

And it's all!

I warned you, it was a hacker ...

Hope this helps someone else ...

Thanks for the help and suggestions!

+3
source

In general, I do not recommend using eval(parse(...)) , but in this case it works:

Extract the rule:

 rule <- unname(unlist(path.rpart(model, nodes=7)))[-1] node number: 7 root Petal.Length>=2.45 Petal.Width>=1.75 rule [1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

Retrieve data using the rule:

 node_data <- with(iris, iris[eval(parse(text=paste(rule, collapse=" & "))), ]) head(node_data) Sepal.Length Sepal.Width Petal.Length Petal.Width Species 71 5.9 3.2 4.8 1.8 versicolor 101 6.3 3.3 6.0 2.5 virginica 102 5.8 2.7 5.1 1.9 virginica 103 7.1 3.0 5.9 2.1 virginica 104 6.3 2.9 5.6 1.8 virginica 105 6.5 3.0 5.8 2.2 virginica 
+2
source

Beginning with

 Rule number: 16 [yval=bad cover=220 N=121 Y=99 (37%) prob=0.04] checking< 2.5 afford< 54 history< 3.5 coapp< 2.5 

You will have a "prob" vector that starts as all zeros that you can update with rule 16:

 prob <- ifelse( dat[['checking']] < 2.5 & dat[['afford']] < 54 dat[['history']] < 3.5 dat[['coapp']] < 2.5) , 0.04, prob ) 

Then you will need to follow all other rules (which should not change the probabilities for this case, since the tree should be disjoint estimates). Probably, there can be much more effective methods for making forecasts than this. For example ... function predict.rpart .

+1
source

Source: https://habr.com/ru/post/922229/


All Articles