Random Forest Example

Implement Random Forest algorithm with TensorFlow, and apply it to classify handwritten digit images. This example is using the MNIST database of handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).

from __future__ import print_function

import tensorflow as tf
from tensorflow.python.ops import resources
from tensorflow.contrib.tensor_forest.python import tensor_forest

# Ignore all GPUs, tf random forest does not benefit from it.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
# Parameters
num_steps = 500 # Total steps to train
batch_size = 1024 # The number of samples per batch
num_classes = 10 # The 10 digits
num_features = 784 # Each image is 28x28 pixels
num_trees = 10
max_nodes = 1000

# Input and Target data
X = tf.placeholder(tf.float32, shape=[None, num_features])
# For random forest, labels must be integers (the class id)
Y = tf.placeholder(tf.int32, shape=[None])

# Random Forest Parameters
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
                                      num_features=num_features,
                                      num_trees=num_trees,
                                      max_nodes=max_nodes).fill()
# Build the Random Forest
forest_graph = tensor_forest.RandomForestGraphs(hparams)
# Get training graph and loss
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X, Y)

# Measure the accuracy
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Initialize the variables (i.e. assign their default value) and forest resources
init_vars = tf.group(tf.global_variables_initializer(),
    resources.initialize_resources(resources.shared_resources()))
INFO:tensorflow:Constructing forest with params = 
INFO:tensorflow:{'valid_leaf_threshold': 1, 'split_after_samples': 250, 'num_output_columns': 11, 'feature_bagging_fraction': 1.0, 'split_initializations_per_input': 3, 'bagged_features': None, 'min_split_samples': 5, 'max_nodes': 1000, 'num_features': 784, 'num_trees': 10, 'num_splits_to_consider': 784, 'base_random_seed': 0, 'num_outputs': 1, 'dominate_fraction': 0.99, 'max_fertile_nodes': 500, 'bagged_num_features': 784, 'dominate_method': 'bootstrap', 'bagging_fraction': 1.0, 'regression': False, 'num_classes': 10}
INFO:tensorflow:training graph for tree: 0
INFO:tensorflow:training graph for tree: 1
INFO:tensorflow:training graph for tree: 2
INFO:tensorflow:training graph for tree: 3
INFO:tensorflow:training graph for tree: 4
INFO:tensorflow:training graph for tree: 5
INFO:tensorflow:training graph for tree: 6
INFO:tensorflow:training graph for tree: 7
INFO:tensorflow:training graph for tree: 8
INFO:tensorflow:training graph for tree: 9
# Start TensorFlow session
sess = tf.train.MonitoredSession()

# Run the initializer
sess.run(init_vars)

# Training
for i in range(1, num_steps + 1):
    # Prepare Data
    # Get the next batch of MNIST data (only images are needed, not labels)
    batch_x, batch_y = mnist.train.next_batch(batch_size)
    _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
    if i % 50 == 0 or i == 1:
        acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
        print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

# Test Model
test_x, test_y = mnist.test.images, mnist.test.labels
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
Step 1, Loss: -0.000000, Acc: 0.112305
Step 50, Loss: -123.800003, Acc: 0.863281
Step 100, Loss: -274.200012, Acc: 0.863281
Step 150, Loss: -425.399994, Acc: 0.872070
Step 200, Loss: -582.799988, Acc: 0.917969
Step 250, Loss: -740.200012, Acc: 0.912109
Step 300, Loss: -895.799988, Acc: 0.939453
Step 350, Loss: -998.000000, Acc: 0.924805
Step 400, Loss: -998.000000, Acc: 0.940430
Step 450, Loss: -998.000000, Acc: 0.914062
Step 500, Loss: -998.000000, Acc: 0.927734
Test Accuracy: 0.9204

results matching ""

    No results matching ""