Skip to content

Commit ddd5707

Browse files
committed
support libsvm export
1 parent 44ba9c9 commit ddd5707

File tree

7 files changed

+95
-17
lines changed

7 files changed

+95
-17
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ human-friendly API. Currently, is under an active development.
1818

1919
libLTR is published to maven-central for scala 2.12 and 2.13, so for SBT, add this snippet to `build.sbt`:
2020
```scala
21-
libraryDependencies += "io.github.metarank" %% "ltrlib" % "0.1.15"
21+
libraryDependencies += "io.github.metarank" %% "ltrlib" % "0.1.16"
2222
```
2323

2424
For maven:
2525
```xml
2626
<dependency>
2727
<groupId>io.github.metarank</groupId>
2828
<artifactId>ltrlib_2.13</artifactId>
29-
<version>0.1.15</version>
29+
<version>0.1.16</version>
3030
</dependency>
3131
```
3232
## Usage

build.sbt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ import Deps._
22

33
name := "ltrlib"
44

5-
version := "0.1.15"
5+
version := "0.1.16"
66

7-
scalaVersion := "2.12.17"
7+
scalaVersion := "2.13.10"
88

99
crossScalaVersions := List("2.13.10", "2.12.17")
1010

@@ -14,7 +14,7 @@ Test / logBuffered := false
1414

1515
Test / parallelExecution := false
1616

17-
scalacOptions ++= Seq("-feature", "-deprecation", "-target:jvm-1.8")
17+
scalacOptions ++= Seq("-feature", "-deprecation", "-release:8")
1818

1919
libraryDependencies ++= Seq(
2020
"org.scalatest" %% "scalatest" % scalatestVersion % Test,
@@ -28,7 +28,7 @@ libraryDependencies ++= Seq(
2828
"io.github.metarank" %% "cfor" % "0.2",
2929
"io.github.metarank" % "lightgbm4j" % "3.3.2-2",
3030
"io.github.metarank" % "xgboost-java" % "1.6.1-2",
31-
"com.opencsv" % "opencsv" % "5.7.0",
31+
"com.opencsv" % "opencsv" % "5.7.1",
3232
"org.scala-lang.modules" %% "scala-collection-compat" % "2.8.1"
3333
)
3434

project/build.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
sbt.version = 1.7.2
1+
sbt.version = 1.7.3

src/main/scala/io/github/metarank/ltrlib/output/CSVOutputFormat.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
package io.github.metarank.ltrlib.output
22

3-
import com.opencsv.CSVWriter
3+
import com.opencsv.{CSVWriter, CSVWriterBuilder}
44
import io.github.metarank.ltrlib.model.{Dataset, DatasetDescriptor, Feature, Query}
55

66
import java.io.{OutputStream, OutputStreamWriter}
77

88
object CSVOutputFormat extends OutputFormat {
9-
def write(stream: OutputStream, data: Dataset) = {
9+
def write(stream: OutputStream, data: Dataset, header: Boolean) = {
1010
val writer = new CSVWriter(new OutputStreamWriter(stream))
11-
writer.writeNext(writeHeader(data.desc).toArray)
11+
if (header) writer.writeNext(writeHeader(data.desc).toArray, false)
1212
for {
1313
query <- data.groups
1414
line <- writeGroup(query)
1515
} {
16-
writer.writeNext(line.toArray)
16+
writer.writeNext(line.toArray, false)
1717
}
1818
writer.close()
1919
logger.debug(s"wrote ${data.groups.size} groups to CSV file")
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package io.github.metarank.ltrlib.output
2+
3+
import io.github.metarank.ltrlib.model.Dataset
4+
5+
import java.io.OutputStream
6+
7+
object LibSVMOutputFormat extends OutputFormat {
8+
def write(data: OutputStream, ds: Dataset) = {
9+
for {
10+
query <- ds.groups
11+
rowid <- 0 until query.rows
12+
} {
13+
val row = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2}:${x._1}")
14+
val line = s"${query.labels(rowid)} qid:${query.group} ${row.mkString(" ")}\n"
15+
data.write(line.getBytes())
16+
}
17+
}
18+
19+
def write(data: OutputStream, groups: OutputStream, ds: Dataset) = {
20+
for {
21+
query <- ds.groups
22+
} {
23+
groups.write(s"${query.rows}\n".getBytes())
24+
for {
25+
rowid <- 0 until query.rows
26+
} {
27+
val row = query.getRow(rowid).zipWithIndex.filter(_._1 != 0).map(x => s"${x._2}:${x._1}")
28+
val line = s"${query.labels(rowid)} ${row.mkString(" ")}\n"
29+
data.write(line.getBytes())
30+
}
31+
}
32+
}
33+
}

src/test/scala/io/github/metarank/ltrlib/output/CSVOutputFormatTest.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@ class CSVOutputFormatTest extends AnyFlatSpec with Matchers {
1212
val desc = DatasetDescriptor(List(SingularFeature("f1"), SingularFeature("f2")))
1313
val ds = Dataset(desc, List(Query(desc, List(LabeledItem(1, 1, Array(1.0, 2.0))))))
1414
val out = new ByteArrayOutputStream()
15-
CSVOutputFormat.write(out, ds)
15+
CSVOutputFormat.write(out, ds, true)
1616
val str = new String(out.toByteArray)
17-
str shouldBe """"label","group","f1","f2"
18-
|"1.0","1","1.0","2.0"
17+
str shouldBe """label,group,f1,f2
18+
|1.0,1,1.0,2.0
1919
|""".stripMargin
2020
}
2121

2222
it should "export CSV with vectors" in {
2323
val desc = DatasetDescriptor(List(SingularFeature("f1"), VectorFeature("f2", 2)))
2424
val ds = Dataset(desc, List(Query(desc, List(LabeledItem(1, 1, Array(1.0, 2.0, 3.0))))))
2525
val out = new ByteArrayOutputStream()
26-
CSVOutputFormat.write(out, ds)
26+
CSVOutputFormat.write(out, ds, true)
2727
val str = new String(out.toByteArray)
28-
str shouldBe """"label","group","f1","f2_0","f2_1"
29-
|"1.0","1","1.0","2.0","3.0"
28+
str shouldBe """label,group,f1,f2_0,f2_1
29+
|1.0,1,1.0,2.0,3.0
3030
|""".stripMargin
3131
}
3232
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package io.github.metarank.ltrlib.output
2+
3+
import io.github.metarank.ltrlib.model.{Dataset, DatasetDescriptor, LabeledItem, Query}
4+
import io.github.metarank.ltrlib.model.Feature.SingularFeature
5+
import org.scalatest.flatspec.AnyFlatSpec
6+
import org.scalatest.matchers.should.Matchers
7+
8+
import java.io.ByteArrayOutputStream
9+
import scala.collection.immutable.List
10+
11+
class LibSVMOutputFormatTest extends AnyFlatSpec with Matchers {
12+
val desc = DatasetDescriptor(List(SingularFeature("f1"), SingularFeature("f2")))
13+
val ds = Dataset(
14+
desc,
15+
List(
16+
Query(desc, List(LabeledItem(1, 1, Array(1.0, 2.0)), LabeledItem(0, 1, Array(0.0, 0.0)))),
17+
Query(desc, List(LabeledItem(1, 2, Array(1.0, 2.0))))
18+
)
19+
)
20+
21+
it should "export with qid label" in {
22+
val out = new ByteArrayOutputStream()
23+
LibSVMOutputFormat.write(out, ds)
24+
val str = new String(out.toByteArray)
25+
str shouldBe
26+
"""1.0 qid:1 0:1.0 1:2.0
27+
|0.0 qid:1
28+
|1.0 qid:2 0:1.0 1:2.0
29+
|""".stripMargin
30+
}
31+
32+
it should "export with separate groups file" in {
33+
val data = new ByteArrayOutputStream()
34+
val groups = new ByteArrayOutputStream()
35+
LibSVMOutputFormat.write(data, groups, ds)
36+
val str = new String(data.toByteArray)
37+
val gstr = new String(groups.toByteArray)
38+
str shouldBe
39+
"""1.0 0:1.0 1:2.0
40+
|0.0
41+
|1.0 0:1.0 1:2.0
42+
|""".stripMargin
43+
gstr shouldBe "2\n1\n"
44+
}
45+
}

0 commit comments

Comments
 (0)