I am trying to create a user-defined aggregate function that I can call from python. I tried to answer this question. I basically implemented the following (taken from here ):
package com.blu.bla; import java.util.ArrayList; import java.util.List; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.Row; public class MySum extends UserDefinedAggregateFunction { private StructType _inputDataType; private StructType _bufferSchema; private DataType _returnDataType; public MySum() { List<StructField> inputFields = new ArrayList<StructField>(); inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); _inputDataType = DataTypes.createStructType(inputFields); List<StructField> bufferFields = new ArrayList<StructField>(); bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); _bufferSchema = DataTypes.createStructType(bufferFields); _returnDataType = DataTypes.DoubleType; } @Override public StructType inputSchema() { return _inputDataType; } @Override public StructType bufferSchema() { return _bufferSchema; } @Override public DataType dataType() { return _returnDataType; } @Override public boolean deterministic() { return true; } @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, null); } @Override public void update(MutableAggregationBuffer buffer, Row input) { if (!input.isNullAt(0)) { if (buffer.isNullAt(0)) { buffer.update(0, input.getDouble(0)); } else { Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); } } } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { buffer1.update(0, buffer2.getDouble(0)); } else { Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); } } } @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { return null; } else { return buffer.getDouble(0); } } }
Then I compiled it with all the dependencies and ran pyspark using -jars myjar.jar
In pyspark, I did:
df = sqlCtx.createDataFrame([(1.0, "a"), (2.0, "b"), (3.0, "C")], ["A", "B"]) from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql import Row def myCol(col): _f = sc._jvm.com.blu.bla.MySum.apply return Column(_f(_to_seq(sc,[col], _to_java_column))) b = df.agg(myCol("A"))
I got the following error:
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-24-f45b2a367e67> in <module>() ----> 1 b = df.agg(myCol("A")) <ipython-input-22-afcb8884e1db> in myCol(col) 4 def myCol(col): 5 _f = sc._jvm.com.blu.bla.MySum.apply ----> 6 return Column(_f(_to_seq(sc,[col], _to_java_column))) TypeError: 'JavaPackage' object is not callable
I also tried adding the -driver-class-path path to the pyspark call, but got the same result.
Also tried to access the java class through java import:
from py4j.java_gateway import java_import jvm = sc._gateway.jvm java_import(jvm, "com.bla.blu.MySum") def myCol2(col): _f = jvm.bla.blu.MySum.apply return Column(_f(_to_seq(sc,[col], _to_java_column)))
Also tried just creating a class (as suggested here ):
a = jvm.com.bla.blu.MySum()
Everyone gets the same error message.
I canβt understand what the problem is.