Flink Iterations vs. While loop
Posted by
Dan Drewes on
URL: http://deprecated-apache-flink-user-mailing-list-archive.369.s1.nabble.com/Flink-Iterations-vs-While-loop-tp8863.html
Hi,
for my bachelor thesis I'm testing an implementation of L-BFGS
algorithm with Flink Iterations against a version without Flink
Iterations but a casual while loop instead. Both programs use the
same Map and Reduce transformations in each iteration. It was
expected, that the performance of the Flink Iterations would scale
better with increasing size of the input data set. However, the
measured results on an ibm-power-cluster are very similar for both
versions, e.g. around 30 minutes for 200 GB data. The cluster has 8
nodes, was configured with 4 slots per node and I used a total
parallelism of 32.
In every Iteration of the while loop a new flink job is started and
I thought, that also the data would be distributed over the network
again in each iteration which should consume a significant and
measurable amount of time. Is that thought wrong or what is the
computional overhead of the flink iterations that is equalizing this
disadvantage?
I include the relevant part of both programs and also attach the
generated execution plans.
Thank you for any ideas as I could not find much about this issue in
the flink docs.
Best, Dan
Flink Iterations:
DataSet<double[]> data = ...
State state = initialState(m, initweights,0,new double[initweights.length]);
DataSet<State> statedataset = env.fromElements(state);
//start of iteration section
IterativeDataSet<State> loop= statedataset.iterate(niter);;
DataSet<State> statewithnewlossgradient = data.map(difffunction).withBroadcastSet(loop, "state")
.reduce(accumulate)
.map(new NormLossGradient(datasize))
.map(new SetLossGradient()).withBroadcastSet(loop,"state")
.map(new LBFGS());
DataSet<State> converged = statewithnewlossgradient.filter(
new FilterFunction<State>() {
@Override
public boolean filter(State value) throws Exception {
if(value.getIflag()[0] == 0){
return false;
}
return true;
}
}
);
DataSet<State> finalstate = loop.closeWith(statewithnewlossgradient,converged);
While loop:
DataSet<double[]> data =...
State state = initialState(m, initweights,0,new double[initweights.length]);
int cnt=0;
do{
LBFGS lbfgs = new LBFGS();
statedataset=data.map(difffunction).withBroadcastSet(statedataset, "state")
.reduce(accumulate)
.map(new NormLossGradient(datasize))
.map(new SetLossGradient()).withBroadcastSet(statedataset,"state")
.map(lbfgs);
cnt++;
}while (cnt<niter && statedataset.collect().get(0).getIflag()[0] != 0);