Skip to content

Commit

Permalink
Fix fcase() segfault (#6452) (#6451)
Browse files Browse the repository at this point in the history
* More descriptive variable names for fcase counting variables

* Fix fcase() segfault (#6452)
  • Loading branch information
MichaelChirico authored Sep 7, 2024
1 parent 0409294 commit b1c6bf3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ rowwiseDT(

5. Queries like `DT[, min(x):max(x)]` now work as expected, i.e. the same as `DT[, seq(min(x), max(x))]` or `with(DT, min(x):max(x))`, [#2069](https://github.com/Rdatatable/data.table/issues/2069). Shorthand like `DT[, a:b]` meaning "select from columns `a` through `b`" still works. Thanks to @franknarf1 for reporting and @jangorecki for the fix.

6. Fixed a segfault in `fcase()`, [#6448](https://github.com/Rdatatable/data.table/issues/6448). Thanks @ethanbsmith for reporting with reprex, @aitap for finding the root cause, and @MichaelChirico for the PR.

## NOTES

1. Tests run again when some Suggests packages are missing, [#6411](https://github.com/Rdatatable/data.table/issues/6411). Thanks @aadler for the note and @MichaelChirico for the fix.
Expand Down
11 changes: 11 additions & 0 deletions inst/tests/tests.Rraw
Original file line number Diff line number Diff line change
Expand Up @@ -19182,3 +19182,14 @@ test(2285.09, merge(merge(y, x), z), data.table(a=3L, key="a"))
test(2285.10, merge(merge(y, z), x), data.table(a=3L, key="a"))
test(2285.11, merge(merge(z, x), y), data.table(a=3L, key="a"))
test(2285.12, merge(merge(z, y), x), data.table(a=3L, key="a"))

# ensure proper PROTECT() within fcase, #6448
x <- 1:3
test(2286,
fcase(
x<2, structure(list(1), class = "foo"),
x<3, structure(list(2), class = "foo"),
# Force gc() and some allocations which have a good chance at landing in the region that was earlier left unprotected
{ gc(full = TRUE); replicate(10, FALSE); x<4 },
`attr<-`(list(3), "class", "foo")),
structure(list(1, 2, 3), class = "foo"))
68 changes: 37 additions & 31 deletions src/fifelse.c
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,16 @@ SEXP fcaseR(SEXP rho, SEXP args) {
"Note that the default argument must be named explicitly, e.g., default=0"), narg - 2);
}
int nprotect=0, l;
int64_t len0=0, len1=0, len2=0;
SEXP ans=R_NilValue, value0=R_NilValue, tracker=R_NilValue, whens=R_NilValue, thens=R_NilValue;
int64_t n_ans=0, n_this_arg=0, n_undecided=0;
SEXP ans=R_NilValue, tracker=R_NilValue, whens=R_NilValue, thens=R_NilValue;
SEXP ans_class, ans_levels;
PROTECT_INDEX Iwhens, Ithens;
PROTECT_WITH_INDEX(whens, &Iwhens); nprotect++;
PROTECT_WITH_INDEX(thens, &Ithens); nprotect++;
SEXPTYPE type0=NILSXP;
SEXPTYPE ans_type=NILSXP;
// naout means if the output is scalar logic na
bool imask = true, naout = false, idefault = false;
bool ans_is_factor;
int *restrict p = NULL;
const int n = narg/2;
for (int i=0; i<n; ++i) {
Expand All @@ -238,35 +240,39 @@ SEXP fcaseR(SEXP rho, SEXP args) {
const int *restrict pwhens = LOGICAL(whens);
l = 0;
if (i == 0) {
len0 = xlength(whens);
len2 = len0;
type0 = TYPEOF(thens);
value0 = thens;
ans = PROTECT(allocVector(type0, len0)); nprotect++;
n_ans = xlength(whens);
n_undecided = n_ans;
ans_type = TYPEOF(thens);
ans_class = PROTECT(getAttrib(thens, R_ClassSymbol)); nprotect++;
ans_is_factor = isFactor(thens);
if (ans_is_factor) {
ans_levels = PROTECT(getAttrib(thens, R_LevelsSymbol)); nprotect++;
}
ans = PROTECT(allocVector(ans_type, n_ans)); nprotect++;
copyMostAttrib(thens, ans);
tracker = PROTECT(allocVector(INTSXP, len0)); nprotect++;
tracker = PROTECT(allocVector(INTSXP, n_ans)); nprotect++;
p = INTEGER(tracker);
} else {
imask = false;
naout = xlength(thens) == 1 && TYPEOF(thens) == LGLSXP && LOGICAL(thens)[0]==NA_LOGICAL;
if (xlength(whens) != len0) {
if (xlength(whens) != n_ans) {
// no need to check `idefault` here because the con for default is always `TRUE`
error(_("Argument #%d has length %lld which differs from that of argument #1 (%lld). "
"Please make sure all logical conditions have the same length."),
i*2+1, (long long)xlength(whens), (long long)len0);
i*2+1, (long long)xlength(whens), (long long)n_ans);
}
if (!naout && TYPEOF(thens) != type0) {
if (!naout && TYPEOF(thens) != ans_type) {
if (idefault) {
error(_("Resulting value is of type %s but 'default' is of type %s. "
"Please make sure that both arguments have the same type."), type2char(type0), type2char(TYPEOF(thens)));
"Please make sure that both arguments have the same type."), type2char(ans_type), type2char(TYPEOF(thens)));
} else {
error(_("Argument #%d is of type %s, however argument #2 is of type %s. "
"Please make sure all output values have the same type."),
i*2+2, type2char(TYPEOF(thens)), type2char(type0));
i*2+2, type2char(TYPEOF(thens)), type2char(ans_type));
}
}
if (!naout) {
if (!R_compute_identical(PROTECT(getAttrib(value0, R_ClassSymbol)), PROTECT(getAttrib(thens, R_ClassSymbol)), 0)) {
if (!R_compute_identical(ans_class, PROTECT(getAttrib(thens, R_ClassSymbol)), 0)) {
if (idefault) {
error(_("Resulting value has different class than 'default'. "
"Please make sure that both arguments have the same class."));
Expand All @@ -275,35 +281,35 @@ SEXP fcaseR(SEXP rho, SEXP args) {
"Please make sure all output values have the same class."), i*2+2);
}
}
UNPROTECT(2); // class(value0), class(thens)
UNPROTECT(1); // class(thens)
}
if (!naout && isFactor(value0)) {
if (!R_compute_identical(PROTECT(getAttrib(value0, R_LevelsSymbol)), PROTECT(getAttrib(thens, R_LevelsSymbol)), 0)) {
if (!naout && ans_is_factor) {
if (!R_compute_identical(ans_levels, PROTECT(getAttrib(thens, R_LevelsSymbol)), 0)) {
if (idefault) {
error(_("Resulting value and 'default' are both type factor but their levels are different."));
} else {
error(_("Argument #2 and argument #%d are both factor but their levels are different."), i*2+2);
}
}
UNPROTECT(2); // levels(value0), levels(thens)
UNPROTECT(1); // levels(thens)
}
}
len1 = xlength(thens);
if (len1 != len0 && len1 != 1) {
n_this_arg = xlength(thens);
if (n_this_arg != n_ans && n_this_arg != 1) {
if (idefault) {
error(_("Length of 'default' must be 1 or %lld."), (long long)len0);
error(_("Length of 'default' must be 1 or %lld."), (long long)n_ans);
} else {
error(_("Length of output value #%d (%lld) must either be 1 or match the length of the logical condition (%lld)."), i*2+2, (long long)len1, (long long)len0);
error(_("Length of output value #%d (%lld) must either be 1 or match the length of the logical condition (%lld)."), i*2+2, (long long)n_this_arg, (long long)n_ans);
}
}
int64_t thenMask = len1>1 ? INT64_MAX : 0;
int64_t thenMask = n_this_arg>1 ? INT64_MAX : 0;
switch(TYPEOF(ans)) {
case LGLSXP: {
const int *restrict pthens;
if (!naout) pthens = LOGICAL(thens); // the content is not useful if out is NA_LOGICAL scalar
int *restrict pans = LOGICAL(ans);
const int pna = NA_LOGICAL;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
pans[idx] = naout ? pna : pthens[idx & thenMask];
Expand All @@ -320,7 +326,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
if (!naout) pthens = INTEGER(thens); // the content is not useful if out is NA_LOGICAL scalar
int *restrict pans = INTEGER(ans);
const int pna = NA_INTEGER;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
pans[idx] = naout ? pna : pthens[idx & thenMask];
Expand All @@ -338,7 +344,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
double *restrict pans = REAL(ans);
const double na_double = INHERITS(ans, char_integer64) ? NA_INT64_D : NA_REAL;
const double pna = na_double;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
pans[idx] = naout ? pna : pthens[idx & thenMask];
Expand All @@ -355,7 +361,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
if (!naout) pthens = COMPLEX(thens); // the content is not useful if out is NA_LOGICAL scalar
Rcomplex *restrict pans = COMPLEX(ans);
const Rcomplex pna = NA_CPLX;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
pans[idx] = naout ? pna : pthens[idx & thenMask];
Expand All @@ -371,7 +377,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
const SEXP *restrict pthens=NULL;
if (!naout) pthens = STRING_PTR_RO(thens); // the content is not useful if out is NA_LOGICAL scalar
const SEXP pna = NA_STRING;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
SET_STRING_ELT(ans, idx, naout ? pna : pthens[idx & thenMask]);
Expand All @@ -388,7 +394,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
// assign the NA values as it does for other atomic types
const SEXP *restrict pthens=NULL;
if (!naout) pthens = SEXPPTR_RO(thens); // the content is not useful if out is NA_LOGICAL scalar
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
if (!naout) SET_VECTOR_ELT(ans, idx, pthens[idx & thenMask]);
Expand All @@ -403,7 +409,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
if (l==0) {
break; // stop early as nothing left to do
}
len2 = l;
n_undecided = l;
}
UNPROTECT(nprotect); // whens, thens, ans, tracker
return ans;
Expand Down

0 comments on commit b1c6bf3

Please sign in to comment.