Calculating stateful counts per key with checkpoints

classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

Calculating stateful counts per key with checkpoints

hkmaki
Hi,

I'm new to Flink and I'm trying to write a program calculating stateful counts per key with checkpoints. I would expect my test program following the Checkpointing Instance Fields example to calculate the counts per key, but it seems to group the data by task rather than by key.

I'm able to calculate the counts by key using flatMapWithState, but I don't know how to use checkpoints with it. It's not at all obvious to me why the behavior is different from the "Checkpointing Instance Fields" case.

Is there a way to use checkpoints with stateful counts per key?

The output of inputData.keyBy(0).flatMap(new TestCounters).print is

1> (A,count=1)
1> (F,count=2)
2> (B,count=1)
2> (C,count=2)
2> (D,count=3)
2> (E,count=4)
2> (E,count=5)
2> (E,count=6)
2> (H,count=7)
4> (G,count=1)

while the output of inputData.keyBy(0).flatMapWithState(...).print is

2> (B,1)
4> (G,1)
1> (A,1)
2> (C,1)
1> (F,1)
2> (D,1)
2> (E,1)
2> (E,2)
2> (E,3)
2> (H,1)

The full code:

import org.apache.flink.api.common.functions.RichFlatMapFunction
import org.apache.flink.streaming.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.checkpoint.Checkpointed
import org.apache.flink.util.Collector

object FlinkStreamingTest {

  def main(args: Array[String]) {

    val env = StreamExecutionEnvironment.createLocalEnvironment()

    val checkpointIntervalMillis = 10000
    env.enableCheckpointing(checkpointIntervalMillis)

    val inputData = env.fromElements(("A",0),("B",0),("C",0),("D",0),
      ("E",0),("E",0),("E",0),
      ("F",0),("G",0),("H",0))

    inputData.keyBy(0).flatMap(new TestCounters).print

    /*
    inputData.keyBy(0).flatMapWithState((keyAndCount: (String, Int), count: Option[Int]) =>
      count match {
        case None => (Iterator((keyAndCount._1, 1)), Some(1))
        case Some(c) => (Iterator((keyAndCount._1, c+1)), Some(c+1))
      }).print
    */

    env.execute("Counters test")
  }
}

case class CounterClass(var count: Int)

class TestCounters extends RichFlatMapFunction[(String, Int), (String, String)] with Checkpointed[CounterClass] {

  var counterValue: CounterClass = null

  override def flatMap(in: (String, Int), out: Collector[(String, String)]) = {
    counterValue.count = counterValue.count + 1
    out.collect((in._1,"count="+counterValue.count))
  }

  override def open(config: Configuration): Unit = {
    if(counterValue == null) {
      counterValue = new CounterClass(0)
    }
  }

  override def snapshotState(l: Long, l1: Long): CounterClass = {
    counterValue
  }

  override def restoreState(state: CounterClass): Unit = {
    counterValue = state
  }
}