rdd- groupbyKey 와 달리 spark dataset 은 groupby 로 리듀스를 할 수 있는데, 


이때 .agg 형식으로 aggregation 을 한다.


다만  .agg 안에 들어가는 function 은 단한 선언으로 안되고, 

udf (user defined aggregationfuncion) 을 만들어서 구현해주어야 하는데, 동작이 뭔가 까다롭다



일단 shell 이나 main 에서는


아래와 같이 실행 해주면되고

val td = rd.select($"jobid".alias("jobid"), $"time".alias("starttime"), $"time".alias("endtime"), $"result".alias("report_result"), $"result".alias("error"))
//val td = rd.select($"jobid".alias("jobid"),$"time".alias("starttime"))

import CustomAgg._
val minagg = new minAggregation()
val maxagg = new maxAggregation()
val erragg = new errorAggregation()

val td2 = td.groupBy("jobid").agg(minagg(td("starttime")), maxagg(td("endtime")) , erragg(td("error"))) 


UDF 는 아래와 같이 만들어준다.

(max 값찾는 agg / min 값찾는 agg , error 찾는 agg 등이다.)

update 에서는 파티션내의 병합이 

merge 에서는 파티션간의 병합이 이루어지는 것 같다.

ackage CustomAgg

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.mutable.ArrayBuffer

object CustomAggregation
{

class maxAggregation extends UserDefinedAggregateFunction {
// Input Data Type Schema
//def inputSchema: StructType = StructType(Array(StructField("col5", StringType)))

// Intermediate Schema
//def bufferSchema = StructType(Array( StructField("col5_collapsed", StringType)))

def inputSchema = new StructType().add("x", StringType)
def bufferSchema = new StructType().add("buff", ArrayType(StringType))
// Returned Data Type .
def dataType: DataType = StringType

// Self-explaining
def deterministic = true

// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, ArrayBuffer.empty[String])// initialize array
}

// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (!input.isNullAt(0)) {
println(input.getString(0))
println(buffer.getSeq[String](0).mkString)
if(buffer.getSeq[String](0).mkString.isEmpty)
buffer.update(0, Seq(input.getString(0)))
else if(buffer.getSeq[String](0).mkString.compareTo(input.getString(0)) > 0)
buffer.update(0, buffer.getSeq[String](0))
else
buffer.update(0, Seq(input.getString(0)))
println(buffer.getSeq[String](0).mkString)
}
}


// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
if(buffer1.getSeq[String](0).mkString.compareTo(buffer2.getSeq[String](0).mkString) > 0)
buffer1.update(0, buffer1.getSeq[String](0))
else
buffer1.update(0, buffer2.getSeq[String](0))
}

def evaluate(buffer: Row) = UTF8String.fromString(
buffer.getSeq[String](0).mkString(","))
}

class minAggregation extends UserDefinedAggregateFunction {
// Input Data Type Schema
//def inputSchema: StructType = StructType(Array(StructField("col5", StringType)))

// Intermediate Schema
//def bufferSchema = StructType(Array( StructField("col5_collapsed", StringType)))

def inputSchema = new StructType().add("x", StringType)
def bufferSchema = new StructType().add("buff",StringType)
// Returned Data Type .
def dataType: DataType = StringType

// Self-explaining
def deterministic = true

// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, "")// initialize array
}

// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (!input.isNullAt(0)) {

if(buffer.getString(0).isEmpty)
buffer.update(0, input.getString(0))
else if(buffer.getString(0).compareTo(input.getString(0)) < 0)
buffer.update(0, buffer.getString(0))
else
buffer.update(0, input.getString(0))

println( "updated :" + buffer.getString(0))
}
}


// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {

if(buffer1.getString(0).compareTo(buffer2.getString(0))< 0 && !buffer1.getString(0).isEmpty)
buffer1.update(0, buffer1.getString(0))
else
buffer1.update(0, buffer2.getString(0))
}

def evaluate(buffer: Row) = UTF8String.fromString(
buffer.getString(0))
}


class errorAggregation extends UserDefinedAggregateFunction {
// Input Data Type Schema
//def inputSchema: StructType = StructType(Array(StructField("col5", StringType)))

// Intermediate Schema
//def bufferSchema = StructType(Array( StructField("col5_collapsed", StringType)))

def inputSchema = new StructType().add("x", StringType)
def bufferSchema = new StructType().add("buff", ArrayType(StringType))
// Returned Data Type .
def dataType: DataType = StringType

// Self-explaining
def deterministic = true

// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, Seq(""))// initialize array
}

def checkerror (str : String) : Boolean = {

if(str.compareTo("Info") == 0) false
if(str.compareTo("0") == 0) false
if(str.compareTo("Success") == 0) false
if(str.compareTo("") == 0) false
true
}

// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (!input.isNullAt(0)) {
if(checkerror(input.getString(0)))
{
input.getString(0)
buffer.update(0, Seq(input.getString(0)))
}
}
}


// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
println(buffer1.getSeq[String](0).toString())
println(buffer2.getSeq[String](0).toString())
buffer1.update(0, buffer1.getSeq[String](0) ++ buffer2.getSeq[String](0))
}

def evaluate(buffer: Row) = UTF8String.fromString(
buffer.getSeq[String](0).mkString(","))
}


class GeometricMean extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", StringType) :: Nil)

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("count", LongType) ::
StructField("product", StringType) :: Nil
)

// This is the output type of your aggregatation function.
override def dataType: DataType = StringType

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
buffer(1) = " "
}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[String](0) + 1
buffer(1) = buffer.getAs[String](1) + input.getAs[String](0)
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
}
}

} 


+ Recent posts