tensorflowonspark.TFSparkNode module

This module provides low-level functions for managing the TensorFlowOnSpark cluster.

class TFNodeContext(executor_id=0, job_name='', task_index=0, cluster_spec={}, defaultFS='file://', working_dir='.', mgr=None, tmp_socket=None)[source]

Bases: object

Encapsulates unique metadata for a TensorFlowOnSpark node/executor and provides methods to interact with Spark and HDFS.

An instance of this object will be passed to the TensorFlow “main” function via the ctx argument. To simply the end-user API, this class now mirrors the functions of the TFNode module.

Args:
executor_id

integer identifier for this executor, per nodeRDD = sc.parallelize(range(num_executors), num_executors).

job_name

TensorFlow job name (e.g. ‘ps’ or ‘worker’) of this TF node, per cluster_spec.

task_index

integer rank per job_name, e.g. “worker:0”, “worker:1”, “ps:0”.

cluster_spec

dictionary for constructing a tf.train.ClusterSpec.

defaultFS

string representation of default FileSystem, e.g. file:// or hdfs://<namenode>:8020/.

working_dir

the current working directory for local filesystems, or YARN containers.

mgr

TFManager instance for this Python worker.

tmp_socket

temporary socket used to select random port for TF GRPC server.

absolute_path(path)[source]

Convenience function to access TFNode.hdfs_path directly from this object instance.

export_saved_model(sess, export_dir, tag_set, signatures)[source]

Convenience function to access TFNode.export_saved_model directly from this object instance.

get_data_feed(train_mode=True, qname_in='input', qname_out='output', input_mapping=None)[source]

Convenience function to access TFNode.DataFeed directly from this object instance.

release_port()[source]

Convenience function to access TFNode.release_assigned_port directly from this object instance.

start_cluster_server(num_gpus=1, rdma=False)[source]

Convenience function to access TFNode.start_cluster_server directly from this object instance.

class TFSparkNode[source]

Bases: object

Low-level functions used by the high-level TFCluster APIs to manage cluster state.

This class is not intended for end-users (see TFNode for end-user APIs).

For cluster management, this wraps the per-node cluster logic as Spark RDD mapPartitions functions, where the RDD is expected to be a “nodeRDD” of the form: nodeRDD = sc.parallelize(range(num_executors), num_executors).

For data feeding, this wraps the feeding logic as Spark RDD mapPartitions functions on a standard “dataRDD”.

This also manages a reference to the TFManager “singleton” per executor. Since Spark can spawn more than one python-worker per executor, this will reconnect to the “singleton” instance as needed.

cluster_id = None

Unique ID for a given TensorFlowOnSpark cluster, used for invalidating state for new clusters.

mgr = None

TFManager instance

inference(cluster_info, feed_timeout=600, qname='input')[source]

Feeds Spark partitions into the shared multiprocessing.Queue and returns inference results.

Args:
cluster_info

node reservation information for the cluster (e.g. host, executor_id, pid, ports, etc)

feed_timeout

number of seconds after which data feeding times out (600 sec default)

qname

INTERNAL_USE

Returns:

A dataRDD.mapPartitions() function

run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background)[source]

Wraps the user-provided TensorFlow main function in a Spark mapPartitions function.

Args:
fn

TensorFlow “main” function provided by the user.

tf_args

argparse args, or command line ARGV. These will be passed to the fn.

cluster_meta

dictionary of cluster metadata (e.g. cluster_id, reservation.Server address, etc).

tensorboard

boolean indicating if the chief worker should spawn a Tensorboard server.

log_dir

directory to save tensorboard event logs. If None, defaults to a fixed path on local filesystem.

queues

INTERNAL_USE

background

boolean indicating if the TensorFlow “main” function should be run in a background process.

Returns:

A nodeRDD.mapPartitions() function.

shutdown(cluster_info, grace_secs=0, queues=['input'])[source]

Stops all TensorFlow nodes by feeding None into the multiprocessing.Queues.

Args:
cluster_info

node reservation information for the cluster (e.g. host, executor_id, pid, ports, etc).

queues

INTERNAL_USE

Returns:

A nodeRDD.mapPartitions() function

train(cluster_info, cluster_meta, feed_timeout=600, qname='input')[source]

Feeds Spark partitions into the shared multiprocessing.Queue.

Args:
cluster_info

node reservation information for the cluster (e.g. host, executor_id, pid, ports, etc)

cluster_meta

dictionary of cluster metadata (e.g. cluster_id, reservation.Server address, etc)

feed_timeout

number of seconds after which data feeding times out (600 sec default)

qname

INTERNAL_USE

Returns:

A dataRDD.mapPartitions() function