Skip to content

Commit 4f43857

Browse files
committed
fix: po-block works with latest pipelines release
1 parent a68f008 commit 4f43857

File tree

4 files changed

+14
-3
lines changed

4 files changed

+14
-3
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ CRAN-SUBMISSION
1717
paper/data
1818
.idea/
1919
.vsc/
20-
paper/data
20+
paper/data
21+
.vscode/

R/PipeOpTorchBlock.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ PipeOpTorchBlock = R6Class("PipeOpTorchBlock",
146146
block = private$.block$clone(deep = TRUE)
147147
graph = private$.make_graph(block, param_vals$n_blocks)
148148
inputs = set_names(inputs, graph$input$name)
149-
graph$train(inputs, single_input = FALSE)
149+
out = graph$train(inputs, single_input = FALSE)
150+
self$state = map(out, "pointer_shape")
151+
return(out)
150152
},
151153
.param_set_base = NULL,
152154
.additional_phash_input = function() {

tests/testthat/helper_autotest.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ expect_pipeop_torch = function(graph, id, task, module_class = id, exclude_args
2626
require_namespaces(c("testthat"))
2727
po_test = graph$pipeops[[id]]
2828
result = graph$train(task)
29+
expect_true(graph$is_trained)
2930
md = result[[1]]
3031

3132
# PipeOp overwrites hash and phash if necessary
@@ -321,7 +322,7 @@ expect_torch_callback = function(torch_callback, check_man = TRUE, check_paramse
321322

322323
return(implemented_stages)
323324
}
324-
325+
325326
implemented_stages = get_all_implemented_stages(cbgen)
326327
expect_subset(implemented_stages, mlr_reflections$torch$callback_stages)
327328
expect_true(length(implemented_stages) > 0)

tests/testthat/test_PipeOpTorchBlock.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,10 @@ test_that("trafo works", {
111111
expect_class(network$module_list$`0`$bias, "torch_tensor")
112112
expect_true(is.null(network$module_list$`1`$bias))
113113
})
114+
115+
test_that("state", {
116+
md = po("torch_ingress_num")$train(list(tsk("iris")))
117+
block = nn("block", nn("relu"), n_blocks = 1L)
118+
block$train(md)
119+
expect_true(block$is_trained)
120+
})

0 commit comments

Comments
 (0)