MatrixMultiplication

classic Classic list List threaded Threaded
2 messages Options
Reply | Threaded
Open this post in threaded view
|

MatrixMultiplication

Lydia Ickler
Hi, 

I wrote to you before about the MatrixMultiplication in Flink … Unfortunately, the multiplication of a pair of 1000 x 1000 matrices is taking already almost a minute.
Would you please take a look at my attached code. Maybe you can suggest something to make it faster?
Or would it be better to tackle the problem with the Gelly API? (Since the matrix is an adjacency matrix). And if so how would you tackle it?

Thanks in advance and best regards, 
Lydia

package de.tuberlin.dima.aim3.assignment3;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.CsvReader;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.DataSet;


public class MatrixMultiplication {

static String input = null;
static String output = null;

public void run() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Integer, Double>> matrixA = readMatrix(env, input);

matrixA.join(matrixA).where(1).equalTo(0)
.map(new ProjectJoinResultMapper()).groupBy(0, 1).sum(2).writeAsCsv(output);


env.execute();
}



public static DataSource<Tuple3<Integer, Integer, Double>> readMatrix(ExecutionEnvironment env,
String filePath) {
CsvReader csvReader = env.readCsvFile(filePath);
csvReader.fieldDelimiter(',');
csvReader.includeFields("fttt");
return csvReader.types(Integer.class, Integer.class, Double.class);
}

public static final class ProjectJoinResultMapper implements
MapFunction<Tuple2<Tuple3<Integer, Integer, Double>,
Tuple3<Integer, Integer, Double>>,
Tuple3<Integer, Integer, Double>> {
@Override
public Tuple3<Integer, Integer, Double> map(
Tuple2<Tuple3<Integer, Integer, Double>, Tuple3<Integer, Integer, Double>> value)
throws Exception {
Integer row = value.f0.f0;
Integer column = value.f1.f1;
Double product = value.f0.f2 * value.f1.f2;
return new Tuple3<Integer, Integer, Double>(row, column, product);
}
}


public static void main(String[] args) throws Exception {
if(args.length<2){
System.err.println("Usage: MatrixMultiplication <input path> <result path>");
System.exit(0);
}
input = args[0];
output = args[1];
new MatrixMultiplication().run();
}

}

Reply | Threaded
Open this post in threaded view
|

Re: MatrixMultiplication

Till Rohrmann
Hi Lydia,

the implementation looks correct. What you could do to speed up the computation is to exploit existing partitionings in order to avoid unnecessary network shuffles. Moreover, you could block your matrices to increase the data granularity at the cost of parallelism.

Cheers,
Till

On Mon, Mar 14, 2016 at 10:11 PM, Lydia Ickler <[hidden email]> wrote:
Hi, 

I wrote to you before about the MatrixMultiplication in Flink … Unfortunately, the multiplication of a pair of 1000 x 1000 matrices is taking already almost a minute.
Would you please take a look at my attached code. Maybe you can suggest something to make it faster?
Or would it be better to tackle the problem with the Gelly API? (Since the matrix is an adjacency matrix). And if so how would you tackle it?

Thanks in advance and best regards, 
Lydia

package de.tuberlin.dima.aim3.assignment3;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.CsvReader;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.DataSet;


public class MatrixMultiplication {

static String input = null;
static String output = null;

public void run() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Integer, Double>> matrixA = readMatrix(env, input);

matrixA.join(matrixA).where(1).equalTo(0)
.map(new ProjectJoinResultMapper()).groupBy(0, 1).sum(2).writeAsCsv(output);


env.execute();
}



public static DataSource<Tuple3<Integer, Integer, Double>> readMatrix(ExecutionEnvironment env,
String filePath) {
CsvReader csvReader = env.readCsvFile(filePath);
csvReader.fieldDelimiter(',');
csvReader.includeFields("fttt");
return csvReader.types(Integer.class, Integer.class, Double.class);
}

public static final class ProjectJoinResultMapper implements
MapFunction<Tuple2<Tuple3<Integer, Integer, Double>,
Tuple3<Integer, Integer, Double>>,
Tuple3<Integer, Integer, Double>> {
@Override
public Tuple3<Integer, Integer, Double> map(
Tuple2<Tuple3<Integer, Integer, Double>, Tuple3<Integer, Integer, Double>> value)
throws Exception {
Integer row = value.f0.f0;
Integer column = value.f1.f1;
Double product = value.f0.f2 * value.f1.f2;
return new Tuple3<Integer, Integer, Double>(row, column, product);
}
}


public static void main(String[] args) throws Exception {
if(args.length<2){
System.err.println("Usage: MatrixMultiplication <input path> <result path>");
System.exit(0);
}
input = args[0];
output = args[1];
new MatrixMultiplication().run();
}

}