tensorflowonspark.TFSparkNode module
This module provides low-level functions for managing the TensorFlowOnSpark cluster.
- class TFNodeContext(executor_id=0, job_name='', task_index=0, cluster_spec={}, defaultFS='file://', working_dir='.', mgr=None, tmp_socket=None)[source]
Bases:
object
Encapsulates unique metadata for a TensorFlowOnSpark node/executor and provides methods to interact with Spark and HDFS.
An instance of this object will be passed to the TensorFlow “main” function via the ctx argument. To simply the end-user API, this class now mirrors the functions of the TFNode module.
- Args:
- executor_id
integer identifier for this executor, per
nodeRDD = sc.parallelize(range(num_executors), num_executors).
- job_name
TensorFlow job name (e.g. ‘ps’ or ‘worker’) of this TF node, per cluster_spec.
- task_index
integer rank per job_name, e.g. “worker:0”, “worker:1”, “ps:0”.
- cluster_spec
dictionary for constructing a tf.train.ClusterSpec.
- defaultFS
string representation of default FileSystem, e.g.
file://
orhdfs://<namenode>:8020/
.- working_dir
the current working directory for local filesystems, or YARN containers.
- mgr
TFManager instance for this Python worker.
- tmp_socket
temporary socket used to select random port for TF GRPC server.
- absolute_path(path)[source]
Convenience function to access
TFNode.hdfs_path
directly from this object instance.
- export_saved_model(sess, export_dir, tag_set, signatures)[source]
Convenience function to access
TFNode.export_saved_model
directly from this object instance.
- get_data_feed(train_mode=True, qname_in='input', qname_out='output', input_mapping=None)[source]
Convenience function to access
TFNode.DataFeed
directly from this object instance.
- class TFSparkNode[source]
Bases:
object
Low-level functions used by the high-level TFCluster APIs to manage cluster state.
This class is not intended for end-users (see TFNode for end-user APIs).
For cluster management, this wraps the per-node cluster logic as Spark RDD mapPartitions functions, where the RDD is expected to be a “nodeRDD” of the form:
nodeRDD = sc.parallelize(range(num_executors), num_executors)
.For data feeding, this wraps the feeding logic as Spark RDD mapPartitions functions on a standard “dataRDD”.
This also manages a reference to the TFManager “singleton” per executor. Since Spark can spawn more than one python-worker per executor, this will reconnect to the “singleton” instance as needed.
- cluster_id = None
Unique ID for a given TensorFlowOnSpark cluster, used for invalidating state for new clusters.
- mgr = None
TFManager instance
- inference(cluster_info, feed_timeout=600, qname='input')[source]
Feeds Spark partitions into the shared multiprocessing.Queue and returns inference results.
- Args:
- cluster_info
node reservation information for the cluster (e.g. host, executor_id, pid, ports, etc)
- feed_timeout
number of seconds after which data feeding times out (600 sec default)
- qname
INTERNAL_USE
- Returns:
A dataRDD.mapPartitions() function
- run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background)[source]
Wraps the user-provided TensorFlow main function in a Spark mapPartitions function.
- Args:
- fn
TensorFlow “main” function provided by the user.
- tf_args
argparse
args, or command lineARGV
. These will be passed to thefn
.- cluster_meta
dictionary of cluster metadata (e.g. cluster_id, reservation.Server address, etc).
- tensorboard
boolean indicating if the chief worker should spawn a Tensorboard server.
- log_dir
directory to save tensorboard event logs. If None, defaults to a fixed path on local filesystem.
- queues
INTERNAL_USE
- background
boolean indicating if the TensorFlow “main” function should be run in a background process.
- Returns:
A nodeRDD.mapPartitions() function.
- shutdown(cluster_info, grace_secs=0, queues=['input'])[source]
Stops all TensorFlow nodes by feeding
None
into the multiprocessing.Queues.- Args:
- cluster_info
node reservation information for the cluster (e.g. host, executor_id, pid, ports, etc).
- queues
INTERNAL_USE
- Returns:
A nodeRDD.mapPartitions() function
- train(cluster_info, cluster_meta, feed_timeout=600, qname='input')[source]
Feeds Spark partitions into the shared multiprocessing.Queue.
- Args:
- cluster_info
node reservation information for the cluster (e.g. host, executor_id, pid, ports, etc)
- cluster_meta
dictionary of cluster metadata (e.g. cluster_id, reservation.Server address, etc)
- feed_timeout
number of seconds after which data feeding times out (600 sec default)
- qname
INTERNAL_USE
- Returns:
A dataRDD.mapPartitions() function