Source code for tensorflowonspark.TFManager

# Copyright 2017 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.

from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function

from multiprocessing.managers import BaseManager
from multiprocessing import JoinableQueue

[docs]class TFManager(BaseManager): """Python multiprocessing.Manager for distributed, multi-process communication.""" pass
# global to each Spark executor's python worker mgr = None # TFManager qdict = {} # dictionary of queues kdict = {} # dictionary of key-values def _get(key): return kdict[key] def _set(key, value): kdict[key] = value def _get_queue(qname): try: return qdict[qname] except KeyError: return None
[docs]def start(authkey, queues, mode='local'): """Create a new multiprocess.Manager (or return existing one). Args: :authkey: string authorization key :queues: *INTERNAL_USE* :mode: 'local' indicates that the manager will only be accessible from the same host, otherwise remotely accessible. Returns: A TFManager instance, which is also cached in local memory of the Python worker process. """ global mgr, qdict, kdict qdict.clear() kdict.clear() for q in queues: qdict[q] = JoinableQueue() TFManager.register('get_queue', callable=lambda qname: _get_queue(qname)) TFManager.register('get', callable=lambda key: _get(key)) TFManager.register('set', callable=lambda key, value: _set(key, value)) if mode == 'remote': mgr = TFManager(address=('', 0), authkey=authkey) else: mgr = TFManager(authkey=authkey) mgr.start() return mgr
[docs]def connect(address, authkey): """Connect to a multiprocess.Manager. Args: :address: unique address to the TFManager, either a unique connection string for 'local', or a (host, port) tuple for remote. :authkey: string authorization key Returns: A TFManager instance referencing the remote TFManager at the supplied address. """ TFManager.register('get_queue') TFManager.register('get') TFManager.register('set') m = TFManager(address, authkey=authkey) m.connect() return m