Skip to content

Commit c47f7d1

Browse files
authored
fix(learner): printer uses cli (#428)
1 parent a257499 commit c47f7d1

File tree

4 files changed

+34
-3
lines changed

4 files changed

+34
-3
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Depends:
4646
R (>= 3.5.0)
4747
Imports:
4848
backports,
49+
cli,
4950
checkmate (>= 2.2.0),
5051
data.table,
5152
lgr,

R/LearnerTorch.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,11 @@ LearnerTorch = R6Class("LearnerTorch",
262262
#' Currently unused.
263263
print = function(...) {
264264
super$print(...)
265-
catn(str_indent("* Optimizer:", private$.optimizer$id))
266-
catn(str_indent("* Loss:", private$.loss$id))
267-
catn(str_indent("* Callbacks:", if (length(private$.callbacks)) as_short_string(paste0(ids(private$.callbacks), collapse = ","), 1000L) else "-"))
265+
mlr3misc::cat_cli({
266+
cli::cli_li("Optimizer: {private$.optimizer$id}")
267+
cli::cli_li("Loss: {private$.loss$id}")
268+
cli::cli_li(paste0("Callbacks: ", if (length(private$.callbacks)) as_short_string(paste0(ids(private$.callbacks), collapse = ","), 1000L) else "-"))
269+
})
268270
},
269271
#' @description
270272
#' Marshal the learner.

tests/testthat/_snaps/LearnerTorch.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# printer
2+
3+
Code
4+
lrn("classif.mlp", callbacks = list(t_clbk("history"), t_clbk("progress")))
5+
Output
6+
7+
-- <LearnerTorchMLP> (classif.mlp): Multi Layer Perceptron ---------------------
8+
* Model: -
9+
* Parameters: device=auto, num_threads=1, num_interop_threads=1, seed=random,
10+
eval_freq=1, measures_train=<list>, measures_valid=<list>, patience=0,
11+
min_delta=0, shuffle=TRUE, tensor_dataset=FALSE, jit_trace=FALSE,
12+
neurons=integer(0), p=0.5, activation=<nn_relu>, activation_args=<list>
13+
* Validate: NULL
14+
* Packages: mlr3, mlr3torch, torch, and progress
15+
* Predict Types: [response] and prob
16+
* Feature Types: integer, numeric, and lazy_tensor
17+
* Encapsulation: none (fallback: -)
18+
* Properties: internal_tuning, marshal, multiclass, twoclass, and validation
19+
* Other settings: use_weights = 'error'
20+
* Optimizer: adam
21+
* Loss: cross_entropy
22+
* Callbacks: history,progress
23+

tests/testthat/test_LearnerTorch.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,3 +1042,8 @@ test_that("NA prediction during validation does not cause issues.", {
10421042
is.na(learner$model$callbacks$history$valid.regr.mse[1L]),
10431043
)
10441044
})
1045+
1046+
test_that("printer", {
1047+
expect_snapshot(lrn("classif.mlp",
1048+
callbacks = list(t_clbk("history"), t_clbk("progress"))))
1049+
})

0 commit comments

Comments
 (0)