11.3 MNIST Client Example
# From : https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_input_data.py
"""Functions for downloading and reading MNIST data."""
from __future__ import print_function
import gzip
import os
import numpy
from six.moves import urllib
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Successfully downloaded %s %d bytes.' % (filename, statinfo.st_size))
return filepath
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting %s' % filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
def dense_to_one_hot(labels_dense, num_classes=10):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def extract_labels(filename, one_hot=False):
"""Extract the labels into a 1D uint8 numpy array [index]."""
print('Extracting %s' % filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError(
'Invalid magic number %d in MNIST label file: %s' %
(magic, filename))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
if one_hot:
return dense_to_one_hot(labels)
return labels
class DataSet(object):
"""Class encompassing test, validation and training MNIST data set."""
def __init__(self, images, labels, fake_data=False, one_hot=False):
"""Construct a DataSet. one_hot arg is used only if fake_data is true."""
if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape,
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert images.shape[3] == 1
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2])
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
def images(self):
return self._images
def labels(self):
return self._labels
def num_examples(self):
return self._num_examples
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size, fake_data=False):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
fake_label = 0
return [fake_image for _ in range(batch_size)], [
fake_label for _ in range(batch_size)
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Shuffle the data
perm = numpy.arange(self._num_examples)
self._images = self._images[perm]
self._labels = self._labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
def read_data_sets(train_dir, fake_data=False, one_hot=False):
"""Return training, validation and testing data sets."""
class DataSets(object):
data_sets = DataSets()
if fake_data:
data_sets.train = DataSet([], [], fake_data=True, one_hot=one_hot)
data_sets.validation = DataSet([], [], fake_data=True, one_hot=one_hot)
data_sets.test = DataSet([], [], fake_data=True, one_hot=one_hot)
return data_sets
local_file = maybe_download(TRAIN_IMAGES, train_dir)
train_images = extract_images(local_file)
local_file = maybe_download(TRAIN_LABELS, train_dir)
train_labels = extract_labels(local_file, one_hot=one_hot)
local_file = maybe_download(TEST_IMAGES, train_dir)
test_images = extract_images(local_file)
local_file = maybe_download(TEST_LABELS, train_dir)
test_labels = extract_labels(local_file, one_hot=one_hot)
validation_images = train_images[:VALIDATION_SIZE]
validation_labels = train_labels[:VALIDATION_SIZE]
train_images = train_images[VALIDATION_SIZE:]
train_labels = train_labels[VALIDATION_SIZE:]
data_sets.train = DataSet(train_images, train_labels)
data_sets.validation = DataSet(validation_images, validation_labels)
data_sets.test = DataSet(test_images, test_labels)
return data_sets
# From https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_client.py
from __future__ import print_function
import sys
import threading
from grpc.beta import implementations
import numpy
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
class _ResultCounter(object):
"""Counter for the prediction results."""
def __init__(self, num_tests, concurrency):
self._num_tests = num_tests
self._concurrency = concurrency
self._error = 0
self._done = 0
self._active = 0
self._condition = threading.Condition()
def inc_error(self):
with self._condition:
self._error += 1
def inc_done(self):
with self._condition:
self._done += 1
def dec_active(self):
with self._condition:
self._active -= 1
def get_error_rate(self):
with self._condition:
while self._done != self._num_tests:
return self._error / float(self._num_tests)
def throttle(self):
with self._condition:
while self._active == self._concurrency:
self._active += 1
def _create_rpc_callback(label, result_counter):
"""Creates RPC callback function.
label: The correct label for the predicted example.
result_counter: Counter for the prediction result.
The callback function.
def _callback(result_future):
"""Callback function.
Calculates the statistics for the prediction result.
result_future: Result future of the RPC.
exception = result_future.exception()
if exception:
response = numpy.array(
prediction = numpy.argmax(response)
if label != prediction:
return _callback
def do_inference(hostport, work_dir, concurrency, num_tests):
"""Tests PredictionService with concurrent requests.
hostport: Host:port address of the PredictionService.
work_dir: The full path of working directory for test data set.
concurrency: Maximum number of concurrent requests.
num_tests: Number of test images to use.
The classification error rate.
IOError: An error occurred processing test data set.
test_data_set = read_data_sets(work_dir).test #mnist_input_data.read_data_sets(work_dir).test
host, port = hostport.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
result_counter = _ResultCounter(num_tests, concurrency)
for i in range(num_tests):
request = predict_pb2.PredictRequest()
request.model_spec.name = 'mnist'
request.model_spec.signature_name = 'predict_images'
image, label = test_data_set.next_batch(1)
tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))
result_future = stub.Predict.future(request, 5.0) # 5 seconds
_create_rpc_callback(label[0], result_counter))
if i%50==0:
return result_counter.get_error_rate()
error_rate = do_inference(hostport='',
print('\nInference error rate: %s%%' % (error_rate * 100))
Extracting /home/armando/datasets/mnist/train-images-idx3-ubyte.gz
Extracting /home/armando/datasets/mnist/train-labels-idx1-ubyte.gz
Extracting /home/armando/datasets/mnist/t10k-images-idx3-ubyte.gz
Extracting /home/armando/datasets/mnist/t10k-labels-idx1-ubyte.gz
Inference error rate: 7.0%