The question is cross-posted on Stack Overflow https://stackoverflow.com/questions/67092978/pyflink-vectorized-udf-throws-nullpointerexception.
I have a ML model that takes two numpy.ndarray - `users` and `items` - and returns an numpy.ndarray `predictions`. In normal Python code, I would do: ```python model = load_model() df = load_data() # the DataFrame includes 4 columns, namely, user_id, movie_id, rating, and timestamp users = df.user_id.values items = df.movie_id.values predictions = model(users, items) ``` I am looking into porting this code into Flink to leverage its distributed nature. My assumption is: by distributing the prediction workload on multiple Flink nodes, I should be able to run the whole prediction faster. So I compose a PyFlink job. Note I implement an UDF called `predict` to run the prediction. ```python # batch_prediction.py model = load_model() settings = EnvironmentSettings.new_instance().use_blink_planner().build() exec_env = StreamExecutionEnvironment.get_execution_environment() t_env = StreamTableEnvironment.create(exec_env, environment_settings=settings) SOURCE_DDL = """ CREATE TABLE source ( user_id INT, movie_id INT, rating TINYINT, event_ms BIGINT ) WITH ( 'connector' = 'filesystem', 'format' = 'csv', 'csv.field-delimiter' = '\t', 'path' = 'ml-100k/u1.test' ) """ SINK_DDL = """ CREATE TABLE sink ( prediction DOUBLE ) WITH ( 'connector' = 'print' ) """ t_env.execute_sql(SOURCE_DDL) t_env.execute_sql(SINK_DDL) t_env.execute_sql( "INSERT INTO sink SELECT PREDICT(user_id, movie_id) FROM source" ).wait() ``` Here is the UDF. ```python # batch_prediction.py (cont) @udf(result_type=DataTypes.DOUBLE()) def predict(user, item): return model([user], [item]).item() t_env.create_temporary_function("predict", predict) ``` The job runs fine. However, the prediction actually runs on each and every row of the `source` table, which is not performant. Instead, I want to split the 80,000 (user_id, movie_id) pairs into, let's say, 100 batches, with each batch having 800 rows. The job triggers the `model(users, items)` function 100 times (= # of batch), where both `users` and `items` have 800 elements. I couldn't find a way to do this. By looking at the [docs](https://ci.apache.org/projects/flink/flink-docs-stable/dev/python/table-api-users-guide/udfs/vectorized_python_udfs.html), vectorized user-defined functions may work. ```python # batch_prediction.py (snippet) # I add the func_type="pandas" @udf(result_type=DataTypes.DOUBLE(), func_type="pandas") def predict(user, item): ... ``` Unfortunately, it doesn't. ``` > python batch_prediction.py ... Traceback (most recent call last): File "batch_prediction.py", line 55, in <module> "INSERT INTO sink SELECT PREDICT(user_id, movie_id) FROM source" File "/usr/local/anaconda3/envs/flink-ml/lib/python3.7/site-packages/pyflink/table/table_result.py", line 76, in wait get_method(self._j_table_result, "await")() File "/usr/local/anaconda3/envs/flink-ml/lib/python3.7/site-packages/py4j/java_gateway.py", line 1286, in __call__ answer, self.gateway_client, self.target_id, self.name) File "/usr/local/anaconda3/envs/flink-ml/lib/python3.7/site-packages/pyflink/util/exceptions.py", line 147, in deco return f(*a, **kw) File "/usr/local/anaconda3/envs/flink-ml/lib/python3.7/site-packages/py4j/protocol.py", line 328, in get_return_value format(target_id, ".", name), value) py4j.protocol.Py4JJavaError: An error occurred while calling o51.await. : java.util.concurrent.ExecutionException: org.apache.flink.table.api.TableException: Failed to wait job finish at java.util.concurrent.CompletableFuture.reportGet(CompletableFuture.java:357) at java.util.concurrent.CompletableFuture.get(CompletableFuture.java:1908) at org.apache.flink.table.api.internal.TableResultImpl.awaitInternal(TableResultImpl.java:119) at org.apache.flink.table.api.internal.TableResultImpl.await(TableResultImpl.java:86) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.flink.api.python.shaded.py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at org.apache.flink.api.python.shaded.py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at org.apache.flink.api.python.shaded.py4j.Gateway.invoke(Gateway.java:282) at org.apache.flink.api.python.shaded.py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at org.apache.flink.api.python.shaded.py4j.commands.CallCommand.execute(CallCommand.java:79) at org.apache.flink.api.python.shaded.py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.flink.table.api.TableException: Failed to wait job finish at org.apache.flink.table.api.internal.InsertResultIterator.hasNext(InsertResultIterator.java:59) at org.apache.flink.table.api.internal.TableResultImpl$CloseableRowIteratorWrapper.hasNext(TableResultImpl.java:355) at org.apache.flink.table.api.internal.TableResultImpl$CloseableRowIteratorWrapper.isFirstRowReady(TableResultImpl.java:368) at org.apache.flink.table.api.internal.TableResultImpl.lambda$awaitInternal$1(TableResultImpl.java:107) at java.util.concurrent.CompletableFuture$AsyncRun.run(CompletableFuture.java:1640) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.util.concurrent.ExecutionException: org.apache.flink.runtime.client.JobExecutionException: Job execution failed. at java.util.concurrent.CompletableFuture.reportGet(CompletableFuture.java:357) at java.util.concurrent.CompletableFuture.get(CompletableFuture.java:1908) at org.apache.flink.table.api.internal.InsertResultIterator.hasNext(InsertResultIterator.java:57) ... 7 more Caused by: org.apache.flink.runtime.client.JobExecutionException: Job execution failed. at org.apache.flink.runtime.jobmaster.JobResult.toJobExecutionResult(JobResult.java:147) at org.apache.flink.runtime.minicluster.MiniClusterJobClient.lambda$getJobExecutionResult$2(MiniClusterJobClient.java:119) at java.util.concurrent.CompletableFuture.uniApply(CompletableFuture.java:616) at java.util.concurrent.CompletableFuture$UniApply.tryFire(CompletableFuture.java:591) at java.util.concurrent.CompletableFuture.postComplete(CompletableFuture.java:488) at java.util.concurrent.CompletableFuture.complete(CompletableFuture.java:1975) at org.apache.flink.runtime.rpc.akka.AkkaInvocationHandler.lambda$invokeRpc$0(AkkaInvocationHandler.java:229) at java.util.concurrent.CompletableFuture.uniWhenComplete(CompletableFuture.java:774) at java.util.concurrent.CompletableFuture$UniWhenComplete.tryFire(CompletableFuture.java:750) at java.util.concurrent.CompletableFuture.postComplete(CompletableFuture.java:488) at java.util.concurrent.CompletableFuture.complete(CompletableFuture.java:1975) at org.apache.flink.runtime.concurrent.FutureUtils$1.onComplete(FutureUtils.java:996) at akka.dispatch.OnComplete.internal(Future.scala:264) at akka.dispatch.OnComplete.internal(Future.scala:261) at akka.dispatch.japi$CallbackBridge.apply(Future.scala:191) at akka.dispatch.japi$CallbackBridge.apply(Future.scala:188) at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:36) at org.apache.flink.runtime.concurrent.Executors$DirectExecutionContext.execute(Executors.java:74) at scala.concurrent.impl.CallbackRunnable.executeWithValue(Promise.scala:44) at scala.concurrent.impl.Promise$DefaultPromise.tryComplete(Promise.scala:252) at akka.pattern.PromiseActorRef.$bang(AskSupport.scala:572) at akka.pattern.PipeToSupport$PipeableFuture$$anonfun$pipeTo$1.applyOrElse(PipeToSupport.scala:22) at akka.pattern.PipeToSupport$PipeableFuture$$anonfun$pipeTo$1.applyOrElse(PipeToSupport.scala:21) at scala.concurrent.Future$$anonfun$andThen$1.apply(Future.scala:436) at scala.concurrent.Future$$anonfun$andThen$1.apply(Future.scala:435) at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:36) at akka.dispatch.BatchingExecutor$AbstractBatch.processBatch(BatchingExecutor.scala:55) at akka.dispatch.BatchingExecutor$BlockableBatch$$anonfun$run$1.apply$mcV$sp(BatchingExecutor.scala:91) at akka.dispatch.BatchingExecutor$BlockableBatch$$anonfun$run$1.apply(BatchingExecutor.scala:91) at akka.dispatch.BatchingExecutor$BlockableBatch$$anonfun$run$1.apply(BatchingExecutor.scala:91) at scala.concurrent.BlockContext$.withBlockContext(BlockContext.scala:72) at akka.dispatch.BatchingExecutor$BlockableBatch.run(BatchingExecutor.scala:90) at akka.dispatch.TaskInvocation.run(AbstractDispatcher.scala:40) at akka.dispatch.ForkJoinExecutorConfigurator$AkkaForkJoinTask.exec(ForkJoinExecutorConfigurator.scala:44) at akka.dispatch.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:260) at akka.dispatch.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:1339) at akka.dispatch.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979) at akka.dispatch.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107) Caused by: org.apache.flink.runtime.JobException: Recovery is suppressed by NoRestartBackoffTimeStrategy at org.apache.flink.runtime.executiongraph.failover.flip1.ExecutionFailureHandler.handleFailure(ExecutionFailureHandler.java:116) at org.apache.flink.runtime.executiongraph.failover.flip1.ExecutionFailureHandler.getFailureHandlingResult(ExecutionFailureHandler.java:78) at org.apache.flink.runtime.scheduler.DefaultScheduler.handleTaskFailure(DefaultScheduler.java:224) at org.apache.flink.runtime.scheduler.DefaultScheduler.maybeHandleTaskFailure(DefaultScheduler.java:217) at org.apache.flink.runtime.scheduler.DefaultScheduler.updateTaskExecutionStateInternal(DefaultScheduler.java:208) at org.apache.flink.runtime.scheduler.SchedulerBase.updateTaskExecutionState(SchedulerBase.java:610) at org.apache.flink.runtime.scheduler.SchedulerNG.updateTaskExecutionState(SchedulerNG.java:89) at org.apache.flink.runtime.jobmaster.JobMaster.updateTaskExecutionState(JobMaster.java:419) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.flink.runtime.rpc.akka.AkkaRpcActor.handleRpcInvocation(AkkaRpcActor.java:286) at org.apache.flink.runtime.rpc.akka.AkkaRpcActor.handleRpcMessage(AkkaRpcActor.java:201) at org.apache.flink.runtime.rpc.akka.FencedAkkaRpcActor.handleRpcMessage(FencedAkkaRpcActor.java:74) at org.apache.flink.runtime.rpc.akka.AkkaRpcActor.handleMessage(AkkaRpcActor.java:154) at akka.japi.pf.UnitCaseStatement.apply(CaseStatements.scala:26) at akka.japi.pf.UnitCaseStatement.apply(CaseStatements.scala:21) at scala.PartialFunction$class.applyOrElse(PartialFunction.scala:123) at akka.japi.pf.UnitCaseStatement.applyOrElse(CaseStatements.scala:21) at scala.PartialFunction$OrElse.applyOrElse(PartialFunction.scala:170) at scala.PartialFunction$OrElse.applyOrElse(PartialFunction.scala:171) at scala.PartialFunction$OrElse.applyOrElse(PartialFunction.scala:171) at akka.actor.Actor$class.aroundReceive(Actor.scala:517) at akka.actor.AbstractActor.aroundReceive(AbstractActor.scala:225) at akka.actor.ActorCell.receiveMessage(ActorCell.scala:592) at akka.actor.ActorCell.invoke(ActorCell.scala:561) at akka.dispatch.Mailbox.processMailbox(Mailbox.scala:258) at akka.dispatch.Mailbox.run(Mailbox.scala:225) at akka.dispatch.Mailbox.exec(Mailbox.scala:235) ... 4 more Caused by: org.apache.flink.streaming.runtime.tasks.AsynchronousException: Caught exception while processing timer. at org.apache.flink.streaming.runtime.tasks.StreamTask$StreamTaskAsyncExceptionHandler.handleAsyncException(StreamTask.java:1108) at org.apache.flink.streaming.runtime.tasks.StreamTask.handleAsyncException(StreamTask.java:1082) at org.apache.flink.streaming.runtime.tasks.StreamTask.invokeProcessingTimeCallback(StreamTask.java:1213) at org.apache.flink.streaming.runtime.tasks.StreamTask.lambda$null$17(StreamTask.java:1202) at org.apache.flink.streaming.runtime.tasks.StreamTaskActionExecutor$SynchronizedStreamTaskActionExecutor.runThrowing(StreamTaskActionExecutor.java:92) at org.apache.flink.streaming.runtime.tasks.mailbox.Mail.run(Mail.java:78) at org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl.tryYield(MailboxExecutorImpl.java:91) at org.apache.flink.streaming.runtime.tasks.StreamOperatorWrapper.quiesceTimeServiceAndCloseOperator(StreamOperatorWrapper.java:155) at org.apache.flink.streaming.runtime.tasks.StreamOperatorWrapper.close(StreamOperatorWrapper.java:130) at org.apache.flink.streaming.runtime.tasks.OperatorChain.closeOperators(OperatorChain.java:412) at org.apache.flink.streaming.runtime.tasks.StreamTask.afterInvoke(StreamTask.java:585) at org.apache.flink.streaming.runtime.tasks.StreamTask.invoke(StreamTask.java:547) at org.apache.flink.runtime.taskmanager.Task.doRun(Task.java:722) at org.apache.flink.runtime.taskmanager.Task.run(Task.java:547) at java.lang.Thread.run(Thread.java:748) Caused by: TimerException{java.lang.RuntimeException: Failed to close remote bundle} ... 13 more Caused by: java.lang.RuntimeException: Failed to close remote bundle at org.apache.flink.streaming.api.runners.python.beam.BeamPythonFunctionRunner.finishBundle(BeamPythonFunctionRunner.java:371) at org.apache.flink.streaming.api.runners.python.beam.BeamPythonFunctionRunner.flush(BeamPythonFunctionRunner.java:325) at org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator.invokeFinishBundle(AbstractPythonFunctionOperator.java:291) at org.apache.flink.table.runtime.operators.python.scalar.arrow.RowDataArrowPythonScalarFunctionOperator.invokeFinishBundle(RowDataArrowPythonScalarFunctionOperator.java:77) at org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator.checkInvokeFinishBundleByTime(AbstractPythonFunctionOperator.java:285) at org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator.lambda$open$0(AbstractPythonFunctionOperator.java:134) at org.apache.flink.streaming.runtime.tasks.StreamTask.invokeProcessingTimeCallback(StreamTask.java:1211) ... 12 more Caused by: java.lang.NullPointerException at org.apache.flink.streaming.api.runners.python.beam.BeamPythonFunctionRunner.finishBundle(BeamPythonFunctionRunner.java:369) ... 18 more ``` The error messages are not very helpful. Can anyone help? Thanks! Note: source code can be found [here](https://github.com/YikSanChan/flink-torch/tree/83ea0510172db3d7ff33db19883150f2fe5c1f43). To run the code, you will need Anaconda locally, then: ``` conda env create -f environment.yml conda activate flink-ml ``` Best, Yik San |
Hi Yik San, 1) There are two kinds of Python UDFs in PyFlink: - General Python UDFs which process input elements at row basis. That is, it will process one row at a time. - Pandas UDFs which process input elements at batch basis. So you are correct that you need to use Pandas UDF for your requirements. 2) For Pandas UDF, the input type for each input argument is Pandas.Series and the result type should also be a Pandas.Series. Besides, the length of the result should be the same as the inputs. Could you check if this is the case for your Pandas UDF implementation? Regards, Dian On Wed, Apr 14, 2021 at 9:44 PM Yik San Chan <[hidden email]> wrote:
|
Hi Dian, Thanks for the reminder. Yes, the original udf implementation does not qualify the input and output type requirement. After adding a unit test, I was able to find what's wrong, and fix my UDF implementation. Here is the new implementation FYI. @udf(result_type=DataTypes.DOUBLE(), func_type="pandas") def predict(users, items): n_users, n_items = 943, 1682 model = MatrixFactorization(n_users, n_items) model.load_state_dict(torch.load("model.pth")) return pd.Series(model(users, items).detach().numpy())
And here is the unit test.def test_predict(): f = predict._func users = pd.Series([1, 2, 3]) items = pd.Series([1, 4, 9]) preds = f(users, items) assert isinstance(preds, pd.Series) assert len(preds) == 3 Thank you so much! Best, Yik San On Wed, Apr 14, 2021 at 11:03 PM Dian Fu <[hidden email]> wrote:
|
Great! Thanks for letting me know~
|
Hi Dian, I wonder if we can improve the error tracing and message so that it becomes more obvious where the problem is? To me, a NPE really says very little. Best, Yik San On Thu, Apr 15, 2021 at 11:07 AM Dian Fu <[hidden email]> wrote:
|
Definitely agree with you. Have created https://issues.apache.org/jira/browse/FLINK-22297 as a following up.
|
Hi Dian, Thank you so much for tracking the issue! I run into another NullPointerException when running pandas UDF, but this time I add an unit test to ensure the input and output type already ... And the new issue looks even more odd ... Do you mind taking a look? http://apache-flink-user-mailing-list-archive.2336050.n4.nabble.com/PyFlink-called-already-closed-and-NullPointerException-td42997.html Thank you! Best, Yik San On Fri, Apr 16, 2021 at 11:05 AM Dian Fu <[hidden email]> wrote:
|
Sure. I have replied. Let’s discuss it in that thread.
|
Free forum by Nabble | Edit this page |