# Copyright 2017 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.
from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function
import logging
import os
import socket
import subprocess
import errno
from socket import error as socket_error
from . import gpu_info
logger = logging.getLogger(__name__)
[docs]def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
"""Setup environment variables for Hadoop compatibility and GPU allocation"""
# ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI)
if 'HADOOP_PREFIX' in os.environ and 'TFOS_CLASSPATH_UPDATED' not in os.environ:
classpath = os.environ['CLASSPATH']
hadoop_path = os.path.join(os.environ['HADOOP_PREFIX'], 'bin', 'hadoop')
hadoop_classpath = subprocess.check_output([hadoop_path, 'classpath', '--glob']).decode()
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath
os.environ['TFOS_CLASSPATH_UPDATED'] = '1'
if gpu_info.is_gpu_available() and num_gpus > 0:
# reserve GPU(s), if requested
if worker_index >= 0 and nodes and len(nodes) > 0:
# compute my index relative to other nodes on the same host, if known
my_addr = nodes[worker_index]
my_host = my_addr.split(':')[0]
local_peers = [n for n in nodes if n.startswith(my_host)]
my_index = local_peers.index(my_addr)
else:
# otherwise, just use global worker index
my_index = worker_index
gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
logger.info("Using gpu(s): {0}".format(gpus_to_use))
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use
else:
# CPU
logger.info("Using CPU")
os.environ['CUDA_VISIBLE_DEVICES'] = ''
[docs]def get_ip_address():
"""Simple utility to get host IP address."""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip_address = s.getsockname()[0]
except socket_error as sockerr:
if sockerr.errno != errno.ENETUNREACH:
raise sockerr
ip_address = socket.gethostbyname(socket.getfqdn())
finally:
s.close()
return ip_address
[docs]def find_in_path(path, file):
"""Find a file in a given path string."""
for p in path.split(os.pathsep):
candidate = os.path.join(p, file)
if os.path.exists(candidate) and os.path.isfile(candidate):
return candidate
return False
[docs]def write_executor_id(num):
"""Write executor_id into a local file in the executor's current working directory"""
with open("executor_id", "w") as f:
f.write(str(num))
[docs]def read_executor_id():
"""Read worker id from a local file in the executor's current working directory"""
if os.path.isfile("executor_id"):
with open("executor_id", "r") as f:
return int(f.read())
else:
msg = "No executor_id file found on this node, please ensure that:\n" + \
"1. Spark num_executors matches TensorFlow cluster_size\n" + \
"2. Spark tasks per executor is 1\n" + \
"3. Spark dynamic allocation is disabled\n" + \
"4. There are no other root-cause exceptions on other nodes\n"
raise Exception(msg)