spark.version
2.4.5
import java.sql.Date
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class DiffCount(val threshold:Double) extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("sms", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("count", LongType) ::
StructField("base_string", StringType) ::
StructField("match_count", LongType) :: Nil
)
def similarity(str1: String, str2: String): Double = {
if (str1 == str2) 1.0 else 0.0
}
// This is the output type of your aggregatation function.
override def dataType: DataType = LongType
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = ""
buffer(2) = 0L
}
override def deterministic: Boolean = true
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val newString = input.getString(0)
if(buffer.getAs[Long](0) == 0L){
buffer(1) = newString
buffer(0) = 1L
}else{
val baseString = buffer.getAs[String](1)
val simil = similarity(baseString, newString)
buffer(2) = buffer.getAs[Long](2) + simil.toLong
buffer(0) = buffer.getAs[Long](0) + 1
}
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = buffer.getLong(2)
}
defined class DiffCount
val devicesDf = Seq(
("notebook"),
("notebook"),
("small phone"),
("camera"),
("small phone"),
("large phone"),
("camera"),
("small phone")
).toDF("sms")
devicesDf = [sms: string]
[sms: string]
val diff_count = new DiffCount(0.0)
devicesDf
.withColumn("diff_count",diff_count(col("sms")) over Window.rowsBetween(0,Window.unboundedFollowing)).show()
+-----------+----------+ | sms|diff_count| +-----------+----------+ | notebook| 1| | notebook| 0| |small phone| 2| | camera| 1| |small phone| 1| |large phone| 0| | camera| 0| |small phone| 0| +-----------+----------+
diff_count = DiffCount@51b29ee9
DiffCount@51b29ee9