Skip to content

Commit

Permalink
Merge pull request #506 from JuliaTrustworthyAI/505-sort-out-converge…
Browse files Browse the repository at this point in the history
…nce-for-flattenedce

hopefully this will do it?
  • Loading branch information
pat-alt authored Dec 31, 2024
2 parents 3961fd4 + 418534c commit 3c7b76f
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 17 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v1.1.1].

## Version [1.4.2] - 2024-12-31

### Changed

- Slight change to `FlattenedCE` and `unflatten` to ensure that basic functionality remains intact. [#505]

## Version [1.4.1] - 2024-12-19

### Changed
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CounterfactualExplanations"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
authors = ["Patrick Altmeyer <[email protected]> and contributors"]
version = "1.4.1"
version = "1.4.2"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
2 changes: 1 addition & 1 deletion src/counterfactuals/core_struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function CounterfactualExplanation(
x,
data,
M,
deepcopy(generator),
generator,
nothing,
convergence,
num_counterfactuals,
Expand Down
31 changes: 23 additions & 8 deletions src/counterfactuals/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,32 @@ A flattened representation of a `CounterfactualExplanation`, containing only the
struct FlattenedCE <: AbstractCounterfactualExplanation
factual::AbstractArray
target::RawTargetType
counterfactual_state::AbstractArray
counterfactual::AbstractArray
search::Dict
end

"""
(ce::CounterfactualExplanation)()::FlattenedCE
Calling the `ce::CounterfactualExplanation` object results in a [`FlattenedCE`](@ref) instance, which is the flattened version of the original.
"""
(ce::CounterfactualExplanation)()::FlattenedCE =
FlattenedCE(ce.factual, ce.target, ce.counterfactual)
function (ce::CounterfactualExplanation)(; store_path::Bool=false)::FlattenedCE
search_dict = ce.search
if !store_path
search_dict[:path] = nothing
end
return FlattenedCE(
ce.factual, ce.target, ce.counterfactual_state, ce.counterfactual, search_dict
)
end

"""
flatten(ce::CounterfactualExplanation)
Alias for `(ce::CounterfactualExplanation)()`. Converts a `CounterfactualExplanation` to its flattened form.
"""
flatten(ce::CounterfactualExplanation) = ce()
flatten(ce::CounterfactualExplanation; kwrgs...) = ce(; kwrgs...)

function unflatten(
flat_ce::FlattenedCE,
Expand All @@ -32,16 +41,22 @@ function unflatten(
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)::CounterfactualExplanation
return CounterfactualExplanation(
ce = CounterfactualExplanation(
flat_ce.factual,
flat_ce.target,
target_encoded(flat_ce, data),
flat_ce.counterfactual_state,
flat_ce.counterfactual,
data,
M,
generator;
initialization=initialization,
convergence=convergence,
num_counterfactuals=size(flat_ce.counterfactual, 2),
generator,
flat_ce.search,
get_convergence_type(convergence, data.y_levels),
size(flat_ce.counterfactual, 2),
initialization,
)
adjust_shape!(ce)
return ce
end

"""
Expand Down
3 changes: 1 addition & 2 deletions src/counterfactuals/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ A convenience method that adjusts the dimensions of the counterfactual state and
function adjust_shape!(ce::CounterfactualExplanation)

# Dimensionality:
x = deepcopy(ce.factual)
counterfactual_state = adjust_shape(ce, x) # augment to account for specified number of counterfactuals
counterfactual_state = adjust_shape(ce, ce.counterfactual) # augment to account for specified number of counterfactuals
ce.counterfactual_state = counterfactual_state
target_encoded = ce.target_encoded
ce.target_encoded = adjust_shape(ce, target_encoded)
Expand Down
5 changes: 2 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,18 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
Aqua = "0.8"
BenchmarkTools = "1.5.0"
CausalInference = "0.17.0"
CausalInference = "0.17, 0.18"
Chain = "0.6.0"
CompatHelperLocal = "0.1.26"
DataFrames = "1.6.1"
DecisionTree = "0.12.4"
EnergySamplers = "1.0"
Flux = "0.12, 0.13, 0.14"
Flux = "0.12, 0.13, 0.14, 0.15, 0.16"
JointEnergyModels = "0.1.7"
LaplaceRedux = "1.1.0"
MLDatasets = "0.7.17"
MLJBase = "1.7.0"
MLJDecisionTreeInterface = "0.4.2"
MLJFlux = "0.5, 0.6"
MLJModels = "0.15, 0.16, 0.17"
MLUtils = "0.4.4"
MultivariateStats = "0.10.3"
Expand Down
19 changes: 17 additions & 2 deletions test/counterfactuals/generate_counterfactual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,22 @@ using CounterfactualExplanations.Objectives: distance_mad
flat_ce = CounterfactualExplanations.flatten(ce)
@test flat_ce isa FlattenedCE
target_encoded(flat_ce, ce.data)
_ce = unflatten(flat_ce, ce.data, ce.M, ce.generator)
@test _ce isa CounterfactualExplanation

@testset "Unflattened" begin
_ce = unflatten(flat_ce, ce.data, ce.M, ce.generator)
@test _ce isa CounterfactualExplanation
@test converged(ce) == converged(_ce)
@test ce.x′ == _ce.x′
@test ce.s′ == _ce.s′
@test ce.factual == _ce.factual
@test ce.counterfactual == _ce.counterfactual
@test ce.target == _ce.target
@test ce.data == _ce.data
@test ce.M == _ce.M
@test ce.generator == _ce.generator
@test ce.counterfactual_state == _ce.counterfactual_state
@test ce.target_encoded == _ce.target_encoded
@test converged(ce) == converged(_ce)
end
end
end

0 comments on commit 3c7b76f

Please sign in to comment.