Skip to content

Commit 01b96cf

Browse files
authored
Merge pull request #127 from metarank/fix/libsvm-label-format
handle doubles as label ids in libsvm format
2 parents af05a86 + 776e8ba commit 01b96cf

File tree

6 files changed

+21
-14
lines changed

6 files changed

+21
-14
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import Deps._
22

33
name := "ltrlib"
44

5-
version := "0.1.23-M5"
5+
version := "0.1.23-M7"
66

77
scalaVersion := "2.13.10"
88

src/main/scala/io/github/metarank/ltrlib/booster/CatboostOptions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ case class CatboostOptions(
1212
randomSeed: Int = math.abs(Random.nextInt()),
1313
objective: String = "QueryRMSE",
1414
loggingLevel: String = "Verbose",
15-
earlyStopping: Option[Int] = None
15+
earlyStopping: Option[Int] = None,
16+
boostingType: String = "Plain"
1617
) extends BoosterOptions

src/main/scala/io/github/metarank/ltrlib/input/LibsvmInputFormat.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ object LibsvmInputFormat extends InputFormat {
6262
throw new IllegalArgumentException(
6363
s"LibSVM format requires at least two columns: label and qid, but got ${tokens.length} on row $index"
6464
)
65-
val label = tokens(0).toInt
65+
val label = tokens(0).toDouble
6666
val qid = tokens(1) match {
6767
case queryPattern(_, id) => id.toInt
6868
case _ => throw new IllegalArgumentException(s"qid format for item '${tokens(1)}' is not supported on row $index")
@@ -88,6 +88,6 @@ object LibsvmInputFormat extends InputFormat {
8888
values(featureIndex - 1) = value // libsvm indexing starts from 1
8989
i += 1
9090
}
91-
new LabeledItem(label, qid, values)
91+
LabeledItem(label, qid, values)
9292
}
9393
}

src/main/scala/io/github/metarank/ltrlib/output/LibSVMOutputFormat.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ object LibSVMOutputFormat extends OutputFormat {
1010
query <- ds.groups
1111
rowid <- 0 until query.rows
1212
} {
13-
val row = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2 + offset}:${x._1}")
14-
val line = s"${query.labels(rowid)} qid:${query.group} ${row.mkString(" ")}\n"
13+
val features = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2 + offset}:${x._1}")
14+
val label = math.round(query.labels(rowid)).toString
15+
val line = (List(label, s"qid:${query.group}") ++ features).mkString("", " ", "\n")
1516
data.write(line.getBytes())
1617
}
1718
}
@@ -24,8 +25,9 @@ object LibSVMOutputFormat extends OutputFormat {
2425
for {
2526
rowid <- 0 until query.rows
2627
} {
27-
val row = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2}:${x._1}")
28-
val line = s"${query.labels(rowid)} ${row.mkString(" ")}\n"
28+
val features = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2}:${x._1}").toList
29+
val label = math.round(query.labels(rowid)).toString
30+
val line = (List(label) ++ features).mkString("", " ", "\n")
2931
data.write(line.getBytes())
3032
}
3133
}

src/test/scala/io/github/metarank/ltrlib/input/LibsvmInputFormatTest.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ class LibsvmInputFormatTest extends AnyFlatSpec with Matchers {
1111
parse1("1 qid:1 1:1", LabeledItem(1, 1, Array(1.0)))
1212
}
1313

14+
it should "load 1 feature with double label and qid" in {
15+
parse1("1.0 qid:1 1:1", LabeledItem(1, 1, Array(1.0)))
16+
}
17+
1418
it should "fail on nan" in {
1519
Try(LibsvmInputFormat.parseLine(1, "1 qid:1 1:NaN")).isFailure shouldBe true
1620
}

src/test/scala/io/github/metarank/ltrlib/output/LibSVMOutputFormatTest.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class LibSVMOutputFormatTest extends AnyFlatSpec with Matchers {
2323
LibSVMOutputFormat.write(out, ds)
2424
val str = new String(out.toByteArray)
2525
str shouldBe
26-
"""1.0 qid:1 0:1.0 1:2.0
27-
|0.0 qid:1
28-
|1.0 qid:2 0:1.0 1:NaN
26+
"""1 qid:1 0:1.0 1:2.0
27+
|0 qid:1
28+
|1 qid:2 0:1.0 1:NaN
2929
|""".stripMargin
3030
}
3131

@@ -36,9 +36,9 @@ class LibSVMOutputFormatTest extends AnyFlatSpec with Matchers {
3636
val str = new String(data.toByteArray)
3737
val gstr = new String(groups.toByteArray)
3838
str shouldBe
39-
"""1.0 0:1.0 1:2.0
40-
|0.0
41-
|1.0 0:1.0 1:NaN
39+
"""1 0:1.0 1:2.0
40+
|0
41+
|1 0:1.0 1:NaN
4242
|""".stripMargin
4343
gstr shouldBe "2\n1\n"
4444
}

0 commit comments

Comments
 (0)