So here's a little challenge I need a generator for testing that produces all combinations of inputs from a number of lists. For example:
a,b,c = [1,2], [1,2,3], [1,2]
So how do we go about that?
Python's itertools.combinations
can't help us because it won't allow multiple lists as inputs.
We need itertools.product
instead:
import itertools
for combination in itertools.product(*[a,b,c]):
print(combination)
(1, 1, 1) (1, 1, 2) (1, 2, 1) (1, 2, 2) (1, 3, 1) (1, 3, 2) (2, 1, 1) (2, 1, 2) (2, 2, 1) (2, 2, 2) (2, 3, 1) (2, 3, 2)
So how does it work?
The simplest implementation I can come up with is to create a list of keys, and increment them step by step.
Then, when the key reaches it's maximum index, we reset the values up to it.
def product(scales):
keys = [0 for _ in scales]
counter = 1
for sub_scale in scales:
counter *= len(sub_scale)
for c in range(counter):
v = [sub_scale[ix] for ix, sub_scale in zip(keys, scales)]
yield v
for pointer, sub_scale in enumerate(scales):
if keys[pointer] + 1 == len(sub_scale):
keys[pointer] = 0
else:
keys[pointer] += 1
break
for combination in product([a,b,c]):
print(combination)
[1, 1, 1] [2, 1, 1] [1, 2, 1] [2, 2, 1] [1, 3, 1] [2, 3, 1] [1, 1, 2] [2, 1, 2] [1, 2, 2] [2, 2, 2] [1, 3, 2] [2, 3, 2]
With this approach I can also pick out the nth combination, simply be recalculating the indices.
def nth_combination(n, scales):
counter = 1
for sub_scale in scales:
counter *= len(sub_scale)
if not 0 < n and n <= counter:
raise ValueError(f"{n} > counter")
values = []
multiplier = 1
for scale_no, sub_scale in enumerate(scales):
ix = (n % (len(sub_scale) * multiplier)) // multiplier
multiplier *= len(sub_scale)
values.append(sub_scale[ix])
return tuple(values)
a, b, c, d = [1, 2], [1, 2, 3], [4, 5], [6, 7, 8, 9]
expected_result = list(itertools.product(*[a,b,c,d]))
all_nth_combinations = [nth_combination(n, [a,b,c,d]) for n in range(1, (2*3*2*4)+1)]
all_nth_combinations.sort()
for a,b in zip(expected_result, all_nth_combinations):
sign = "==" if a==b else "!="
print(a, sign ,b)
(1, 1, 4, 6) == (1, 1, 4, 6) (1, 1, 4, 7) == (1, 1, 4, 7) (1, 1, 4, 8) == (1, 1, 4, 8) (1, 1, 4, 9) == (1, 1, 4, 9) (1, 1, 5, 6) == (1, 1, 5, 6) (1, 1, 5, 7) == (1, 1, 5, 7) (1, 1, 5, 8) == (1, 1, 5, 8) (1, 1, 5, 9) == (1, 1, 5, 9) (1, 2, 4, 6) == (1, 2, 4, 6) (1, 2, 4, 7) == (1, 2, 4, 7) (1, 2, 4, 8) == (1, 2, 4, 8) (1, 2, 4, 9) == (1, 2, 4, 9) (1, 2, 5, 6) == (1, 2, 5, 6) (1, 2, 5, 7) == (1, 2, 5, 7) (1, 2, 5, 8) == (1, 2, 5, 8) (1, 2, 5, 9) == (1, 2, 5, 9) (1, 3, 4, 6) == (1, 3, 4, 6) (1, 3, 4, 7) == (1, 3, 4, 7) (1, 3, 4, 8) == (1, 3, 4, 8) (1, 3, 4, 9) == (1, 3, 4, 9) (1, 3, 5, 6) == (1, 3, 5, 6) (1, 3, 5, 7) == (1, 3, 5, 7) (1, 3, 5, 8) == (1, 3, 5, 8) (1, 3, 5, 9) == (1, 3, 5, 9) (2, 1, 4, 6) == (2, 1, 4, 6) (2, 1, 4, 7) == (2, 1, 4, 7) (2, 1, 4, 8) == (2, 1, 4, 8) (2, 1, 4, 9) == (2, 1, 4, 9) (2, 1, 5, 6) == (2, 1, 5, 6) (2, 1, 5, 7) == (2, 1, 5, 7) (2, 1, 5, 8) == (2, 1, 5, 8) (2, 1, 5, 9) == (2, 1, 5, 9) (2, 2, 4, 6) == (2, 2, 4, 6) (2, 2, 4, 7) == (2, 2, 4, 7) (2, 2, 4, 8) == (2, 2, 4, 8) (2, 2, 4, 9) == (2, 2, 4, 9) (2, 2, 5, 6) == (2, 2, 5, 6) (2, 2, 5, 7) == (2, 2, 5, 7) (2, 2, 5, 8) == (2, 2, 5, 8) (2, 2, 5, 9) == (2, 2, 5, 9) (2, 3, 4, 6) == (2, 3, 4, 6) (2, 3, 4, 7) == (2, 3, 4, 7) (2, 3, 4, 8) == (2, 3, 4, 8) (2, 3, 4, 9) == (2, 3, 4, 9) (2, 3, 5, 6) == (2, 3, 5, 6) (2, 3, 5, 7) == (2, 3, 5, 7) (2, 3, 5, 8) == (2, 3, 5, 8) (2, 3, 5, 9) == (2, 3, 5, 9)