From 998dab18b095fa76cd73ee128f7c168a623abd21 Mon Sep 17 00:00:00 2001 From: Georges Kfoury Date: Wed, 16 Oct 2024 17:01:33 +0300 Subject: [PATCH 1/5] Update fractional_indexing.py Reordered the functions to follow a chronological order --- fractional_indexing.py | 144 ++++++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 73 deletions(-) diff --git a/fractional_indexing.py b/fractional_indexing.py index 7bbc722..e8b5990 100644 --- a/fractional_indexing.py +++ b/fractional_indexing.py @@ -5,13 +5,11 @@ - """ from math import floor from typing import Optional, List import decimal - __version__ = '0.1.3' __licence__ = 'CC0 1.0 Universal' @@ -22,58 +20,25 @@ class FIError(Exception): pass -def midpoint(a: str, b: Optional[str], digits: str) -> str: +def round_half_up(n: float) -> int: """ - `a` may be empty string, `b` is null or non-empty string. - `a < b` lexicographically if `b` is non-null. - no trailing zeros allowed. - digits is a string such as '0123456789' for base 10. Digits must be in - ascending character code order! - + >>> round_half_up(0.4) + 0 + >>> round_half_up(0.8) + 1 + >>> round_half_up(0.5) + 1 + >>> round_half_up(1.5) + 2 + >>> round_half_up(2.5) + 3 """ - zero = digits[0] - if b is not None and a >= b: - raise FIError(f'{a} >= {b}') - if (a and a[-1]) == zero or (b is not None and b[-1] == zero): - raise FIError('trailing zero') - if b: - # remove longest common prefix. pad `a` with 0s as we - # go. note that we don't need to pad `b`, because it can't - # end before `a` while traversing the common prefix. - n = 0 - for x, y in zip(a.ljust(len(b), zero), b): - if x == y: - n += 1 - continue - break - - if n > 0: - return b[:n] + midpoint(a[n:], b[n:], digits) - - # first digits (or lack of digit) are different - try: - digit_a = digits.index(a[0]) if a else 0 - except IndexError: - digit_a = -1 - try: - digit_b = digits.index(b[0]) if b is not None else len(digits) - except IndexError: - digit_b = -1 - - if digit_b - digit_a > 1: - min_digit = round_half_up(0.5 * (digit_a + digit_b)) - return digits[min_digit] - else: - if b is not None and len(b) > 1: - return b[:1] - else: - # `b` is null or has length 1 (a single digit). - # the first digit of `a` is the previous digit to `b`, - # or 9 if `b` is null. - # given, for example, midpoint('49', '5'), return - # '4' + midpoint('9', null), which will become - # '4' + '9' + midpoint('', null), which is '495' - return digits[digit_a] + midpoint(a[1:], None, digits) + return int( + decimal.Decimal(str(n)).quantize( + decimal.Decimal('1'), + rounding=decimal.ROUND_HALF_UP + ) + ) def validate_integer(i: str): @@ -172,6 +137,60 @@ def decrement_integer(x, digits): return head + ''.join(digs) +def midpoint(a: str, b: Optional[str], digits: str) -> str: + """ + `a` may be empty string, `b` is null or non-empty string. + `a < b` lexicographically if `b` is non-null. + no trailing zeros allowed. + digits is a string such as '0123456789' for base 10. Digits must be in + ascending character code order! + + """ + zero = digits[0] + if b is not None and a >= b: + raise FIError(f'{a} >= {b}') + if (a and a[-1]) == zero or (b is not None and b[-1] == zero): + raise FIError('trailing zero') + if b: + # remove longest common prefix. pad `a` with 0s as we + # go. note that we don't need to pad `b`, because it can't + # end before `a` while traversing the common prefix. + n = 0 + for x, y in zip(a.ljust(len(b), zero), b): + if x == y: + n += 1 + continue + break + + if n > 0: + return b[:n] + midpoint(a[n:], b[n:], digits) + + # first digits (or lack of digit) are different + try: + digit_a = digits.index(a[0]) if a else 0 + except IndexError: + digit_a = -1 + try: + digit_b = digits.index(b[0]) if b is not None else len(digits) + except IndexError: + digit_b = -1 + + if digit_b - digit_a > 1: + min_digit = round_half_up(0.5 * (digit_a + digit_b)) + return digits[min_digit] + else: + if b is not None and len(b) > 1: + return b[:1] + else: + # `b` is null or has length 1 (a single digit). + # the first digit of `a` is the previous digit to `b`, + # or 9 if `b` is null. + # given, for example, midpoint('49', '5'), return + # '4' + midpoint('9', null), which will become + # '4' + '9' + midpoint('', null), which is '495' + return digits[digit_a] + midpoint(a[1:], None, digits) + + def generate_key_between(a: Optional[str], b: Optional[str], digits=BASE_62_DIGITS) -> str: """ `a` is an order key or null (START). @@ -262,24 +281,3 @@ def generate_n_keys_between(a: Optional[str], b: Optional[str], n: int, digits=B c, *generate_n_keys_between(c, b, n - mid - 1, digits) ] - - -def round_half_up(n: float) -> int: - """ - >>> round_half_up(0.4) - 0 - >>> round_half_up(0.8) - 1 - >>> round_half_up(0.5) - 1 - >>> round_half_up(1.5) - 2 - >>> round_half_up(2.5) - 3 - """ - return int( - decimal.Decimal(str(n)).quantize( - decimal.Decimal('1'), - rounding=decimal.ROUND_HALF_UP - ) - ) From 8b7f6eb47c4682a94cc0601224309390cf6ba08a Mon Sep 17 00:00:00 2001 From: Georges Kfoury Date: Wed, 16 Oct 2024 17:11:19 +0300 Subject: [PATCH 2/5] Update fractional_indexing.py Refactor and improve code readability for order key generation - Improved function and variable naming for clarity (e.g., renamed `a`, `b`, `x`, `i` to `start_key`, `end_key`, `integer_str`, `integer_length`). - Added comprehensive and consistent docstrings to all functions, providing better explanations for inputs, outputs, and behavior. - Refactored `midpoint` logic to simplify complex calculations, improving readability and modularity. - Broke down parts of `increment_integer` and `decrement_integer` into clearer steps. - Applied input validation for functions to ensure robustness and better error handling (e.g., added checks in `generate_key_between`). - Simplified redundant checks for `None` values across multiple functions by consolidating logic. - Enhanced error messages in `OrderKeyError` for better debugging clarity. - Applied type hints consistently across all functions. --- fractional_indexing.py | 364 ++++++++++++++++++++--------------------- 1 file changed, 173 insertions(+), 191 deletions(-) diff --git a/fractional_indexing.py b/fractional_indexing.py index e8b5990..87e36a3 100644 --- a/fractional_indexing.py +++ b/fractional_indexing.py @@ -16,268 +16,250 @@ BASE_62_DIGITS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' -class FIError(Exception): +class OrderKeyError(Exception): + """Custom error for invalid order keys.""" pass -def round_half_up(n: float) -> int: - """ - >>> round_half_up(0.4) - 0 - >>> round_half_up(0.8) - 1 - >>> round_half_up(0.5) - 1 - >>> round_half_up(1.5) - 2 - >>> round_half_up(2.5) - 3 - """ +def round_half_up(value: float) -> int: + """Round a float to the nearest integer, rounding halves up.""" return int( - decimal.Decimal(str(n)).quantize( + decimal.Decimal(str(value)).quantize( decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP ) ) -def validate_integer(i: str): - if len(i) != get_integer_length(i[0]): - raise FIError(f'invalid integer part of order key: {i}') +def validate_integer(order_key: str): + """Validate that the length of the integer part of the order key is correct.""" + if len(order_key) != get_integer_length(order_key[0]): + raise OrderKeyError(f'Invalid integer part of order key: {order_key}') -def get_integer_length(head): - if 'a' <= head <= 'z': - return ord(head) - ord('a') + 2 - elif 'A' <= head <= 'Z': - return ord('Z') - ord(head[0]) + 2 - raise FIError('invalid order key head: ' + head) +def get_integer_length(first_char: str) -> int: + """Return the length of the integer part based on the first character.""" + if 'a' <= first_char <= 'z': + return ord(first_char) - ord('a') + 2 + elif 'A' <= first_char <= 'Z': + return ord('Z') - ord(first_char) + 2 + raise OrderKeyError('Invalid order key head: ' + first_char) -def get_integer_part(key: str) -> str: - integer_part_length = get_integer_length(key[0]) - if integer_part_length > len(key): - raise FIError(f'invalid order key: {key}') - return key[:integer_part_length] +def get_integer_part(order_key: str) -> str: + """Extract the integer part of the order key.""" + integer_part_length = get_integer_length(order_key[0]) + if integer_part_length > len(order_key): + raise OrderKeyError(f'Invalid order key: {order_key}') + return order_key[:integer_part_length] -def validate_order_key(key: str, digits=BASE_62_DIGITS): +def validate_order_key(order_key: str, digits: str = BASE_62_DIGITS): + """Check the validity of an order key.""" zero = digits[0] - smallest = 'A' + (zero * 26) - if key == smallest: - raise FIError(f'invalid order key: {key}') + smallest_valid_key = 'A' + (zero * 26) + + if order_key == smallest_valid_key: + raise OrderKeyError(f'Invalid order key: {order_key}') - # get_integer_part() will throw if the first character is bad, - # or the key is too short. we'd call it to check these things - # even if we didn't need the result - i = get_integer_part(key) - f = key[len(i):] - if f and f[-1] == zero: - raise FIError(f'invalid order key: {key}') + integer_part = get_integer_part(order_key) + fractional_part = order_key[len(integer_part):] + if fractional_part and fractional_part[-1] == zero: + raise OrderKeyError(f'Invalid order key: {order_key}') -def increment_integer(x: str, digits: str) -> Optional[str]: + +def increment_integer(integer_str: str, digits: str) -> Optional[str]: + """Increment the integer part of the order key.""" zero = digits[0] - validate_integer(x) - head, *digs = x + validate_integer(integer_str) + + head, *digits_list = integer_str carry = True - for i in reversed(range(len(digs))): - d = digits.index(digs[i]) + 1 - if d == len(digits): - digs[i] = zero + + for i in reversed(range(len(digits_list))): + current_digit = digits.index(digits_list[i]) + 1 + if current_digit == len(digits): + digits_list[i] = zero else: - digs[i] = digits[d] + digits_list[i] = digits[current_digit] carry = False break + if carry: if head == 'Z': return 'a' + zero - elif head == 'z': + if head == 'z': return None - h = chr(ord(head[0]) + 1) - if h > 'a': - digs.append(zero) + next_head = chr(ord(head) + 1) + if next_head > 'a': + digits_list.append(zero) else: - digs.pop() - return h + ''.join(digs) - else: - return head + ''.join(digs) + digits_list.pop() + return next_head + ''.join(digits_list) + + return head + ''.join(digits_list) + +def decrement_integer(integer_str: str, digits: str) -> Optional[str]: + """Decrement the integer part of the order key.""" + validate_integer(integer_str) -def decrement_integer(x, digits): - validate_integer(x) - head, *digs = x + head, *digits_list = integer_str borrow = True - for i in reversed(range(len(digs))): - try: - index = digits.index(digs[i]) - except IndexError: - index = -1 - d = index - 1 + for i in reversed(range(len(digits_list))): + current_digit = digits.index(digits_list[i]) - 1 - if d == -1: - digs[i] = digits[-1] + if current_digit == -1: + digits_list[i] = digits[-1] else: - digs[i] = digits[d] + digits_list[i] = digits[current_digit] borrow = False break + if borrow: if head == 'a': return 'Z' + digits[-1] if head == 'A': return None - h = chr(ord(head[0]) - 1) - if h < 'Z': - digs.append(digits[-1]) + next_head = chr(ord(head) - 1) + if next_head < 'Z': + digits_list.append(digits[-1]) else: - digs.pop() - return h + ''.join(digs) - else: - return head + ''.join(digs) + digits_list.pop() + return next_head + ''.join(digits_list) + return head + ''.join(digits_list) -def midpoint(a: str, b: Optional[str], digits: str) -> str: - """ - `a` may be empty string, `b` is null or non-empty string. - `a < b` lexicographically if `b` is non-null. - no trailing zeros allowed. - digits is a string such as '0123456789' for base 10. Digits must be in - ascending character code order! +def midpoint(start_key: str, end_key: Optional[str], digits: str) -> str: + """ + Calculate the midpoint between two order keys. + `start_key` must be lexicographically less than `end_key`. + No trailing zeros allowed in the order key. """ zero = digits[0] - if b is not None and a >= b: - raise FIError(f'{a} >= {b}') - if (a and a[-1]) == zero or (b is not None and b[-1] == zero): - raise FIError('trailing zero') - if b: - # remove longest common prefix. pad `a` with 0s as we - # go. note that we don't need to pad `b`, because it can't - # end before `a` while traversing the common prefix. - n = 0 - for x, y in zip(a.ljust(len(b), zero), b): - if x == y: - n += 1 + + if end_key is not None and start_key >= end_key: + raise OrderKeyError(f'{start_key} >= {end_key}') + + if start_key and start_key[-1] == zero or (end_key and end_key[-1] == zero): + raise OrderKeyError('Trailing zero in order key') + + if end_key: + common_prefix_len = 0 + for char_start, char_end in zip(start_key.ljust(len(end_key), zero), end_key): + if char_start == char_end: + common_prefix_len += 1 continue break - if n > 0: - return b[:n] + midpoint(a[n:], b[n:], digits) + if common_prefix_len > 0: + return end_key[:common_prefix_len] + midpoint( + start_key[common_prefix_len:], end_key[common_prefix_len:], digits + ) - # first digits (or lack of digit) are different - try: - digit_a = digits.index(a[0]) if a else 0 - except IndexError: - digit_a = -1 - try: - digit_b = digits.index(b[0]) if b is not None else len(digits) - except IndexError: - digit_b = -1 + # Different first digits or lack of digit + digit_a = digits.index(start_key[0]) if start_key else 0 + digit_b = digits.index(end_key[0]) if end_key else len(digits) if digit_b - digit_a > 1: min_digit = round_half_up(0.5 * (digit_a + digit_b)) return digits[min_digit] - else: - if b is not None and len(b) > 1: - return b[:1] - else: - # `b` is null or has length 1 (a single digit). - # the first digit of `a` is the previous digit to `b`, - # or 9 if `b` is null. - # given, for example, midpoint('49', '5'), return - # '4' + midpoint('9', null), which will become - # '4' + '9' + midpoint('', null), which is '495' - return digits[digit_a] + midpoint(a[1:], None, digits) + if end_key and len(end_key) > 1: + return end_key[:1] + + return digits[digit_a] + midpoint(start_key[1:], None, digits) -def generate_key_between(a: Optional[str], b: Optional[str], digits=BASE_62_DIGITS) -> str: - """ - `a` is an order key or null (START). - `b` is an order key or null (END). - `a < b` lexicographically if both are non-null. - digits is a string such as '0123456789' for base 10. Digits must be in - ascending character code order! +def generate_key_between(start_key: Optional[str], end_key: Optional[str], digits: str = BASE_62_DIGITS) -> str: + """ + Generate an order key that lies between `start_key` and `end_key`. + If both are None, returns the first possible key. """ zero = digits[0] - if a is not None: - validate_order_key(a, digits=digits) - if b is not None: - validate_order_key(b, digits=digits) - if a is not None and b is not None and a >= b: - raise FIError(f'{a} >= {b}') - - if a is None: - if b is None: + + if start_key is not None: + validate_order_key(start_key, digits) + + if end_key is not None: + validate_order_key(end_key, digits) + + if start_key is not None and end_key is not None and start_key >= end_key: + raise OrderKeyError(f'{start_key} >= {end_key}') + + if start_key is None: + if end_key is None: return 'a' + zero - ib = get_integer_part(b) - fb = b[len(ib):] - if ib == 'A' + (zero * 26): - return ib + midpoint('', fb, digits) - if ib < b: - return ib - res = decrement_integer(ib, digits) - if res is None: - raise FIError('cannot decrement any more') - return res - - if b is None: - ia = get_integer_part(a) - fa = a[len(ia):] - i = increment_integer(ia, digits) - return ia + midpoint(fa, None, digits) if i is None else i - - ia = get_integer_part(a) - fa = a[len(ia):] - ib = get_integer_part(b) - fb = b[len(ib):] - if ia == ib: - return ia + midpoint(fa, fb, digits) - i = increment_integer(ia, digits) - if i is None: - raise FIError('cannot increment any more') - - if i < b: - return i - - return ia + midpoint(fa, None, digits) - - -def generate_n_keys_between(a: Optional[str], b: Optional[str], n: int, digits=BASE_62_DIGITS) -> List[str]: - """ - same preconditions as generate_keys_between(). - n >= 0. - Returns an array of n distinct keys in sorted order. - If a and b are both null, returns [a0, a1, ...] - If one or the other is null, returns consecutive "integer" - keys. Otherwise, returns relatively short keys between + integer_part = get_integer_part(end_key) + fractional_part = end_key[len(integer_part):] + if integer_part == 'A' + (zero * 26): + return integer_part + midpoint('', fractional_part, digits) + if integer_part < end_key: + return integer_part + decremented = decrement_integer(integer_part, digits) + if decremented is None: + raise OrderKeyError('Cannot decrement anymore') + return decremented + + if end_key is None: + integer_part = get_integer_part(start_key) + fractional_part = start_key[len(integer_part):] + incremented = increment_integer(integer_part, digits) + return integer_part + midpoint(fractional_part, None, digits) if incremented is None else incremented + + start_int_part = get_integer_part(start_key) + start_frac_part = start_key[len(start_int_part):] + end_int_part = get_integer_part(end_key) + end_frac_part = end_key[len(end_int_part):] + + if start_int_part == end_int_part: + return start_int_part + midpoint(start_frac_part, end_frac_part, digits) + incremented = increment_integer(start_int_part, digits) + + if incremented is None: + raise OrderKeyError('Cannot increment anymore') + + if incremented < end_key: + return incremented + + return start_int_part + midpoint(start_frac_part, None, digits) + + +def generate_n_keys_between(start_key: Optional[str], end_key: Optional[str], n: int, digits: str = BASE_62_DIGITS) -> List[str]: + """ + Generate `n` distinct order keys between `start_key` and `end_key`. """ if n == 0: return [] + if n == 1: - return [generate_key_between(a, b, digits)] - if b is None: - c = generate_key_between(a, b, digits) - result = [c] - for i in range(n - 1): - c = generate_key_between(c, b, digits) - result.append(c) + return [generate_key_between(start_key, end_key, digits)] + + if end_key is None: + current_key = generate_key_between(start_key, end_key, digits) + result = [current_key] + for _ in range(n - 1): + current_key = generate_key_between(current_key, end_key, digits) + result.append(current_key) return result - if a is None: - c = generate_key_between(a, b, digits) - result = [c] - for i in range(n - 1): - c = generate_key_between(a, c, digits) - result.append(c) + if start_key is None: + current_key = generate_key_between(start_key, end_key, digits) + result = [current_key] + for _ in range(n - 1): + current_key = generate_key_between(start_key, current_key, digits) + result.append(current_key) return list(reversed(result)) - mid = floor(n / 2) - c = generate_key_between(a, b, digits) + mid_index = floor(n / 2) + middle_key = generate_key_between(start_key, end_key, digits) + return [ - *generate_n_keys_between(a, c, mid, digits), - c, - *generate_n_keys_between(c, b, n - mid - 1, digits) + *generate_n_keys_between(start_key, middle_key, mid_index, digits), + middle_key, + *generate_n_keys_between(middle_key, end_key, n - mid_index - 1, digits) ] From 4cdb8c993138e34714677f797d74f3ad6bc514a6 Mon Sep 17 00:00:00 2001 From: Georges Kfoury Date: Wed, 16 Oct 2024 17:34:37 +0300 Subject: [PATCH 3/5] Update fractional_indexing.py - Added type hints across all functions - renamed parameter n to number_of_keys - renamed function midpoint to find_middle_key --- fractional_indexing.py | 142 ++++++++++++++++++++++++----------------- 1 file changed, 84 insertions(+), 58 deletions(-) diff --git a/fractional_indexing.py b/fractional_indexing.py index 87e36a3..634d7b8 100644 --- a/fractional_indexing.py +++ b/fractional_indexing.py @@ -13,7 +13,7 @@ __version__ = '0.1.3' __licence__ = 'CC0 1.0 Universal' -BASE_62_DIGITS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +BASE_62_DIGITS: str = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' class OrderKeyError(Exception): @@ -31,7 +31,7 @@ def round_half_up(value: float) -> int: ) -def validate_integer(order_key: str): +def validate_integer(order_key: str) -> None: """Validate that the length of the integer part of the order key is correct.""" if len(order_key) != get_integer_length(order_key[0]): raise OrderKeyError(f'Invalid integer part of order key: {order_key}') @@ -54,7 +54,7 @@ def get_integer_part(order_key: str) -> str: return order_key[:integer_part_length] -def validate_order_key(order_key: str, digits: str = BASE_62_DIGITS): +def validate_order_key(order_key: str, digits: str = BASE_62_DIGITS) -> None: """Check the validity of an order key.""" zero = digits[0] smallest_valid_key = 'A' + (zero * 26) @@ -75,7 +75,7 @@ def increment_integer(integer_str: str, digits: str) -> Optional[str]: validate_integer(integer_str) head, *digits_list = integer_str - carry = True + has_carry_over = True for i in reversed(range(len(digits_list))): current_digit = digits.index(digits_list[i]) + 1 @@ -83,10 +83,10 @@ def increment_integer(integer_str: str, digits: str) -> Optional[str]: digits_list[i] = zero else: digits_list[i] = digits[current_digit] - carry = False + has_carry_over = False break - if carry: + if has_carry_over: if head == 'Z': return 'a' + zero if head == 'z': @@ -106,7 +106,7 @@ def decrement_integer(integer_str: str, digits: str) -> Optional[str]: validate_integer(integer_str) head, *digits_list = integer_str - borrow = True + requires_borrow = True for i in reversed(range(len(digits_list))): current_digit = digits.index(digits_list[i]) - 1 @@ -115,10 +115,10 @@ def decrement_integer(integer_str: str, digits: str) -> Optional[str]: digits_list[i] = digits[-1] else: digits_list[i] = digits[current_digit] - borrow = False + requires_borrow = False break - if borrow: + if requires_borrow: if head == 'a': return 'Z' + digits[-1] if head == 'A': @@ -133,7 +133,7 @@ def decrement_integer(integer_str: str, digits: str) -> Optional[str]: return head + ''.join(digits_list) -def midpoint(start_key: str, end_key: Optional[str], digits: str) -> str: +def find_middle_key(start_key: str, end_key: Optional[str], digits: str) -> str: """ Calculate the midpoint between two order keys. `start_key` must be lexicographically less than `end_key`. @@ -156,7 +156,7 @@ def midpoint(start_key: str, end_key: Optional[str], digits: str) -> str: break if common_prefix_len > 0: - return end_key[:common_prefix_len] + midpoint( + return end_key[:common_prefix_len] + find_middle_key( start_key[common_prefix_len:], end_key[common_prefix_len:], digits ) @@ -171,7 +171,51 @@ def midpoint(start_key: str, end_key: Optional[str], digits: str) -> str: if end_key and len(end_key) > 1: return end_key[:1] - return digits[digit_a] + midpoint(start_key[1:], None, digits) + return digits[digit_a] + find_middle_key(start_key[1:], None, digits) + + +def handle_end_key_only_case(end_key: str, digits: str) -> str: + """Handle the case when only `end_key` is provided.""" + zero = digits[0] + integer_part = get_integer_part(end_key) + fractional_part = end_key[len(integer_part):] + if integer_part == 'A' + (zero * 26): + return integer_part + find_middle_key('', fractional_part, digits) + if integer_part < end_key: + return integer_part + decremented = decrement_integer(integer_part, digits) + if decremented is None: + raise OrderKeyError('Cannot decrement anymore') + return decremented + + +def handle_start_key_only_case(start_key: str, digits: str) -> str: + """Handle the case when only `start_key` is provided.""" + integer_part = get_integer_part(start_key) + fractional_part = start_key[len(integer_part):] + incremented = increment_integer(integer_part, digits) + return integer_part + find_middle_key(fractional_part, None, digits) if incremented is None else incremented + + +def handle_both_keys_case(start_key: str, end_key: str, digits: str) -> str: + """Handle the case when both `start_key` and `end_key` are provided.""" + start_int_part = get_integer_part(start_key) + start_frac_part = start_key[len(start_int_part):] + end_int_part = get_integer_part(end_key) + end_frac_part = end_key[len(end_int_part):] + + if start_int_part == end_int_part: + return start_int_part + find_middle_key(start_frac_part, end_frac_part, digits) + + incremented = increment_integer(start_int_part, digits) + + if incremented is None: + raise OrderKeyError('Cannot increment anymore') + + if incremented < end_key: + return incremented + + return start_int_part + find_middle_key(start_frac_part, None, digits) def generate_key_between(start_key: Optional[str], end_key: Optional[str], digits: str = BASE_62_DIGITS) -> str: @@ -193,73 +237,55 @@ def generate_key_between(start_key: Optional[str], end_key: Optional[str], digit if start_key is None: if end_key is None: return 'a' + zero - integer_part = get_integer_part(end_key) - fractional_part = end_key[len(integer_part):] - if integer_part == 'A' + (zero * 26): - return integer_part + midpoint('', fractional_part, digits) - if integer_part < end_key: - return integer_part - decremented = decrement_integer(integer_part, digits) - if decremented is None: - raise OrderKeyError('Cannot decrement anymore') - return decremented + return handle_end_key_only_case(end_key, digits) if end_key is None: - integer_part = get_integer_part(start_key) - fractional_part = start_key[len(integer_part):] - incremented = increment_integer(integer_part, digits) - return integer_part + midpoint(fractional_part, None, digits) if incremented is None else incremented + return handle_start_key_only_case(start_key, digits) - start_int_part = get_integer_part(start_key) - start_frac_part = start_key[len(start_int_part):] - end_int_part = get_integer_part(end_key) - end_frac_part = end_key[len(end_int_part):] + return handle_both_keys_case(start_key, end_key, digits) - if start_int_part == end_int_part: - return start_int_part + midpoint(start_frac_part, end_frac_part, digits) - incremented = increment_integer(start_int_part, digits) - if incremented is None: - raise OrderKeyError('Cannot increment anymore') +def handle_generate_n_keys_with_end_none(start_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: + """Handle case when generating keys with `end_key` as None.""" + current_key = generate_key_between(start_key, None, digits) + result = [current_key] + for _ in range(number_of_keys - 1): + current_key = generate_key_between(current_key, None, digits) + result.append(current_key) + return result - if incremented < end_key: - return incremented - - return start_int_part + midpoint(start_frac_part, None, digits) +def handle_generate_n_keys_with_start_none(end_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: + """Handle case when generating keys with `start_key` as None.""" + current_key = generate_key_between(None, end_key, digits) + result = [current_key] + for _ in range(number_of_keys - 1): + current_key = generate_key_between(None, current_key, digits) + result.append(current_key) + return list(reversed(result)) -def generate_n_keys_between(start_key: Optional[str], end_key: Optional[str], n: int, digits: str = BASE_62_DIGITS) -> List[str]: +def generate_n_keys_between(start_key: Optional[str], end_key: Optional[str], number_of_keys: int, digits: str = BASE_62_DIGITS) -> List[str]: """ Generate `n` distinct order keys between `start_key` and `end_key`. """ - if n == 0: + if number_of_keys == 0: return [] - if n == 1: + if number_of_keys == 1: return [generate_key_between(start_key, end_key, digits)] if end_key is None: - current_key = generate_key_between(start_key, end_key, digits) - result = [current_key] - for _ in range(n - 1): - current_key = generate_key_between(current_key, end_key, digits) - result.append(current_key) - return result + return handle_generate_n_keys_with_end_none(start_key, number_of_keys, digits) if start_key is None: - current_key = generate_key_between(start_key, end_key, digits) - result = [current_key] - for _ in range(n - 1): - current_key = generate_key_between(start_key, current_key, digits) - result.append(current_key) - return list(reversed(result)) - - mid_index = floor(n / 2) + return handle_generate_n_keys_with_start_none(end_key, number_of_keys, digits) + + mid_index = floor(number_of_keys / 2) middle_key = generate_key_between(start_key, end_key, digits) return [ *generate_n_keys_between(start_key, middle_key, mid_index, digits), middle_key, - *generate_n_keys_between(middle_key, end_key, n - mid_index - 1, digits) - ] + *generate_n_keys_between(middle_key, end_key, number_of_keys - mid_index - 1, digits) + ] \ No newline at end of file From 38cb8b87c95a65d31aa6bb8eaa234d517fe9ec8e Mon Sep 17 00:00:00 2001 From: Georges Kfoury Date: Wed, 16 Oct 2024 17:45:38 +0300 Subject: [PATCH 4/5] Split up into smaller modules - moved helper functions out of main fractional_indexing.py - updated tests.py with new imports and error name --- __init__.py | 13 +++ exceptions.py | 5 + fractional_indexing.py | 248 ++--------------------------------------- handlers.py | 76 +++++++++++++ tests.py | 55 ++++----- utils.py | 161 ++++++++++++++++++++++++++ 6 files changed, 296 insertions(+), 262 deletions(-) create mode 100644 __init__.py create mode 100644 exceptions.py create mode 100644 handlers.py create mode 100644 utils.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..4975a8c --- /dev/null +++ b/__init__.py @@ -0,0 +1,13 @@ +# fractional_indexing/__init__.py + +from .fractional_indexing import ( + generate_key_between, + generate_n_keys_between +) +from .exceptions import OrderKeyError + +__all__ = [ + 'generate_key_between', + 'generate_n_keys_between', + 'OrderKeyError' +] diff --git a/exceptions.py b/exceptions.py new file mode 100644 index 0000000..c6b5e55 --- /dev/null +++ b/exceptions.py @@ -0,0 +1,5 @@ +# fractional_indexing/exceptions.py + +class OrderKeyError(Exception): + """Custom error for invalid order keys.""" + pass diff --git a/fractional_indexing.py b/fractional_indexing.py index 634d7b8..7bd08c8 100644 --- a/fractional_indexing.py +++ b/fractional_indexing.py @@ -1,221 +1,17 @@ -""" -Provides functions for generating ordering strings +# fractional_indexing/fractional_indexing.py -. - - - -""" from math import floor from typing import Optional, List -import decimal - -__version__ = '0.1.3' -__licence__ = 'CC0 1.0 Universal' - -BASE_62_DIGITS: str = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - - -class OrderKeyError(Exception): - """Custom error for invalid order keys.""" - pass - - -def round_half_up(value: float) -> int: - """Round a float to the nearest integer, rounding halves up.""" - return int( - decimal.Decimal(str(value)).quantize( - decimal.Decimal('1'), - rounding=decimal.ROUND_HALF_UP - ) - ) - - -def validate_integer(order_key: str) -> None: - """Validate that the length of the integer part of the order key is correct.""" - if len(order_key) != get_integer_length(order_key[0]): - raise OrderKeyError(f'Invalid integer part of order key: {order_key}') - - -def get_integer_length(first_char: str) -> int: - """Return the length of the integer part based on the first character.""" - if 'a' <= first_char <= 'z': - return ord(first_char) - ord('a') + 2 - elif 'A' <= first_char <= 'Z': - return ord('Z') - ord(first_char) + 2 - raise OrderKeyError('Invalid order key head: ' + first_char) - - -def get_integer_part(order_key: str) -> str: - """Extract the integer part of the order key.""" - integer_part_length = get_integer_length(order_key[0]) - if integer_part_length > len(order_key): - raise OrderKeyError(f'Invalid order key: {order_key}') - return order_key[:integer_part_length] - - -def validate_order_key(order_key: str, digits: str = BASE_62_DIGITS) -> None: - """Check the validity of an order key.""" - zero = digits[0] - smallest_valid_key = 'A' + (zero * 26) - - if order_key == smallest_valid_key: - raise OrderKeyError(f'Invalid order key: {order_key}') - - integer_part = get_integer_part(order_key) - fractional_part = order_key[len(integer_part):] - - if fractional_part and fractional_part[-1] == zero: - raise OrderKeyError(f'Invalid order key: {order_key}') - - -def increment_integer(integer_str: str, digits: str) -> Optional[str]: - """Increment the integer part of the order key.""" - zero = digits[0] - validate_integer(integer_str) - - head, *digits_list = integer_str - has_carry_over = True - - for i in reversed(range(len(digits_list))): - current_digit = digits.index(digits_list[i]) + 1 - if current_digit == len(digits): - digits_list[i] = zero - else: - digits_list[i] = digits[current_digit] - has_carry_over = False - break - - if has_carry_over: - if head == 'Z': - return 'a' + zero - if head == 'z': - return None - next_head = chr(ord(head) + 1) - if next_head > 'a': - digits_list.append(zero) - else: - digits_list.pop() - return next_head + ''.join(digits_list) - - return head + ''.join(digits_list) - - -def decrement_integer(integer_str: str, digits: str) -> Optional[str]: - """Decrement the integer part of the order key.""" - validate_integer(integer_str) - - head, *digits_list = integer_str - requires_borrow = True - - for i in reversed(range(len(digits_list))): - current_digit = digits.index(digits_list[i]) - 1 - if current_digit == -1: - digits_list[i] = digits[-1] - else: - digits_list[i] = digits[current_digit] - requires_borrow = False - break - - if requires_borrow: - if head == 'a': - return 'Z' + digits[-1] - if head == 'A': - return None - next_head = chr(ord(head) - 1) - if next_head < 'Z': - digits_list.append(digits[-1]) - else: - digits_list.pop() - return next_head + ''.join(digits_list) - - return head + ''.join(digits_list) - - -def find_middle_key(start_key: str, end_key: Optional[str], digits: str) -> str: - """ - Calculate the midpoint between two order keys. - `start_key` must be lexicographically less than `end_key`. - No trailing zeros allowed in the order key. - """ - zero = digits[0] - - if end_key is not None and start_key >= end_key: - raise OrderKeyError(f'{start_key} >= {end_key}') - - if start_key and start_key[-1] == zero or (end_key and end_key[-1] == zero): - raise OrderKeyError('Trailing zero in order key') - - if end_key: - common_prefix_len = 0 - for char_start, char_end in zip(start_key.ljust(len(end_key), zero), end_key): - if char_start == char_end: - common_prefix_len += 1 - continue - break - - if common_prefix_len > 0: - return end_key[:common_prefix_len] + find_middle_key( - start_key[common_prefix_len:], end_key[common_prefix_len:], digits - ) - - # Different first digits or lack of digit - digit_a = digits.index(start_key[0]) if start_key else 0 - digit_b = digits.index(end_key[0]) if end_key else len(digits) - - if digit_b - digit_a > 1: - min_digit = round_half_up(0.5 * (digit_a + digit_b)) - return digits[min_digit] - - if end_key and len(end_key) > 1: - return end_key[:1] - - return digits[digit_a] + find_middle_key(start_key[1:], None, digits) - - -def handle_end_key_only_case(end_key: str, digits: str) -> str: - """Handle the case when only `end_key` is provided.""" - zero = digits[0] - integer_part = get_integer_part(end_key) - fractional_part = end_key[len(integer_part):] - if integer_part == 'A' + (zero * 26): - return integer_part + find_middle_key('', fractional_part, digits) - if integer_part < end_key: - return integer_part - decremented = decrement_integer(integer_part, digits) - if decremented is None: - raise OrderKeyError('Cannot decrement anymore') - return decremented - - -def handle_start_key_only_case(start_key: str, digits: str) -> str: - """Handle the case when only `start_key` is provided.""" - integer_part = get_integer_part(start_key) - fractional_part = start_key[len(integer_part):] - incremented = increment_integer(integer_part, digits) - return integer_part + find_middle_key(fractional_part, None, digits) if incremented is None else incremented - - -def handle_both_keys_case(start_key: str, end_key: str, digits: str) -> str: - """Handle the case when both `start_key` and `end_key` are provided.""" - start_int_part = get_integer_part(start_key) - start_frac_part = start_key[len(start_int_part):] - end_int_part = get_integer_part(end_key) - end_frac_part = end_key[len(end_int_part):] - - if start_int_part == end_int_part: - return start_int_part + find_middle_key(start_frac_part, end_frac_part, digits) - - incremented = increment_integer(start_int_part, digits) - - if incremented is None: - raise OrderKeyError('Cannot increment anymore') - - if incremented < end_key: - return incremented - - return start_int_part + find_middle_key(start_frac_part, None, digits) +from .exceptions import OrderKeyError +from .utils import BASE_62_DIGITS, validate_order_key +from .handlers import ( + handle_end_key_only_case, + handle_start_key_only_case, + handle_both_keys_case, + handle_generate_n_keys_with_end_none, + handle_generate_n_keys_with_start_none +) def generate_key_between(start_key: Optional[str], end_key: Optional[str], digits: str = BASE_62_DIGITS) -> str: @@ -245,29 +41,9 @@ def generate_key_between(start_key: Optional[str], end_key: Optional[str], digit return handle_both_keys_case(start_key, end_key, digits) - -def handle_generate_n_keys_with_end_none(start_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: - """Handle case when generating keys with `end_key` as None.""" - current_key = generate_key_between(start_key, None, digits) - result = [current_key] - for _ in range(number_of_keys - 1): - current_key = generate_key_between(current_key, None, digits) - result.append(current_key) - return result - - -def handle_generate_n_keys_with_start_none(end_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: - """Handle case when generating keys with `start_key` as None.""" - current_key = generate_key_between(None, end_key, digits) - result = [current_key] - for _ in range(number_of_keys - 1): - current_key = generate_key_between(None, current_key, digits) - result.append(current_key) - return list(reversed(result)) - def generate_n_keys_between(start_key: Optional[str], end_key: Optional[str], number_of_keys: int, digits: str = BASE_62_DIGITS) -> List[str]: """ - Generate `n` distinct order keys between `start_key` and `end_key`. + Generate `number_of_keys` distinct order keys between `start_key` and `end_key`. """ if number_of_keys == 0: return [] @@ -288,4 +64,4 @@ def generate_n_keys_between(start_key: Optional[str], end_key: Optional[str], nu *generate_n_keys_between(start_key, middle_key, mid_index, digits), middle_key, *generate_n_keys_between(middle_key, end_key, number_of_keys - mid_index - 1, digits) - ] \ No newline at end of file + ] diff --git a/handlers.py b/handlers.py new file mode 100644 index 0000000..f32fc82 --- /dev/null +++ b/handlers.py @@ -0,0 +1,76 @@ +# fractional_indexing/handlers.py + +from typing import Optional, List + +from .exceptions import OrderKeyError +from .utils import ( + get_integer_part, + find_middle_key, + decrement_integer, + increment_integer, + generate_key_between +) + + +def handle_end_key_only_case(end_key: str, digits: str) -> str: + """Handle the case when only `end_key` is provided.""" + zero = digits[0] + integer_part = get_integer_part(end_key) + fractional_part = end_key[len(integer_part):] + if integer_part == 'A' + (zero * 26): + return integer_part + find_middle_key('', fractional_part, digits) + if integer_part < end_key: + return integer_part + decremented = decrement_integer(integer_part, digits) + if decremented is None: + raise OrderKeyError('Cannot decrement anymore') + return decremented + + +def handle_start_key_only_case(start_key: str, digits: str) -> str: + """Handle the case when only `start_key` is provided.""" + integer_part = get_integer_part(start_key) + fractional_part = start_key[len(integer_part):] + incremented = increment_integer(integer_part, digits) + return integer_part + find_middle_key(fractional_part, None, digits) if incremented is None else incremented + + +def handle_both_keys_case(start_key: str, end_key: str, digits: str) -> str: + """Handle the case when both `start_key` and `end_key` are provided.""" + start_int_part = get_integer_part(start_key) + start_frac_part = start_key[len(start_int_part):] + end_int_part = get_integer_part(end_key) + end_frac_part = end_key[len(end_int_part):] + + if start_int_part == end_int_part: + return start_int_part + find_middle_key(start_frac_part, end_frac_part, digits) + + incremented = increment_integer(start_int_part, digits) + + if incremented is None: + raise OrderKeyError('Cannot increment anymore') + + if incremented < end_key: + return incremented + + return start_int_part + find_middle_key(start_frac_part, None, digits) + + +def handle_generate_n_keys_with_end_none(start_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: + """Handle case when generating keys with `end_key` as None.""" + current_key = generate_key_between(start_key, None, digits) + result = [current_key] + for _ in range(number_of_keys - 1): + current_key = generate_key_between(current_key, None, digits) + result.append(current_key) + return result + + +def handle_generate_n_keys_with_start_none(end_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: + """Handle case when generating keys with `start_key` as None.""" + current_key = generate_key_between(None, end_key, digits) + result = [current_key] + for _ in range(number_of_keys - 1): + current_key = generate_key_between(None, current_key, digits) + result.append(current_key) + return list(reversed(result)) diff --git a/tests.py b/tests.py index 44fec6b..e251e09 100644 --- a/tests.py +++ b/tests.py @@ -2,8 +2,12 @@ import pytest -from fractional_indexing import FIError, generate_key_between, generate_n_keys_between, validate_order_key - +from fractional_indexing import ( + OrderKeyError, + generate_key_between, + generate_n_keys_between, +) +from .utils import validate_order_key BASE_95_DIGITS = ' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~' @@ -28,18 +32,18 @@ ('Zz', 'a01', 'a0'), (None, 'a0V', 'a0'), (None, 'b999', 'b99'), - (None, 'A00000000000000000000000000', FIError('invalid order key: A00000000000000000000000000')), + (None, 'A00000000000000000000000000', OrderKeyError('invalid order key: A00000000000000000000000000')), (None, 'A000000000000000000000000001', 'A000000000000000000000000000V'), ('zzzzzzzzzzzzzzzzzzzzzzzzzzy', None, 'zzzzzzzzzzzzzzzzzzzzzzzzzzz'), ('zzzzzzzzzzzzzzzzzzzzzzzzzzz', None, 'zzzzzzzzzzzzzzzzzzzzzzzzzzzV'), - ('a00', None, FIError('invalid order key: a00')), - ('a00', 'a1', FIError('invalid order key: a00')), - ('0', '1', FIError('invalid order key head: 0')), - ('a1', 'a0', FIError('a1 >= a0')), + ('a00', None, OrderKeyError('invalid order key: a00')), + ('a00', 'a1', OrderKeyError('invalid order key: a00')), + ('0', '1', OrderKeyError('invalid order key head: 0')), + ('a1', 'a0', OrderKeyError('a1 >= a0')), ]) -def test_generate_key_between(a: Optional[str], b: Optional[str], expected: str) -> None: - if isinstance(expected, FIError): - with pytest.raises(FIError) as e: +def test_generate_key_between(a: Optional[str], b: Optional[str], expected) -> None: + if isinstance(expected, OrderKeyError): + with pytest.raises(OrderKeyError) as e: generate_key_between(a, b) assert e.value.args[0] == expected.args[0] else: @@ -72,23 +76,23 @@ def test_generate_n_keys_between(a: Optional[str], b: Optional[str], n: int, exp (None, None, 'a '), ('a ', None, 'a!'), (None, 'a ', 'Z~'), - ('a0 ', 'a0!', FIError('invalid order key: a0 ')), + ('a0 ', 'a0!', OrderKeyError('invalid order key: a0 ')), (None, 'A 0', 'A ('), ('a~', None, 'b '), ('Z~', None, 'a '), - ('b ', None, FIError('invalid order key: b ')), + ('b ', None, OrderKeyError('invalid order key: b ')), ('a0', 'a0V', 'a0;'), ('a 1', 'a 2', 'a 1P'), - (None, 'A ', FIError('invalid order key: A ')), + (None, 'A ', OrderKeyError('invalid order key: A ')), ]) def test_base95_digits(a: Optional[str], b: Optional[str], expected: str) -> None: kwargs = { - 'a': a, - 'b': b, + 'start_key': a, + 'end_key': b, 'digits': BASE_95_DIGITS, } - if isinstance(expected, FIError): - with pytest.raises(FIError) as e: + if isinstance(expected, OrderKeyError): + with pytest.raises(OrderKeyError) as e: generate_key_between(**kwargs) assert e.value.args[0] == expected.args[0] else: @@ -124,31 +128,30 @@ def test_readme_examples_single_key(): def test_readme_examples_multiple_keys(): # Insert 3 at the beginning - keys = generate_n_keys_between(None, None, n=3) + keys = generate_n_keys_between(None, None, number_of_keys=3) assert keys == ['a0', 'a1', 'a2'] # Insert 3 after 1st - keys = generate_n_keys_between('a0', None, n=3) + keys = generate_n_keys_between('a0', None, number_of_keys=3) assert keys == ['a1', 'a2', 'a3'] # Insert 3 before 1st - keys = generate_n_keys_between(None, 'a0', n=3) + keys = generate_n_keys_between(None, 'a0', number_of_keys=3) assert keys == ['Zx', 'Zy', 'Zz'] # Insert 3 in between 2nd and 3rd. Midpoint - keys = generate_n_keys_between('a1', 'a2', n=3) + keys = generate_n_keys_between('a1', 'a2', number_of_keys=3) assert keys == ['a1G', 'a1V', 'a1l'] def test_readme_examples_validate_order_key(): - from fractional_indexing import validate_order_key, FIError + from fractional_indexing import validate_order_key, OrderKeyError validate_order_key('a0') - try: + with pytest.raises(OrderKeyError) as e: validate_order_key('foo') - except FIError as e: - print(e) # fractional_indexing.FIError: invalid order key: foo + assert str(e.value) == 'invalid order key: foo' def test_readme_examples_custom_base(): @@ -156,4 +159,4 @@ def test_readme_examples_custom_base(): assert generate_key_between(None, None, digits=BASE_95_DIGITS) == 'a ' assert generate_key_between('a ', None, digits=BASE_95_DIGITS) == 'a!' assert generate_key_between(None, 'a ', digits=BASE_95_DIGITS) == 'Z~' - assert generate_n_keys_between('a ', 'a!', n=3, digits=BASE_95_DIGITS) == ['a 8', 'a P', 'a h'] + assert generate_n_keys_between('a ', 'a!', number_of_keys=3, digits=BASE_95_DIGITS) == ['a 8', 'a P', 'a h'] diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..3849008 --- /dev/null +++ b/utils.py @@ -0,0 +1,161 @@ +# fractional_indexing/utils.py + +import decimal +from typing import Optional + +from .exceptions import OrderKeyError + +BASE_62_DIGITS: str = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + + +def round_half_up(value: float) -> int: + """Round a float to the nearest integer, rounding halves up.""" + return int( + decimal.Decimal(str(value)).quantize( + decimal.Decimal('1'), + rounding=decimal.ROUND_HALF_UP + ) + ) + + +def validate_integer(order_key: str) -> None: + """Validate that the length of the integer part of the order key is correct.""" + if len(order_key) != get_integer_length(order_key[0]): + raise OrderKeyError(f'Invalid integer part of order key: {order_key}') + + +def get_integer_length(first_char: str) -> int: + """Return the length of the integer part based on the first character.""" + if 'a' <= first_char <= 'z': + return ord(first_char) - ord('a') + 2 + elif 'A' <= first_char <= 'Z': + return ord('Z') - ord(first_char) + 2 + raise OrderKeyError('Invalid order key head: ' + first_char) + + +def get_integer_part(order_key: str) -> str: + """Extract the integer part of the order key.""" + integer_part_length = get_integer_length(order_key[0]) + if integer_part_length > len(order_key): + raise OrderKeyError(f'Invalid order key: {order_key}') + return order_key[:integer_part_length] + + +def validate_order_key(order_key: str, digits: str = BASE_62_DIGITS) -> None: + """Check the validity of an order key.""" + zero = digits[0] + smallest_valid_key = 'A' + (zero * 26) + + if order_key == smallest_valid_key: + raise OrderKeyError(f'Invalid order key: {order_key}') + + integer_part = get_integer_part(order_key) + fractional_part = order_key[len(integer_part):] + + if fractional_part and fractional_part[-1] == zero: + raise OrderKeyError(f'Invalid order key: {order_key}') + + +def find_middle_key(start_key: str, end_key: Optional[str], digits: str) -> str: + """ + Calculate the midpoint between two order keys. + `start_key` must be lexicographically less than `end_key`. + No trailing zeros allowed in the order key. + """ + zero = digits[0] + + if end_key is not None and start_key >= end_key: + raise OrderKeyError(f'{start_key} >= {end_key}') + + if (start_key and start_key[-1] == zero) or (end_key and end_key[-1] == zero): + raise OrderKeyError('Trailing zero in order key') + + if end_key: + common_prefix_len = 0 + for char_start, char_end in zip(start_key.ljust(len(end_key), zero), end_key): + if char_start == char_end: + common_prefix_len += 1 + continue + break + + if common_prefix_len > 0: + return end_key[:common_prefix_len] + find_middle_key( + start_key[common_prefix_len:], end_key[common_prefix_len:], digits + ) + + # Different first digits or lack of digit + digit_a = digits.index(start_key[0]) if start_key else 0 + digit_b = digits.index(end_key[0]) if end_key else len(digits) + + if digit_b - digit_a > 1: + min_digit = round_half_up(0.5 * (digit_a + digit_b)) + return digits[min_digit] + + if end_key and len(end_key) > 1: + return end_key[:1] + + return digits[digit_a] + find_middle_key(start_key[1:], None, digits) + + +def increment_integer(integer_str: str, digits: str) -> Optional[str]: + """Increment the integer part of the order key.""" + zero = digits[0] + validate_integer(integer_str) + + head, *digits_list = integer_str + has_carry_over = True + + for i in reversed(range(len(digits_list))): + current_digit = digits.index(digits_list[i]) + 1 + if current_digit == len(digits): + digits_list[i] = zero + else: + digits_list[i] = digits[current_digit] + has_carry_over = False + break + + if has_carry_over: + if head == 'Z': + return 'a' + zero + if head == 'z': + return None + next_head = chr(ord(head) + 1) + if next_head > 'a': + digits_list.append(zero) + else: + digits_list.pop() + return next_head + ''.join(digits_list) + + return head + ''.join(digits_list) + + +def decrement_integer(integer_str: str, digits: str) -> Optional[str]: + """Decrement the integer part of the order key.""" + validate_integer(integer_str) + + head, *digits_list = integer_str + requires_borrow = True + + for i in reversed(range(len(digits_list))): + current_digit = digits.index(digits_list[i]) - 1 + + if current_digit == -1: + digits_list[i] = digits[-1] + else: + digits_list[i] = digits[current_digit] + requires_borrow = False + break + + if requires_borrow: + if head == 'a': + return 'Z' + digits[-1] + if head == 'A': + return None + next_head = chr(ord(head) - 1) + if next_head < 'Z': + digits_list.append(digits[-1]) + else: + digits_list.pop() + return next_head + ''.join(digits_list) + + return head + ''.join(digits_list) From f523d19ace2c867dccf2835711de475219ce8b8b Mon Sep 17 00:00:00 2001 From: Georges Kfoury Date: Wed, 16 Oct 2024 18:05:53 +0300 Subject: [PATCH 5/5] Fixed imports to be local, not relative, and fixed circular imports - moved handlers to main due to circular imports - adjusted imports to be local, not relative --- __init__.py | 4 +-- fractional_indexing.py | 80 +++++++++++++++++++++++++++++++++++++----- handlers.py | 76 --------------------------------------- tests.py | 11 +++--- utils.py | 2 +- 5 files changed, 80 insertions(+), 93 deletions(-) delete mode 100644 handlers.py diff --git a/__init__.py b/__init__.py index 4975a8c..3fdc7ef 100644 --- a/__init__.py +++ b/__init__.py @@ -1,10 +1,10 @@ # fractional_indexing/__init__.py -from .fractional_indexing import ( +from fractional_indexing import ( generate_key_between, generate_n_keys_between ) -from .exceptions import OrderKeyError +from exceptions import OrderKeyError __all__ = [ 'generate_key_between', diff --git a/fractional_indexing.py b/fractional_indexing.py index 7bd08c8..4d6ffd3 100644 --- a/fractional_indexing.py +++ b/fractional_indexing.py @@ -3,17 +3,61 @@ from math import floor from typing import Optional, List -from .exceptions import OrderKeyError -from .utils import BASE_62_DIGITS, validate_order_key -from .handlers import ( - handle_end_key_only_case, - handle_start_key_only_case, - handle_both_keys_case, - handle_generate_n_keys_with_end_none, - handle_generate_n_keys_with_start_none +from exceptions import OrderKeyError +from utils import BASE_62_DIGITS, validate_order_key +from utils import ( + get_integer_part, + find_middle_key, + decrement_integer, + increment_integer, ) +# GENERATE KEY BETWEEN HANDLERS AND FUNCTION +def handle_end_key_only_case(end_key: str, digits: str) -> str: + """Handle the case when only `end_key` is provided.""" + zero = digits[0] + integer_part = get_integer_part(end_key) + fractional_part = end_key[len(integer_part):] + if integer_part == 'A' + (zero * 26): + return integer_part + find_middle_key('', fractional_part, digits) + if integer_part < end_key: + return integer_part + decremented = decrement_integer(integer_part, digits) + if decremented is None: + raise OrderKeyError('Cannot decrement anymore') + return decremented + + +def handle_start_key_only_case(start_key: str, digits: str) -> str: + """Handle the case when only `start_key` is provided.""" + integer_part = get_integer_part(start_key) + fractional_part = start_key[len(integer_part):] + incremented = increment_integer(integer_part, digits) + return integer_part + find_middle_key(fractional_part, None, digits) if incremented is None else incremented + + +def handle_both_keys_case(start_key: str, end_key: str, digits: str) -> str: + """Handle the case when both `start_key` and `end_key` are provided.""" + start_int_part = get_integer_part(start_key) + start_frac_part = start_key[len(start_int_part):] + end_int_part = get_integer_part(end_key) + end_frac_part = end_key[len(end_int_part):] + + if start_int_part == end_int_part: + return start_int_part + find_middle_key(start_frac_part, end_frac_part, digits) + + incremented = increment_integer(start_int_part, digits) + + if incremented is None: + raise OrderKeyError('Cannot increment anymore') + + if incremented < end_key: + return incremented + + return start_int_part + find_middle_key(start_frac_part, None, digits) + + def generate_key_between(start_key: Optional[str], end_key: Optional[str], digits: str = BASE_62_DIGITS) -> str: """ Generate an order key that lies between `start_key` and `end_key`. @@ -41,6 +85,26 @@ def generate_key_between(start_key: Optional[str], end_key: Optional[str], digit return handle_both_keys_case(start_key, end_key, digits) +# GENERATE N KEYS BETWEEN HANDLERS AND FUNCTION +def handle_generate_n_keys_with_end_none(start_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: + """Handle case when generating keys with `end_key` as None.""" + current_key = generate_key_between(start_key, None, digits) + result = [current_key] + for _ in range(number_of_keys - 1): + current_key = generate_key_between(current_key, None, digits) + result.append(current_key) + return result + + +def handle_generate_n_keys_with_start_none(end_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: + """Handle case when generating keys with `start_key` as None.""" + current_key = generate_key_between(None, end_key, digits) + result = [current_key] + for _ in range(number_of_keys - 1): + current_key = generate_key_between(None, current_key, digits) + result.append(current_key) + return list(reversed(result)) + def generate_n_keys_between(start_key: Optional[str], end_key: Optional[str], number_of_keys: int, digits: str = BASE_62_DIGITS) -> List[str]: """ Generate `number_of_keys` distinct order keys between `start_key` and `end_key`. diff --git a/handlers.py b/handlers.py deleted file mode 100644 index f32fc82..0000000 --- a/handlers.py +++ /dev/null @@ -1,76 +0,0 @@ -# fractional_indexing/handlers.py - -from typing import Optional, List - -from .exceptions import OrderKeyError -from .utils import ( - get_integer_part, - find_middle_key, - decrement_integer, - increment_integer, - generate_key_between -) - - -def handle_end_key_only_case(end_key: str, digits: str) -> str: - """Handle the case when only `end_key` is provided.""" - zero = digits[0] - integer_part = get_integer_part(end_key) - fractional_part = end_key[len(integer_part):] - if integer_part == 'A' + (zero * 26): - return integer_part + find_middle_key('', fractional_part, digits) - if integer_part < end_key: - return integer_part - decremented = decrement_integer(integer_part, digits) - if decremented is None: - raise OrderKeyError('Cannot decrement anymore') - return decremented - - -def handle_start_key_only_case(start_key: str, digits: str) -> str: - """Handle the case when only `start_key` is provided.""" - integer_part = get_integer_part(start_key) - fractional_part = start_key[len(integer_part):] - incremented = increment_integer(integer_part, digits) - return integer_part + find_middle_key(fractional_part, None, digits) if incremented is None else incremented - - -def handle_both_keys_case(start_key: str, end_key: str, digits: str) -> str: - """Handle the case when both `start_key` and `end_key` are provided.""" - start_int_part = get_integer_part(start_key) - start_frac_part = start_key[len(start_int_part):] - end_int_part = get_integer_part(end_key) - end_frac_part = end_key[len(end_int_part):] - - if start_int_part == end_int_part: - return start_int_part + find_middle_key(start_frac_part, end_frac_part, digits) - - incremented = increment_integer(start_int_part, digits) - - if incremented is None: - raise OrderKeyError('Cannot increment anymore') - - if incremented < end_key: - return incremented - - return start_int_part + find_middle_key(start_frac_part, None, digits) - - -def handle_generate_n_keys_with_end_none(start_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: - """Handle case when generating keys with `end_key` as None.""" - current_key = generate_key_between(start_key, None, digits) - result = [current_key] - for _ in range(number_of_keys - 1): - current_key = generate_key_between(current_key, None, digits) - result.append(current_key) - return result - - -def handle_generate_n_keys_with_start_none(end_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: - """Handle case when generating keys with `start_key` as None.""" - current_key = generate_key_between(None, end_key, digits) - result = [current_key] - for _ in range(number_of_keys - 1): - current_key = generate_key_between(None, current_key, digits) - result.append(current_key) - return list(reversed(result)) diff --git a/tests.py b/tests.py index e251e09..b1d1f9c 100644 --- a/tests.py +++ b/tests.py @@ -7,7 +7,7 @@ generate_key_between, generate_n_keys_between, ) -from .utils import validate_order_key +from utils import validate_order_key BASE_95_DIGITS = ' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~' @@ -45,7 +45,7 @@ def test_generate_key_between(a: Optional[str], b: Optional[str], expected) -> N if isinstance(expected, OrderKeyError): with pytest.raises(OrderKeyError) as e: generate_key_between(a, b) - assert e.value.args[0] == expected.args[0] + assert e.value.args[0].lower() == expected.args[0].lower() else: act = generate_key_between(a, b) print(f'exp: {expected}') @@ -54,6 +54,7 @@ def test_generate_key_between(a: Optional[str], b: Optional[str], expected) -> N assert act == expected + @pytest.mark.parametrize(['a', 'b', 'n', 'expected'], [ (None, None, 5, 'a0 a1 a2 a3 a4'), ('a4', None, 10, 'a5 a6 a7 a8 a9 b00 b01 b02 b03 b04'), @@ -94,7 +95,7 @@ def test_base95_digits(a: Optional[str], b: Optional[str], expected: str) -> Non if isinstance(expected, OrderKeyError): with pytest.raises(OrderKeyError) as e: generate_key_between(**kwargs) - assert e.value.args[0] == expected.args[0] + assert e.value.args[0].lower() == expected.args[0].lower() else: act = generate_key_between(**kwargs) print() @@ -145,13 +146,11 @@ def test_readme_examples_multiple_keys(): def test_readme_examples_validate_order_key(): - from fractional_indexing import validate_order_key, OrderKeyError - validate_order_key('a0') with pytest.raises(OrderKeyError) as e: validate_order_key('foo') - assert str(e.value) == 'invalid order key: foo' + assert str(e.value).lower() == 'invalid order key: foo'.lower() def test_readme_examples_custom_base(): diff --git a/utils.py b/utils.py index 3849008..b9c3442 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,7 @@ import decimal from typing import Optional -from .exceptions import OrderKeyError +from exceptions import OrderKeyError BASE_62_DIGITS: str = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'