Source code for fcdmft.utils.worksteal

from mpi4py import MPI
import numpy as np


[docs] class MPIWorkStealingScheduler: def __init__(self, cartcomm=None): """MPI 3 RMA-based work stealing scheduler Parameters ---------- cartcomm : mpi4py.MPI.Cartcomm, optional 1-D Cartesian communicator """ if cartcomm is None: commworld = MPI.COMM_WORLD cartcomm = commworld.Create_cart( dims=[commworld.Get_size()], periods=[True], reorder=True ) assert cartcomm.Get_dim() == 1, "Cartcomm must be 1-D" self.cartcomm = cartcomm self.workitems = None self.nwork_by_rank = None self.cur_job_win = None self.cur_job_local = None self.global_workitems = None self.one = np.array([1], dtype=np.int64) self.buf_one = np.zeros(1, dtype=np.int64) self.cur_victim = cartcomm.Get_rank() self.empty = False
[docs] def initialize_workitems(self, workitems): """Put jobs in the scheduler. A window is then created on each rank to store the current job index. Parameters ---------- workitems : array_like of int 2D array (or list of tuples) of jobs, one per row. """ self.workitems = np.asarray(workitems, dtype=np.int64) self.global_workitems = self.cartcomm.allgather(self.workitems) self.nwork_by_rank = [w.shape[0] for w in self.global_workitems] self.cur_job_win = MPI.Win.Allocate(size=MPI.INT64_T.size, comm=self.cartcomm) self.cur_job_local = np.frombuffer(self.cur_job_win.tomemory(), dtype=np.int64) self.cur_job_win.Lock_all() self.cur_job_local[0] = 0 self.cur_job_win.Sync()
[docs] def finalize(self): """Free the RMA window and clean up. """ self.cur_job_win.Unlock_all() self.cur_job_win.Free() self.global_workitems = None self.workitems = None
[docs] def get_workitems(self): """Generator that yields jobs to be done. The main work-stealing loop lives here. Yields ------ ndarray 1D array of np.int64; a row of the global workitems array. """ while True: # try to get a job from the victim's queue self.cur_job_win.Get_accumulate( [self.one, 1, MPI.INT64_T], [self.buf_one, 1, MPI.INT64_T], target_rank=self.cur_victim, ) self.cur_job_win.Flush_local(self.cur_victim) job_idx = self.buf_one[0] # victim's queue has a job. steal it! if job_idx < self.nwork_by_rank[self.cur_victim]: # steal job: yield self.global_workitems[self.cur_victim][job_idx] # victim's queue is empty. try next victim else: self.cur_victim = (self.cur_victim + 1) % self.cartcomm.size # wrapped all the way around. Done! if self.cur_victim == self.cartcomm.rank: break
if __name__ == "__main__": import sys scheduler = MPIWorkStealingScheduler() rank = scheduler.cartcomm.Get_rank() size = scheduler.cartcomm.Get_size() nworkperrank = 40 workprefactor = 1000 loc_workitems = [ (i + nworkperrank * rank, np.random.randint(workprefactor)) for i in range(nworkperrank) ] glob_workitems = scheduler.cartcomm.allgather(loc_workitems) scheduler.initialize_workitems(np.asarray(loc_workitems)) print("Initialized work items.") local_done = [] for idx, size in scheduler.get_workitems(): x = np.random.random((size, size)) y = x @ x.T + x.T @ x del x del y local_done.append(idx) global_done = scheduler.cartcomm.allgather(local_done) done = np.sort(np.concatenate(global_done, axis=0)) sys.stdout.flush() scheduler.cartcomm.barrier() print(f"Rank {rank} did {len(local_done)} jobs") sys.stdout.flush() scheduler.cartcomm.barrier() if rank == 0: print(f"{len(done)} jobs done out of {nworkperrank*scheduler.cartcomm.size}") uniq, cts = np.unique(done, return_counts=True) print(f"{uniq.size} of these jobs unique") if uniq.size == nworkperrank * scheduler.cartcomm.size: print("All jobs done") scheduler.finalize() MPI.Finalize()