package org.apache.flink.examples.java.clustering; /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.List; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields; import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.configuration.Configuration; import org.apache.flink.examples.java.clustering.util.KMeansOutlierDetectionData;; /** * This example implements a basic K-Means clustering algorithm. * *

* K-Means is an iterative clustering algorithm and works as follows:
* K-Means is given a set of data points to be clustered and an initial set of * K cluster centers. In each iteration, the algorithm computes the * distance of each data point to each cluster center. Each point is assigned to * the cluster center which is closest to it. Subsequently, each cluster center * is moved to the center (mean) of all points that have been assigned to * it. The moved cluster centers are fed into the next iteration. The algorithm * terminates after a fixed number of iterations (as in this implementation) or * if cluster centers do not (significantly) move in an iteration.
* This is the Wikipedia entry for the * K-Means Clustering * algorithm. * *

* This implementation works on two-dimensional data points.
* It computes an assignment of data points to cluster centers, i.e., each data * point is annotated with the id of the final cluster (center) it belongs to. * *

* Input files are plain text files and must be formatted as follows: *

* *

* Usage: * KMeans <points path> <centers path> <result path> <num iterations> *
* If no parameters are provided, the program is run with default data from * {@link org.apache.flink.examples.java.clustering.util.KMeansData} and 10 * iterations. * *

* This example shows how to use: *

*/ @SuppressWarnings("serial") public class KMeansOutlierDetection { // ************************************************************************* // PROGRAM // ************************************************************************* public static void main(String[] args) throws Exception { if (!parseParameters(args)) { return; } // set up execution environment ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); // get input data DataSet points = getPointDataSet(env); DataSet centroids = getCentroidDataSet(env); // set number of bulk iterations for KMeans algorithm IterativeDataSet loop = centroids.iterate(numIterations); DataSet newCentroids = points // compute closest centroid for each point .map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids") // count and sum point coordinates for each centroid .map(new CountAppender()).groupBy(0).reduce(new CentroidAccumulator()) // compute new centroids from point counts and coordinate sums .map(new CentroidAverager()); // feed new centroids back into next iteration DataSet finalCentroids = loop.closeWith(newCentroids); DataSet> clusteredPoints = points // assign points to final clusters .map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids"); DataSet fElements = env.fromCollection(findOutliers(clusteredPoints, finalCentroids)); // emit result if (fileOutput) { fElements.writeAsCsv(outputPath, "\n", " "); // clusteredPoints.writeAsCsv(outputPath, "\n", " "); // since file sinks are lazy, we trigger the execution explicitly env.execute("KMeans Example"); } else { fElements.print(); } } @SuppressWarnings("rawtypes") public static List findOutliers(DataSet> clusteredPoints, DataSet centroids) throws Exception { List finalElements = new ArrayList(); List> elements = clusteredPoints.collect(); List centroidList = centroids.collect(); List, Double>> elementsWithDistance = new ArrayList, Double>>(); for (Centroid centroid : centroidList) { elementsWithDistance = new ArrayList, Double>>(); double totalDistance = 0; int elementsCount = 0; for (Tuple2 e : elements) { // compute distance if (e.f0 == centroid.id) { Tuple3, Double> newElement = new Tuple3, Double>(); double distance = e.f1.euclideanDistance(centroid); totalDistance += distance; newElement.setFields(centroid, e, distance); elementsWithDistance.add(newElement); elementsCount++; } } // finding mean double mean = totalDistance / elementsCount; double sdTotalDistanceSquare = 0; for (Tuple3, Double> elementWithDistance : elementsWithDistance) { double distanceSquare = Math.pow(mean - elementWithDistance.f2, 2); sdTotalDistanceSquare += distanceSquare; } double sd = Math.sqrt(sdTotalDistanceSquare / elementsCount); double upperlimit = mean + 2 * sd; double lowerlimit = mean - 2 * sd; Tuple3 newElement = new Tuple3();// true // = // outlier for (Tuple3, Double> elementWithDistance : elementsWithDistance) { newElement = new Tuple3(); if (elementWithDistance.f2 < lowerlimit || elementWithDistance.f2 > upperlimit) { // set as outlier newElement.setFields(elementWithDistance.f1.f0, elementWithDistance.f1.f1, true); } else { newElement.setFields(elementWithDistance.f1.f0, elementWithDistance.f1.f1, false); } finalElements.add(newElement); } } return finalElements; } // ************************************************************************* // DATA TYPES // ************************************************************************* /** * A simple two-dimensional point. */ public static class Point implements Serializable { public double x, y; public Point() { } public Point(double x, double y) { this.x = x; this.y = y; } public Point add(Point other) { x += other.x; y += other.y; return this; } public Point div(long val) { x /= val; y /= val; return this; } public double euclideanDistance(Point other) { return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y)); } public void clear() { x = y = 0.0; } @Override public String toString() { return x + " " + y; } } /** * A simple two-dimensional centroid, basically a point with an ID. */ public static class Centroid extends Point { public int id; public Centroid() { } public Centroid(int id, double x, double y) { super(x, y); this.id = id; } public Centroid(int id, Point p) { super(p.x, p.y); this.id = id; } @Override public String toString() { return id + " " + super.toString(); } } // ************************************************************************* // USER FUNCTIONS // ************************************************************************* /** Converts a {@code Tuple2} into a Point. */ @ForwardedFields("0->x; 1->y") public static final class TuplePointConverter implements MapFunction, Point> { @Override public Point map(Tuple2 t) throws Exception { return new Point(t.f0, t.f1); } } /** Converts a {@code Tuple3} into a Centroid. */ @ForwardedFields("0->id; 1->x; 2->y") public static final class TupleCentroidConverter implements MapFunction, Centroid> { @Override public Centroid map(Tuple3 t) throws Exception { return new Centroid(t.f0, t.f1, t.f2); } } /** Determines the closest cluster center for a data point. */ @ForwardedFields("*->1") public static final class SelectNearestCenter extends RichMapFunction> { private Collection centroids; /** * Reads the centroid values from a broadcast variable into a * collection. */ @Override public void open(Configuration parameters) throws Exception { this.centroids = getRuntimeContext().getBroadcastVariable("centroids"); } @Override public Tuple2 map(Point p) throws Exception { double minDistance = Double.MAX_VALUE; int closestCentroidId = -1; // check all cluster centers for (Centroid centroid : centroids) { // compute distance double distance = p.euclideanDistance(centroid); // update nearest cluster if necessary if (distance < minDistance) { minDistance = distance; closestCentroidId = centroid.id; } } // emit a new record with the center id and the data point. return new Tuple2(closestCentroidId, p); } } /** Appends a count variable to the tuple. */ @ForwardedFields("f0;f1") public static final class CountAppender implements MapFunction, Tuple3> { @Override public Tuple3 map(Tuple2 t) { return new Tuple3(t.f0, t.f1, 1L); } } /** Sums and counts point coordinates. */ @ForwardedFields("0") public static final class CentroidAccumulator implements ReduceFunction> { @Override public Tuple3 reduce(Tuple3 val1, Tuple3 val2) { return new Tuple3(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2); } } /** Computes new centroid from coordinate sum and count of points. */ @ForwardedFields("0->id") public static final class CentroidAverager implements MapFunction, Centroid> { @Override public Centroid map(Tuple3 value) { return new Centroid(value.f0, value.f1.div(value.f2)); } } // ************************************************************************* // UTIL METHODS // ************************************************************************* private static boolean fileOutput = false; private static String pointsPath = null; private static String centersPath = null; private static String outputPath = null; private static int numIterations = 10; private static boolean parseParameters(String[] programArguments) { if (programArguments.length > 0) { // parse input arguments fileOutput = true; if (programArguments.length == 4) { pointsPath = programArguments[0]; centersPath = programArguments[1]; outputPath = programArguments[2]; numIterations = Integer.parseInt(programArguments[3]); } else { System.err.println("Usage: KMeans "); return false; } } else { System.out.println("Executing K-Means example with default parameters and built-in default data."); System.out.println(" Provide parameters to read input data from files."); System.out.println(" See the documentation for the correct format of input files."); System.out.println(" We provide a data generator to create synthetic input files for this program."); System.out.println(" Usage: KMeans "); } return true; } private static DataSet getPointDataSet(ExecutionEnvironment env) { if (fileOutput) { // read points from CSV file return env.readCsvFile(pointsPath).fieldDelimiter(" ").includeFields(true, true) .types(Double.class, Double.class).map(new TuplePointConverter()); } else { return KMeansOutlierDetectionData.getDefaultPointDataSet(env); } } private static DataSet getCentroidDataSet(ExecutionEnvironment env) { if (fileOutput) { return env.readCsvFile(centersPath).fieldDelimiter(" ").includeFields(true, true, true) .types(Integer.class, Double.class, Double.class).map(new TupleCentroidConverter()); } else { return KMeansOutlierDetectionData.getDefaultCentroidDataSet(env); } } }