Skip to content

Commit 80e53a2

Browse files
committed
unit tests passing
1 parent 6cddd43 commit 80e53a2

File tree

8 files changed

+60
-118
lines changed

8 files changed

+60
-118
lines changed

src/palimpzest/query/processor/__init__.py

Whitespace-only changes.

src/palimpzest/utils/model_helpers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
2-
from typing import List, Optional
32

43
from palimpzest.constants import Model
54

65

7-
def get_vision_models() -> List[Model]:
6+
def get_vision_models() -> list[Model]:
87
"""
98
Return the set of vision models which the system has access to based on the set of environment variables.
109
"""
@@ -18,7 +17,7 @@ def get_vision_models() -> List[Model]:
1817
return models
1918

2019

21-
def get_models(include_vision: Optional[bool] = False) -> List[Model]:
20+
def get_models(include_vision: bool = False) -> list[Model]:
2221
"""
2322
Return the set of models which the system has access to based on the set environment variables.
2423
"""

tests/pytest/test_cost_model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,8 @@ def test_compute_operator_estimates(
5252
)
5353
def test_estimate_plan_cost(self, simple_plan_sample_execution_data, physical_plan, expected_cost_est_results):
5454
# register a fake dataset
55-
dataset_id = "foobar"
5655
vals = [1, 2, 3, 4, 5, 6]
57-
DataDirectory().register_dataset(
58-
vals=vals,
59-
dataset_id=dataset_id,
60-
)
56+
DataDirectory().get_or_register_memory_source(vals=vals)
6157
input_cardinality = len(vals)
6258

6359
# TODO: if we test with a plan other than the simple test plan; this will break

tests/pytest/test_datasource.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
import pytest
21
from palimpzest.core.data.datasources import MemorySource
32
from palimpzest.core.elements.records import DataRecord
4-
from palimpzest.core.lib.schemas import Schema, List
3+
from palimpzest.core.lib.fields import Field
4+
from palimpzest.core.lib.schemas import Schema, SourceRecord
55
from palimpzest.query.operators.datasource import MarshalAndScanDataOp
6-
from palimpzest.core.lib.schemas import SourceRecord
6+
7+
8+
class List(Schema):
9+
value = Field(desc="List item")
10+
711

812
def test_marshal_and_scan_memory_source():
913
# Create test data

tests/pytest/test_datasources.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
# write tests for src/palimpzest/core/data/datasources.py
22

33
import os
4-
import pytest
4+
55
import pandas as pd
6+
import pytest
7+
68
from palimpzest.core.data.datasources import (
7-
FileSource,
8-
TextFileDirectorySource,
9+
FileSource,
10+
HTMLFileDirectorySource,
911
ImageFileDirectorySource,
1012
MemorySource,
11-
HTMLFileDirectorySource
13+
TextFileDirectorySource,
1214
)
13-
from palimpzest.core.lib.fields import ListField
1415
from palimpzest.core.elements.records import DataRecord
15-
from palimpzest.core.lib.schemas import List, Schema, Number, File, TextFile, WebPage, ImageFile
16+
from palimpzest.core.lib.schemas import File, TextFile, WebPage
17+
1618

1719
@pytest.fixture
1820
def temp_text_file():

tests/pytest/test_execution_no_cache.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
)
1818

1919

20+
@pytest.fixture
21+
def optimizer():
22+
return Optimizer(policy=MaxQuality(), cost_model=CostModel())
23+
24+
@pytest.fixture
25+
def config():
26+
return QueryProcessorConfig(nocache=True)
27+
28+
2029
@pytest.mark.parametrize(
2130
argnames=("query_processor",),
2231
argvalues=[
@@ -97,16 +106,7 @@ class TestParallelExecutionNoCache:
97106
],
98107
indirect=True,
99108
)
100-
101-
@pytest.fixture
102-
def optimizer(self):
103-
return Optimizer(policy=MaxQuality(), cost_model=CostModel())
104-
105-
@pytest.fixture
106-
def config(self):
107-
return QueryProcessorConfig(nocache=True)
108-
109-
def test_execute_full_plan(self, mocker, query_processor, dataset, optimizer, config, physical_plan, expected_records, side_effect):
109+
def test_execute_full_plan(self, mocker, query_processor, optimizer, config, dataset, physical_plan, expected_records, side_effect):
110110
"""
111111
This test executes the given
112112
"""

tests/pytest/test_physical.py

Lines changed: 19 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import os
44
import sys
5-
import pytest
5+
6+
from palimpzest.core.lib.fields import NumericField, StringField
67
from palimpzest.core.lib.schemas import Schema
78
from palimpzest.query.operators.physical import PhysicalOperator
8-
from palimpzest.core.lib.fields import StringField, NumericField
99

1010
sys.path.append("./tests/")
1111
sys.path.append("./tests/refactor-tests/")
@@ -16,11 +16,15 @@
1616
load_env()
1717

1818

19-
2019
class SimpleSchema(Schema):
2120
name = StringField(desc="The name of the person")
2221
age = NumericField(desc="The age of the person")
2322

23+
class SimpleSchemaTwo(Schema):
24+
name = StringField(desc="The name of the person")
25+
age = NumericField(desc="The age of the person")
26+
height = NumericField(desc="The height of the person")
27+
2428
def test_physical_operator_init():
2529
"""Test basic initialization of PhysicalOperator"""
2630

@@ -31,7 +35,7 @@ def test_physical_operator_init():
3135
logical_op_id="logical1",
3236
verbose=True
3337
)
34-
38+
3539
assert op.output_schema == SimpleSchema
3640
assert op.input_schema == SimpleSchema
3741
assert op.depends_on == ["op1", "op2"]
@@ -41,44 +45,44 @@ def test_physical_operator_init():
4145
def test_physical_operator_equality():
4246
"""Test equality comparison between PhysicalOperators"""
4347
schema1 = SimpleSchema()
44-
schema2 = SimpleSchema()
45-
48+
schema2 = SimpleSchemaTwo()
49+
4650
op1 = PhysicalOperator(output_schema=schema1)
4751
op2 = PhysicalOperator(output_schema=schema1)
4852
op3 = PhysicalOperator(output_schema=schema2, verbose=True)
49-
53+
5054
assert op1 == op2 # Same output schema
5155
assert op1 == op1 # Same instance
5256
assert op1 == op1.copy() # Copy should be equal
5357
assert op1 != op3 # Different parameters
5458

5559
def test_physical_operator_str():
5660
"""Test string representation of PhysicalOperator"""
57-
61+
5862
op = PhysicalOperator(
5963
output_schema=SimpleSchema,
6064
input_schema=SimpleSchema
6165
)
62-
66+
6367
str_rep = str(op)
6468
assert "SimpleSchema -> PhysicalOperator -> SimpleSchema" in str_rep
6569
assert "age, name" in str_rep
6670

6771
def test_physical_operator_id_generation():
6872
"""Test operator ID generation and hashing"""
6973
op = PhysicalOperator(output_schema=SimpleSchema)
70-
74+
7175
# Test that op_id is initially None
7276
assert op.op_id is None
73-
77+
7478
# Get op_id and verify it's generated
7579
op_id = op.get_op_id()
7680
assert op_id is not None
7781
assert isinstance(op_id, str)
78-
82+
7983
# Test that subsequent calls return the same id
8084
assert op.get_op_id() == op_id
81-
85+
8286
# Test that hash is based on op_id
8387
assert hash(op) == int(op_id, 16)
8488

@@ -91,72 +95,12 @@ def test_physical_operator_copy():
9195
logical_op_id="logical1",
9296
verbose=True
9397
)
94-
98+
9599
copied = original.copy()
96-
100+
97101
assert copied is not original # Different instances
98102
assert copied == original # But equal in content
99103
assert copied.get_op_id() == original.get_op_id() # Same op_id
100104
assert copied.depends_on == original.depends_on
101105
assert copied.logical_op_id == original.logical_op_id
102106
assert copied.verbose == original.verbose
103-
104-
105-
# TODO: uncomment once I understand what is supposed to be happening with
106-
# ParallelConvertFromCandidateOp and ParallelFilterCandidateOp (I don't
107-
# have these on my branch; possibly came from another branch)
108-
109-
# def test_convert(email_schema):
110-
# """Test the physical operators equality sign"""
111-
# remove_cache()
112-
113-
# params = {
114-
# "output_schema": email_schema,
115-
# "input_schema": File,
116-
# "model": pz.Model.GPT_4o_MINI,
117-
# "cardinality": "oneToOne",
118-
# }
119-
120-
# # simpleConvert = pz.Convert(**params)
121-
# parallelConvert = pz.ParallelConvertFromCandidateOp(**params, streaming="")
122-
# monolityhConvert = pz.ConvertOp(**params)
123-
124-
# assert parallelConvert == parallelConvert
125-
# assert monolityhConvert == monolityhConvert
126-
# assert parallelConvert != monolityhConvert
127-
128-
# print(str(parallelConvert))
129-
# print(str(monolityhConvert))
130-
131-
# a = parallelConvert.copy()
132-
# b = monolityhConvert.copy()
133-
# assert a == parallelConvert
134-
# assert b == monolityhConvert
135-
# assert a != b
136-
137-
# def test_filter(email_schema):
138-
# """Test the physical operators filter"""
139-
# remove_cache()
140-
141-
# params = {
142-
# "output_schema": email_schema,
143-
# "input_schema": email_schema,
144-
# "filter": pz.Filter("This is a sample filter"),
145-
# }
146-
147-
# # simpleConvert = pz.Convert(**params)
148-
# parallelFilter = pz.ParallelFilterCandidateOp(**params, streaming="")
149-
# monoFilter = pz.NonLLMFilter(**params)
150-
151-
# assert parallelFilter == parallelFilter
152-
# assert monoFilter == monoFilter
153-
# assert parallelFilter != monoFilter
154-
155-
# print(str(parallelFilter))
156-
# print(str(monoFilter))
157-
158-
# a = parallelFilter.copy()
159-
# b = monoFilter.copy()
160-
# assert a == parallelFilter
161-
# assert b == monoFilter
162-
# assert a != b

tests/pytest/test_rules.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import pytest
2-
from palimpzest.query.optimizer.rules import (
3-
PushDownFilter, NonLLMConvertRule, LLMConvertBondedRule,
4-
BasicSubstitutionRule, NonLLMFilterRule
5-
)
6-
from palimpzest.query.optimizer.primitives import LogicalExpression, Group
7-
from palimpzest.query.operators.logical import (
8-
ConvertScan, FilteredScan, BaseScan
9-
)
10-
from palimpzest.query.operators.filter import Filter
2+
113
from palimpzest.core.lib.schemas import Schema, StringField
4+
from palimpzest.query.operators.logical import BaseScan
5+
from palimpzest.query.optimizer.primitives import LogicalExpression
6+
from palimpzest.query.optimizer.rules import BasicSubstitutionRule
127

138

149
@pytest.fixture
@@ -30,24 +25,26 @@ def test_substitute_methods(base_scan_op):
3025
logical_expr = LogicalExpression(
3126
operator=base_scan_op,
3227
input_group_ids=[],
33-
input_fields=set(),
34-
generated_fields=set(["id", "text"]),
28+
input_fields={},
29+
generated_fields={"some_id": StringField(desc="id"), "text": StringField(desc="text")},
30+
depends_on_field_names=set(),
3531
group_id=1
3632
)
37-
33+
3834
# Apply the BasicSubstitutionRule
3935
physical_exprs = BasicSubstitutionRule.substitute(logical_expr, verbose=False)
40-
36+
4137
# Verify the substitution
4238
assert len(physical_exprs) == 1
4339
physical_expr = list(physical_exprs)[0]
44-
40+
4541
# Check that the operator was correctly converted to MarshalAndScanDataOp
4642
assert physical_expr.operator.__class__.__name__ == "MarshalAndScanDataOp"
47-
43+
4844
# Verify that the important properties were preserved
4945
assert physical_expr.operator.dataset_id == base_scan_op.dataset_id
5046
assert physical_expr.input_group_ids == logical_expr.input_group_ids
5147
assert physical_expr.input_fields == logical_expr.input_fields
5248
assert physical_expr.generated_fields == logical_expr.generated_fields
53-
assert physical_expr.group_id == logical_expr.group_id
49+
assert physical_expr.depends_on_field_names == logical_expr.depends_on_field_names
50+
assert physical_expr.group_id == logical_expr.group_id

0 commit comments

Comments
 (0)