관리-도구
편집 파일: recipes.py
"""Imported from the recipes section of the itertools documentation. All functions taken from the recipes section of the itertools library docs [1]_. Some backward-compatible usability improvements have been made. .. [1] http://docs.python.org/library/itertools.html#recipes """ import math import operator from collections import deque from collections.abc import Sized from functools import partial, reduce from itertools import ( chain, combinations, compress, count, cycle, groupby, islice, product, repeat, starmap, tee, zip_longest, ) from random import randrange, sample, choice from sys import hexversion __all__ = [ 'all_equal', 'batched', 'before_and_after', 'consume', 'convolve', 'dotproduct', 'first_true', 'factor', 'flatten', 'grouper', 'iter_except', 'iter_index', 'matmul', 'ncycles', 'nth', 'nth_combination', 'padnone', 'pad_none', 'pairwise', 'partition', 'polynomial_eval', 'polynomial_from_roots', 'polynomial_derivative', 'powerset', 'prepend', 'quantify', 'reshape', 'random_combination_with_replacement', 'random_combination', 'random_permutation', 'random_product', 'repeatfunc', 'roundrobin', 'sieve', 'sliding_window', 'subslices', 'sum_of_squares', 'tabulate', 'tail', 'take', 'totient', 'transpose', 'triplewise', 'unique_everseen', 'unique_justseen', ] _marker = object() # zip with strict is available for Python 3.10+ try: zip(strict=True) except TypeError: _zip_strict = zip else: _zip_strict = partial(zip, strict=True) # math.sumprod is available for Python 3.12+ _sumprod = getattr(math, 'sumprod', lambda x, y: dotproduct(x, y)) def take(n, iterable): """Return first *n* items of the iterable as a list. >>> take(3, range(10)) [0, 1, 2] If there are fewer than *n* items in the iterable, all of them are returned. >>> take(10, range(3)) [0, 1, 2] """ return list(islice(iterable, n)) def tabulate(function, start=0): """Return an iterator over the results of ``func(start)``, ``func(start + 1)``, ``func(start + 2)``... *func* should be a function that accepts one integer argument. If *start* is not specified it defaults to 0. It will be incremented each time the iterator is advanced. >>> square = lambda x: x ** 2 >>> iterator = tabulate(square, -3) >>> take(4, iterator) [9, 4, 1, 0] """ return map(function, count(start)) def tail(n, iterable): """Return an iterator over the last *n* items of *iterable*. >>> t = tail(3, 'ABCDEFG') >>> list(t) ['E', 'F', 'G'] """ # If the given iterable has a length, then we can use islice to get its # final elements. Note that if the iterable is not actually Iterable, # either islice or deque will throw a TypeError. This is why we don't # check if it is Iterable. if isinstance(iterable, Sized): yield from islice(iterable, max(0, len(iterable) - n), None) else: yield from iter(deque(iterable, maxlen=n)) def consume(iterator, n=None): """Advance *iterable* by *n* steps. If *n* is ``None``, consume it entirely. Efficiently exhausts an iterator without returning values. Defaults to consuming the whole iterator, but an optional second argument may be provided to limit consumption. >>> i = (x for x in range(10)) >>> next(i) 0 >>> consume(i, 3) >>> next(i) 4 >>> consume(i) >>> next(i) Traceback (most recent call last): File "<stdin>", line 1, in <module> StopIteration If the iterator has fewer items remaining than the provided limit, the whole iterator will be consumed. >>> i = (x for x in range(3)) >>> consume(i, 5) >>> next(i) Traceback (most recent call last): File "<stdin>", line 1, in <module> StopIteration """ # Use functions that consume iterators at C speed. if n is None: # feed the entire iterator into a zero-length deque deque(iterator, maxlen=0) else: # advance to the empty slice starting at position n next(islice(iterator, n, n), None) def nth(iterable, n, default=None): """Returns the nth item or a default value. >>> l = range(10) >>> nth(l, 3) 3 >>> nth(l, 20, "zebra") 'zebra' """ return next(islice(iterable, n, None), default) def all_equal(iterable): """ Returns ``True`` if all the elements are equal to each other. >>> all_equal('aaaa') True >>> all_equal('aaab') False """ g = groupby(iterable) return next(g, True) and not next(g, False) def quantify(iterable, pred=bool): """Return the how many times the predicate is true. >>> quantify([True, False, True]) 2 """ return sum(map(pred, iterable)) def pad_none(iterable): """Returns the sequence of elements and then returns ``None`` indefinitely. >>> take(5, pad_none(range(3))) [0, 1, 2, None, None] Useful for emulating the behavior of the built-in :func:`map` function. See also :func:`padded`. """ return chain(iterable, repeat(None)) padnone = pad_none def ncycles(iterable, n): """Returns the sequence elements *n* times >>> list(ncycles(["a", "b"], 3)) ['a', 'b', 'a', 'b', 'a', 'b'] """ return chain.from_iterable(repeat(tuple(iterable), n)) def dotproduct(vec1, vec2): """Returns the dot product of the two iterables. >>> dotproduct([10, 10], [20, 20]) 400 """ return sum(map(operator.mul, vec1, vec2)) def flatten(listOfLists): """Return an iterator flattening one level of nesting in a list of lists. >>> list(flatten([[0, 1], [2, 3]])) [0, 1, 2, 3] See also :func:`collapse`, which can flatten multiple levels of nesting. """ return chain.from_iterable(listOfLists) def repeatfunc(func, times=None, *args): """Call *func* with *args* repeatedly, returning an iterable over the results. If *times* is specified, the iterable will terminate after that many repetitions: >>> from operator import add >>> times = 4 >>> args = 3, 5 >>> list(repeatfunc(add, times, *args)) [8, 8, 8, 8] If *times* is ``None`` the iterable will not terminate: >>> from random import randrange >>> times = None >>> args = 1, 11 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP [2, 4, 8, 1, 8, 4] """ if times is None: return starmap(func, repeat(args)) return starmap(func, repeat(args, times)) def _pairwise(iterable): """Returns an iterator of paired items, overlapping, from the original >>> take(4, pairwise(count())) [(0, 1), (1, 2), (2, 3), (3, 4)] On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`. """ a, b = tee(iterable) next(b, None) return zip(a, b) try: from itertools import pairwise as itertools_pairwise except ImportError: pairwise = _pairwise else: def pairwise(iterable): return itertools_pairwise(iterable) pairwise.__doc__ = _pairwise.__doc__ class UnequalIterablesError(ValueError): def __init__(self, details=None): msg = 'Iterables have different lengths' if details is not None: msg += (': index 0 has length {}; index {} has length {}').format( *details ) super().__init__(msg) def _zip_equal_generator(iterables): for combo in zip_longest(*iterables, fillvalue=_marker): for val in combo: if val is _marker: raise UnequalIterablesError() yield combo def _zip_equal(*iterables): # Check whether the iterables are all the same size. try: first_size = len(iterables[0]) for i, it in enumerate(iterables[1:], 1): size = len(it) if size != first_size: raise UnequalIterablesError(details=(first_size, i, size)) # All sizes are equal, we can use the built-in zip. return zip(*iterables) # If any one of the iterables didn't have a length, start reading # them until one runs out. except TypeError: return _zip_equal_generator(iterables) def grouper(iterable, n, incomplete='fill', fillvalue=None): """Group elements from *iterable* into fixed-length groups of length *n*. >>> list(grouper('ABCDEF', 3)) [('A', 'B', 'C'), ('D', 'E', 'F')] The keyword arguments *incomplete* and *fillvalue* control what happens for iterables whose length is not a multiple of *n*. When *incomplete* is `'fill'`, the last group will contain instances of *fillvalue*. >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x')) [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] When *incomplete* is `'ignore'`, the last group will not be emitted. >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x')) [('A', 'B', 'C'), ('D', 'E', 'F')] When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised. >>> it = grouper('ABCDEFG', 3, incomplete='strict') >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... UnequalIterablesError """ args = [iter(iterable)] * n if incomplete == 'fill': return zip_longest(*args, fillvalue=fillvalue) if incomplete == 'strict': return _zip_equal(*args) if incomplete == 'ignore': return zip(*args) else: raise ValueError('Expected fill, strict, or ignore') def roundrobin(*iterables): """Yields an item from each iterable, alternating between them. >>> list(roundrobin('ABC', 'D', 'EF')) ['A', 'D', 'E', 'B', 'F', 'C'] This function produces the same output as :func:`interleave_longest`, but may perform better for some inputs (in particular when the number of iterables is small). """ # Recipe credited to George Sakkis pending = len(iterables) nexts = cycle(iter(it).__next__ for it in iterables) while pending: try: for next in nexts: yield next() except StopIteration: pending -= 1 nexts = cycle(islice(nexts, pending)) def partition(pred, iterable): """ Returns a 2-tuple of iterables derived from the input iterable. The first yields the items that have ``pred(item) == False``. The second yields the items that have ``pred(item) == True``. >>> is_odd = lambda x: x % 2 != 0 >>> iterable = range(10) >>> even_items, odd_items = partition(is_odd, iterable) >>> list(even_items), list(odd_items) ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) If *pred* is None, :func:`bool` is used. >>> iterable = [0, 1, False, True, '', ' '] >>> false_items, true_items = partition(None, iterable) >>> list(false_items), list(true_items) ([0, False, ''], [1, True, ' ']) """ if pred is None: pred = bool t1, t2, p = tee(iterable, 3) p1, p2 = tee(map(pred, p)) return (compress(t1, map(operator.not_, p1)), compress(t2, p2)) def powerset(iterable): """Yields all possible subsets of the iterable. >>> list(powerset([1, 2, 3])) [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] :func:`powerset` will operate on iterables that aren't :class:`set` instances, so repeated elements in the input will produce repeated elements in the output. Use :func:`unique_everseen` on the input to avoid generating duplicates: >>> seq = [1, 1, 0] >>> list(powerset(seq)) [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)] >>> from more_itertools import unique_everseen >>> list(powerset(unique_everseen(seq))) [(), (1,), (0,), (1, 0)] """ s = list(iterable) return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) def unique_everseen(iterable, key=None): """ Yield unique elements, preserving order. >>> list(unique_everseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D'] >>> list(unique_everseen('ABBCcAD', str.lower)) ['A', 'B', 'C', 'D'] Sequences with a mix of hashable and unhashable items can be used. The function will be slower (i.e., `O(n^2)`) for unhashable items. Remember that ``list`` objects are unhashable - you can use the *key* parameter to transform the list to a tuple (which is hashable) to avoid a slowdown. >>> iterable = ([1, 2], [2, 3], [1, 2]) >>> list(unique_everseen(iterable)) # Slow [[1, 2], [2, 3]] >>> list(unique_everseen(iterable, key=tuple)) # Faster [[1, 2], [2, 3]] Similarly, you may want to convert unhashable ``set`` objects with ``key=frozenset``. For ``dict`` objects, ``key=lambda x: frozenset(x.items())`` can be used. """ seenset = set() seenset_add = seenset.add seenlist = [] seenlist_add = seenlist.append use_key = key is not None for element in iterable: k = key(element) if use_key else element try: if k not in seenset: seenset_add(k) yield element except TypeError: if k not in seenlist: seenlist_add(k) yield element def unique_justseen(iterable, key=None): """Yields elements in order, ignoring serial duplicates >>> list(unique_justseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D', 'A', 'B'] >>> list(unique_justseen('ABBCcAD', str.lower)) ['A', 'B', 'C', 'A', 'D'] """ if key is None: return map(operator.itemgetter(0), groupby(iterable)) return map(next, map(operator.itemgetter(1), groupby(iterable, key))) def iter_except(func, exception, first=None): """Yields results from a function repeatedly until an exception is raised. Converts a call-until-exception interface to an iterator interface. Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel to end the loop. >>> l = [0, 1, 2] >>> list(iter_except(l.pop, IndexError)) [2, 1, 0] Multiple exceptions can be specified as a stopping condition: >>> l = [1, 2, 3, '...', 4, 5, 6] >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) [7, 6, 5] >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) [4, 3, 2] >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) [] """ try: if first is not None: yield first() while 1: yield func() except exception: pass def first_true(iterable, default=None, pred=None): """ Returns the first true value in the iterable. If no true value is found, returns *default* If *pred* is not None, returns the first item for which ``pred(item) == True`` . >>> first_true(range(10)) 1 >>> first_true(range(10), pred=lambda x: x > 5) 6 >>> first_true(range(10), default='missing', pred=lambda x: x > 9) 'missing' """ return next(filter(pred, iterable), default) def random_product(*args, repeat=1): """Draw an item at random from each of the input iterables. >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP ('c', 3, 'Z') If *repeat* is provided as a keyword argument, that many items will be drawn from each iterable. >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP ('a', 2, 'd', 3) This equivalent to taking a random selection from ``itertools.product(*args, **kwarg)``. """ pools = [tuple(pool) for pool in args] * repeat return tuple(choice(pool) for pool in pools) def random_permutation(iterable, r=None): """Return a random *r* length permutation of the elements in *iterable*. If *r* is not specified or is ``None``, then *r* defaults to the length of *iterable*. >>> random_permutation(range(5)) # doctest:+SKIP (3, 4, 0, 1, 2) This equivalent to taking a random selection from ``itertools.permutations(iterable, r)``. """ pool = tuple(iterable) r = len(pool) if r is None else r return tuple(sample(pool, r)) def random_combination(iterable, r): """Return a random *r* length subsequence of the elements in *iterable*. >>> random_combination(range(5), 3) # doctest:+SKIP (2, 3, 4) This equivalent to taking a random selection from ``itertools.combinations(iterable, r)``. """ pool = tuple(iterable) n = len(pool) indices = sorted(sample(range(n), r)) return tuple(pool[i] for i in indices) def random_combination_with_replacement(iterable, r): """Return a random *r* length subsequence of elements in *iterable*, allowing individual elements to be repeated. >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP (0, 0, 1, 2, 2) This equivalent to taking a random selection from ``itertools.combinations_with_replacement(iterable, r)``. """ pool = tuple(iterable) n = len(pool) indices = sorted(randrange(n) for i in range(r)) return tuple(pool[i] for i in indices) def nth_combination(iterable, r, index): """Equivalent to ``list(combinations(iterable, r))[index]``. The subsequences of *iterable* that are of length *r* can be ordered lexicographically. :func:`nth_combination` computes the subsequence at sort position *index* directly, without computing the previous subsequences. >>> nth_combination(range(5), 3, 5) (0, 3, 4) ``ValueError`` will be raised If *r* is negative or greater than the length of *iterable*. ``IndexError`` will be raised if the given *index* is invalid. """ pool = tuple(iterable) n = len(pool) if (r < 0) or (r > n): raise ValueError c = 1 k = min(r, n - r) for i in range(1, k + 1): c = c * (n - k + i) // i if index < 0: index += c if (index < 0) or (index >= c): raise IndexError result = [] while r: c, n, r = c * r // n, n - 1, r - 1 while index >= c: index -= c c, n = c * (n - r) // n, n - 1 result.append(pool[-1 - n]) return tuple(result) def prepend(value, iterator): """Yield *value*, followed by the elements in *iterator*. >>> value = '0' >>> iterator = ['1', '2', '3'] >>> list(prepend(value, iterator)) ['0', '1', '2', '3'] To prepend multiple values, see :func:`itertools.chain` or :func:`value_chain`. """ return chain([value], iterator) def convolve(signal, kernel): """Convolve the iterable *signal* with the iterable *kernel*. >>> signal = (1, 2, 3, 4, 5) >>> kernel = [3, 2, 1] >>> list(convolve(signal, kernel)) [3, 8, 14, 20, 26, 14, 5] Note: the input arguments are not interchangeable, as the *kernel* is immediately consumed and stored. """ # This implementation intentionally doesn't match the one in the itertools # documentation. kernel = tuple(kernel)[::-1] n = len(kernel) window = deque([0], maxlen=n) * n for x in chain(signal, repeat(0, n - 1)): window.append(x) yield _sumprod(kernel, window) def before_and_after(predicate, it): """A variant of :func:`takewhile` that allows complete access to the remainder of the iterator. >>> it = iter('ABCdEfGhI') >>> all_upper, remainder = before_and_after(str.isupper, it) >>> ''.join(all_upper) 'ABC' >>> ''.join(remainder) # takewhile() would lose the 'd' 'dEfGhI' Note that the first iterator must be fully consumed before the second iterator can generate valid results. """ it = iter(it) transition = [] def true_iterator(): for elem in it: if predicate(elem): yield elem else: transition.append(elem) return # Note: this is different from itertools recipes to allow nesting # before_and_after remainders into before_and_after again. See tests # for an example. remainder_iterator = chain(transition, it) return true_iterator(), remainder_iterator def triplewise(iterable): """Return overlapping triplets from *iterable*. >>> list(triplewise('ABCDE')) [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')] """ for (a, _), (b, c) in pairwise(pairwise(iterable)): yield a, b, c def sliding_window(iterable, n): """Return a sliding window of width *n* over *iterable*. >>> list(sliding_window(range(6), 4)) [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)] If *iterable* has fewer than *n* items, then nothing is yielded: >>> list(sliding_window(range(3), 4)) [] For a variant with more features, see :func:`windowed`. """ it = iter(iterable) window = deque(islice(it, n - 1), maxlen=n) for x in it: window.append(x) yield tuple(window) def subslices(iterable): """Return all contiguous non-empty subslices of *iterable*. >>> list(subslices('ABC')) [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']] This is similar to :func:`substrings`, but emits items in a different order. """ seq = list(iterable) slices = starmap(slice, combinations(range(len(seq) + 1), 2)) return map(operator.getitem, repeat(seq), slices) def polynomial_from_roots(roots): """Compute a polynomial's coefficients from its roots. >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3) >>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60 [1, -4, -17, 60] """ factors = zip(repeat(1), map(operator.neg, roots)) return list(reduce(convolve, factors, [1])) def iter_index(iterable, value, start=0, stop=None): """Yield the index of each place in *iterable* that *value* occurs, beginning with index *start* and ending before index *stop*. See :func:`locate` for a more general means of finding the indexes associated with particular values. >>> list(iter_index('AABCADEAF', 'A')) [0, 1, 4, 7] >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive [1, 4, 7] >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive [1, 4] """ seq_index = getattr(iterable, 'index', None) if seq_index is None: # Slow path for general iterables it = islice(iterable, start, stop) for i, element in enumerate(it, start): if element is value or element == value: yield i else: # Fast path for sequences stop = len(iterable) if stop is None else stop i = start - 1 try: while True: yield (i := seq_index(value, i + 1, stop)) except ValueError: pass def sieve(n): """Yield the primes less than n. >>> list(sieve(30)) [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] """ if n > 2: yield 2 start = 3 data = bytearray((0, 1)) * (n // 2) limit = math.isqrt(n) + 1 for p in iter_index(data, 1, start, limit): yield from iter_index(data, 1, start, p * p) data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p))) start = p * p yield from iter_index(data, 1, start) def _batched(iterable, n, *, strict=False): """Batch data into tuples of length *n*. If the number of items in *iterable* is not divisible by *n*: * The last batch will be shorter if *strict* is ``False``. * :exc:`ValueError` will be raised if *strict* is ``True``. >>> list(batched('ABCDEFG', 3)) [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)] On Python 3.13 and above, this is an alias for :func:`itertools.batched`. """ if n < 1: raise ValueError('n must be at least one') it = iter(iterable) while batch := tuple(islice(it, n)): if strict and len(batch) != n: raise ValueError('batched(): incomplete batch') yield batch if hexversion >= 0x30D00A2: from itertools import batched as itertools_batched def batched(iterable, n, *, strict=False): return itertools_batched(iterable, n, strict=strict) else: batched = _batched batched.__doc__ = _batched.__doc__ def transpose(it): """Swap the rows and columns of the input matrix. >>> list(transpose([(1, 2, 3), (11, 22, 33)])) [(1, 11), (2, 22), (3, 33)] The caller should ensure that the dimensions of the input are compatible. If the input is empty, no output will be produced. """ return _zip_strict(*it) def reshape(matrix, cols): """Reshape the 2-D input *matrix* to have a column count given by *cols*. >>> matrix = [(0, 1), (2, 3), (4, 5)] >>> cols = 3 >>> list(reshape(matrix, cols)) [(0, 1, 2), (3, 4, 5)] """ return batched(chain.from_iterable(matrix), cols) def matmul(m1, m2): """Multiply two matrices. >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)])) [(49, 80), (41, 60)] The caller should ensure that the dimensions of the input matrices are compatible with each other. """ n = len(m2[0]) return batched(starmap(_sumprod, product(m1, transpose(m2))), n) def factor(n): """Yield the prime factors of n. >>> list(factor(360)) [2, 2, 2, 3, 3, 5] """ for prime in sieve(math.isqrt(n) + 1): while not n % prime: yield prime n //= prime if n == 1: return if n > 1: yield n def polynomial_eval(coefficients, x): """Evaluate a polynomial at a specific value. Example: evaluating x^3 - 4 * x^2 - 17 * x + 60 at x = 2.5: >>> coefficients = [1, -4, -17, 60] >>> x = 2.5 >>> polynomial_eval(coefficients, x) 8.125 """ n = len(coefficients) if n == 0: return x * 0 # coerce zero to the type of x powers = map(pow, repeat(x), reversed(range(n))) return _sumprod(coefficients, powers) def sum_of_squares(it): """Return the sum of the squares of the input values. >>> sum_of_squares([10, 20, 30]) 1400 """ return _sumprod(*tee(it)) def polynomial_derivative(coefficients): """Compute the first derivative of a polynomial. Example: evaluating the derivative of x^3 - 4 * x^2 - 17 * x + 60 >>> coefficients = [1, -4, -17, 60] >>> derivative_coefficients = polynomial_derivative(coefficients) >>> derivative_coefficients [3, -8, -17] """ n = len(coefficients) powers = reversed(range(1, n)) return list(map(operator.mul, coefficients, powers)) def totient(n): """Return the count of natural numbers up to *n* that are coprime with *n*. >>> totient(9) 6 >>> totient(12) 4 """ for p in unique_justseen(factor(n)): n = n // p * (p - 1) return n