首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >在python中断言相等长度的zip迭代器

在python中断言相等长度的zip迭代器
EN

Stack Overflow用户
提问于 2015-10-06 01:32:57
回答 5查看 7.9K关注 0票数 36

我正在寻找一种很好的方法来zip几个迭代器,如果迭代器的长度不相等,就会引发异常。

在迭代器是列表或者有一个len方法的情况下,这个解决方案是干净而简单的:

def zip_equal(it1, it2):
    if len(it1) != len(it2):
        raise ValueError("Lengths of iterables are different")
    return zip(it1, it2)

但是,如果it1it2是生成器,则前面的函数会失败,因为长度未定义为TypeError: object of type 'generator' has no len()

我想itertools模块提供了一种简单的方法来实现它,但是到目前为止我还没有找到它。我想出了这个自制的解决方案:

def zip_equal(it1, it2):
    exhausted = False
    while True:
        try:
            el1 = next(it1)
            if exhausted: # in a previous iteration it2 was exhausted but it1 still has elements
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            exhausted = True
            # it2 must be exhausted too.
        try:
            el2 = next(it2)
            # here it2 is not exhausted.
            if exhausted:  # it1 was exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            # here it2 is exhausted
            if not exhausted:
                # but it1 was not exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
            exhausted = True
        if not exhausted:
            yield (el1, el2)
        else:
            return

该解决方案可以使用以下代码进行测试:

it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it1, it2))           # len(it1) < len(it2) => raise
it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it2, it1))           # len(it2) > len(it1) => raise
it1 = (x for x in ['a', 'b', 'c', 'd'])  # it1 has length 4
it2 = (x for x in [0, 1, 2, 3])          # it2 has length 4
list(zip_equal(it1, it2))                # like zip (or izip in python2)

我是否忽略了任何替代解决方案?我的zip_equal函数有没有更简单的实现?

更新:

对于需要Python3.10或更高版本的answer

  • Thorough性能基准测试和最佳性能解决方案:Stefan的answer

  • Simple

  • Simple answer (请查看注释以获得某些角落情况下的错误修复)

  • 比Martijn的python更复杂,但具有更好的性能:cjerdonek的python您不介意包依赖,请参阅pylang的python
EN

回答 5

Stack Overflow用户

回答已采纳

发布于 2021-10-07 17:14:38

一种新的解决方案,甚至比它所基于的cjerdonek解决方案更快,并且是一个基准。首先,我的解决方案是绿色的。请注意,“总大小”在所有情况下都是相同的,即两百万个值。X轴是可迭代的数量。从1个具有200万个值的迭代器,然后是2个具有100万个值的迭代器,一直到100,000个每个具有20个值的迭代器。

黑色的是Python的zip,我在这里使用的是Python3.8,所以它不会做这个问题的检查等长的任务,但我把它作为一个人所能希望的最大速度的参考/限制。您可以看到我的解决方案非常接近。

对于最常见的压缩两个可迭代程序的情况,我的速度几乎是cjerdonek之前最快的解决方案的三倍,但比zip慢不了多少。文本形式的时间:

         number of iterables     1     2     3     4     5    10   100  1000 10000 50000 100000
-----------------------------------------------------------------------------------------------
       more_itertools__pylang 209.3 132.1 105.8  93.7  87.4  74.4  54.3  51.9  53.9  66.9  84.5
   fillvalue__Martijn_Pieters 159.1 101.5  85.6  74.0  68.8  59.0  44.1  43.0  44.9  56.9  72.0
     chain_raising__cjerdonek  58.5  35.1  26.3  21.9  19.7  16.6  10.4  12.7  34.4 115.2 223.2
     ziptail__Stefan_Pochmann  10.3  12.4  10.4   9.2   8.7   7.8   6.7   6.8   9.4  22.6  37.8
                          zip  10.3   8.5   7.8   7.4   7.4   7.1   6.4   6.8   9.0  19.4  32.3

My code (Try it online!):

def zip_equal(*iterables):

    # For trivial cases, use pure zip.
    if len(iterables) < 2:
        return zip(*iterables)

    # Tail for the first iterable
    first_stopped = False
    def first_tail():
        nonlocal first_stopped 
        first_stopped = True
        return
        yield

    # Tail for the zip
    def zip_tail():
        if not first_stopped:
            raise ValueError('zip_equal: first iterable is longer')
        for _ in chain.from_iterable(rest):
            raise ValueError('zip_equal: first iterable is shorter')
            yield

    # Put the pieces together
    iterables = iter(iterables)
    first = chain(next(iterables), first_tail())
    rest = list(map(iter, iterables))
    return chain(zip(first, *rest), zip_tail())

基本的想法是让zip(*iterables)做所有的工作,然后在它停止之后,检查是否所有的迭代器都是相同的长度。当且仅当:

  1. zip已停止,因为第一个可迭代对象没有其他元素(即,其他可迭代对象不再具有任何其他元素(即,不再有其他可迭代对象)。

如何检查这些条件:

  • 因为我需要在zip结束后检查这些条件,所以我不能完全返回zip对象。相反,我在后面链接了一个执行检查的空zip_tail迭代器。
  • 为了支持检查第一个条件,我在它后面链接了一个空first_tail迭代器,它的唯一工作是记录第一个可迭代对象的迭代停止(即,它被要求提供另一个元素,但它没有一个元素,所以要求first_tail迭代器提供一个元素)。
  • 为了支持检查第二个条件,我获取所有其他可迭代对象的迭代器,并在将它们提供给iterable之前将它们保存在列表中

附注: more-itertools几乎使用与Martijn相同的方法,但会进行正确的is检查,而不是Martijn的not quite correct sentinel in combo。这可能是它变慢的主要原因。

基准代码(Try it online!):

import timeit
import itertools
from itertools import repeat, chain, zip_longest
from collections import deque
from sys import hexversion, maxsize

#-----------------------------------------------------------------------------
# Solution by Martijn Pieters
#-----------------------------------------------------------------------------

def zip_equal__fillvalue__Martijn_Pieters(*iterables):
    sentinel = object()
    for combo in zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError('Iterables have different lengths')
        yield combo

#-----------------------------------------------------------------------------
# Solution by pylang
#-----------------------------------------------------------------------------

def zip_equal__more_itertools__pylang(*iterables):
    return more_itertools__zip_equal(*iterables)

_marker = object()

def _zip_equal_generator(iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        for val in combo:
            if val is _marker:
                raise UnequalIterablesError()
        yield combo

def more_itertools__zip_equal(*iterables):
    """``zip`` the input *iterables* together, but raise
    ``UnequalIterablesError`` if they aren't all the same length.

        >>> it_1 = range(3)
        >>> it_2 = iter('abc')
        >>> list(zip_equal(it_1, it_2))
        [(0, 'a'), (1, 'b'), (2, 'c')]

        >>> it_1 = range(3)
        >>> it_2 = iter('abcd')
        >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
        ...
        more_itertools.more.UnequalIterablesError: Iterables have different
        lengths

    """
    if hexversion >= 0x30A00A6:
        warnings.warn(
            (
                'zip_equal will be removed in a future version of '
                'more-itertools. Use the builtin zip function with '
                'strict=True instead.'
            ),
            DeprecationWarning,
        )
    # Check whether the iterables are all the same size.
    try:
        first_size = len(iterables[0])
        for i, it in enumerate(iterables[1:], 1):
            size = len(it)
            if size != first_size:
                break
        else:
            # If we didn't break out, we can use the built-in zip.
            return zip(*iterables)

        # If we did break out, there was a mismatch.
        raise UnequalIterablesError(details=(first_size, i, size))
    # If any one of the iterables didn't have a length, start reading
    # them until one runs out.
    except TypeError:
        return _zip_equal_generator(iterables)

#-----------------------------------------------------------------------------
# Solution by cjerdonek
#-----------------------------------------------------------------------------

class ExhaustedError(Exception):
    def __init__(self, index):
        """The index is the 0-based index of the exhausted iterable."""
        self.index = index

def raising_iter(i):
    """Return an iterator that raises an ExhaustedError."""
    raise ExhaustedError(i)
    yield

def terminate_iter(i, iterable):
    """Return an iterator that raises an ExhaustedError at the end."""
    return itertools.chain(iterable, raising_iter(i))

def zip_equal__chain_raising__cjerdonek(*iterables):
    iterators = [terminate_iter(*args) for args in enumerate(iterables)]
    try:
        yield from zip(*iterators)
    except ExhaustedError as exc:
        index = exc.index
        if index > 0:
            raise RuntimeError('iterable {} exhausted first'.format(index)) from None
        # Check that all other iterators are also exhausted.
        for i, iterator in enumerate(iterators[1:], start=1):
            try:
                next(iterator)
            except ExhaustedError:
                pass
            else:
                raise RuntimeError('iterable {} is longer'.format(i)) from None
            
#-----------------------------------------------------------------------------
# Solution by Stefan Pochmann
#-----------------------------------------------------------------------------

def zip_equal__ziptail__Stefan_Pochmann(*iterables):

    # For trivial cases, use pure zip.
    if len(iterables) < 2:
        return zip(*iterables)

    # Tail for the first iterable
    first_stopped = False
    def first_tail():
        nonlocal first_stopped 
        first_stopped = True
        return
        yield

    # Tail for the zip
    def zip_tail():
        if not first_stopped:
            raise ValueError(f'zip_equal: first iterable is longer')
        for _ in chain.from_iterable(rest):
            raise ValueError(f'zip_equal: first iterable is shorter')
            yield

    # Put the pieces together
    iterables = iter(iterables)
    first = chain(next(iterables), first_tail())
    rest = list(map(iter, iterables))
    return chain(zip(first, *rest), zip_tail())

#-----------------------------------------------------------------------------
# List of solutions to be speedtested
#-----------------------------------------------------------------------------

solutions = [
    zip_equal__more_itertools__pylang,
    zip_equal__fillvalue__Martijn_Pieters,
    zip_equal__chain_raising__cjerdonek,
    zip_equal__ziptail__Stefan_Pochmann,
    zip,
]

def name(solution):
    return solution.__name__[11:] or 'zip'

#-----------------------------------------------------------------------------
# The speedtest code
#-----------------------------------------------------------------------------

def test(m, n):
    """Speedtest all solutions with m iterables of n elements each."""

    all_times = {solution: [] for solution in solutions}
    def show_title():
        print(f'{m} iterators of length {n:,}:')
    if verbose: show_title()
    def show_times(times, solution):
        print(*('%3d ms ' % t for t in times),
              name(solution))
        
    for _ in range(3):
        for solution in solutions:
            times = sorted(timeit.repeat(lambda: deque(solution(*(repeat(i, n) for i in range(m))), 0), number=1, repeat=5))[:3]
            times = [round(t * 1e3, 3) for t in times]
            all_times[solution].append(times)
            if verbose: show_times(times, solution)
        if verbose: print()
        
    if verbose:
        print('best by min:')
        show_title()
        for solution in solutions:
            show_times(min(all_times[solution], key=min), solution)
        print('best by max:')
    show_title()
    for solution in solutions:
        show_times(min(all_times[solution], key=max), solution)
    print()

    stats.append((m,
                  [min(all_times[solution], key=min)
                   for solution in solutions]))

#-----------------------------------------------------------------------------
# Run the speedtest for several numbers of iterables
#-----------------------------------------------------------------------------

stats = []
verbose = False
total_elements = 2 * 10**6
for m in 1, 2, 3, 4, 5, 10, 100, 1000, 10000, 50000, 100000:
    test(m, total_elements // m)

#-----------------------------------------------------------------------------
# Print the speedtest results for use in the plotting script
#-----------------------------------------------------------------------------

print('data for plotting by https://replit.com/@pochmann/zipequal-plot')
names = [name(solution) for solution in solutions]
print(f'{names = }')
print(f'{stats = }')

用于绘图/表格的代码(也称为at Replit):

import matplotlib.pyplot as plt

names = ['more_itertools__pylang', 'fillvalue__Martijn_Pieters', 'chain_raising__cjerdonek', 'ziptail__Stefan_Pochmann', 'zip']
stats = [(1, [[208.762, 211.211, 214.189], [159.568, 162.233, 162.24], [57.668, 58.94, 59.23], [10.418, 10.583, 10.723], [10.057, 10.443, 10.456]]), (2, [[130.065, 130.26, 130.52], [100.314, 101.206, 101.276], [34.405, 34.998, 35.188], [12.152, 12.473, 12.773], [8.671, 8.857, 9.395]]), (3, [[106.417, 107.452, 107.668], [90.693, 91.154, 91.386], [26.908, 27.863, 28.145], [10.457, 10.461, 10.789], [8.071, 8.157, 8.228]]), (4, [[97.547, 98.686, 98.726], [77.076, 78.31, 79.381], [23.134, 23.176, 23.181], [9.321, 9.4, 9.581], [7.541, 7.554, 7.635]]), (5, [[86.393, 88.046, 88.222], [68.633, 69.649, 69.742], [19.845, 20.006, 20.135], [8.726, 8.935, 9.016], [7.201, 7.26, 7.304]]), (10, [[70.384, 71.762, 72.473], [57.87, 58.149, 58.411], [15.808, 16.252, 16.262], [7.568, 7.57, 7.864], [6.732, 6.888, 6.911]]), (100, [[53.108, 54.245, 54.465], [44.436, 44.601, 45.226], [10.502, 11.073, 11.109], [6.721, 6.733, 6.847], [6.753, 6.774, 6.815]]), (1000, [[52.119, 52.476, 53.341], [42.775, 42.808, 43.649], [12.538, 12.853, 12.862], [6.802, 6.971, 7.002], [6.679, 6.724, 6.838]]), (10000, [[54.802, 55.006, 55.187], [45.981, 46.066, 46.735], [34.416, 34.672, 35.009], [9.485, 9.509, 9.626], [9.036, 9.042, 9.112]]), (50000, [[66.681, 66.98, 67.441], [56.593, 57.341, 57.631], [113.988, 114.022, 114.106], [22.088, 22.412, 22.595], [19.412, 19.431, 19.934]]), (100000, [[86.846, 88.111, 88.258], [74.796, 75.431, 75.927], [218.977, 220.182, 223.343], [38.89, 39.385, 39.88], [32.332, 33.117, 33.594]])]

colors = {
    'more_itertools__pylang': 'm',
    'fillvalue__Martijn_Pieters': 'red',
    'chain_raising__cjerdonek': 'gold',
    'ziptail__Stefan_Pochmann': 'lime',
    'zip': 'black',
}

ns = [n for n, _ in stats]
print('%28s' % 'number of iterables', *('%5d' % n for n in ns))
print('-' * 95)
x = range(len(ns))
for i, name in enumerate(names):
    ts = [min(tss[i]) for _, tss in stats]
    color = colors[name]
    if color:
        plt.plot(x, ts, '.-', color=color, label=name)
        print('%29s' % name, *('%5.1f' % t for t in ts))
plt.xticks(x, ns, size=9)
plt.ylim(0, 133)
plt.title('zip_equal(m iterables with 2,000,000/m values each)', weight='bold')
plt.xlabel('Number of zipped *iterables* (not their lengths)', weight='bold')
plt.ylabel('Time (for complete iteration) in milliseconds', weight='bold')
plt.legend(loc='upper center')
#plt.show()
plt.savefig('zip_equal_plot.png', dpi=200)
票数 4
EN

Stack Overflow用户

发布于 2015-10-06 01:45:04

我可以想到一个更简单的解决方案,使用itertools.zip_longest()并在生成的元组中存在用于填充较短迭代器的前端值时引发异常:

from itertools import zip_longest

def zip_equal(*iterables):
    sentinel = object()
    for combo in zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError('Iterables have different lengths')
        yield combo

不幸的是,我们不能结合使用zip()yield from来避免每次迭代都要进行测试的Python代码循环;一旦最短的迭代器用完,zip()将推进前面所有的迭代器,因此如果这些迭代器中只有一个额外的项,那么就会吞噬证据。

票数 28
EN

Stack Overflow用户

发布于 2016-11-15 03:27:28

这里有一种方法,它不需要对迭代的每个循环进行任何额外的检查。这可能是可取的,特别是对于长可迭代。

这个想法是在每个迭代器的末尾填充一个“值”,当到达时会引发一个异常,然后只在最后进行所需的验证。该方法使用zip()itertools.chain()

下面的代码是为Python 3.5编写的。

import itertools

class ExhaustedError(Exception):
    def __init__(self, index):
        """The index is the 0-based index of the exhausted iterable."""
        self.index = index

def raising_iter(i):
    """Return an iterator that raises an ExhaustedError."""
    raise ExhaustedError(i)
    yield

def terminate_iter(i, iterable):
    """Return an iterator that raises an ExhaustedError at the end."""
    return itertools.chain(iterable, raising_iter(i))

def zip_equal(*iterables):
    iterators = [terminate_iter(*args) for args in enumerate(iterables)]
    try:
        yield from zip(*iterators)
    except ExhaustedError as exc:
        index = exc.index
        if index > 0:
            raise RuntimeError('iterable {} exhausted first'.format(index)) from None
        # Check that all other iterators are also exhausted.
        for i, iterator in enumerate(iterators[1:], start=1):
            try:
                next(iterator)
            except ExhaustedError:
                pass
            else:
                raise RuntimeError('iterable {} is longer'.format(i)) from None

下面是它被使用时的样子。

>>> list(zip_equal([1, 2], [3, 4], [5, 6]))
[(1, 3, 5), (2, 4, 6)]

>>> list(zip_equal([1, 2], [3], [4]))
RuntimeError: iterable 1 exhausted first

>>> list(zip_equal([1], [2, 3], [4]))
RuntimeError: iterable 1 is longer

>>> list(zip_equal([1], [2], [3, 4]))
RuntimeError: iterable 2 is longer
票数 6
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/32954486

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档