Skip to content

Fix #708 Check fastmath flags #1068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 8, 2025

Conversation

NimaSarajpoor
Copy link
Collaborator

@NimaSarajpoor NimaSarajpoor commented Jan 28, 2025

This PR is to fix #708. An initial inspection was done by @seanlaw in this comment. I am copying that list here for transparency and better tracking.

  • aamp._compute_diagonal - P, PL, PR contains np.inf and p can be np.inf
  • aamp._aamp - P, PL, PR contains np.inf
  • core._sliding_dot_product - Should be okay
  • core._calculate_squared_distance_profile - Should be okay (Returned value in D_squared might be np.inf but no arithmetic operation)
  • core.calculate_distance_profile - Should be okay (Returned value in D_squared might be np.inf)
  • core._p_norm_distance_profile - Should be okay (can p be np.inf? not supported. See: Add support for p=np.inf for non-normalized p-norm distance #1071 )
  • core._mass Should be okay
  • core._apply_exclusion_zone - val contains np.inf
  • core._count_diagonal_ndist - Should be okay
  • core._get_array_ranges - Should be okay
  • core._get_ranges - Should be okay
  • core._total_diagonal_ndists - Should be okay
  • fastmath._add_assoc - Should be okay
  • maamp._compute_multi_p_norm - p could possibly be np.inf (not supported. See: Add support for p=np.inf for non-normalized p-norm distance #1071). p_norm array can contain np.inf
  • mstump._compute_multi_D - Might be okay??
  • scraamp._compute_PI - P_NORM contains np.inf
  • scraamp._prescraamp - P_NORM contains np.inf
  • scrump._compute_PI - references np.inf values, so likely bad
  • scrump._prescrump - P_squared is np.inf
  • stump._compute_diagonal - ρ, ρL, and ρR contain np.inf
  • stump._stump - ρ, ρL, and ρR contain np.inf

@NimaSarajpoor NimaSarajpoor changed the title Fix #708 Fix #708 Check fastmath flags Jan 28, 2025
Copy link

codecov bot commented Jan 28, 2025

Codecov Report

Attention: Patch coverage is 1.78571% with 110 lines in your changes missing coverage. Please review.

Project coverage is 96.62%. Comparing base (9504301) to head (8520288).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
fastmath.py 0.00% 110 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1068      +/-   ##
==========================================
- Coverage   97.31%   96.62%   -0.70%     
==========================================
  Files          93       93              
  Lines       15239    15376     +137     
==========================================
+ Hits        14830    14857      +27     
- Misses        409      519     +110     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@NimaSarajpoor
Copy link
Collaborator Author

After going through core._p_norm_distance_profile, I noticed that we should not set p to np.inf. In fact, we should not set p to np.inf in all non-normalized functions. This is because we cannot follow the usual calculation for Minkowski distance when p is np.inf as that is a limiting case. If we ignore that "limit", we can see the issue in the following example:

arr = np.array([0.9, 1.0, 1.1])
out = np.power(arr, np.inf)  # [0.0, 1.0, np.inf]

Any value less than 1.0 will be 0.0, and any value above 1.0 will be np.inf. And that will give us wrong output. So, we can add a note to the docstring that p=np.inf is not supported.

Maybe we open another issue for this case, known as Chebyshev distance. Then, we can think if we should add support for it. The distance can be computed via rolling-max approach. Not sure if it can be added to the current code base without hurting the design.


For now, we should not consider p=np.inf when we want to make a decision about fastmath flag. And we can revise the flag later if needed once the function starts supporting the case p=np.inf

@seanlaw
Copy link
Contributor

seanlaw commented Feb 2, 2025

@NimaSarajpoor I had noticed it as well when I looked the other day. I agree, we should not allow p = np.inf for now and simply ignore it (after adding a note to the docstring(s))

@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw

@NimaSarajpoor I had noticed it as well when I looked the other day. I agree, we should not allow p = np.inf for now and simply ignore it (after adding a note to the docstring(s))

Created the issue #1071

@seanlaw
Copy link
Contributor

seanlaw commented Feb 7, 2025

@NimaSarajpoor Is this ready to be merged?

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Feb 9, 2025

@seanlaw

@NimaSarajpoor Is this ready to be merged?

Not yet. I am trying to get chain of caller-callees.

I have a script that can give a dictionary with key as (module_name, func_nam), and value as list of callees (some of the code are based on the work I did initially in #1025). The code is provided below. Do you think we should add the code to STUMPY? In that case, I need to clean this up and push it. If not, I can just use it locally to get a chain of caller-callees and list them here. And then check each chain regarding fastmath flag. The reason that I am a bit hesitant to add it to STUMPY is that I am not sure if we are going to make any exceptions. The goal is to check for a switch from fastmath == True to fastmath != True (or vice versa) in caller-callee chain. But, if we decide to allow that switching happens in some cases, then adding a check to STUMPY may not make sense.

Code
# in STUMPY's root directory 

import ast
import importlib
import pathlib


def _get_func_callees(node, so_far_callees):
    for n in ast.iter_child_nodes(node):
        if isinstance(n, ast.Call):
            obj = n.func
            if isinstance(obj, ast.Attribute):  # e.g., np.sum
                name = obj.attr 
            elif isinstance(obj, ast.Name):  # e.g., sum
                name = obj.id
            else:
                msg = f"The type {type(obj)} is not supported"
                raise ValueError(msg)

            so_far_callees.append(name)

        _get_func_callees(n, so_far_callees)


def get_func_callees(func_node):
    """
    For a given node of type ast.FunctionDef, visit all of its child nodes,
    and return a list of all of its callees
    """
    out = []
    _get_func_callees(func_node, so_far_callees=out)

    return out


def get_func_nodes(filepath):
    """
    For the given `filepath`, return a dictionary with the key
    being the function name and the value being a set of function names
    that are called by the function
    """
    file_contents = ""
    with open(filepath, encoding="utf8") as f:
        file_contents = f.read()
    module = ast.parse(file_contents)

    func_nodes = [
        node for node in module.body if isinstance(node, ast.FunctionDef)
    ]

    return func_nodes


def get_callees():
    ignore = ["__init__.py", "__pycache__"]

    stumpy_path = pathlib.Path(__file__).parent / "stumpy"
    filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file())

    all_callees = {}
    for filepath in filepaths:
        file_name = filepath.name
        if (
            file_name not in ignore 
            and not file_name.startswith("gpu")
            and str(filepath).endswith(".py")
        ):
            module_name = file_name.replace(".py", "")
            module = importlib.import_module(f".{module_name}", package="stumpy")
            
            func_nodes = get_func_nodes(filepath)
            for node in func_nodes:
                all_callees[(module_name, node.name)] = get_func_callees(node)

    
    # clean all_callees to only include callees that are in stumpy
    all_stumpy_funcs = set(item[1] for item in all_callees.keys())

    out = {}
    for (module_name, func_name), callees in all_callees.items():
        lst = []
        for callee in callees:
            if callee in all_stumpy_funcs:
                lst.append(callee)
        out[(module_name, func_name)] = lst

    return out


out = get_callees()
print(out)

@NimaSarajpoor NimaSarajpoor mentioned this pull request Feb 9, 2025
59 tasks
@seanlaw
Copy link
Contributor

seanlaw commented Feb 10, 2025

The reason that I am a bit hesitant to add it to STUMPY is that I am not sure if we are going to make any exceptions. The goal is to check for a switch from fastmath == True to fastmath != True (or vice versa) in caller-callee chain. But, if we decide to allow that switching happens in some cases, then adding a check to STUMPY may not make sense.

At this point, I don't anticipate allowing any exceptions. Since this is something tedious, I would prefer to automate it and add it to test.sh like all of the other checks. Naturally, I think we would now add this to the fastmath.py script, right?

One thing to consider is that ast may not easily allow you to traverse across different Python modules

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Feb 11, 2025

@seanlaw

Since this is something tedious, I would prefer to automate it and add it to test.sh like all of the other checks

👍

Naturally, I think we would now add this to the fastmath.py script, right?

Right. This should be placed in ./fastmath.py, and then it can be used in the testing process via test.sh

One thing to consider is that ast may not easily allow you to traverse across different Python modules

Currently I am not trying to jump between modules. What I do is that I collect one-level-deep callees of ALL stumpy functions. that's all I need to create chain for a given caller. However, if I can find a tool that can jump from one module to different module, then finding chain should become easier. Going to look for it.

@seanlaw
Copy link
Contributor

seanlaw commented Feb 11, 2025

Currently I am not trying to jump between modules. What I do is that I collect one-level-deep callees of ALL stumpy functions. that's all I need to create chain for a given caller. However, if I can find a tool that can jump from one module to different module, then finding chain should become easier. Going to look for it.

I have some ideas and will be able to share them soon

@seanlaw
Copy link
Contributor

seanlaw commented Feb 11, 2025

@NimaSarajpoor While quite verbose, I believe that this will work nicely to generate a list of njit call stacks that is able to jump ACROSS modules:

import fastmath
import ast
import pathlib

class FunctionCallVisitor(ast.NodeVisitor):
    def __init__(self):
        super().__init__()
        self.module_names = []
        self.call_stack = []
        self.last_depth = 0
        self.out = []

        # Setup lists, dicts, and ast objects
        self.njit_funcs = fastmath.get_njit_funcs('stumpy')
        self.njit_modules = set(mod_name for mod_name, func_name in self.njit_funcs)
        self.njit_nodes = {}
        self.ast_modules = {}
        
        stumpy_path = pathlib.Path('__file__').parent / "stumpy"
        filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file())
        ignore = ["__init__.py", "__pycache__"]
        
        for filepath in filepaths:
            file_name = filepath.name
            if (
                file_name not in ignore 
                and not file_name.startswith("gpu")
                and str(filepath).endswith(".py")
            ):
                module_name = file_name.replace(".py", "")
                file_contents = ""
                with open(filepath, encoding="utf8") as f:
                    file_contents = f.read()
                self.ast_modules[module_name] = ast.parse(file_contents)
        
                for node in self.ast_modules[module_name].body:
                    if isinstance(node, ast.FunctionDef):
                        func_name = node.name
                        if (module_name, func_name) in self.njit_funcs:
                            self.njit_nodes[f'{module_name}.{func_name}'] = node


    def push_module(self, module_name):
        self.module_names.append(module_name)
        
    def pop_module(self):
        if self.module_names:
            self.module_names.pop()

    def push_call_stack(self, module_name, func_name):
        self.call_stack.append(f'{module_name}.{func_name}')
        
    def pop_call_stack(self):
        if self.call_stack:
            self.call_stack.pop()

    def goto_deeper_func(self, node):
        self.generic_visit(node)

    def goto_next_func(self, node):
        self.generic_visit(node)

    def push_out(self):
        unique = True
        for cs in self.out:
            if ' '.join(self.call_stack) in ' '.join(cs):
                unique = False
                break

        if unique:
            self.out.append(self.call_stack.copy())
        
    def visit_Call(self, node):
        callee_name = ast.unparse(node.func)

        if "." in callee_name:
            new_module_name, new_func_name = callee_name.split('.')[:2]

            if new_module_name in self.njit_modules:
                self.push_module(new_module_name)
        else:
            if self.module_names:
                new_module_name = self.module_names[-1]
                new_func_name = callee_name
                callee_name = f'{new_module_name}.{new_func_name}'

        if callee_name in self.njit_nodes.keys():
            callee_node = self.njit_nodes[callee_name]
            self.push_call_stack(new_module_name, new_func_name)
            self.goto_deeper_func(callee_node)
            self.pop_module()  # This line should be deleted!!!  See comments below
            self.push_out()
            self.pop_call_stack()

        self.goto_next_func(node)


def get_njit_call_stacks():
    visitor = FunctionCallVisitor()
    
    for module_name in visitor.njit_modules:
        visitor.push_module(module_name)
        
        for node in visitor.ast_modules[module_name].body:
            if isinstance(node, ast.FunctionDef):
                func_name = node.name
                if (module_name, func_name) in visitor.njit_funcs:
                    visitor.push_call_stack(module_name, func_name)
                    visitor.visit(node)
                    visitor.pop_call_stack()

        visitor.pop_module()

    return visitor.out


if __name__ == '__main__':
    for cs in get_njit_call_stacks():
        print(cs)

The output should be:

['core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['maamp._compute_multi_p_norm', 'core._apply_exclusion_zone']
['stump._compute_diagonal', 'core._shift_insert_at_index']
['stump._stump', 'core._count_diagonal_ndist']
['stump._stump', 'core._get_array_ranges']
['stump._stump', 'stump._compute_diagonal', 'core._shift_insert_at_index']
['stump._stump', 'core._merge_topk_ρI']
['aamp._compute_diagonal', 'core._shift_insert_at_index']
['aamp._aamp', 'core._count_diagonal_ndist']
['aamp._aamp', 'core._get_array_ranges']
['aamp._aamp', 'aamp._compute_diagonal', 'core._shift_insert_at_index']
['aamp._aamp', 'core._merge_topk_PI']
['scraamp._compute_PI', 'core._p_norm_distance_profile', 'core._sliding_dot_product']
['scraamp._compute_PI', 'core._apply_exclusion_zone']
['scraamp._compute_PI', 'core._shift_insert_at_index']
['scraamp._prescraamp', 'core._get_ranges']
['scraamp._prescraamp', 'core._merge_topk_PI']
['mstump._compute_multi_D', 'core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['mstump._compute_multi_D', 'core._apply_exclusion_zone']
['scrump._compute_PI', 'core._sliding_dot_product']
['scrump._compute_PI', 'core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['scrump._compute_PI', 'core._apply_exclusion_zone']
['scrump._compute_PI', 'core._shift_insert_at_index']
['scrump._prescrump', 'core._get_ranges']
['scrump._prescrump', 'core._merge_topk_PI']

One immediate observation is that our njit call stacks are very, very flat/shallow, which is a GREAT thing! In my head, I was dreading to find that we have very deeply nested call stacks and so this is a pleasant surprise. Please verify that we haven't missed any edge cases.

Hopefully, you are able to take these call stacks and check the fastmath flags accordingly. Please let me know if you have any questions.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Feb 12, 2025

While quite verbose, I believe that this will work nicely to generate a list of njit call stacks that is able to jump ACROSS modules:

Please verify that we haven't missed any edge cases.

Hopefully, you are able to take these call stacks and check the fastmath flags accordingly

Thanks for sharing it!! I will go through it. I can compare its output with the one I get when I use my own script (a script that I used locally on top of the script I shared in my previous comment to get caller-callee chains). And then I will work on checking the fastmath flags.

One immediate observation is that our njit call stacks are very, very flat/shallow, which is a GREAT thing!

YES! I noticed it too when I obtained the caller-callee chains. 👍

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Mar 25, 2025

Using the code provided in #1068 (comment), I was able to detect call stacks where the number of "fastmath flags" variants is more than one.

if __name__ == '__main__':
    for cs in get_njit_call_stacks():
        cs_fastmath_flag = []
        for f in cs:
            mod_name, func_name = f.split('.')
            module = importlib.import_module(f".{mod_name}", package='stumpy')
            func = getattr(module, func_name)
            flag = func.targetoptions['fastmath']
            cs_fastmath_flag.append(flag)

        outer_flag = cs_fastmath_flag[0]
        if not all(flag == outer_flag for flag in cs_fastmath_flag):
            lst = [(f, flag) for f, flag in zip(cs, cs_fastmath_flag)]
            print('Inconsistent fastmath flags in a call stack (top-to-bottom):')
            print(*lst, sep='\n')
            print('='*80)  

Running the code in the branch check_njit_fastmath results in the following outputs:

Inconsistent fastmath flags in a call stack (top-to-bottom):
('aamp._aamp', {'contract', 'nsz', 'afn', 'arcp', 'reassoc'})
('core._count_diagonal_ndist', True)
================================================================================
Inconsistent fastmath flags in a call stack (top-to-bottom):
('aamp._aamp', {'contract', 'nsz', 'afn', 'arcp', 'reassoc'})
('core._get_array_ranges', True)
================================================================================
Inconsistent fastmath flags in a call stack (top-to-bottom):
('scrump._compute_PI', {'contract', 'nsz', 'afn', 'arcp', 'reassoc'})
('core._sliding_dot_product', True)
================================================================================
Inconsistent fastmath flags in a call stack (top-to-bottom):
('scrump._prescrump', {'contract', 'nsz', 'afn', 'arcp', 'reassoc'})
('core._get_ranges', True)
================================================================================
Inconsistent fastmath flags in a call stack (top-to-bottom):
('stump._stump', {'contract', 'nsz', 'afn', 'arcp', 'reassoc'})
('core._count_diagonal_ndist', True)
================================================================================
Inconsistent fastmath flags in a call stack (top-to-bottom):
('stump._stump', {'contract', 'nsz', 'afn', 'arcp', 'reassoc'})
('core._get_array_ranges', True)
================================================================================
Inconsistent fastmath flags in a call stack (top-to-bottom):
('scraamp._compute_PI', {'contract', 'nsz', 'afn', 'arcp', 'reassoc'})
('core._p_norm_distance_profile', True)
('core._sliding_dot_product', True)
================================================================================
Inconsistent fastmath flags in a call stack (top-to-bottom):
('scraamp._prescraamp', {'contract', 'nsz', 'afn', 'arcp', 'reassoc'})
('core._get_ranges', True)
================================================================================

I will add the code to ./fastmath.py and run it via test.sh so that we can see the error. In most of the cases above, the callee function seems to not be computationally heavy. So, it should be okay to replace fastmath True with a set of flags that matches the fastmath flag of the very first caller. There are a couple of cases where callee is _sliding_dot_product. Changing its fastmath flag may affect the performance. Will check it.

In addition, I will use the code in #1025 locally to check the output of get_njit_call_stacks().

@NimaSarajpoor
Copy link
Collaborator Author

In addition, I will use the code in #1025 locally to check the output of get_njit_call_stacks().

I used the ugly code in #1025 to see if it can provide a njit caller-callee chain that is not in the output of get_njit_call_stacks(). I got the following list.

['core.calculate_distance_profile', 'core._calculate_squared_distance_profile']
['core._mass', 'core.calculate_distance_profile']
['core._update_incremental_PI', 'core._apply_exclusion_zone']
['core._update_incremental_PI', 'core._shift_insert_at_index']
['scrump._prescrump', 'scrump._compute_PI']
['scraamp._prescraamp', 'scraamp._compute_PI']

Note that this is just one-level deep, meaning it ignores the callee(s) of a callee. I will find and fix the root cause of such discrepancy.

@seanlaw
Copy link
Contributor

seanlaw commented Mar 29, 2025

I got the following list.

I think I've spotted a small bug in #1068 (comment). In the method visit_Call,:

        if callee_name in self.njit_nodes.keys():
            callee_node = self.njit_nodes[callee_name]
            self.push_call_stack(new_module_name, new_func_name)
            self.goto_deeper_func(callee_node)
            self.pop_module()  ###### THIS LINE SHOULD BE REMOVED!!!
            self.push_out()
            self.pop_call_stack()

If you delete the offending self.pop_module() line then you should get:

['scrump._compute_PI', 'core._sliding_dot_product']
['scrump._compute_PI', 'core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['scrump._compute_PI', 'core._apply_exclusion_zone']
['scrump._compute_PI', 'core._shift_insert_at_index']
['scrump._prescrump', 'core._get_ranges']
['scrump._prescrump', 'core._merge_topk_PI']
['maamp._compute_multi_p_norm', 'core._apply_exclusion_zone']
['scraamp._compute_PI', 'core._p_norm_distance_profile', 'core._sliding_dot_product']
['scraamp._compute_PI', 'core._apply_exclusion_zone']
['scraamp._compute_PI', 'core._shift_insert_at_index']
['scraamp._prescraamp', 'core._get_ranges']
['scraamp._prescraamp', 'core._merge_topk_PI']
['core.calculate_distance_profile', 'core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['core._mass', 'core.calculate_distance_profile', 'core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['core._update_incremental_PI', 'core._apply_exclusion_zone']
['core._update_incremental_PI', 'core._shift_insert_at_index']
['mstump._compute_multi_D', 'core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['mstump._compute_multi_D', 'core._apply_exclusion_zone']
['aamp._compute_diagonal', 'core._shift_insert_at_index']
['aamp._aamp', 'core._count_diagonal_ndist']
['aamp._aamp', 'core._get_array_ranges']
['aamp._aamp', 'core._merge_topk_PI']
['stump._compute_diagonal', 'core._shift_insert_at_index']
['stump._stump', 'core._count_diagonal_ndist']
['stump._stump', 'core._get_array_ranges']
['stump._stump', 'core._merge_topk_ρI']

I hope that helps.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Mar 30, 2025

Removing the line self.pop_module() did help with covering some of those missing cases. There are four cases that are missing now:

['aamp._aamp', 'aamp._compute_diagonal']  # NEW
['stump._stump', 'stump._compute_diagonal']  # NEW
['scrump._prescrump', 'scrump._compute_PI']
['scraamp._prescraamp', 'scraamp._compute_PI']

Regarding the first two cases, I noticed that the code now considers the module core for _compute_diagonal. So, it checks if core._compute_diagonal is one of the njit functions IIUC, the reason behind this new behaviour is that as we go deeper in a callee, the module name is appended to list via method push_module, and when we want to go to the next callee, the latest module name obtained from that list is not the correct one.


Potential solution
I noticed that if we pop only when we push, we should be good:

def visit_Call(self, node):
    callee_name = ast.unparse(node.func)

    flag_push = False    #### NEW
    if "." in callee_name:
        new_module_name, new_func_name = callee_name.split('.')[:2]

        if new_module_name in self.njit_modules:
            self.push_module(new_module_name)
            flag_push = True    #### NEW
    else:
        if self.module_names:
            new_module_name = self.module_names[-1]
            new_func_name = callee_name
            callee_name = f'{new_module_name}.{new_func_name}'
    
    if callee_name in self.njit_nodes.keys():
        callee_node = self.njit_nodes[callee_name]
        self.push_call_stack(new_module_name, new_func_name)
        
        self.goto_deeper_func(callee_node)
        if flag_push:    #### NEW
            self.pop_module()
        self.push_out()
        self.pop_call_stack()
    
    self.goto_next_func(node)

@seanlaw
Btw, could you please give me a hint on how visit_Call is used in the script? Are we overwriting a method of a parent class? Couldn't find anything regarding visit_Call in the ast's doc.

I think I've found the answer to my question:

visit(node)
Visit a node. The default implementation calls the method called self.visit_classname where classname is the name of the node class

@seanlaw
Copy link
Contributor

seanlaw commented Mar 30, 2025

I noticed that if we pop only when we push, we should be good:

Yeah, pushing and popping should be a paired action so as long as we keep track of them correctly (i.e., without mis-counting) then I think that is the right thing to do.

@seanlaw
Copy link
Contributor

seanlaw commented Mar 30, 2025

Yeah, pushing and popping should be a paired action so as long as we keep track of them correctly (i.e., without mis-counting) then I think that is the right thing to do.

Instead of calling it flag_push, I think it is better to call it module_changed so that it will read:

if module_changed:
    self.pop_module()

@seanlaw
Copy link
Contributor

seanlaw commented Apr 5, 2025

@NimaSarajpoor How are things going here? I have two questions:

  1. Is our AST traversal function "good" now (i.e., sufficient to traverse all njit call stacks? Did anything major change in the class or class methods?
  2. How many config.STUMPY_FASTMATH_TRUE do we have left after we've made all of the changes? My thinking is that at some point, if there are too few left, should we simply make everything config.STUMPY_FASTMATH_FLAGS instead?

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Apr 6, 2025

How are things going here?

No particular concern from my side. The script you shared helped a lot!!

Is our AST traversal function "good" now (i.e., sufficient to traverse all njit call stacks?

Noticed no missing call stacks. Will check again and provide an update if I notice otherwise.

Did anything major change in the class or class methods?

Not really. There were only two changes:

  • Added the boolean variable module_changed to track a push to module_names, and only pop when there was a push
  • Removed self.last_depth = 0 since it is not used in the script

How many config.STUMPY_FASTMATH_TRUE do we have left after we've made all of the changes? My thinking is that at some point, if there are too few left, should we simply make everything config.STUMPY_FASTMATH_FLAGS instead

Good point. I checked and there are currently 7 functions with fastmath=True in the branch check_njit_fastmath.

core._sliding_dot_product
core._p_norm_distance_profile
core._mass
core._count_diagonal_ndist
core._get_array_ranges
core._get_ranges
core._total_diagonal_ndists

I checked the first three cases, and noticed < 5% performance hit after changing the flag to config.STUMPY_FASTMATH_FLAGS. The results are shown below. A few notes:

(1) The y-value is a ratio, which is "the running time when fastmath=config.STUMPY_FASTMATH_TRUE" divided by "the running time when fastmath=config.STUMPY_FASTMATH_FLAGS". So, a data point with y-value < 1 means a performance hit.

(2) In all cases: len(Q)==64

(3) Regarding core._mass: After changing its flag, the performance hit was measured for core.mass

_sliding_dot_product_plot
_p_norm_distance_profile
mass_plot

I used the following code:

import importlib
import numpy as np
import time

from matplotlib import pyplot as plt
from stumpy import cache, config, core, fastmath
from stumpy.core import  _sliding_dot_product


def measure_running_time(func, n_T_values, n_Q=64, n_iter=1000):
    """
    func: A function that accepts Q and T
    n_T_values: an array that contains `n_T`, i.e. len(T), values
    """
    seed = 0
    np.random.seed(seed)
    
    running_times = np.zeros(len(n_T_values), dtype=np.float64)
    for i, n_T in enumerate(n_T_values):
        T = np.random.rand(n_T)
        Q = np.random.rand(n_Q)
        
        func(Q, T)  # dummy call 

        lst = []
        for _ in range(n_iter):
            tic = time.time()    
            func(Q, T)
            toc = time.time()
            lst.append(toc - tic)
        
        running_times[i] = np.mean(lst)
    
    return running_times

### inputs
n_T_values = np.arange(1, 10 + 1) * 10000
n_Q = 64
n_iter = 1000

module_name = 'core'
func_name = '_sliding_dot_product' 
njit_func = '_sliding_dot_product'
# NOTE: For mass, set `func_name` to `mass`, and set the `njit_func` to `_mass`

### import function 
module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(core, func_name)

### default
config._reset()
cache._recompile()
running_times_default = measure_running_time(func, n_T_values, n_Q=n_Q, n_iter=n_iter)
        

### change fastmath flag to `config.STUMPY_FASTMATH_FLAGS`
fastmath._set(module_name, njit_func, config.STUMPY_FASTMATH_FLAGS)
cache._recompile()
running_times_new = measure_running_time(func, n_T_values, n_Q=n_Q, n_iter=n_iter)

### plot
plt.figure(figsize=(20, 5))
plt.title(f'running time ratio: default/new \n for `{module_name}.{func_name}`')

r = running_times_default / running_times_new
r[r > 1.05] = 1.05
r[r < 0.95] = 0.95
plt.plot(np.arange(len(n_T_values)), r, marker='o')
plt.axhline(1, color='red', linestyle='--')
plt.ylabel('Running time Ratio (default / new)')

plt.xticks(ticks=np.arange(len(n_T_values)), labels=[str(i//1000) + '_000' for i in n_T_values])
plt.xlabel("n_T")
plt.grid()
plt.show()

@seanlaw
Copy link
Contributor

seanlaw commented Apr 6, 2025

Please let me know if this is ready to be merged

@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw

Please let me know if this is ready to be merged

My thinking is that at some point, if there are too few left, should we simply make everything config.STUMPY_FASTMATH_FLAGS instead?

According to the results provided in my previous comment, I think it should be okay to change config.STUMPY_FASTMATH_TRUE to config.STUMPY_FASTMATH_FLAGS. Do you have any concern regarding that?

@seanlaw
Copy link
Contributor

seanlaw commented Apr 6, 2025

Do you have any concern regarding that?

Not really. However, wouldn't we need to remove a lot of unused config.STUMPY_FASTMATH_TRUE code? I would say just leave it for now.

@NimaSarajpoor
Copy link
Collaborator Author

Github Actions raised this error:

_______________ test_mpdist_snippets_s_with_isconstant[3-3-9-T0] _______________

T = array([ 293.67128677,  308.19248823,  -90.46204887, -484.08056803,
       -815.73526712, -404.95991025, -591.06946033,...15185, -230.50234171, -360.47856789,  -41.299737  ,
       -955.80732529,   47.51877429,  372.92107172, -390.49869207])
m = 9, k = 3, s = 3

    @pytest.mark.parametrize("T", test_data)
    @pytest.mark.parametrize("m", m)
    @pytest.mark.parametrize("k", k)
    @pytest.mark.parametrize("s", s)
    def test_mpdist_snippets_s_with_isconstant(T, m, k, s):
        isconstant_custom_func = functools.partial(
            naive.isconstant_func_stddev_threshold, quantile_threshold=0.05
        )
        (
            ref_snippets,
            ref_indices,
            ref_profiles,
            ref_fractions,
            ref_areas,
            ref_regimes,
        ) = naive.mpdist_snippets(
            T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func
        )
        (
            cmp_snippets,
            cmp_indices,
            cmp_profiles,
            cmp_fractions,
            cmp_areas,
            cmp_regimes,
        ) = snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
    
>       npt.assert_almost_equal(
            ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
        )

tests/test_snippets.py:211: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py:81: in inner
    return func(*args, **kwds)
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py:81: in inner
    return func(*args, **kwds)
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: in wrapper
    return fun(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_array_almost_equal.<locals>.compare at 0x10ce03a60>, array([[-336.77194671, -854.63885204, -784.0031....1080978 ,   92.0933612 ,
        -213.1082644 ,  684.15146167,  639.03341797,  221.1559157 ,
          70.8424083 ]]))
kwds = {'err_msg': '', 'header': 'Arrays are not almost equal to 5 decimals', 'precision': 5, 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Arrays are not almost equal to 5 decimals
E           
E           Mismatched elements: 9 / 27 (33.3%)
E           Max absolute difference among violations: 1230.1028783
E           Max relative difference among violations: 7.87042798
E            ACTUAL: array([[-336.77195, -854.63885, -784.00315, -230.50234, -360.47857,
E                    -41.29974, -955.80733,   47.51877,  372.92107],
E                  [ 552.12966,  652.35062,  -73.57151,  601.59769, -180.18051,...
E            DESIRED: array([[-336.77195, -854.63885, -784.00315, -230.50234, -360.47857,
E                    -41.29974, -955.80733,   47.51877,  372.92107],
E                  [ 552.12966,  652.35062,  -73.57151,  601.59769, -180.18051,...

/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py:81: AssertionError
======================== 1 failed, 146 passed in 35.82s ========================
Error: Test execution encountered exit code 1

A similar error was previously reported in #1061 for the same test function. Going to add the error above to that issue. I think it should be okay to re-run the Github Actions here.

@seanlaw ?

@NimaSarajpoor
Copy link
Collaborator Author

[Update]
The PR should be ready to merge.
Btw, checked the output of the function get_njit_call_stacks again, and noticed no issue.

@seanlaw seanlaw merged commit c707ad9 into stumpy-dev:main Apr 8, 2025
44 of 54 checks passed
@seanlaw
Copy link
Contributor

seanlaw commented Apr 8, 2025

Thank you @NimaSarajpoor for getting this over the finish line!

@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw
Thank you for the support, and the lessons I learned!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Check fastmath=True
2 participants