Skip to content

Commit

Permalink
feat(predictor): Add support for denylisting accounts
Browse files Browse the repository at this point in the history
There are many occasions when the training data set may include accounts which
should not be predicted (e.g., accounts used for manual reconciliation of
AR/AP). This feature allows the user to stop the predictor from learning these
accounts, thus preventing contamination of the training set, without having to
maintain a separate filtered copy of their transactions.

NB: Currently, all CI runs are broken because of beancount 3.0 hitting PyPi.
Manual testing with beancount<3.0.0 shows all tests passing.
  • Loading branch information
hlieberman committed Aug 14, 2024
1 parent 2354891 commit f25bf29
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
4 changes: 4 additions & 0 deletions smart_importer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def __init__(
predict=True,
overwrite=False,
string_tokenizer: Callable[[str], list] | None = None,
denylist_accounts: list[str] = []
):
super().__init__()
self.training_data = None
self.open_accounts: dict[str, str] = {}
self.denylist_accounts = denylist_accounts
self.pipeline: Pipeline | None = None
self.is_fitted = False
self.lock = threading.Lock()
Expand Down Expand Up @@ -133,6 +135,8 @@ def training_data_filter(self, txn):
for pos in txn.postings:
if pos.account not in self.open_accounts:
return False
if pos.account in self.denylist_accounts:
return False
if self.account == pos.account:
found_import_account = True
return found_import_account or not self.account
Expand Down
14 changes: 13 additions & 1 deletion tests/predictors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
2017-01-13 * "Gas Quick"
Assets:US:BofA:Checking -17.45 USD
2017-01-14 * "Axe Throwing with Joe"
Assets:US:BofA:Checking -13.37 USD
"""
)

Expand All @@ -43,6 +46,7 @@
2016-01-01 open Expenses:Auto:Gas USD
2016-01-01 open Expenses:Food:Groceries USD
2016-01-01 open Expenses:Food:Restaurant USD
2016-01-01 open Expenses:Denylisted USD
2016-01-06 * "Farmer Fresh" "Buying groceries"
Assets:US:BofA:Checking -2.50 USD
Expand Down Expand Up @@ -93,6 +97,11 @@
2016-01-12 * "Gas Quick"
Assets:US:BofA:Checking -24.09 USD
Expenses:Auto:Gas
2016-01-08 * "Axe Throwing with Joe"
Assets:US:BofA:Checking -38.36 USD
Expenses:Denylisted
"""
)

Expand All @@ -105,6 +114,7 @@
"Gimme Coffee",
"Uncle Boons",
None,
None,
]

ACCOUNT_PREDICTIONS = [
Expand All @@ -116,8 +126,10 @@
"Expenses:Food:Coffee",
"Expenses:Food:Groceries",
"Expenses:Auto:Gas",
"Expenses:Food:Groceries",
]

DENYLISTED_ACCOUNTS = ["Expenses:Denylisted"]

class BasicTestImporter(ImporterProtocol):
def extract(self, file, existing_entries=None):
Expand All @@ -133,7 +145,7 @@ def file_account(self, file):


PAYEE_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPayees()])
POSTING_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPostings()])
POSTING_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPostings(denylist_accounts=DENYLISTED_ACCOUNTS)])


def test_empty_training_data():
Expand Down

0 comments on commit f25bf29

Please sign in to comment.