Skip to content

Commit 669e073

Browse files
authored
Merge pull request #128 from metarank/fix/libsvm-feature-index
libsvm format assumes feature index numbering from 1
2 parents 01b96cf + f745e96 commit 669e073

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

src/main/scala/io/github/metarank/ltrlib/booster/CatboostBooster.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ object CatboostBooster extends BoosterFactory[String, CatboostBooster, CatboostO
5151
override def formatData(ds: BoosterDataset, parent: Option[String]): String = {
5252
val file = File.newTemporaryFile("catboost-", ".svm")
5353
val stream = file.newFileOutputStream()
54-
LibSVMOutputFormat.write(stream, ds.original, 1)
54+
LibSVMOutputFormat.write(stream, ds.original)
5555
stream.close()
5656
s"libsvm://$file"
5757
}

src/main/scala/io/github/metarank/ltrlib/output/LibSVMOutputFormat.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ import io.github.metarank.ltrlib.model.Dataset
55
import java.io.OutputStream
66

77
object LibSVMOutputFormat extends OutputFormat {
8-
def write(data: OutputStream, ds: Dataset, offset: Int = 0) = {
8+
def write(data: OutputStream, ds: Dataset) = {
99
for {
1010
query <- ds.groups
1111
rowid <- 0 until query.rows
1212
} {
13-
val features = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2 + offset}:${x._1}")
13+
val features = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2 + 1}:${x._1}")
1414
val label = math.round(query.labels(rowid)).toString
1515
val line = (List(label, s"qid:${query.group}") ++ features).mkString("", " ", "\n")
1616
data.write(line.getBytes())
@@ -25,7 +25,7 @@ object LibSVMOutputFormat extends OutputFormat {
2525
for {
2626
rowid <- 0 until query.rows
2727
} {
28-
val features = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2}:${x._1}").toList
28+
val features = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2 + 1}:${x._1}").toList
2929
val label = math.round(query.labels(rowid)).toString
3030
val line = (List(label) ++ features).mkString("", " ", "\n")
3131
data.write(line.getBytes())

src/test/scala/io/github/metarank/ltrlib/output/LibSVMOutputFormatTest.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class LibSVMOutputFormatTest extends AnyFlatSpec with Matchers {
2323
LibSVMOutputFormat.write(out, ds)
2424
val str = new String(out.toByteArray)
2525
str shouldBe
26-
"""1 qid:1 0:1.0 1:2.0
26+
"""1 qid:1 1:1.0 2:2.0
2727
|0 qid:1
28-
|1 qid:2 0:1.0 1:NaN
28+
|1 qid:2 1:1.0 2:NaN
2929
|""".stripMargin
3030
}
3131

@@ -36,9 +36,9 @@ class LibSVMOutputFormatTest extends AnyFlatSpec with Matchers {
3636
val str = new String(data.toByteArray)
3737
val gstr = new String(groups.toByteArray)
3838
str shouldBe
39-
"""1 0:1.0 1:2.0
39+
"""1 1:1.0 2:2.0
4040
|0
41-
|1 0:1.0 1:NaN
41+
|1 1:1.0 2:NaN
4242
|""".stripMargin
4343
gstr shouldBe "2\n1\n"
4444
}

0 commit comments

Comments
 (0)