LoadBalancer.java
/*******************************************************************************
* Copyright (c) 2021 Handy Tools for Distributed Computing (HanDist) project.
*
* This program and the accompanying materials are made available to you under
* the terms of the Eclipse Public License 1.0 which accompanies this
* distribution,
* and is available at https://www.eclipse.org/legal/epl-v10.html
*
* SPDX-License-Identifier: EPL-1.0
******************************************************************************/
package handist.collections.dist;
import static apgas.Constructs.*;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.BiFunction;
import apgas.Place;
import handist.collections.dist.util.ObjectInput;
import handist.collections.dist.util.ObjectOutput;
import mpi.MPI;
import mpi.MPIException;
//TODO not used now.
// for internal use
// this is a class for the load balancing
@SuppressWarnings("deprecation")
abstract class LoadBalancer {
// private static Place tmpRoot = Place(0);
static final class ListBalancer<T> extends LoadBalancer {
private final List<T> body;
public ListBalancer(List<T> body, TeamedPlaceGroup pg) {
super(pg);
this.body = body;
}
@Override
void exportOne(ObjectOutput out) throws IOException {
final T one = body.remove(body.size() - 1);
out.writeObject(one);
}
@Override
@SuppressWarnings("unchecked")
void importOne(ObjectInput in) throws ClassNotFoundException, IOException {
body.add((T) in.readObject());
}
@Override
int localSize() {
return body.size();
}
}
static final class MapBalancer<K, V> extends LoadBalancer {
private final Map<K, V> body;
public MapBalancer(Map<K, V> body, TeamedPlaceGroup pg) {
super(pg);
this.body = body;
}
@Override
void exportOne(ObjectOutput out) throws IOException {
assert (!body.isEmpty());
final K key = body.keySet().iterator().next();
final V v = body.remove(key);
out.writeObject(key);
out.writeObject(v);
}
@Override
@SuppressWarnings("unchecked")
void importOne(ObjectInput obj) throws ClassNotFoundException, IOException {
final K key = (K) obj.readObject();
final V v = (V) obj.readObject();
body.put(key, v);
}
@Override
int localSize() {
return body.size();
}
}
private final int myRole;
private final TeamedPlaceGroup pg;
private final ArrayList<Integer> receivers;
private final Place root;
private final ArrayList<Integer> senders;
public LoadBalancer(/* List<T> list, */ TeamedPlaceGroup pg) {
// this.list = list;
this.pg = pg;
root = pg.get(0);
myRole = pg.rank(here());
senders = new ArrayList<>(pg.size());
receivers = new ArrayList<>(pg.size());
}
public void execute() {
if (pg.size() == 1) {
return;
}
try {
relocate(getMoveCount());
} catch (final MPIException e) {
e.printStackTrace();
throw new Error("MPI Exception");
}
}
abstract void exportOne(ObjectOutput out) throws IOException;
// return (fromId, toId) => moveCount
private BiFunction<Integer, Integer, Integer> getMoveCount() throws MPIException {
final int np = pg.size();
final int[] matrix = new int[np * np];
senders.clear();
receivers.clear();
final long[] tmpOverCounts = new long[np];
Arrays.fill(tmpOverCounts, localSize());
final long[] overCounts = new long[np];
// team.alltoall(tmpOverCounts, 0, overCounts, 0, 1);
pg.comm.Alltoall(tmpOverCounts, 0, 1, MPI.LONG, overCounts, 0, 1, MPI.LONG);
long total = 0;
for (int i = 0; i < np; i++) {
total = total + overCounts[i];
}
final long average = total / np;
for (int i = 0; i < np; i++) {
overCounts[i] = overCounts[i] - average;
}
for (int i = 0; i < np; i++) {
final long overCount = overCounts[i];
if (overCount < 0) {
receivers.add(i);
} else if (overCount > 0) {
senders.add(i);
}
}
if ((here().equals(root)) && (0 < senders.size()) && (0 < receivers.size())) {
final Integer[] sendersX = new Integer[senders.size()];
final Integer[] receiversX = new Integer[receivers.size()];
senders.toArray(sendersX);
final Random random = new Random();
for (int i = 0; i < sendersX.length; i++) {
final Integer j = random.nextInt(sendersX.length);
final Integer tmp = sendersX[j];
sendersX[j] = sendersX[i];
sendersX[i] = tmp;
}
receivers.toArray(receiversX);
final Comparator<Integer> comp = (Integer a, Integer b) -> {
return Long.compare(overCounts[b], overCounts[a]);
};
Arrays.sort(receiversX, comp);
int senderPointer = 0;
int receiverPointer = 1;
while ((receiverPointer < receiversX.length) && (senderPointer < sendersX.length)) {
final int i = sendersX[senderPointer];
final int j = receiversX[receiverPointer - 1];
final int k = receiversX[receiverPointer];
while ((overCounts[k] < overCounts[j]) && (0 < overCounts[i])) {
overCounts[i]--;
overCounts[k]++;
matrix[np * i + k]++;
}
if (overCounts[j] == overCounts[k]) {
receiverPointer++;
}
if (overCounts[i] == 0) {
senderPointer++;
}
}
while (senderPointer < sendersX.length) {
final int i = sendersX[senderPointer];
while (0 < overCounts[i]) {
receiverPointer = (receiverPointer + 1) % receiversX.length;
final int j = receiversX[receiverPointer];
overCounts[i]--;
overCounts[j]++;
matrix[np * i + j]++;
}
senderPointer++;
}
}
// team.bcast(tmpRoot, matrix, 0, matrix, 0, np * np);
pg.comm.Bcast(matrix, 0, np * np, MPI.INT, pg.rank(root));
final BiFunction<Integer, Integer, Integer> func = (Integer i0, Integer j0) -> {
return matrix[np * i0 + j0];
};
return func;
}
abstract void importOne(ObjectInput in) throws ClassNotFoundException, IOException;
// private List<T> list;
abstract int localSize();
// execute relocation using getCount function
private void relocate(BiFunction<Integer, Integer, Integer> getCount) {
try {
final int np = pg.size();
final ByteArrayOutputStream s0 = new ByteArrayOutputStream();
final int[] scounts = new int[np];
final int[] sdispls = new int[np];
final int[] rcounts = new int[np];
int s0used = 0;
for (int j = 0; j < np; j++) {
final int count = getCount.apply(myRole, j);
if (count > 0) {
final ObjectOutput s = new ObjectOutput(s0);
s.writeInt(count);
for (int k = 0; k < count; k++) {
exportOne(s);
}
s.close();
final int prev = s0used;
sdispls[j] = prev;
s0used = s0.size();
scounts[j] = s0used - prev;
} else {
sdispls[j] = s0used;
scounts[j] = 0;
}
}
pg.comm.Alltoall(scounts, 0, 1, MPI.INT, rcounts, 0, 1, MPI.INT);
final byte[] sendbuf = s0.toByteArray();
final int[] rdispls = new int[np];
int rused = 0;
for (int i = 0; i < np; i++) {
rdispls[i] = rused;
rused += rcounts[i];
}
final byte[] recvbuf = new byte[rused];
pg.Alltoallv(sendbuf, 0, scounts, sdispls, MPI.BYTE, recvbuf, 0, rcounts, rdispls, MPI.BYTE);
for (int i = 0; i < np; i++) {
if (rcounts[i] == 0) {
continue;
}
final ByteArrayInputStream in = new ByteArrayInputStream(recvbuf, rdispls[i], rcounts[i]);
final ObjectInput ds = new ObjectInput(in);
final int count = ds.readInt();
assert (getCount.apply(i, myRole) == count);
for (int k = 0; k < count; k++) {
importOne(ds);
}
ds.close();
}
} catch (final Exception e) {
e.printStackTrace(System.err);
throw new Error("Exception during LoadBalance Relocation.");
}
}
}
/*
* // test class Main {
*
* public static def main(args: Rail[String]): void { val o = new Main();
* o.run(); }
*
* def run(): void { val pg = Place.places(); val team = TeamOperations(pg);
* pg.broadcastFlat(() => { val executor = new Executor(pg, team);
* executor.start(); }); }
*
* static class Executor(pg: PlaceGroup, team: TeamOperations) {
*
* transient var map: x10.util.Map[Long, String];
*
* def start(): void { initialize(() => 1000000); balance(); }
*
* def initialize(count: ()=>Long): void { val begin = System.nanoTime(); if
* (map == null) { map = new x10.util.HashMap[Long, String](); } val num =
* count(); for (var i: Long = 0; i < num; i++) { map(num * here.id + i) =
* i.toString(); } val end = System.nanoTime(); System.out.println(here +
* " initialize " + ((end - begin) * 1e-6) + " ms"); }
*
* def balance(): void { val begin = System.nanoTime(); val al = new
* ArrayList[x10.util.Map.Entry[Long, String]](map.size());
* al.addAll(map.entries()); map.clear(); val balancer = new
* LoadBalancer[x10.util.Map.Entry[Long, String]](al, pg, team);
* balancer.execute(); for (e in al) { map(e.getKey()) = e.getValue(); } val end
* = System.nanoTime(); System.out.println(here + " balance " + ((end - begin) *
* 1e-6) + " ms"); } } }
*/