spark-user mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From Yong Zhang <java8...@hotmail.com>
Subject Re: Spark dataframe, UserDefinedAggregateFunction(UDAF) help!!
Date Fri, 24 Mar 2017 01:00:57 GMT
Change:

val arrayinput = input.getAs[Array[String]](0)

to:

val arrayinput = input.getAs[Seq[String]](0)


Yong


________________________________
From: shyla deshpande <deshpandeshyla@gmail.com>
Sent: Thursday, March 23, 2017 8:18 PM
To: user
Subject: Spark dataframe, UserDefinedAggregateFunction(UDAF) help!!

This is my input data. The UDAF needs to aggregate the goals for a team and return a map that
 gives the count for every goal in the team.
I am getting the following error

java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to
[Ljava.lang.String;
at com.whil.common.GoalAggregator.update(GoalAggregator.scala:27)

+------+--------------+
|teamid|goals         |
+------+--------------+
|t1    |[Goal1, Goal2]|
|t1    |[Goal1, Goal3]|
|t2    |[Goal1, Goal2]|
|t3    |[Goal2, Goal3]|
+------+--------------+

root
 |-- teamid: string (nullable = true)
 |-- goals: array (nullable = true)
 |    |-- element: string (containsNull = true)

/////////////////////////Calling the UDAF//////////

object TestUDAF {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession
      .builder
      .getOrCreate()

    val sc: SparkContext = spark.sparkContext
    val sqlContext = spark.sqlContext

    import sqlContext.implicits._

    val data = Seq(
      ("t1", Seq("Goal1", "Goal2")),
      ("t1", Seq("Goal1", "Goal3")),
      ("t2", Seq("Goal1", "Goal2")),
      ("t3", Seq("Goal2", "Goal3"))).toDF("teamid","goals")

    data.show(truncate = false)
    data.printSchema()

    import spark.implicits._

    val sumgoals = new GoalAggregator
    val result = data.groupBy("teamid").agg(sumgoals(col("goals")))

    result.show(truncate = false)

  }
}

///////////////UDAF/////////////////

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class GoalAggregator extends UserDefinedAggregateFunction{

  override def inputSchema: org.apache.spark.sql.types.StructType =
  StructType(StructField("value", ArrayType(StringType)) :: Nil)

  override def bufferSchema: StructType = StructType(
      StructField("combined", MapType(StringType,IntegerType)) :: Nil
  )

  override def dataType: DataType = MapType(StringType,IntegerType)

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, Map[String, Integer]())
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val mapbuf = buffer.getAs[Map[String, Int]](0)
    val arrayinput = input.getAs[Array[String]](0)
    val result = mapbuf ++ arrayinput.map(goal => {
      val cnt  = mapbuf.get(goal).getOrElse(0) + 1
      goal -> cnt
    })
    buffer.update(0, result)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Int]](0)
    val map2 = buffer2.getAs[Map[String, Int]](0)
    val result = map1 ++ map2.map { case (k,v) =>
      val cnt = map1.get(k).getOrElse(0) + 1
      k -> cnt
    }
    buffer1.update(0, result)
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Map[String, Int]](0)
  }
}



Mime
View raw message