import collections as _collections
import functools as _functools
import itertools as _itertools
import typing as _t
from operator import is_not as _is_not

from .functional import flatmap as _flatmap

_T = _t.TypeVar('_T')
_T2 = _t.TypeVar('_T2')

[docs]@_functools.singledispatch def capacity(_value: _t.Any) -> int: """ Returns number of elements in value. >>> capacity(range(0)) 0 >>> capacity(range(10)) 10 """ raise TypeError(type(_value))
@capacity.register( def _(_iterable: _t.Iterable[_t.Any]) -> int: counter = _itertools.count() # order matters: if `counter` goes first, # then it will be incremented even for empty `iterable` _collections.deque(zip(_iterable, counter), maxlen=0) return next(counter) @capacity.register( def _(_iterable: _t.Sized) -> int: """ Returns number of elements in sized iterable. """ return len(_iterable)
[docs]def first(_iterable: _t.Iterable[_T]) -> _T: """ Returns first element of iterable. >>> first(range(10)) 0 """ try: return next(iter(_iterable)) except StopIteration as error: raise ValueError('Argument supposed to be non-empty.') from error
[docs]def last(_iterable: _t.Iterable[_T]) -> _T: """ Returns last element of iterable. >>> last(range(10)) 9 """ try: return _collections.deque(_iterable, maxlen=1)[0] except IndexError as error: raise ValueError('Argument supposed to be non-empty.') from error
[docs]def cut(_iterable: _t.Iterable[_T], *, slice_: slice) -> _t.Iterable[_T]: """ Selects elements from iterable based on given slice. Slice fields supposed to be unset or non-negative since it is hard to evaluate negative indices/step for arbitrary iterable which may be potentially infinite or change previous elements if iterating made backwards. """ yield from _itertools.islice(_iterable, slice_.start, slice_.stop, slice_.step)
[docs]def cutter(_slice: slice) -> _t.Callable[[_t.Iterable[_T]], _t.Iterable[_T]]: """ Returns function that selects elements from iterable based on given slice. >>> to_first_triplet = cutter(slice(3)) >>> list(to_first_triplet(range(10))) [0, 1, 2] >>> to_second_triplet = cutter(slice(3, 6)) >>> list(to_second_triplet(range(10))) [3, 4, 5] >>> cut_out_every_third = cutter(slice(0, None, 3)) >>> list(cut_out_every_third(range(10))) [0, 3, 6, 9] """ result = _functools.partial(cut, slice_=_slice) result.__doc__ = ('Selects elements from iterable {slice}.' .format(slice=_slice_to_description(_slice))) return result
def _slice_to_description(_slice: slice) -> str: """Generates human readable representation of `slice` object.""" slice_description_parts = [] start_is_specified = bool(_slice.start) if start_is_specified: slice_description_parts.append('starting from position {start}' .format(start=_slice.start)) step_is_specified = _slice.step is not None if step_is_specified: slice_description_parts.append('with step {step}' .format(step=_slice.step)) if _slice.stop is not None: stop_description_part = ('stopping at position {stop}' .format(stop=_slice.stop)) if start_is_specified or step_is_specified: stop_description_part = 'and ' + stop_description_part slice_description_parts.append(stop_description_part) return ' '.join(slice_description_parts)
[docs]def chopper( _size: int ) -> _t.Callable[[_t.Iterable[_T]], _t.Iterable[_t.Sequence[_T]]]: """ Returns function that splits iterable into chunks of given size. >>> in_three = chopper(3) >>> list(map(tuple, in_three(range(10)))) [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)] """ return _functools.partial(chop, size=_size)
[docs]@_functools.singledispatch def chop(_iterable: _t.Iterable[_T], *, size: int) -> _t.Iterable[_t.Sequence[_T]]: """ Splits iterable into chunks of given size. """ iterator = iter(_iterable) yield from iter(lambda: tuple(_itertools.islice(iterator, size)), ())
@chop.register( def _(_iterable: _t.Sequence[_T], *, size: int) -> _t.Iterable[_t.Sequence[_T]]: """ Splits sequence into chunks of given size. """ if not size: return for start in range(0, len(_iterable), size): yield _iterable[start:start + size] # deque do not support slice notation chop.register(_collections.deque, chop.registry[object]) in_two = chopper(2) in_three = chopper(3) in_four = chopper(4)
[docs]def slide(_iterable: _t.Iterable[_T], *, size: int) -> _t.Iterable[_t.Tuple[_T, ...]]: """ Slides over iterable with window of given size. """ iterator = iter(_iterable) initial = tuple(_itertools.islice(iterator, size)) def shift(previous: _t.Tuple[_T, ...], element: _T) -> _t.Tuple[_T, ...]: return previous[1:] + (element,) yield from _itertools.accumulate( _itertools.chain([initial], iterator), _t.cast(_t.Callable[[_t.Any, _t.Any], _t.Tuple[_T, ...]], shift) )
[docs]def slider(_size: int) -> _t.Callable[[_t.Iterable[_T]], _t.Iterable[_t.Tuple[_T, ...]]]: """ Returns function that slides over iterable with window of given size. >>> pairwise = slider(2) >>> list(pairwise(range(10))) [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9)] """ return _functools.partial(slide, size=_size)
pairwise = slider(2) triplewise = slider(3) quadruplewise = slider(4)
[docs]@_functools.singledispatch def trail(_iterable: _t.Iterable[_T], *, size: int) -> _t.Iterable[_T]: """ Selects elements from the end of iterable. Resulted iterable will have size not greater than given one. """ return _collections.deque(_iterable, maxlen=size)
@trail.register( def _(iterable: _t.Sequence[_T], *, size: int) -> _t.Sequence[_T]: """ Selects elements from the end of sequence. Resulted sequence will have size not greater than given one. """ return iterable[-size:] if size else iterable[:size] # deque do not support slice notation trail.register(_collections.deque, trail.registry[object])
[docs]def trailer(_size: int) -> _t.Callable[[_t.Iterable[_T]], _t.Iterable[_T]]: """ Returns function that selects elements from the end of iterable. Resulted iterable will have size not greater than given one. >>> to_last_pair = trailer(2) >>> list(to_last_pair(range(10))) [8, 9] """ return _functools.partial(trail, size=_size)
[docs]def mapper( _map: _t.Callable[[_T], _T2] ) -> _t.Callable[[_t.Iterable[_T]], _t.Iterable[_T2]]: """ Returns function that applies given map to the each element of iterable. >>> to_str = mapper(str) >>> list(to_str(range(10))) ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] """ return _t.cast(_t.Callable[[_t.Iterable[_T]], _t.Iterable[_T2]], _functools.partial(map, _map))
[docs]def flatmapper( _map: _t.Callable[[_T], _t.Iterable[_T2]] ) -> _t.Callable[[_t.Iterable[_T]], _t.Iterable[_T2]]: """ Returns function that applies map to the each element of iterable and flattens results. >>> relay = flatmapper(range) >>> list(relay(range(5))) [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] """ return _functools.partial(_flatmap, _map)
Group = _t.Tuple[_t.Hashable, _t.Iterable[_T]]
[docs]def group_by(_iterable: _t.Iterable[_T], *, key: _t.Callable[[_T], _t.Hashable]) -> _t.Iterable[Group[_T]]: """ Groups iterable elements based on given key. """ groups: _t.Dict[_t.Hashable, _t.List[_T]] = {} for element in _iterable: groups.setdefault(key(element), []).append(element) yield from groups.items()
[docs]def grouper( _key: _t.Callable[[_T], _t.Hashable] ) -> _t.Callable[[_t.Iterable[_T]], _t.Iterable[Group[_T]]]: """ Returns function that groups iterable elements based on given key. >>> group_by_absolute_value = grouper(abs) >>> list(group_by_absolute_value(range(-5, 5))) [(5, [-5]), (4, [-4, 4]), (3, [-3, 3]), (2, [-2, 2]), (1, [-1, 1]), (0, [0])] >>> def modulo_two(number: int) -> int: ... return number % 2 >>> group_by_evenness = grouper(modulo_two) >>> list(group_by_evenness(range(10))) [(0, [0, 2, 4, 6, 8]), (1, [1, 3, 5, 7, 9])] """ return _functools.partial(group_by, key=_key)
[docs]def expand(_value: _T) -> _t.Iterable[_T]: """ Wraps value into iterable. >>> list(expand(0)) [0] """ yield _value
[docs]def flatten(_iterable: _t.Iterable[_t.Iterable[_T]]) -> _t.Iterable[_T]: """ Returns plain iterable from iterable of iterables. >>> list(flatten([range(5), range(10, 20)])) [0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] """ yield from _itertools.chain.from_iterable(_iterable)
[docs]def interleave(_iterable: _t.Iterable[_t.Iterable[_T]]) -> _t.Iterable[_T]: """ Interleaves elements from given iterable of iterables. >>> list(interleave([range(5), range(10, 20)])) [0, 10, 1, 11, 2, 12, 3, 13, 4, 14, 15, 16, 17, 18, 19] """ iterators = _itertools.cycle(_t.cast(_t.Iterable[_t.Iterator[_T]], map(iter, _iterable))) while True: try: for iterator in iterators: yield next(iterator) except StopIteration: is_not_exhausted = _functools.partial(_is_not, iterator) iterators = _itertools.cycle(_itertools.takewhile(is_not_exhausted, iterators)) else: return