diff --git a/CHANGES.rst b/CHANGES.rst index 7a94e8f8..3d684bf7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,7 @@ Unreleased - Drop support for Python 3.9. - Remove previously deprecated code. +- Improve C speedups (building upon idea in https://github.com/pallets/markupsafe/pull/438) Version 3.0.3 diff --git a/src/markupsafe/_speedups.c b/src/markupsafe/_speedups.c index 8a315f23..1ea11edf 100644 --- a/src/markupsafe/_speedups.c +++ b/src/markupsafe/_speedups.c @@ -1,154 +1,198 @@ #include +#include + +/* + * Lookup tables for HTML escaping. + * + * The five special characters and their replacements: + * '"' (34) -> """ (len 5, delta +4) + * '&' (38) -> "&" (len 5, delta +4) + * '\'' (39) -> "'" (len 5, delta +4) + * '<' (60) -> "<" (len 4, delta +3) + * '>' (62) -> ">" (len 4, delta +3) + * + * REPLACE_INDEX: 0 = no escaping needed, 1-5 = index into REPLACEMENT_STR. + * All escape chars fit in a byte, so for UCS2/UCS4 we guard with c < 256. + */ +static const uint8_t REPLACE_INDEX[256] = { + ['"'] = 1, + ['&'] = 2, + ['\''] = 3, + ['<'] = 4, + ['>'] = 5, +}; -#define GET_DELTA(inp, inp_end, delta) \ - while (inp < inp_end) { \ - switch (*inp++) { \ - case '"': \ - case '\'': \ - case '&': \ - delta += 4; \ - break; \ - case '<': \ - case '>': \ - delta += 3; \ - break; \ - } \ - } +static const char * const REPLACEMENT_STR[] = { + NULL, """, "&", "'", "<", ">" +}; + +static const uint8_t REPLACEMENT_LEN[] = {0, 5, 5, 5, 4, 4}; + +/* Extra output characters needed per input character (0 if no escaping). */ +static const uint8_t DELTA_TABLE[256] = { + ['"'] = 4, + ['&'] = 4, + ['\''] = 4, + ['<'] = 3, + ['>'] = 3, +}; -#define DO_ESCAPE(inp, inp_end, outp) \ - { \ - Py_ssize_t ncopy = 0; \ - while (inp < inp_end) { \ - switch (*inp) { \ - case '"': \ - memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \ - outp += ncopy; ncopy = 0; \ - *outp++ = '&'; \ - *outp++ = '#'; \ - *outp++ = '3'; \ - *outp++ = '4'; \ - *outp++ = ';'; \ - break; \ - case '\'': \ - memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \ - outp += ncopy; ncopy = 0; \ - *outp++ = '&'; \ - *outp++ = '#'; \ - *outp++ = '3'; \ - *outp++ = '9'; \ - *outp++ = ';'; \ - break; \ - case '&': \ - memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \ - outp += ncopy; ncopy = 0; \ - *outp++ = '&'; \ - *outp++ = 'a'; \ - *outp++ = 'm'; \ - *outp++ = 'p'; \ - *outp++ = ';'; \ - break; \ - case '<': \ - memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \ - outp += ncopy; ncopy = 0; \ - *outp++ = '&'; \ - *outp++ = 'l'; \ - *outp++ = 't'; \ - *outp++ = ';'; \ - break; \ - case '>': \ - memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \ - outp += ncopy; ncopy = 0; \ - *outp++ = '&'; \ - *outp++ = 'g'; \ - *outp++ = 't'; \ - *outp++ = ';'; \ - break; \ - default: \ - ncopy++; \ - } \ - inp++; \ - } \ - memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \ +/* Boolean: nonzero if this byte value requires HTML escaping. */ +static const uint8_t NEEDS_ESCAPE[256] = { + ['"'] = 1, + ['&'] = 1, + ['\''] = 1, + ['<'] = 1, + ['>'] = 1, +}; + +/* + * Count the total extra characters needed for escaping a UCS1 string. + * Processes 4 bytes at a time: if none of the four need escaping, the + * entire chunk is skipped with a single OR across four table lookups. + * Falls back to per-byte delta lookup only for chunks containing a + * special character. Returns 0 if no escaping is needed at all. + */ +static Py_ssize_t +count_delta_1(const Py_UCS1 *inp, Py_ssize_t len) +{ + Py_ssize_t i = 0; + Py_ssize_t delta = 0; + + for (; i + 4 <= len; i += 4) { + if (NEEDS_ESCAPE[inp[i]] | NEEDS_ESCAPE[inp[i+1]] | + NEEDS_ESCAPE[inp[i+2]] | NEEDS_ESCAPE[inp[i+3]]) { + delta += DELTA_TABLE[inp[i]] + DELTA_TABLE[inp[i+1]] + + DELTA_TABLE[inp[i+2]] + DELTA_TABLE[inp[i+3]]; + } } + for (; i < len; i++) + delta += DELTA_TABLE[inp[i]]; -static PyObject* + return delta; +} + +static PyObject * escape_unicode_kind1(PyUnicodeObject *in) { - Py_UCS1 *inp = PyUnicode_1BYTE_DATA(in); - Py_UCS1 *inp_end = inp + PyUnicode_GET_LENGTH(in); - Py_UCS1 *outp; - PyObject *out; - Py_ssize_t delta = 0; + const Py_UCS1 *inp = PyUnicode_1BYTE_DATA(in); + Py_ssize_t len = PyUnicode_GET_LENGTH(in); + Py_ssize_t delta = count_delta_1(inp, len); - GET_DELTA(inp, inp_end, delta); if (!delta) { Py_INCREF(in); - return (PyObject*)in; + return (PyObject *)in; } - out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta, - PyUnicode_IS_ASCII(in) ? 127 : 255); + PyObject *out = PyUnicode_New(len + delta, + PyUnicode_IS_ASCII(in) ? 127 : 255); if (!out) return NULL; - inp = PyUnicode_1BYTE_DATA(in); - outp = PyUnicode_1BYTE_DATA(out); - DO_ESCAPE(inp, inp_end, outp); + Py_UCS1 *outp = PyUnicode_1BYTE_DATA(out); + Py_ssize_t prev = 0; + for (Py_ssize_t i = 0; i < len; i++) { + uint8_t ri = REPLACE_INDEX[inp[i]]; + if (ri) { + if (i > prev) { + memcpy(outp, inp + prev, i - prev); + outp += i - prev; + } + uint8_t rlen = REPLACEMENT_LEN[ri]; + memcpy(outp, REPLACEMENT_STR[ri], rlen); + outp += rlen; + prev = i + 1; + } + } + if (len > prev) + memcpy(outp, inp + prev, len - prev); + return out; } -static PyObject* +static PyObject * escape_unicode_kind2(PyUnicodeObject *in) { - Py_UCS2 *inp = PyUnicode_2BYTE_DATA(in); - Py_UCS2 *inp_end = inp + PyUnicode_GET_LENGTH(in); - Py_UCS2 *outp; - PyObject *out; + const Py_UCS2 *inp = PyUnicode_2BYTE_DATA(in); + Py_ssize_t len = PyUnicode_GET_LENGTH(in); Py_ssize_t delta = 0; - GET_DELTA(inp, inp_end, delta); + for (Py_ssize_t i = 0; i < len; i++) + delta += (inp[i] < 256) ? DELTA_TABLE[inp[i]] : 0; + if (!delta) { Py_INCREF(in); - return (PyObject*)in; + return (PyObject *)in; } - out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta, 65535); + PyObject *out = PyUnicode_New(len + delta, 65535); if (!out) return NULL; - inp = PyUnicode_2BYTE_DATA(in); - outp = PyUnicode_2BYTE_DATA(out); - DO_ESCAPE(inp, inp_end, outp); + Py_UCS2 *outp = PyUnicode_2BYTE_DATA(out); + Py_ssize_t prev = 0; + for (Py_ssize_t i = 0; i < len; i++) { + uint8_t ri = (inp[i] < 256) ? REPLACE_INDEX[inp[i]] : 0; + if (ri) { + if (i > prev) { + memcpy(outp, inp + prev, (i - prev) * sizeof(Py_UCS2)); + outp += i - prev; + } + const char *repl = REPLACEMENT_STR[ri]; + uint8_t rlen = REPLACEMENT_LEN[ri]; + for (uint8_t j = 0; j < rlen; j++) + *outp++ = (Py_UCS2)(unsigned char)repl[j]; + prev = i + 1; + } + } + if (len > prev) + memcpy(outp, inp + prev, (len - prev) * sizeof(Py_UCS2)); + return out; } - -static PyObject* +static PyObject * escape_unicode_kind4(PyUnicodeObject *in) { - Py_UCS4 *inp = PyUnicode_4BYTE_DATA(in); - Py_UCS4 *inp_end = inp + PyUnicode_GET_LENGTH(in); - Py_UCS4 *outp; - PyObject *out; + const Py_UCS4 *inp = PyUnicode_4BYTE_DATA(in); + Py_ssize_t len = PyUnicode_GET_LENGTH(in); Py_ssize_t delta = 0; - GET_DELTA(inp, inp_end, delta); + for (Py_ssize_t i = 0; i < len; i++) + delta += (inp[i] < 256) ? DELTA_TABLE[inp[i]] : 0; + if (!delta) { Py_INCREF(in); - return (PyObject*)in; + return (PyObject *)in; } - out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta, 1114111); + PyObject *out = PyUnicode_New(len + delta, 1114111); if (!out) return NULL; - inp = PyUnicode_4BYTE_DATA(in); - outp = PyUnicode_4BYTE_DATA(out); - DO_ESCAPE(inp, inp_end, outp); + Py_UCS4 *outp = PyUnicode_4BYTE_DATA(out); + Py_ssize_t prev = 0; + for (Py_ssize_t i = 0; i < len; i++) { + uint8_t ri = (inp[i] < 256) ? REPLACE_INDEX[inp[i]] : 0; + if (ri) { + if (i > prev) { + memcpy(outp, inp + prev, (i - prev) * sizeof(Py_UCS4)); + outp += i - prev; + } + const char *repl = REPLACEMENT_STR[ri]; + uint8_t rlen = REPLACEMENT_LEN[ri]; + for (uint8_t j = 0; j < rlen; j++) + *outp++ = (Py_UCS4)(unsigned char)repl[j]; + prev = i + 1; + } + } + if (len > prev) + memcpy(outp, inp + prev, (len - prev) * sizeof(Py_UCS4)); + return out; } -static PyObject* +static PyObject * escape_unicode(PyObject *self, PyObject *s) { if (!PyUnicode_Check(s)) @@ -160,11 +204,11 @@ escape_unicode(PyObject *self, PyObject *s) switch (PyUnicode_KIND(s)) { case PyUnicode_1BYTE_KIND: - return escape_unicode_kind1((PyUnicodeObject*) s); + return escape_unicode_kind1((PyUnicodeObject *) s); case PyUnicode_2BYTE_KIND: - return escape_unicode_kind2((PyUnicodeObject*) s); + return escape_unicode_kind2((PyUnicodeObject *) s); case PyUnicode_4BYTE_KIND: - return escape_unicode_kind4((PyUnicodeObject*) s); + return escape_unicode_kind4((PyUnicodeObject *) s); } assert(0); /* shouldn't happen */ return NULL;