Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NU-1921] Add median aggregator #7321

Open
wants to merge 10 commits into
base: add-standard-deviation-and-variance-aggregations
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
* [#7184](https://github.com/TouK/nussknacker/pull/7184) Improve Nu Designer API notifications endpoint, to include events related to currently displayed scenario
* [#7323](https://github.com/TouK/nussknacker/pull/7323) Improve Periodic DeploymentManager db queries
* [#7307](https://github.com/TouK/nussknacker/pull/7307) Added StddevPop, StddevSamp, VarPop and VarSamp aggregators
* [#7321](https://github.com/TouK/nussknacker/pull/7321) Added Median aggregator

## 1.18

Expand Down
1 change: 1 addition & 0 deletions docs/scenarios_authoring/AggregatesInTimeWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Let’s map the above statement on the parameters of the Nussknacker Aggregate c
* StddevSamp - computes sample standard deviation
* VarPop - computes population variance
* VarSamp - computes sample variance
* Median - computes median
jedrz marked this conversation as resolved.
Show resolved Hide resolved
* ApproximateSetCardinality - computes approximate cardinality of a set using [HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) algorithm. Please note that this aggregator treats null as a unique value. If this is undesirable and the set passed to ApproximateSetCardinality aggregator contained null (this can be tested with safe navigation in [SpEL](./Spel.md#safe-navigation)), subtract 1 from the obtained result.

If you need to count events in a window, use the CountWhen aggregate function and aggregate by fixed `true` expression - see the table with examples below. Subsequent sections use the Count function on the diagrams as an example for the **aggregator** - it is the easiest function to use in the examples. Please note, however, that technically, we provide an indirect implementation of this aggregator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,24 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypedObjectTypingResult, TypingResult, Unknown}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregatesSpec.{EPS_BIG_DECIMAL, EPS_DOUBLE}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{AverageAggregator, CountWhenAggregator, FirstAggregator, LastAggregator, ListAggregator, MapAggregator, MaxAggregator, MinAggregator, OptionAggregator, PopulationStandardDeviationAggregator, PopulationVarianceAggregator, SampleStandardDeviationAggregator, SampleVarianceAggregator, SetAggregator, SumAggregator}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.aggregates.{
AverageAggregator,
CountWhenAggregator,
FirstAggregator,
LastAggregator,
ListAggregator,
MapAggregator,
MaxAggregator,
MedianAggregator,
MinAggregator,
OptionAggregator,
PopulationStandardDeviationAggregator,
PopulationVarianceAggregator,
SampleStandardDeviationAggregator,
SampleVarianceAggregator,
SetAggregator,
SumAggregator
}
import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap

import java.lang.{Integer => JInt, Long => JLong}
Expand Down Expand Up @@ -127,7 +144,8 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat
MinAggregator,
FirstAggregator,
LastAggregator,
SumAggregator
SumAggregator,
MedianAggregator
)) { agg =>
addElementsAndComputeResult(List(null), agg) shouldEqual null
}
Expand All @@ -148,6 +166,62 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat
}
}

test("should calculate correct results for median aggregator on integers") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(7, 8), agg)
result shouldBe a[Double]
result shouldEqual 7.5
}

test("should calculate correct results for median aggregator on integers on single value") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(7), agg)
result shouldBe a[Double]
result shouldEqual 7
}

test("should calculate correct results for median aggregator on BigInt") {
val agg = MedianAggregator
addElementsAndComputeResult(List(new BigInteger("7"), new BigInteger("8")), agg) shouldEqual new java.math.BigDecimal("7.5")
}

test("should calculate correct results for median aggregator on floats") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(7.0f, 8.0f), agg)
result shouldBe a[Double]
result shouldEqual 7.5
}

test("should calculate correct results for median aggregator on BigDecimals") {
val agg = MedianAggregator
addElementsAndComputeResult(
List(new java.math.BigDecimal("7"), new java.math.BigDecimal("8")),
agg
) shouldEqual new java.math.BigDecimal("7.5")
}

test("should ignore nulls for median aggregator") {
val agg = MedianAggregator
addElementsAndComputeResult(
List(null, new java.math.BigDecimal("7"), null, new java.math.BigDecimal("8")),
agg
) shouldEqual new java.math.BigDecimal("7.5")
}

test("MedianAggregator test on odd length list") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5, 90), agg)

result shouldEqual 5
}

test("MedianAggregator test on even length list") {
val agg = MedianAggregator
val result = addElementsAndComputeResult(List(80, 70, 3, 1, 4, 60, 2, 5), agg)

result shouldEqual 4.5
}

test("should calculate correct results for standard deviation and variance on integers") {
val table = Table(
("aggregator", "value"),
Expand Down Expand Up @@ -230,7 +304,8 @@ class AggregatesSpec extends AnyFunSuite with TableDrivenPropertyChecks with Mat
( SumAggregator, 15.0),
( MaxAggregator, 5.0),
( MinAggregator, 1.0),
( AverageAggregator, 3.0)
( AverageAggregator, 3.0),
( MedianAggregator, 3.0)
)

forAll(table) { (agg, expectedResult) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
validateOk("#AGG.varPop", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal])
validateOk("#AGG.varSamp", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal])

validateOk("#AGG.median", """#input.eId""", Typed[Double])
validateOk("#AGG.median", """1""", Typed[Double])
validateOk("#AGG.median", """1.5""", Typed[Double])

validateOk("#AGG.median", """T(java.math.BigInteger).ONE""", Typed[java.math.BigDecimal])
validateOk("#AGG.median", """T(java.math.BigDecimal).ONE""", Typed[java.math.BigDecimal])

validateOk("#AGG.set", "#input.str", Typed.fromDetailedType[java.util.Set[String]])
validateOk(
"#AGG.map({f1: #AGG.sum, f2: #AGG.set})",
Expand All @@ -106,6 +113,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
validateError("#AGG.sum", "#input.str", "Invalid aggregate type: String, should be: Number")
validateError("#AGG.countWhen", "#input.str", "Invalid aggregate type: String, should be: Boolean")
validateError("#AGG.average", "#input.str", "Invalid aggregate type: String, should be: Number")
validateError("#AGG.median", "#input.str", "Invalid aggregate type: String, should be: Number")

validateError("#AGG.stddevPop", "#input.str", "Invalid aggregate type: String, should be: Number")
validateError("#AGG.stddevSamp", "#input.str", "Invalid aggregate type: String, should be: Number")
Expand Down Expand Up @@ -170,6 +178,18 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
aggregateVariables shouldBe List(1.0d, 1.5, 3.5)
}

test("median aggregate") {
val id = "1"

val model =
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b")))
val testProcess = sliding("#AGG.median", "#input.eId", emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
aggregateVariables shouldBe List(1.0d, 1.5, 3.5)
}


test("standard deviation and average aggregates") {
val table = Table(
("aggregate", "secondValue"),
Expand Down Expand Up @@ -455,6 +475,19 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true
}

test("emit aggregate for extra window when no data come for median aggregator for return type double") {
val id = "1"

val model =
modelData(List(TestRecordHours(id, 0, 1, "a")))
val testProcess = tumbling("#AGG.median", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
aggregateVariables.length shouldEqual (2)
aggregateVariables(0) shouldEqual 1.0
require((aggregateVariables(1).asInstanceOf[Double].isNaN))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use matcher :)

}

test(
"emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type double"
) {
Expand Down Expand Up @@ -492,6 +525,18 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null)
}

test("emit aggregate for extra window when no data come for median aggregator for return type BigDecimal") {
val id = "1"

val model =
modelData(List(TestRecordHours(id, 0, 1, "a")))
val testProcess =
tumbling("#AGG.median", """T(java.math.BigDecimal).ONE""", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null)
}

test(
"emit aggregate for extra window when no data come for standard deviation and variance aggregator for return type BigDecimal"
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class AggregateHelper implements Serializable {
new FixedExpressionValue("#AGG.stddevSamp", "StddevSamp"),
new FixedExpressionValue("#AGG.varPop", "VarPop"),
new FixedExpressionValue("#AGG.varSamp", "VarSamp"),
new FixedExpressionValue("#AGG.median", "Median"),
new FixedExpressionValue("#AGG.min", "Min"),
new FixedExpressionValue("#AGG.max", "Max"),
new FixedExpressionValue("#AGG.sum", "Sum"),
Expand All @@ -54,6 +55,7 @@ public class AggregateHelper implements Serializable {
private static final Aggregator STDDEV_SAMP = aggregates.SampleStandardDeviationAggregator$.MODULE$;
private static final Aggregator VAR_POP = aggregates.PopulationVarianceAggregator$.MODULE$;
private static final Aggregator VAR_SAMP = aggregates.SampleVarianceAggregator$.MODULE$;
private static final Aggregator MEDIAN = aggregates.MedianAggregator$.MODULE$;
private static final Aggregator APPROX_CARDINALITY = HyperLogLogPlusAggregator$.MODULE$.apply(5, 10);

// Instance methods below are for purpose of using in SpEL so your IDE can report that they are not used.
Expand All @@ -80,6 +82,8 @@ public class AggregateHelper implements Serializable {
public Aggregator varPop = VAR_POP;
public Aggregator varSamp = VAR_SAMP;

public Aggregator median = MEDIAN;

public Aggregator approxCardinality = APPROX_CARDINALITY;

public Aggregator map(@ParamName("parts") Map<String, Aggregator> parts) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ import cats.data.{NonEmptyList, Validated}
import cats.instances.list._
import org.apache.flink.api.common.typeinfo.TypeInfo
import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy
import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.ForLargeFloatingNumbersOperation
import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{
ForLargeFloatingNumbersOperation,
}
import pl.touk.nussknacker.engine.api.typed.typing._
import pl.touk.nussknacker.engine.api.typed.{NumberTypeUtils, typing}
import pl.touk.nussknacker.engine.flink.api.typeinfo.caseclass.CaseClassTypeInfoFactory
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median.MedianHelper
import pl.touk.nussknacker.engine.util.Implicits._
import pl.touk.nussknacker.engine.util.MathUtils
import pl.touk.nussknacker.engine.util.validated.ValidatedSyntax._

import java.util
import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._

/*
Expand Down Expand Up @@ -70,6 +74,26 @@ object aggregates {

}

object MedianAggregator extends Aggregator with LargeFloatingNumberAggregate {

override type Aggregate = ListBuffer[Number]

override type Element = Number

override def zero: Aggregate = ListBuffer.empty

override def addElement(el: Element, agg: Aggregate): Aggregate = if (el == null) agg else agg.addOne(el)

override def mergeAggregates(agg1: Aggregate, agg2: Aggregate): Aggregate = agg1 ++ agg2

override def result(finalAggregate: Aggregate): AnyRef = MedianHelper.calculateMedian(finalAggregate.toList).orNull

override def computeStoredType(input: TypingResult): Validated[String, TypingResult] = Valid(
Typed.genericTypeClass[ListBuffer[_]](List(input))
)

}

object ListAggregator extends Aggregator {

override type Aggregate = List[AnyRef]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package pl.touk.nussknacker.engine.flink.util.transformer.aggregate.median

import pl.touk.nussknacker.engine.api.typed.supertype.NumberTypesPromotionStrategy.{
ForLargeFloatingNumbersOperation,
}
import pl.touk.nussknacker.engine.util.MathUtils

import scala.annotation.tailrec
import scala.util.Random

object MedianHelper {
private val rand = new Random(42)

def calculateMedian(numbers: List[Number]): Option[Number] = {
if (numbers.isEmpty) {
None
} else if (numbers.size % 2 == 1) {
Some(MathUtils.convertToPromotedType(quickSelect(numbers, (numbers.size - 1) / 2))(ForLargeFloatingNumbersOperation))
} else {
// it is possible to fetch both numbers with single recursion, but it would complicate code
val firstNumber = quickSelect(numbers, numbers.size / 2 - 1)
val secondNumber = quickSelect(numbers, numbers.size / 2)

val sum = MathUtils.largeFloatingSum(firstNumber, secondNumber)
Some(MathUtils.divideWithDefaultBigDecimalScale(sum, 2))
}
}

// https://en.wikipedia.org/wiki/Quickselect
@tailrec
private def quickSelect(numbers: List[Number], indexToTake: Int): Number = {
jedrz marked this conversation as resolved.
Show resolved Hide resolved
require(numbers.nonEmpty)
require(indexToTake >= 0)
require(indexToTake < numbers.size)

val randomElement = numbers(rand.nextInt(numbers.size))
val groupedBy = numbers.groupBy(e => {
val cmp = MathUtils.compare(e, randomElement)
if (cmp < 0) {
jedrz marked this conversation as resolved.
Show resolved Hide resolved
-1
} else if (cmp == 0) {
0
} else 1
})
val smallerNumbers = groupedBy.getOrElse(-1, Nil)
val equalNumbers = groupedBy.getOrElse(0, Nil)
val largerNumbers = groupedBy.getOrElse(1, Nil)

if (indexToTake < smallerNumbers.size) {
quickSelect(smallerNumbers, indexToTake)
} else if (indexToTake < smallerNumbers.size + equalNumbers.size) {
equalNumbers(indexToTake - smallerNumbers.size)
} else {
quickSelect(largerNumbers, indexToTake - smallerNumbers.size - equalNumbers.size)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ object sampleTransformers {
new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"),
new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"),
new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"),
new LabeledExpression(label = "Median", expression = "#AGG.median"),
new LabeledExpression(label = "List", expression = "#AGG.list"),
new LabeledExpression(label = "Set", expression = "#AGG.set"),
new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality")
Expand Down Expand Up @@ -102,6 +103,7 @@ object sampleTransformers {
new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"),
new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"),
new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"),
new LabeledExpression(label = "Median", expression = "#AGG.median"),
new LabeledExpression(label = "List", expression = "#AGG.list"),
new LabeledExpression(label = "Set", expression = "#AGG.set"),
new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality")
Expand Down Expand Up @@ -163,6 +165,7 @@ object sampleTransformers {
new LabeledExpression(label = "StddevSamp", expression = "#AGG.stddevSamp"),
new LabeledExpression(label = "VarPop", expression = "#AGG.varPop"),
new LabeledExpression(label = "VarSamp", expression = "#AGG.varSamp"),
new LabeledExpression(label = "Median", expression = "#AGG.median"),
new LabeledExpression(label = "List", expression = "#AGG.list"),
new LabeledExpression(label = "Set", expression = "#AGG.set"),
new LabeledExpression(label = "ApproximateSetCardinality", expression = "#AGG.approxCardinality")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14930,6 +14930,15 @@
}
}
],
"median": [
{
"name": "median",
"signature": {
"noVarArgs": [],
"result": {"refClazzName": "pl.touk.nussknacker.engine.flink.util.transformer.aggregate.Aggregator"}
}
}
],
"first": [
{
"name": "first",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ trait MathUtils {
case n1: java.math.BigDecimal => n1.negate()
}

private def compare(n1: Number, n2: Number): Int = {
jedrz marked this conversation as resolved.
Show resolved Hide resolved
@Hidden
def compare(n1: Number, n2: Number): Int = {
withValuesWithTheSameType(n1, n2)(new SameNumericTypeHandler[Int] {
override def onBytes(n1: java.lang.Byte, n2: java.lang.Byte): Int = n1.compareTo(n2)
override def onShorts(n1: java.lang.Short, n2: java.lang.Short): Int = n1.compareTo(n2)
Expand Down Expand Up @@ -285,7 +286,8 @@ trait MathUtils {
}
}

private def convertToPromotedType(
@Hidden
def convertToPromotedType(
n: Number
)(implicit promotionStrategy: ReturningSingleClassPromotionStrategy): Number = {
// In some cases type can be promoted to other class e.g. Byte is promoted to Int for sum
Expand Down
Loading