# Copyright 2019 Yahoo Inc
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.
from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function
import logging
from . import TFSparkNode
from . import util
logger = logging.getLogger(__name__)
[docs]def run(sc, map_fn, tf_args, num_executors, use_barrier=True):
"""Runs the user map_fn as parallel, independent instances of TF on the Spark executors.
Args:
:sc: SparkContext
:map_fun: user-supplied TensorFlow "main" function
:tf_args: ``argparse`` args, or command-line ``ARGV``. These will be passed to the ``map_fun``.
:num_executors: number of Spark executors. This should match your Spark job's ``--num_executors``.
:use_barrier: Boolean indicating if TFParallel should use Spark barrier execution mode to wait for all executors.
Returns:
None
"""
# get default filesystem from spark
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
# strip trailing "root" slash from "file:///" to be consistent w/ "hdfs://..."
if defaultFS.startswith("file://") and len(defaultFS) > 7 and defaultFS.endswith("/"):
defaultFS = defaultFS[:-1]
def _run(it):
from pyspark import BarrierTaskContext
for i in it:
worker_num = i
if use_barrier:
# use BarrierTaskContext to get placement of all nodes
barrier_ctx = BarrierTaskContext.get()
tasks = barrier_ctx.getTaskInfos()
nodes = [t.address for t in tasks]
num_workers = len(nodes)
else:
nodes = []
num_workers = num_executors
# use the placement info to help allocate GPUs
# note: defaults to CPU if no GPUs present
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
util.single_node_env(num_gpus=num_gpus, worker_index=worker_num, nodes=nodes)
# run the user map_fn
ctx = TFSparkNode.TFNodeContext()
ctx.defaultFS = defaultFS
ctx.worker_num = worker_num
ctx.executor_id = worker_num
ctx.num_workers = num_workers
map_fn(tf_args, ctx)
# return a dummy iterator (since we have to use mapPartitions)
return [0]
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)
if use_barrier:
nodeRDD.barrier().mapPartitions(_run).collect()
else:
nodeRDD.mapPartitions(_run).collect()