# Copyright 2017 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.
"""This module contains client/server methods to manage node reservations during TFCluster startup."""
from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function
import logging
import os
import pickle
import select
import socket
import struct
import sys
import threading
import time
from . import util
logger = logging.getLogger(__name__)
TFOS_SERVER_PORT = "TFOS_SERVER_PORT"
TFOS_SERVER_HOST = "TFOS_SERVER_HOST"
BUFSIZE = 1024
MAX_RETRIES = 3
[docs]class Reservations:
"""Thread-safe store for node reservations.
Args:
:required: expected number of nodes in the cluster.
"""
def __init__(self, required):
self.required = required
self.lock = threading.RLock()
self.reservations = []
[docs] def add(self, meta):
"""Add a reservation.
Args:
:meta: a dictonary of metadata about a node
"""
with self.lock:
self.reservations.append(meta)
[docs] def done(self):
"""Returns True if the ``required`` number of reservations have been fulfilled."""
with self.lock:
return len(self.reservations) >= self.required
[docs] def get(self):
"""Get the list of current reservations."""
with self.lock:
return self.reservations
[docs] def remaining(self):
"""Get a count of remaining/unfulfilled reservations."""
with self.lock:
return self.required - len(self.reservations)
[docs]class MessageSocket(object):
"""Abstract class w/ length-prefixed socket send/receive functions."""
[docs] def receive(self, sock):
"""Receive a message on ``sock``."""
msg = None
data = b''
recv_done = False
recv_len = -1
while not recv_done:
buf = sock.recv(BUFSIZE)
if buf is None or len(buf) == 0:
raise Exception("socket closed")
if recv_len == -1:
recv_len = struct.unpack('>I', buf[:4])[0]
data += buf[4:]
recv_len -= len(data)
else:
data += buf
recv_len -= len(buf)
recv_done = (recv_len == 0)
msg = pickle.loads(data)
return msg
[docs] def send(self, sock, msg):
"""Send ``msg`` to destination ``sock``."""
data = pickle.dumps(msg)
buf = struct.pack('>I', len(data)) + data
sock.sendall(buf)
[docs]class Server(MessageSocket):
"""Simple socket server with length-prefixed pickle messages.
Args:
:count: expected number of nodes in the cluster.
"""
reservations = None #: List of reservations managed by this server.
done = False #: boolean indicating if server should be shutdown.
def __init__(self, count):
assert count > 0, "Expected number of reservations should be greater than zero"
self.reservations = Reservations(count)
[docs] def await_reservations(self, sc, status={}, timeout=600):
"""Block until all reservations are received."""
timespent = 0
while not self.reservations.done():
logger.info("waiting for {0} reservations".format(self.reservations.remaining()))
# check status flags for any errors
if 'error' in status:
sc.cancelAllJobs()
sc.stop()
sys.exit(1)
time.sleep(1)
timespent += 1
if (timespent > timeout):
raise Exception("timed out waiting for reservations to complete")
logger.info("all reservations completed")
return self.reservations.get()
def _handle_message(self, sock, msg):
logger.debug("received: {0}".format(msg))
msg_type = msg['type']
if msg_type == 'REG':
self.reservations.add(msg['data'])
MessageSocket.send(self, sock, 'OK')
elif msg_type == 'QUERY':
MessageSocket.send(self, sock, self.reservations.done())
elif msg_type == 'QINFO':
rinfo = self.reservations.get()
MessageSocket.send(self, sock, rinfo)
elif msg_type == 'STOP':
logger.info("setting server.done")
MessageSocket.send(self, sock, 'OK')
self.done = True
else:
MessageSocket.send(self, sock, 'ERR')
[docs] def start(self):
"""Start listener in a background thread
Returns:
address of the Server as a tuple of (host, port)
"""
server_sock = self.start_listening_socket()
# hostname may not be resolvable but IP address probably will be
host = self.get_server_ip()
port = server_sock.getsockname()[1]
addr = (host, port)
logger.info("listening for reservations at {0}".format(addr))
def _listen(self, sock):
CONNECTIONS = []
CONNECTIONS.append(sock)
while not self.done:
read_socks, write_socks, err_socks = select.select(CONNECTIONS, [], [], 60)
for sock in read_socks:
if sock == server_sock:
client_sock, client_addr = sock.accept()
CONNECTIONS.append(client_sock)
logger.debug("client connected from {0}".format(client_addr))
else:
try:
msg = self.receive(sock)
self._handle_message(sock, msg)
except Exception as e:
logger.debug(e)
sock.close()
CONNECTIONS.remove(sock)
server_sock.close()
t = threading.Thread(target=_listen, args=(self, server_sock))
t.daemon = True
t.start()
return addr
[docs] def get_server_ip(self):
"""Returns the value of TFOS_SERVER_HOST environment variable (if set), otherwise defaults to current host/IP."""
return os.getenv(TFOS_SERVER_HOST, util.get_ip_address())
[docs] def get_server_ports(self):
"""Returns a list of target ports as defined in the TFOS_SERVER_PORT environment (if set), otherwise defaults to 0 (any port).
TFOS_SERVER_PORT should be either a single port number or a range, e.g. '8888' or '9997-9999'
"""
port_string = os.getenv(TFOS_SERVER_PORT, "0")
if '-' not in port_string:
return [int(port_string)]
else:
ports = port_string.split('-')
if len(ports) != 2:
raise Exception("Invalid TFOS_SERVER_PORT: {}".format(port_string))
return list(range(int(ports[0]), int(ports[1]) + 1))
[docs] def start_listening_socket(self):
"""Starts the registration server socket listener."""
port_list = self.get_server_ports()
for port in port_list:
try:
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind(('', port))
server_sock.listen(10)
logger.info("Reservation server binding to port {}".format(port))
break
except Exception as e:
logger.warn("Unable to bind to port {}, error {}".format(port, e))
server_sock = None
pass
if not server_sock:
raise Exception("Reservation server unable to bind to any ports, port_list = {}".format(port_list))
return server_sock
[docs] def stop(self):
"""Stop the Server's socket listener."""
self.done = True
[docs]class Client(MessageSocket):
"""Client to register and await node reservations.
Args:
:server_addr: a tuple of (host, port) pointing to the Server.
"""
sock = None #: socket to server TCP connection
server_addr = None #: address of server
def __init__(self, server_addr):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect(server_addr)
self.server_addr = server_addr
logger.info("connected to server at {0}".format(server_addr))
def _request(self, msg_type, msg_data=None):
"""Helper function to wrap msg w/ msg_type."""
msg = {}
msg['type'] = msg_type
if msg_data:
msg['data'] = msg_data
done = False
tries = 0
while not done and tries < MAX_RETRIES:
try:
MessageSocket.send(self, self.sock, msg)
done = True
except socket.error as e:
tries += 1
if tries >= MAX_RETRIES:
raise
print("Socket error: {}".format(e))
self.sock.close()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect(self.server_addr)
logger.debug("sent: {0}".format(msg))
resp = MessageSocket.receive(self, self.sock)
logger.debug("received: {0}".format(resp))
return resp
[docs] def close(self):
"""Close the client socket."""
self.sock.close()
[docs] def register(self, reservation):
"""Register ``reservation`` with server."""
resp = self._request('REG', reservation)
return resp
[docs] def get_reservations(self):
"""Get current list of reservations."""
cluster_info = self._request('QINFO')
return cluster_info
[docs] def await_reservations(self):
"""Poll until all reservations completed, then return cluster_info."""
done = False
while not done:
done = self._request('QUERY')
time.sleep(1)
return self.get_reservations()
[docs] def request_stop(self):
"""Request server stop."""
resp = self._request('STOP')
return resp