tensorflowonspark.TFNode module

This module provides helper functions for the TensorFlow application.

Primarily, these functions help with:

  • starting the TensorFlow tf.train.Server for the node (allocating GPUs as desired, and determining the node’s role in the cluster).

  • managing input/output data for InputMode.SPARK.

class DataFeed(mgr, train_mode=True, qname_in='input', qname_out='output', input_mapping=None)[source]

Bases: object

This class manages the InputMode.SPARK data feeding process from the perspective of the TensorFlow application.

Args:
mgr

TFManager instance associated with this Python worker process.

train_mode

boolean indicating if the data feed is expecting an output response (e.g. inferencing).

qname_in

INTERNAL_USE

qname_out

INTERNAL_USE

input_mapping

For Spark ML Pipelines only. Dictionary of input DataFrame columns to input TensorFlow tensors.

batch_results(results)[source]

Push a batch of output results to the Spark output RDD of TFCluster.inference().

Note: this currently expects a one-to-one mapping of input to output data, so the length of the results array should match the length of the previously retrieved batch of input data.

Args:
results

array of output data for the equivalent batch of input data.

next_batch(batch_size)[source]

Gets a batch of items from the input RDD.

If multiple tensors are provided per row in the input RDD, e.g. tuple of (tensor1, tensor2, …, tensorN) and:

  • no input_mapping was provided to the DataFeed constructor, this will return an array of batch_size tuples, and the caller is responsible for separating the tensors.

  • an input_mapping was provided to the DataFeed constructor, this will return a dictionary of N tensors, with tensor names as keys and arrays of length batch_size as values.

Note: if the end of the data is reached, this may return with fewer than batch_size items.

Args:
batch_size

number of items to retrieve.

Returns:

A batch of items or a dictionary of tensors.

should_stop()[source]

Check if the feed process was told to stop (by a call to terminate).

terminate()[source]

Terminate data feeding early.

Since TensorFlow applications can often terminate on conditions unrelated to the training data (e.g. steps, accuracy, etc), this method signals the data feeding process to ignore any further incoming data. Note that Spark itself does not have a mechanism to terminate an RDD operation early, so the extra partitions will still be sent to the executors (but will be ignored). Because of this, you should size your input data accordingly to avoid excessive overhead.

batch_results(mgr, results, qname='output')[source]

DEPRECATED. Use TFNode.DataFeed class instead.

export_saved_model(sess, export_dir, tag_set, signatures)[source]

Convenience function to export a saved_model using provided arguments

The caller specifies the saved_model signatures in a simplified python dictionary form, as follows:

signatures = {
  'signature_def_key': {
    'inputs': { 'input_tensor_alias': input_tensor_name },
    'outputs': { 'output_tensor_alias': output_tensor_name },
    'method_name': 'method'
  }
}

And this function will generate the signature_def_map and export the saved_model.

DEPRECATED for TensorFlow 2.x+.

Args:
sess

a tf.Session instance

export_dir

path to save exported saved_model

tag_set

string tag_set to identify the exported graph

signatures

simplified dictionary representation of a TensorFlow signature_def_map

Returns:

A saved_model exported to disk at export_dir.

hdfs_path(ctx, path)[source]

Convenience function to create a Tensorflow-compatible absolute HDFS path from relative paths

Args:
ctx

TFNodeContext containing the metadata specific to this node in the cluster.

path

path to convert

Returns:

An absolute path prefixed with the correct filesystem scheme.

next_batch(mgr, batch_size, qname='input')[source]

DEPRECATED. Use TFNode.DataFeed class instead.

release_port(ctx)[source]

Closes the temporary socket created to assign a port to the TF node.

start_cluster_server(ctx, num_gpus=1, rdma=False)[source]

Function that wraps the creation of TensorFlow tf.train.Server for a node in a distributed TensorFlow cluster.

This is intended to be invoked from within the TF map_fun, replacing explicit code to instantiate tf.train.ClusterSpec and tf.train.Server objects.

DEPRECATED for TensorFlow 2.x+

Args:
ctx

TFNodeContext containing the metadata specific to this node in the cluster.

num_gpu

number of GPUs desired

rdma

boolean indicating if RDMA ‘iverbs’ should be used for cluster communications.

Returns:

A tuple of (cluster_spec, server)

terminate(mgr, qname='input')[source]

DEPRECATED. Use TFNode.DataFeed class instead.