Skip to content

Commit

Permalink
Upgrade ChiSqSelectorModel to spark 3.2.0 compatable design
Browse files Browse the repository at this point in the history
sort filterIndiecs before using it
  • Loading branch information
austinzh committed Apr 27, 2022
1 parent d47d0b4 commit 5ce228d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@ import scala.collection.mutable
/**
* Created by hollinwilkins on 12/27/16.
*/
@SparkCode(uri = "https://github.com/apache/spark/blob/v2.0.0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala")
@SparkCode(uri = "https://github.com/apache/spark/blob/v3.2.0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala")
case class ChiSqSelectorModel(filterIndices: Seq[Int],
inputSize: Int) extends Model {
private val sortedFilterIndices = filterIndices.sorted
def apply(features: Vector): Vector = {
features match {
case SparseVector(size, indices, values) =>
val newSize = filterIndices.length
val newSize = sortedFilterIndices.length
val newValues = mutable.ArrayBuilder.make[Double]
val newIndices = mutable.ArrayBuilder.make[Int]
var i = 0
var j = 0
var indicesIdx = 0
var filterIndicesIdx = 0
while (i < indices.length && j < filterIndices.length) {
while (i < indices.length && j < sortedFilterIndices.length) {
indicesIdx = indices(i)
filterIndicesIdx = filterIndices(j)
filterIndicesIdx = sortedFilterIndices(j)
if (indicesIdx == filterIndicesIdx) {
newIndices += j
newValues += values(i)
Expand All @@ -43,7 +44,7 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int],
Vectors.sparse(newSize, newIndices.result(), newValues.result())
case DenseVector(values) =>
val values = features.toArray
Vectors.dense(filterIndices.map(i => values(i)).toArray)
Vectors.dense(sortedFilterIndices.map(i => values(i)).toArray)
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
Expand All @@ -52,5 +53,5 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int],

override def inputSchema: StructType = StructType("input" -> TensorType.Double(inputSize)).get

override def outputSchema: StructType = StructType("output" -> TensorType.Double(filterIndices.length)).get
override def outputSchema: StructType = StructType("output" -> TensorType.Double(sortedFilterIndices.length)).get
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@ package ml.combust.mleap.core.feature

import ml.combust.mleap.core.types.{StructField, TensorType}
import org.scalatest.FunSpec
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}

class ChiSqSelectorModelSpec extends FunSpec {

describe("input/output schema"){
val model = new ChiSqSelectorModel(Seq(1,2,3), 3)
val model = new ChiSqSelectorModel(Seq(2,3, 1), 3)

it("Dense vector work with unsorted indices") {
val vector = Vectors.dense(1.0,2.0,3.0,4.0)
assert(model(vector) == Vectors.dense(2.0, 3.0, 4.0))
}

it("Sparse vector work with unsorted indices") {
val vector = Vectors.sparse(size = 4, indices=Array(0,1,2,3), values = Array(1.0,2.0,3.0,4.0))
assert(model(vector) == Vectors.sparse(size=3, indices=Array(0,1,2), values=Array(2.0,3.0,4.0)))
}

it("Has the right input schema") {
assert(model.inputSchema.fields ==
Expand Down

0 comments on commit 5ce228d

Please sign in to comment.