import ipyparallel as ipp
rc = ipp.Client()
e_all = rc[:]
Create groups of size 8
%%px
import mpi4py
from mpi4py import MPI
world = MPI.COMM_WORLD
world_group = world.Get_group()
# setup sub comms/groups
N = 8 # number of processes per group
my_root = world.rank - (world.rank % N)
group = world_group.Incl(range(my_root, my_root + N))
# identify groups by their rank 0 root node
my_root = group.Translate_ranks(group, [0], world_group)[0]
# create a comm for the group
comm = world.Create(group)
"{}/{} root={}".format(group.rank, group.size, my_root)
Out[0:41]: '1/8 root=0'
Out[1:43]: '0/8 root=0'
Out[2:41]: '3/8 root=0'
Out[3:43]: '2/8 root=0'
Out[4:41]: '5/8 root=0'
Out[5:43]: '4/8 root=0'
Out[6:41]: '7/8 root=0'
Out[7:43]: '6/8 root=0'
Out[8:41]: '3/8 root=8'
Out[9:43]: '0/8 root=8'
Out[10:41]: '1/8 root=8'
Out[11:43]: '2/8 root=8'
Out[12:43]: '4/8 root=8'
Out[13:41]: '5/8 root=8'
Out[14:41]: '7/8 root=8'
Out[15:43]: '6/8 root=8'
Get roots from all engines to identify the groups, and create a DirectView for each group of engines:
roots = e_all.apply_async(lambda : my_root).get_dict()
roots
{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 8, 9: 8, 10: 8, 11: 8, 12: 8, 13: 8, 14: 8, 15: 8}
groups = {}
for eid, root in roots.items():
if root not in groups:
groups[root] = []
groups[root].append(eid)
views = [ rc[group] for group in groups.values() ]
views
[<DirectView [0, 1, 2, 3,...]>, <DirectView [8, 9, 10, 1...]>]
Activate magics with integer suffixes for each sub-view:
for i, view in enumerate(views):
view.activate(str(i))
%%px0 --block
print("{}/{} (world: {}/{})".format(group.rank, group.size, world.rank, world.size))
[stdout:0] 1/8 (world: 1/16) [stdout:1] 0/8 (world: 0/16) [stdout:2] 3/8 (world: 3/16) [stdout:3] 2/8 (world: 2/16) [stdout:4] 5/8 (world: 5/16) [stdout:5] 4/8 (world: 4/16) [stdout:6] 7/8 (world: 7/16) [stdout:7] 6/8 (world: 6/16)
%%px1 --block
print("{}/{} (world: {}/{})".format(group.rank, group.size, world.rank, world.size))
[stdout:8] 3/8 (world: 11/16) [stdout:9] 0/8 (world: 8/16) [stdout:10] 1/8 (world: 9/16) [stdout:11] 2/8 (world: 10/16) [stdout:12] 4/8 (world: 12/16) [stdout:13] 5/8 (world: 13/16) [stdout:14] 7/8 (world: 15/16) [stdout:15] 6/8 (world: 14/16)
%%px
comm.rank, comm.size
Out[0:45]: (1, 8)
Out[1:47]: (0, 8)
Out[2:45]: (3, 8)
Out[3:47]: (2, 8)
Out[4:45]: (5, 8)
Out[5:47]: (4, 8)
Out[6:45]: (7, 8)
Out[7:47]: (6, 8)
Out[8:45]: (3, 8)
Out[9:47]: (0, 8)
Out[10:45]: (1, 8)
Out[11:47]: (2, 8)
Out[12:47]: (4, 8)
Out[13:45]: (5, 8)
Out[14:45]: (7, 8)
Out[15:47]: (6, 8)