# Copyright 2017 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.
from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function
from multiprocessing.managers import BaseManager
from multiprocessing import JoinableQueue
[docs]class TFManager(BaseManager):
  """Python multiprocessing.Manager for distributed, multi-process communication."""
  pass 
# global to each Spark executor's python worker
mgr = None        # TFManager
qdict = {}        # dictionary of queues
kdict = {}        # dictionary of key-values
def _get(key):
  return kdict[key]
def _set(key, value):
  kdict[key] = value
def _get_queue(qname):
  try:
    return qdict[qname]
  except KeyError:
    return None
[docs]def start(authkey, queues, mode='local'):
  """Create a new multiprocess.Manager (or return existing one).
  Args:
    :authkey: string authorization key
    :queues: *INTERNAL_USE*
    :mode: 'local' indicates that the manager will only be accessible from the same host, otherwise remotely accessible.
  Returns:
    A TFManager instance, which is also cached in local memory of the Python worker process.
  """
  global mgr, qdict, kdict
  qdict.clear()
  kdict.clear()
  for q in queues:
    qdict[q] = JoinableQueue()
  TFManager.register('get_queue', callable=lambda qname: _get_queue(qname))
  TFManager.register('get', callable=lambda key: _get(key))
  TFManager.register('set', callable=lambda key, value: _set(key, value))
  if mode == 'remote':
    mgr = TFManager(address=('', 0), authkey=authkey)
  else:
    mgr = TFManager(authkey=authkey)
  mgr.start()
  return mgr 
[docs]def connect(address, authkey):
  """Connect to a multiprocess.Manager.
  Args:
    :address: unique address to the TFManager, either a unique connection string for 'local', or a (host, port) tuple for remote.
    :authkey: string authorization key
  Returns:
    A TFManager instance referencing the remote TFManager at the supplied address.
  """
  TFManager.register('get_queue')
  TFManager.register('get')
  TFManager.register('set')
  m = TFManager(address, authkey=authkey)
  m.connect()
  return m