tensorflowonspark.pipeline module
This module extends the TensorFlowOnSpark API to support Spark ML Pipelines.
It provides a TFEstimator class to fit a TFModel using TensorFlow. The TFEstimator will actually spawn a TensorFlowOnSpark cluster to conduct distributed training, but due to architectural limitations, the TFModel will only run single-node TensorFlow instances when inferencing on the executors. The executors will run in parallel, so the TensorFlow model must fit in the memory of each executor.
- class Namespace(d)[source]
Bases:
object
Utility class to convert dictionaries to Namespace-like objects.
Based on https://docs.python.org/dev/library/types.html#types.SimpleNamespace
- argv = None
- class TFEstimator(*args: Any, **kwargs: Any)[source]
Bases:
pyspark.ml.pipeline.Estimator
,tensorflowonspark.pipeline.TFParams
,tensorflowonspark.pipeline.HasInputMapping
,tensorflowonspark.pipeline.HasClusterSize
,tensorflowonspark.pipeline.HasNumPS
,tensorflowonspark.pipeline.HasInputMode
,tensorflowonspark.pipeline.HasMasterNode
,tensorflowonspark.pipeline.HasProtocol
,tensorflowonspark.pipeline.HasGraceSecs
,tensorflowonspark.pipeline.HasTensorboard
,tensorflowonspark.pipeline.HasModelDir
,tensorflowonspark.pipeline.HasExportDir
,tensorflowonspark.pipeline.HasTFRecordDir
,tensorflowonspark.pipeline.HasBatchSize
,tensorflowonspark.pipeline.HasEpochs
,tensorflowonspark.pipeline.HasReaders
,tensorflowonspark.pipeline.HasSteps
Spark ML Estimator which launches a TensorFlowOnSpark cluster for distributed training.
The columns of the DataFrame passed to the
fit()
method will be mapped to TensorFlow tensors according to thesetInputMapping()
method. Since the Spark ML Estimator API inherently relies on DataFrames/DataSets, InputMode.TENSORFLOW is not supported.- Args:
- train_fn
TensorFlow “main” function for training.
- tf_args
Arguments specific to the TensorFlow “main” function.
- export_fn
TensorFlow function for exporting a saved_model. DEPRECATED for TF2.x.
- export_fn = None
- train_fn = None
- class TFModel(*args: Any, **kwargs: Any)[source]
Bases:
pyspark.ml.pipeline.Model
,tensorflowonspark.pipeline.TFParams
,tensorflowonspark.pipeline.HasInputMapping
,tensorflowonspark.pipeline.HasOutputMapping
,tensorflowonspark.pipeline.HasBatchSize
,tensorflowonspark.pipeline.HasModelDir
,tensorflowonspark.pipeline.HasExportDir
,tensorflowonspark.pipeline.HasSignatureDefKey
,tensorflowonspark.pipeline.HasTagSet
Spark ML Model backed by a TensorFlow model checkpoint/saved_model on disk.
During
transform()
, each executor will run an independent, single-node instance of TensorFlow in parallel, so the model must fit in memory. The model/session will be loaded/initialized just once for each Spark Python worker, and the session will be cached for subsequent tasks/partitions to avoid re-loading the model for each partition.- Args:
- tf_args
Dictionary of arguments specific to TensorFlow “main” function.
- class TFParams(*args: Any, **kwargs: Any)[source]
Bases:
pyspark.ml.param.shared.Params
Mix-in class to store namespace-style args and merge w/ SparkML-style params.
- args = None
- class TFTypeConverters[source]
Bases:
object
Custom DataFrame TypeConverter for dictionary types (since this is not provided by Spark core).
- get_meta_graph_def(saved_model_dir, tag_set)[source]
Utility function to read a meta_graph_def from disk.
From saved_model_cli.py
DEPRECATED for TF2.0+
- Args:
- saved_model_dir
path to saved_model.
- tag_set
list of string tags identifying the TensorFlow graph within the saved_model.
- Returns:
A TensorFlow meta_graph_def, or raises an Exception otherwise.
- single_node_env(args)[source]
Sets up environment for a single-node TF session.
- Args:
- args
command line arguments as either argparse args or argv list
- yield_batch(iterable, batch_size, num_tensors=1)[source]
Generator that yields batches of a DataFrame iterator.
- Args:
- iterable
Spark partition iterator.
- batch_size
number of items to retrieve per invocation.
- num_tensors
number of tensors (columns) expected in each item.
- Returns:
An array of
num_tensors
arrays, each of length batch_size