Skip to content

Commit f61ebd1

Browse files
authored
Merge pull request #120 from metarank/feature/catboost-options
more options to the catboost booster
2 parents 26b2640 + 38ed15a commit f61ebd1

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

src/main/scala/io/github/metarank/ltrlib/booster/CatboostBooster.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,15 @@ object CatboostBooster extends BoosterFactory[String, CatboostBooster, CatboostO
6464
val modelFile = dir.createChild("model.bin")
6565
val opts = Map(
6666
"--learn-set" -> dataset,
67-
"--loss-function" -> "QueryRMSE",
67+
"--loss-function" -> options.objective,
6868
"--eval-metric" -> s"NDCG:top=${options.ndcgCutoff}",
6969
"--iterations" -> options.trees.toString,
7070
"--depth" -> options.maxDepth.toString,
7171
"--learning-rate" -> options.learningRate.toString,
7272
"--train-dir" -> dir.toString(),
73-
"--model-file" -> modelFile.toString()
73+
"--model-file" -> modelFile.toString(),
74+
"--logging-level" -> "Silent",
75+
"--random-seed" -> options.randomSeed.toString
7476
) ++ test.map(t => Map("--test-set" -> t)).getOrElse(Map.empty)
7577
native_impl.ModeFitImpl(new TVector_TString(opts.flatMap(kv => List(kv._1, kv._2)).toArray))
7678
val bytes = modelFile.byteArray

src/main/scala/io/github/metarank/ltrlib/booster/CatboostOptions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ case class CatboostOptions(
99
learningRate: Double = 0.1,
1010
ndcgCutoff: Int = 10,
1111
maxDepth: Int = 8,
12-
randomSeed: Int = Random.nextInt()
12+
randomSeed: Int = math.abs(Random.nextInt()),
13+
objective: String = "QueryRMSE",
14+
loggingLevel: String = "Verbose"
1315
) extends BoosterOptions

src/main/scala/io/github/metarank/ltrlib/booster/LightGBMOptions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ case class LightGBMOptions(
88
learningRate: Double = 0.1,
99
ndcgCutoff: Int = 10,
1010
maxDepth: Int = 8,
11-
randomSeed: Int = Random.nextInt(),
11+
randomSeed: Int = math.abs(Random.nextInt()),
1212
numLeaves: Int = 16,
1313
featureFraction: Double = 1.0
1414
) extends BoosterOptions

src/main/scala/io/github/metarank/ltrlib/booster/XGBoostOptions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ case class XGBoostOptions(
99
learningRate: Double = 0.1,
1010
ndcgCutoff: Int = 10,
1111
maxDepth: Int = 8,
12-
randomSeed: Int = Random.nextInt(),
12+
randomSeed: Int = math.abs(Random.nextInt()),
1313
subsample: Double = 1.0
1414
) extends BoosterOptions

0 commit comments

Comments
 (0)