Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache does not take into consideration variables passed into prompt_func / parse_func #148

Closed
RyanMarten opened this issue Nov 19, 2024 · 7 comments · Fixed by #167
Closed
Assignees

Comments

@RyanMarten
Copy link
Contributor

Simple examples:

from bespokelabs import curator

x = 10

def prompt_func():
    return f"'Say x is {x}.' Don't say anything else."

prompter = curator.Prompter(prompt_func=prompt_func, model_name="gpt-4o-mini")

print(prompter().to_pandas())

It will say x is 10.

Run it again and set x=20. It will hit the cache and still say x is 10.

@RyanMarten
Copy link
Contributor Author

Instead, let's use dill pickling, which includes the variables.

From https://huggingface.co/docs/datasets/en/about_cache

The hash is computed by dumping the object using a dill pickler and hashing the dumped bytes. The pickler recursively dumps all the variables used in your function, so any change you do to an object that is used in your function, will cause the hash to change.

@RyanMarten
Copy link
Contributor Author

from bespokelabs import curator
import dill
import hashlib
from xxhash import xxh64
import inspect
for x in [10, 20]:
    print(f"x: {x}")
    def prompt_func():
        return f"'Say x is {x}.' Don't say anything else."

    # https://github.com/bespokelabsai/curator/blob/main/src/bespokelabs/curator/prompter/prompter.py#L231
    current_hash = xxh64(inspect.getsource(prompt_func)).hexdigest()
    dill_hash = hashlib.md5(dill.dumps(prompt_func)).hexdigest()

    print(f"Current Hash: {current_hash}")
    print(f"Dill Hash: {dill_hash}")

Actually, dill does not include closure variables

x: 10
Current Hash: 2f7a00618810a082
Dill Hash: 6f31f77bf56705c46756875ac129d62f
x: 20
Current Hash: 2f7a00618810a082
Dill Hash: 6f31f77bf56705c46756875ac129d62f

@vutrung96
Copy link
Contributor

Unrelated note, dill is also sensitive to formatting

import hashlib
import inspect

import dill
from bespokelabs import curator
from xxhash import xxh64


def prompt_func():
    return "Say x is 10. Don't say anything else."

# https://github.com/bespokelabsai/curator/blob/main/src/bespokelabs/curator/prompter/prompter.py#L231
current_hash = xxh64(inspect.getsource(prompt_func)).hexdigest()
dill_hash = hashlib.md5(dill.dumps(prompt_func)).hexdigest()

print(f"Current Hash: {current_hash}")
print(f"Dill Hash: {dill_hash}")

def prompt_func():
    return    "Say x is 10. Don't say anything else."


current_hash = xxh64(inspect.getsource(prompt_func)).hexdigest()
dill_hash = hashlib.md5(dill.dumps(prompt_func)).hexdigest()

print(f"Current Hash: {current_hash}")
print(f"Dill Hash: {dill_hash}")

@vutrung96
Copy link
Contributor

HuggingFace datasets caching works with a similar example

import hashlib
import inspect

import dill
from xxhash import xxh64

from datasets import Dataset

ds = Dataset.from_dict({"x": [1, 2, 3]})
print(ds._fingerprint)
x = 20
def add_x(row):
    return {"x": row["x"] + x}

ds_1 = ds.map(add_x)
print(ds_1.to_pandas())
print(ds_1._fingerprint)

x = 10
def add_x(row):
    return {"x": row["x"] + x}
ds = Dataset.from_dict({"x": [1, 2, 3]})
ds_2 = ds.map(add_x)
print(ds_2.to_pandas())
print(ds_2._fingerprint)

@vutrung96
Copy link
Contributor

vutrung96 commented Nov 19, 2024

Another example where HF works (note how the ordering is swapped correctly)

import hashlib
import inspect

import dill
from xxhash import xxh64

from datasets import Dataset

ds = Dataset.from_dict({"x": [1, 2, 3]})

for x in [20, 10]:
    def add_x(row):
        return {"x": row["x"] + x}

    print(f"Dill Hash: {xxh64(dill.dumps(add_x)).hexdigest()}")


    ds_1 = ds.map(add_x)
    print(ds_1.to_pandas())
    print(ds_1._fingerprint)
(base) trung@trung-cpu:~/dcft_private$ python cache.py
Dill Hash: ea9257454344127e
Map: 100%|██████████████████████████████████████| 3/3 [00:00<00:00, 1106.00 examples/s]
    x
0  21
1  22
2  23
7665613465f05548
Dill Hash: ea9257454344127e
Map: 100%|██████████████████████████████████████| 3/3 [00:00<00:00, 1026.67 examples/s]
    x
0  11
1  12
2  13
ae42f566ba0d4462
import hashlib
import inspect

import dill
from xxhash import xxh64

from datasets import Dataset

ds = Dataset.from_dict({"x": [1, 2, 3]})

for x in [10, 20]:
    def add_x(row):
        return {"x": row["x"] + x}

    print(f"Dill Hash: {xxh64(dill.dumps(add_x)).hexdigest()}")


    ds_1 = ds.map(add_x)
    print(ds_1.to_pandas())
    print(ds_1._fingerprint)
(base) trung@trung-cpu:~/dcft_private$ python cache.py
Dill Hash: ea9257454344127e
Map: 100%|███████████████████████████████████████| 3/3 [00:00<00:00, 891.39 examples/s]
    x
0  11
1  12
2  13
ae42f566ba0d4462
Dill Hash: ea9257454344127e
Map: 100%|██████████████████████████████████████| 3/3 [00:00<00:00, 1086.04 examples/s]
    x
0  21
1  22
2  23
7665613465f05548

@CharlieJCJ
Copy link
Contributor

Same issue appears in #161

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants