GeneralDistManager.java
/*******************************************************************************
* Copyright (c) 2021 Handy Tools for Distributed Computing (HanDist) project.
*
* This program and the accompanying materials are made available to you under
* the terms of the Eclipse Public License 1.0 which accompanies this
* distribution,
* and is available at https://www.eclipse.org/legal/epl-v10.html
*
* SPDX-License-Identifier: EPL-1.0
******************************************************************************/
package handist.collections.dist;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import apgas.util.GlobalID;
import handist.collections.dist.util.IntFloatPair;
import handist.collections.dist.util.IntLongPair;
public abstract class GeneralDistManager<T> implements Serializable {
private static int _debug_level = 5;
/** Serial Version UID */
private static final long serialVersionUID = 7184736551394411890L;
protected T branch;
final GlobalID id;
// @TransientInitExpr(getLocalData())
transient float[] locality;
/*
* Ensure calling updateDist() before balance() balance() should be called in
* all places
*/
public final TeamedPlaceGroup placeGroup; // may be packed into T? or globalID??
/*
* public GeneralDistManager(TeamedPlaceGroup pg, T branch) { this(pg, new
* GlobalID(), branch); }
*/
public GeneralDistManager(TeamedPlaceGroup pg, GlobalID id, T branch) {
this.id = id;
this.placeGroup = pg;
this.branch = branch;
this.locality = new float[pg.size];
Arrays.fill(locality, 1.0f);
id.putHere(branch);
}
public void balance() {
final TeamedPlaceGroup pg = placeGroup;
final GeneralDistManager<T> handle = this;
pg.broadcastFlat(() -> {
handle.teamedBalance();
});
};
public void balance(final float[] balance) {
balanceSpecCheck(balance);
final TeamedPlaceGroup pg = this.placeGroup;
final GeneralDistManager<T> handle = this;
pg.broadcastFlat(() -> {
handle.teamedBalance(balance);
});
}
protected void balanceSpecCheck(final float[] balance) {
if (balance.length != placeGroup.size) {
throw new RuntimeException("[AbstractDistCollection");
}
}
abstract public void checkDistInfo(long[] result);
/**
* Destroy an instance of AbstractDistCollection.
*/
public void destroy() {
placeGroup.remove(id);
}
abstract protected void moveAtSyncCount(final ArrayList<IntLongPair> moveList, final CollectiveMoveManager mm)
throws Exception;
/**
* Return the PlaceGroup.
*
* @return PlaceGroup.
*/
public TeamedPlaceGroup placeGroup() {
return placeGroup;
}
// TODO
// public abstract void integrate(T src);
public void teamedBalance() {
teamedBalance(new CollectiveMoveManager(placeGroup));
}
public void teamedBalance(final float[] balance) {
teamedBalance(balance, new CollectiveMoveManager(placeGroup));
}
public void teamedBalance(final float[] newLocality, final CollectiveMoveManager mm) {
// Rail.copy[Float](ne wL ocality, locality)
if (newLocality.length != placeGroup.size()) {
throw new RuntimeException("[DistCol] the size of newLocality must be the same with placeGroup.size()");
}
System.arraycopy(newLocality, 0, locality, 0, locality.length);
teamedBalance(mm);
}
// TODO
// maybe these methods should move to the interface like RelocatableCollection
// or RelocatableMap
// as default methods.
public void teamedBalance(CollectiveMoveManager mm) {
final int pgSize = placeGroup.size();
final IntFloatPair[] listPlaceLocality = new IntFloatPair[pgSize];
float localitySum = 0.0f;
long globalDataSize = 0;
final long[] localDataSize = new long[pgSize];
for (int i = 0; i < pgSize; i++) {
localitySum += locality[i];
}
checkDistInfo(localDataSize);
for (int i = 0; i < pgSize; i++) {
globalDataSize += localDataSize[i];
final float normalizeLocality = locality[i] / localitySum;
listPlaceLocality[i] = new IntFloatPair(i, normalizeLocality);
}
Arrays.sort(listPlaceLocality, (IntFloatPair a1, IntFloatPair a2) -> {
return Float.compare(a1.second, a2.second);
});
if (_debug_level > 5) {
for (final IntFloatPair pair : listPlaceLocality) {
System.out.print("(" + pair.first + ", " + pair.second + ") ");
}
System.out.println();
placeGroup.barrier(); // for debug print
}
final IntFloatPair[] cumulativeLocality = new IntFloatPair[pgSize];
float sumLocality = 0.0f;
for (int i = 0; i < pgSize; i++) {
sumLocality += listPlaceLocality[i].second;
cumulativeLocality[i] = new IntFloatPair(listPlaceLocality[i].first, sumLocality);
}
cumulativeLocality[pgSize - 1] = new IntFloatPair(listPlaceLocality[pgSize - 1].first, 1.0f);
if (_debug_level > 5) {
for (int i = 0; i < pgSize; i++) {
final IntFloatPair pair = cumulativeLocality[i];
System.out.print("(" + pair.first + ", " + pair.second + ", " + localDataSize[pair.first] + "/"
+ globalDataSize + ") ");
}
System.out.println();
placeGroup.barrier(); // for debug print
}
final ArrayList<ArrayList<IntLongPair>> moveList = new ArrayList<>(pgSize); // ArrayList(index of dest Place,
// num
// data to export)
final ArrayList<IntLongPair> stagedData = new ArrayList<>(); // ArrayList(index of src, num data to export)
long previousCumuNumData = 0;
for (int i = 0; i < pgSize; i++) {
moveList.add(new ArrayList<IntLongPair>());
}
for (int i = 0; i < pgSize; i++) {
final int placeIdx = cumulativeLocality[i].first;
final float placeLocality = cumulativeLocality[i].second;
final long cumuNumData = (long) ((globalDataSize) * placeLocality);
final long targetNumData = cumuNumData - previousCumuNumData;
if (localDataSize[placeIdx] > targetNumData) {
stagedData.add(new IntLongPair(placeIdx, localDataSize[placeIdx] - targetNumData));
if (_debug_level > 5) {
System.out.print(
"stage src: " + placeIdx + " num: " + (localDataSize[placeIdx] - targetNumData) + ", ");
}
}
previousCumuNumData = cumuNumData;
}
if (_debug_level > 5) {
System.out.println();
placeGroup.barrier(); // for debug print
}
previousCumuNumData = 0;
for (int i = 0; i < pgSize; i++) {
final int placeIdx = cumulativeLocality[i].first;
final float placeLocality = cumulativeLocality[i].second;
final long cumuNumData = (long) ((globalDataSize) * placeLocality);
final long targetNumData = cumuNumData - previousCumuNumData;
if (targetNumData > localDataSize[placeIdx]) {
long numToImport = targetNumData - localDataSize[placeIdx];
while (numToImport > 0) {
final IntLongPair pair = stagedData.remove(0);
if (pair.second > numToImport) {
moveList.get(pair.first).add(new IntLongPair(placeIdx, numToImport));
stagedData.add(new IntLongPair(pair.first, pair.second - numToImport));
numToImport = 0;
} else {
moveList.get(pair.first).add(new IntLongPair(placeIdx, pair.second));
numToImport -= pair.second;
}
}
}
previousCumuNumData = cumuNumData;
}
if (_debug_level > 5) {
for (int i = 0; i < pgSize; i++) {
for (final IntLongPair pair : moveList.get(i)) {
System.out.print("src: " + i + " dest: " + pair.first + " size: " + pair.second + ", ");
}
}
System.out.println();
placeGroup.barrier(); // for debug print
}
if (_debug_level > 5) {
final long[] diffNumData = new long[pgSize];
for (int i = 0; i < pgSize; i++) {
for (final IntLongPair pair : moveList.get(i)) {
diffNumData[i] -= pair.second;
diffNumData[pair.first] += pair.second;
}
}
for (final IntFloatPair pair : listPlaceLocality) {
System.out.print("(" + pair.first + ", " + pair.second + ", "
+ (localDataSize[pair.first] + diffNumData[pair.first]) + "/" + globalDataSize + ") ");
}
System.out.println();
}
try {
moveAtSyncCount(moveList.get(placeGroup.myrank), mm);
} catch (final Exception e) {
e.printStackTrace();
throw new Error("[AbstractDistCollection] data transfer error raised.");
}
}
// abstract public Object writeReplace() throws ObjectStreamException;
// return new LaObjectReference(id, ()->{ new AbstractDistCollection<>());
/*
* public final def printAllData(){ for(p in placeGroup){ at(p){
* printLocalData(); } } }
*/
}