Source code for desitarget.internal.sharedmem

"""
desitarget.internal.sharedmem
=============================

Easier parallel programming on shared memory computers.

The source code is at http://github.com/rainwoodman/sharedmem .

.. contents:: Topics
    :local:

Programming Model
-----------------
:py:class:`MapReduce` provides the equivalent to multiprocessing.Pool, with the following
differences:

- MapReduce does not require the work function to be picklable.
- MapReduce adds a reduction step that is guaranteed to run on the coordinator process's
  scope.
- MapReduce allows the use of critical sections and ordered execution in the work
  function.

Modifications to shared Memory arrays, allocated via

- :py:meth:`sharedmem.empty`,
- :py:meth:`sharedmem.empty_like`,
- :py:meth:`sharedmem.copy`,

are visible by all processes, including the coordinator process.

Usage
-----
The package can be installed via :code:`easy_install sharedmem`.
Alternatively, the file :code:`sharedmem.py` can be directly embedded into
other projects.

The only external dependency is numpy, since this was designed to
work with large shared memory chunks through numpy.ndarray.

Environment variable OMP_NUM_THREADS is used to determine the
default number of workers.

Notes
-----
This module depends on the `fork` system call, thus is available
only on posix systems (not Windows).

Examples
--------

Sum up a large array

>>> input = numpy.arange(1024 * 1024 * 128, dtype='f8')
>>> output = sharedmem.empty(1024 * 1024 * 128, dtype='f8')
>>> with MapReduce() as pool:
>>>    chunksize = 1024 * 1024
>>>    def work(i):
>>>        s = slice (i, i + chunksize)
>>>        output[s] = input[s]
>>>        return i, sum(input[s])
>>>    def reduce(i, r):
>>>        print('chunk', i, 'done')
>>>        return r
>>>    r = pool.map(work, range(0, len(input), chunksize), reduce=reduce)
>>> print numpy.sum(r)
>>>

Textual analysis

>>> input = file('mytextfile.txt').readlines()
>>> word_count = {'bacon': 0, 'eggs': 0 }
>>> with MapReduce() as pool:
>>>    def work(line):
>>>        words = line.split()
>>>        for word in words:
>>>            word_count[word] += 1
>>>        return word_count
>>>    def reduce(wc):
>>>        for key in word_count:
>>>            word_count[key] += wc[key]
>>> print word_count
>>>

pool.ordered can be used to require a block of code to be executed in order

>>> with MapReduce() as pool:
>>>    def work(i):
>>>         with pool.ordered:
>>>            print(i)
>>>    pool.map(work, range(10))

pool.critical can be used to require a block of code to be executed in a critical
section.

>>> counter = sharedmem.empty(1)
>>> counter[:] = 0
>>> with MapReduce() as pool:
>>>    def work(i):
>>>         with pool.critical:
>>>             counter[:] += i
>>>    pool.map(work, range(10))
>>> print(counter)

API References
--------------

"""
__author__ = "Yu Feng"
__email__ = "rainwoodman@gmail.com"

__all__ = ['set_debug', 'get_debug',
        'total_memory', 'cpu_count',
        'WorkerException', 'StopProcessGroup',
        'background',
        'MapReduce', 'MapReduceByThread',
        'empty', 'empty_like',
        'full', 'full_like',
        'copy',
        ]

import os
import sys
import multiprocessing
import threading

if sys.version_info.major == 3:
    import queue
elif sys.version_info.major == 2:
    import Queue as queue


from collections import deque
import traceback
import warnings
import gc
import threading
import heapq
import os
import pickle

import numpy
from multiprocessing import RawArray
import ctypes
import mmap
#logger = multiprocessing.log_to_stderr()
#logger.setLevel(multiprocessing.SUBDEBUG)

__shmdebug__ = False

[docs]def set_debug(flag): """ Set the debug mode. In debug mode (flag==True), no workers are spawn. All work are done in serial on the coordinator thread/process. This eases debuggin when the worker throws out an exception. """ global __shmdebug__ __shmdebug__ = flag
[docs]def get_debug(): """ Get the debug mode """ global __shmdebug__ return __shmdebug__
[docs]def total_memory(): """ Returns the the amount of memory available for use. This function is not very useful. The memory is obtained from MemTotal entry in /proc/meminfo. """ with file('/proc/meminfo', 'r') as f: for line in f: words = line.split() if words[0].upper() == 'MEMTOTAL:': return int(words[1]) * 1024 raise IOError('MemTotal unknown')
[docs]def cpu_count(): """ Returns the number of worker processes to be spawned. The default value is the number of physical cpu cores seen by python. :code:`OMP_NUM_THREADS` environment variable overrides it. On PBS/torque systems if OMP_NUM_THREADS is empty, we try to use the value of :code:`PBS_NUM_PPN` variable. Notes ----- On some machines the physical number of cores does not equal the number of cpus shall be used. PSC Blacklight for example. """ num = os.getenv("OMP_NUM_THREADS") if num is None: num = os.getenv("PBS_NUM_PPN") try: return int(num) except: return multiprocessing.cpu_count()
class LostExceptionType(Warning): pass
[docs]class WorkerException(Exception): """ Represents an exception that has occured during a worker process Attributes ---------- reason : Exception, or subclass of Exception. The underlining reason of the exception. If the original exception can be pickled, the type of the exception is preserved. Otherwise, a LostExceptionType warning is issued, and reason is of type Exception. traceback : str The string version of the traceback that can be used to inspect the error. """ def __init__(self, reason, traceback): if not isinstance(reason, Exception): warnings.warn("Type information of Unpicklable exception %s is lost" % reason, LostExceptionType) reason = Exception(reason) self.reason = reason self.traceback = traceback Exception.__init__(self, "%s\n%s" % (str(reason), str(traceback)))
[docs]class StopProcessGroup(Exception): """ StopProcessGroup will terminate the worker process/thread """ def __init__(self): Exception.__init__(self, "StopProcessGroup")
class ProcessGroup(object): """ Monitoring a group of worker processes """ def __init__(self, backend, main, np, args=()): self.Errors = backend.QueueFactory(1) self._tls = backend.StorageFactory() self.main = main self.args = args self.guard = threading.Thread(target=self._guardMain) self.errorguard = threading.Thread(target=self._errorGuard) # this has to be from backend because the workers will check # this variable. self.guardDead = backend.EventFactory() # each dead child releases one sempahore # when all dead guard will proceed to set guarddead self.semaphore = threading.Semaphore(0) self.JoinedProcesses = multiprocessing.RawValue('l') self.P = [ backend.WorkerFactory(target=self._workerMain, args=(rank,)) \ for rank in range(np) ] self.G = [ threading.Thread(target=self._workerGuard, args=(rank, self.P[rank])) \ for rank in range(np) ] return def _workerMain(self, rank): self._tls.rank = rank try: self.main(self, *self.args) except WorkerException as e: raise RuntimError("worker exception shall never be caught by a worker") except StopProcessGroup as e: pass except BaseException as e: try: # Put in the string version of the exception, # Some of the Exception types in extension types are probably # not picklable (thus can't be sent via a queue), # However, we don't use the extra information in customized # Exception types anyways. try: pickle.dumps(e) except Exception as ee: e = str(e) tb = traceback.format_exc() self.Errors.put((e, tb), timeout=0) except queue.Full: pass finally: # self.Errors.close() # self.Errors.join_thread() # making all workers exit one after another # on some Linuxes if many workers (56+) access # mmap randomly the termination of the workers # run into a deadlock. while self.JoinedProcesses.value < rank: continue pass def killall(self): for p in self.P: if not p.is_alive(): continue try: if isinstance(p, threading.Thread): p.join() else: os.kill(p._popen.pid, 5) except Exception as e: print(e) continue def _errorGuard(self): # this guard will kill every child if # an error is observed. We watch for this every 0.5 seconds # (errors do not happen very often) # if guardDead is set or killall is emitted, this will end immediately. while not self.guardDead.is_set(): if not self.Errors.empty(): self.killall() break # for python 2.6.x wait returns None XXX self.guardDead.wait(timeout=0.5) def _workerGuard(self, rank, process): process.join() if isinstance(process, threading.Thread): pass else: if process.exitcode < 0 and process.exitcode != -5: e = Exception("worker process %d killed by signal %d" % (rank, - process.exitcode)) try: self.Errors.put((e, ""), timeout=0) except queue.Full: pass self.semaphore.release() def _guardMain(self): # this guard will wait till all children are dead. # we then set the guardDead event for x in self.G: self.semaphore.acquire() self.JoinedProcesses.value = self.JoinedProcesses.value + 1 self.guardDead.set() def start(self): self.JoinedProcesses.value = 0 self.guardDead.clear() # collect the garbages before forking so that the left-over # junk won't throw out assertion errors due to # wrong pid in multiprocess.heap gc.collect() for x in self.P: x.start() # p is alive from the moment start returns. # thus we can join them immediately after start returns. # guardMain will check if the worker has been # killed by the os, and simulate an error if so. for x in self.G: x.start() self.errorguard.start() self.guard.start() def get_exception(self): exp = self.Errors.get(timeout=0) return WorkerException(*exp) def get(self, Q): """ Protected get. Get an item from Q. Will block. but if the process group has errors, raise an StopProcessGroup exception. A worker process will terminate upon StopProcessGroup. The coordinator process shall read the error """ while self.Errors.empty(): try: return Q.get(timeout=1) except queue.Empty: if not self.is_alive(): raise StopProcessGroup else: continue else: raise StopProcessGroup def put(self, Q, item): while self.Errors.empty(): try: Q.put(item, timeout=1) return except queue.Full: if not self.is_alive(): raise StopProcessGroup else: continue else: raise StopProcessGroup def is_alive(self): return not self.guardDead.is_set() def join(self): self.guardDead.wait() for x in self.G: x.join() self.errorguard.join() self.guard.join() if not self.Errors.empty(): raise WorkerException(*self.Errors.get()) class Ordered(object): def __init__(self, backend): # self.counter = lambda : None #multiprocessing.RawValue('l') self.event = backend.EventFactory() self.counter = multiprocessing.RawValue('l') self.tls = backend.StorageFactory() def reset(self): self.counter.value = 0 self.event.set() def move(self, iter): self.tls.iter = iter def __enter__(self): while self.counter.value != self.tls.iter: self.event.wait() self.event.clear() return self def __exit__(self, *args): # increase counter before releasing the value # so that the others waiting will see the new counter self.counter.value = self.counter.value + 1 self.event.set() class ThreadBackend: QueueFactory = staticmethod(queue.Queue) EventFactory = staticmethod(threading.Event) LockFactory = staticmethod(threading.Lock) StorageFactory = staticmethod(threading.local) @staticmethod def WorkerFactory(*args, **kwargs): worker = threading.Thread(*args, **kwargs) worker.daemon = True return worker class ProcessBackend: QueueFactory = staticmethod(multiprocessing.Queue) EventFactory = staticmethod(multiprocessing.Event) LockFactory = staticmethod(multiprocessing.Lock) @staticmethod def WorkerFactory(*args, **kwargs): worker = multiprocessing.Process(*args, **kwargs) worker.daemon = True return worker @staticmethod def StorageFactory(): return lambda:None
[docs]class background(object): """ Asyncrhonized function call via a background process. Parameters ---------- function : callable the function to call *args : positional arguments **kwargs : keyward arguments Examples -------- >>> def function(*args, **kwargs): >>> pass >>> bg = background(function, *args, **kwargs) >>> rt = bg.wait() """ def __init__(self, function, *args, **kwargs): backend = kwargs.pop('backend', ProcessBackend) self.result = backend.QueueFactory(1) self.worker = backend.WorkerFactory(target=self._closure, args=(function, args, kwargs, self.result)) self.worker.start() def _closure(self, function, args, kwargs, result): try: rt = function(*args, **kwargs) except Exception as e: result.put((e, traceback.format_exc())) else: result.put((None, rt))
[docs] def wait(self): """ Wait and join the child process. The return value of the function call is returned. If any exception occurred it is wrapped and raised. """ e, r = self.result.get() self.worker.join() self.worker = None self.result = None if isinstance(e, Exception): raise WorkerException(e, r) return r
[docs]def MapReduceByThread(np=None): """ Creates a MapReduce object but with the Thread backend. The process backend is usually preferred. """ return MapReduce(backend=ThreadBackend, np=np)
[docs]class MapReduce(object): """ A pool of worker processes for a Map-Reduce operation Parameters ---------- backend : ProcessBackend or ThreadBackend ProcessBackend is preferred. ThreadBackend can be used in cases where processes creation is not allowed. np : int or None Number of processes to use. Default (None) is from OMP_NUM_THREADS or the number of available cores on the computer. If np is 0, all operations are performed on the coordinator process -- no child processes are created. Notes ----- Always wrap the call to :py:meth:`map` in a context manager ('with') block. """ def __init__(self, backend=ProcessBackend, np=None): self.backend = backend if np is None: self.np = cpu_count() else: self.np = np def _main(self, pg, Q, R, sequence, realfunc): # get and put will raise WorkerException # and terminate the process. # the exception is muted in ProcessGroup, # as it will only be dispatched from coordinator. while True: capsule = pg.get(Q) if capsule is None: return if len(capsule) == 1: i, = capsule work = sequence[i] else: i, work = capsule self.ordered.move(i) r = realfunc(work) pg.put(R, (i, r)) def __enter__(self): self.critical = self.backend.LockFactory() self.ordered = Ordered(self.backend) return self def __exit__(self, *args): self.ordered = None pass
[docs] def map(self, func, sequence, reduce=None, star=False): """ Map-reduce with multile processes. Apply func to each item on the sequence, in parallel. As the results are collected, reduce is called on the result. The reduced result is returned as a list. Parameters ---------- func : callable The function to call. It must accept the same number of arguments as the length of an item in the sequence. .. warning:: func is not supposed to use exceptions for flow control. In non-debug mode all exceptions will be wrapped into a :py:class:`WorkerException`. sequence : list or array_like The sequence of arguments to be applied to func. reduce : callable, optional Apply an reduction operation on the return values of func. If func returns a tuple, they are treated as positional arguments of reduce. star : boolean if True, the items in sequence are treated as positional arguments of reduce. Returns ------- results : list The list of reduced results from the map operation, in the order of the arguments of sequence. Raises ------ WorkerException If any of the worker process encounters an exception. Inspect :py:attr:`WorkerException.reason` for the underlying exception. """ def realreduce(r): if reduce: if isinstance(r, tuple): return reduce(*r) else: return reduce(r) return r def realfunc(i): if star: return func(*i) else: return func(i) if self.np == 0 or get_debug(): #Do this in serial return [realreduce(realfunc(i)) for i in sequence] Q = self.backend.QueueFactory(64) R = self.backend.QueueFactory(64) self.ordered.reset() pg = ProcessGroup(main=self._main, np=self.np, backend=self.backend, args=(Q, R, sequence, realfunc)) pg.start() L = [] N = [] def feeder(pg, Q, N): # will fail silently if any error occurs. j = 0 try: for i, work in enumerate(sequence): if not hasattr(sequence, '__getitem__'): pg.put(Q, (i, work)) else: pg.put(Q, (i, )) j = j + 1 N.append(j) for i in range(self.np): pg.put(Q, None) except StopProcessGroup: return finally: pass feeder = threading.Thread(None, feeder, args=(pg, Q, N)) feeder.start() # we run fetcher on main thread to catch exceptions # raised by reduce count = 0 try: while True: try: capsule = pg.get(R) except queue.Empty: continue except StopProcessGroup: raise pg.get_exception() capsule = capsule[0], realreduce(capsule[1]) heapq.heappush(L, capsule) count = count + 1 if len(N) > 0 and count == N[0]: # if finished feeding see if all # results have been obtained break rt = [] # R.close() # R.join_thread() while len(L) > 0: rt.append(heapq.heappop(L)[1]) pg.join() feeder.join() assert N[0] == len(rt) return rt except BaseException as e: pg.killall() pg.join() feeder.join() raise
[docs]def empty_like(array, dtype=None): """ Create a shared memory array from the shape of array. """ array = numpy.asarray(array) if dtype is None: dtype = array.dtype return anonymousmemmap(array.shape, dtype)
[docs]def empty(shape, dtype='f8'): """ Create an empty shared memory array. """ return anonymousmemmap(shape, dtype)
[docs]def full_like(array, value, dtype=None): """ Create a shared memory array with the same shape and type as a given array, filled with `value`. """ shared = empty_like(array, dtype) shared[:] = value return shared
[docs]def full(shape, value, dtype='f8'): """ Create a shared memory array of given shape and type, filled with `value`. """ shared = empty(shape, dtype) shared[:] = value return shared
[docs]def copy(a): """ Copy an array to the shared memory. Notes ----- copy is not always necessary because the private memory is always copy-on-write. Use :code:`a = copy(a)` to immediately dereference the old 'a' on private memory """ shared = anonymousmemmap(a.shape, dtype=a.dtype) shared[:] = a[:] return shared
def fromiter(iter, dtype, count=None): return copy(numpy.fromiter(iter, dtype, count)) def __unpickle__(ai, dtype): dtype = numpy.dtype(dtype) tp = numpy.ctypeslib._typecodes['|u1'] # if there are strides, use strides, otherwise the stride is the itemsize of dtype if ai['strides']: tp *= ai['strides'][-1] else: tp *= dtype.itemsize for i in numpy.asarray(ai['shape'])[::-1]: tp *= i # grab a flat char array at the sharemem address, with length at least contain ai required ra = tp.from_address(ai['data'][0]) buffer = numpy.ctypeslib.as_array(ra).ravel() # view it as what it should look like shm = numpy.ndarray(buffer=buffer, dtype=dtype, strides=ai['strides'], shape=ai['shape']).view(type=anonymousmemmap) return shm class anonymousmemmap(numpy.memmap): """ Arrays allocated on shared memory. The array is stored in an anonymous memory map that is shared between child-processes. """ def __new__(subtype, shape, dtype=numpy.uint8, order='C'): descr = numpy.dtype(dtype) _dbytes = descr.itemsize shape = numpy.atleast_1d(shape) size = 1 for k in shape: size *= k bytes = int(size*_dbytes) if bytes > 0: mm = mmap.mmap(-1, bytes) else: mm = numpy.empty(0, dtype=descr) self = numpy.ndarray.__new__(subtype, shape, dtype=descr, buffer=mm, order=order) self._mmap = mm return self def __array_wrap__(self, outarr, context=None): # after ufunc this won't be on shm! return numpy.ndarray.__array_wrap__(self.view(numpy.ndarray), outarr, context) def __reduce__(self): return __unpickle__, (self.__array_interface__, self.dtype)