15.2 mnist dist slurm
import sys
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('job_name','ps','name of the job, default ps')
tf.flags.DEFINE_integer('task_index',0,'index of the job, default 0')
tf.flags.DEFINE_string('ps_hosts','localhost:9001','Comma-separated list of hostname:port pairs, default localhost:9001')
tf.flags.DEFINE_string('worker_hosts','localhost:9002','Comma-separated list of hostname:port pairs, default localhost:9002')
def main(_):
mnist = input_data.read_data_sets('/home/armando/datasets/mnist', one_hot=True)
ps = FLAGS.ps_hosts.split(',')
workers = FLAGS.worker_hosts.split(',')
clusterSpec = tf.train.ClusterSpec({'ps': ps, 'worker': workers})
config = tf.ConfigProto()
config.allow_soft_placement = True
if FLAGS.job_name=='ps':
config.device_count['GPU']=0
server = tf.train.Server(clusterSpec,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index,
config=config
)
server.join()
sys.exit('0')
elif FLAGS.job_name=='worker':
config.gpu_options.per_process_gpu_memory_fraction = 0.2
server = tf.train.Server(clusterSpec,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index,
config=config
)
is_chief = (FLAGS.task_index==0)
worker_device='/job:worker/task:{}'.format(FLAGS.task_index)
device_func = tf.train.replica_device_setter(worker_device=worker_device,
cluster=clusterSpec
)
with tf.device(device_func):
global_step = tf.train.get_or_create_global_step()
x_test = mnist.test.images
y_test = mnist.test.labels
n_outputs = 10
n_inputs = 784
learning_rate = 0.01
n_epochs = 50
batch_size = 100
n_batches = int(mnist.train.num_examples/batch_size)
n_epochs_print=10
x_p = tf.placeholder(dtype=tf.float32,
name='x_p',
shape=[None, n_inputs])
y_p = tf.placeholder(dtype=tf.float32,
name='y_p',
shape=[None, n_outputs])
w = tf.Variable(tf.random_normal([n_inputs, n_outputs],
name='w'
)
)
b = tf.Variable(tf.random_normal([n_outputs],
name='b'
)
)
logits = tf.matmul(x_p,w) + b
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_p,
logits=logits
)
loss_op = tf.reduce_mean(cross_entropy)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.minimize(loss_op,global_step=global_step)
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y_p, 1))
accuracy_op = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
sv = tf.train.Supervisor(is_chief=is_chief,
init_op = tf.global_variables_initializer(),
global_step=global_step)
with sv.prepare_or_wait_for_session(server.target) as mts:
lstep = 0
for epoch in range(n_epochs):
for batch in range(n_batches):
x_batch, y_batch = mnist.train.next_batch(batch_size)
feed_dict={x_p:x_batch,y_p:y_batch}
_,loss,gstep=mts.run([train_op,loss_op,global_step],
feed_dict=feed_dict)
lstep +=1
if (epoch+1)%n_epochs_print==0:
print('worker={},epoch={},global_step={}, local_step={}, loss = {}'.
format(FLAGS.task_index,epoch,gstep,lstep,loss))
feed_dict={x_p:x_test,y_p:y_test}
accuracy = mts.run(accuracy_op, feed_dict=feed_dict)
print('worker={}, final accuracy = {}'.format(FLAGS.task_index,accuracy))
sv.stop()
if __name__ == '__main__':
tf.app.run()