Skip to content

Commit

Permalink
Merge pull request #34 from lsst/tickets/DM-39756-revert
Browse files Browse the repository at this point in the history
DM-39756: Revert "Use importlib.resources"
  • Loading branch information
timj authored Jul 8, 2023
2 parents 71c565e + fc27b25 commit c1f99d0
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 94 deletions.
24 changes: 12 additions & 12 deletions python/lsst/alert/packet/bin/validateAvroRoundTrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ def main():
schema_major, schema_minor = lsst.alert.packet.get_latest_schema_version()
else:
schema_major, schema_minor = args.schema_version.split(".")

with lsst.alert.packet.get_schema_path(schema_major, schema_minor) as schema_root:
alert_schema = lsst.alert.packet.Schema.from_file(
os.path.join(schema_root,
schema_filename(schema_major, schema_minor)),
)
if args.input_data:
input_data = args.input_data
else:
input_data = os.path.join(schema_root, "sample_data", SAMPLE_FILENAME)
with open(input_data) as f:
json_data = json.load(f)
schema_root = lsst.alert.packet.get_schema_path(schema_major, schema_minor)

alert_schema = lsst.alert.packet.Schema.from_file(
os.path.join(schema_root,
schema_filename(schema_major, schema_minor)),
)
if args.input_data:
input_data = args.input_data
else:
input_data = os.path.join(schema_root, "sample_data", SAMPLE_FILENAME)
with open(input_data) as f:
json_data = json.load(f)

# Load difference stamp if included
stamp_size = 0
Expand Down
45 changes: 15 additions & 30 deletions python/lsst/alert/packet/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,21 @@
"""Routines for working with Avro schemas.
"""

import contextlib
import io
import os.path
import pkg_resources
from pathlib import PurePath
from importlib import resources

import fastavro

__all__ = ["get_schema_root", "get_latest_schema_version", "get_schema_path",
"Schema", "get_path_to_latest_schema"]


def _get_ref(*args):
"""Return the package resource file path object.
Parameters are relative to lsst.alert.packet.
"""
return resources.files("lsst.alert.packet").joinpath(*args)


@contextlib.contextmanager
def get_schema_root():
"""Return the root of the directory within which schemas are stored.
Returned as a context manager yielding the path to the root.
"""
with resources.as_file(_get_ref("schema")) as f:
yield str(f)
return pkg_resources.resource_filename(__name__, "schema")


def get_latest_schema_version():
Expand All @@ -63,14 +50,12 @@ def get_latest_schema_version():
The minor version number.
"""
with _get_ref("schema", "latest.txt").open("rb") as fh:
val = fh.read()
val = pkg_resources.resource_string(__name__, "schema/latest.txt")
clean = val.strip()
major, minor = clean.split(b".", 1)
return int(major), int(minor)


@contextlib.contextmanager
def get_schema_path(major, minor):
"""Get the path to a package resource directory housing alert schema
definitions.
Expand All @@ -88,11 +73,13 @@ def get_schema_path(major, minor):
Path to the directory containing the schemas.
"""
with resources.as_file(_get_ref("schema", str(major), str(minor))) as f:
yield str(f)

# Note that as_posix() is right here, since pkg_resources
# always uses slash-delimited paths, even on Windows.
path = PurePath(f"schema/{major}/{minor}/")
return pkg_resources.resource_filename(__name__, path.as_posix())


@contextlib.contextmanager
def get_path_to_latest_schema():
"""Get the path to the primary schema file for the latest schema.
Expand All @@ -103,8 +90,8 @@ def get_path_to_latest_schema():
"""

major, minor = get_latest_schema_version()
with get_schema_path(major, minor) as schema_path:
yield (PurePath(schema_path) / f"lsst.v{major}_{minor}.alert.avsc").as_posix()
schema_path = PurePath(get_schema_path(major, minor))
return (schema_path / f"lsst.v{major}_{minor}.alert.avsc").as_posix()


def resolve_schema_definition(to_resolve, seen_names=None):
Expand Down Expand Up @@ -321,16 +308,14 @@ def from_file(cls, filename=None):
if filename is None:
major, minor = get_latest_schema_version()
root_name = f"lsst.v{major}_{minor}.alert"
with get_schema_path(major, minor) as schema_path:
filename = os.path.join(
schema_path,
root_name + ".avsc",
)
schema_definition = fastavro.schema.load_schema(filename)
filename = os.path.join(
get_schema_path(major, minor),
root_name + ".avsc",
)
else:
root_name = PurePath(filename).stem
schema_definition = fastavro.schema.load_schema(filename)

schema_definition = fastavro.schema.load_schema(filename)
if hasattr(fastavro.schema._schema, 'SCHEMA_DEFS'):
# Old fastavro gives a back a list if it recursively loaded more
# than one file, otherwise a dict.
Expand Down
20 changes: 9 additions & 11 deletions python/lsst/alert/packet/schemaRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,13 @@ def from_filesystem(cls, root=None, schema_root="lsst.v5_0.alert"):
"""
from .schema import Schema
from .schema import get_schema_root

with get_schema_root() as default_root:
if not root:
root = default_root
registry = cls()
schema_root_file = schema_root + ".avsc"
for root, dirs, files in os.walk(root, followlinks=False):
if schema_root_file in files:
schema = Schema.from_file(os.path.join(root, schema_root_file))
version = ".".join(root.split("/")[-2:])
registry.register_schema(schema, version)
if not root:
root = get_schema_root()
registry = cls()
schema_root_file = schema_root + ".avsc"
for root, dirs, files in os.walk(root, followlinks=False):
if schema_root_file in files:
schema = Schema.from_file(os.path.join(root, schema_root_file))
version = ".".join(root.split("/")[-2:])
registry.register_schema(schema, version)
return registry
11 changes: 5 additions & 6 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@ def setUp(self):
"""
self.test_schema_version = get_latest_schema_version()
self.test_schema = Schema.from_file()
with get_schema_path(*self.test_schema_version) as schema_path:
sample_json_path = posixpath.join(
schema_path, "sample_data", "alert.json",
)
with open(sample_json_path, "r") as f:
self.sample_alert = json.load(f)
sample_json_path = posixpath.join(
get_schema_path(*self.test_schema_version), "sample_data", "alert.json",
)
with open(sample_json_path, "r") as f:
self.sample_alert = json.load(f)

def _mock_alerts(self, n):
"""Return a list of alerts with mock values, matching
Expand Down
6 changes: 2 additions & 4 deletions test/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,15 @@ class SchemaRootTestCase(unittest.TestCase):
"""

def test_get_schema_root(self):
with get_schema_root() as schema_root:
self.assertTrue(os.path.isdir(schema_root))
self.assertTrue(os.path.isdir(get_schema_root()))


class PathLatestSchemTestCase(unittest.TestCase):
"""Test for get_path_to_latest_schema().
"""

def test_path_latest_schema(self):
with get_path_to_latest_schema() as schema_path:
self.assertTrue(os.path.isfile(schema_path))
self.assertTrue(os.path.isfile(get_path_to_latest_schema()))


class ResolveTestCase(unittest.TestCase):
Expand Down
60 changes: 29 additions & 31 deletions test/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,14 @@ def test_example_json(self):
no_data = ("1.0",) # No example data is available.

for version in self.registry.known_versions:
with get_schema_root() as schema_root:
path = path_to_sample_data(schema_root, version, "alert.json")
schema = self.registry.get_by_version(version) # noqa: F841
if version in no_data:
self.assertFalse(os.path.exists(path))
else:
with open(path, "r") as f:
data = json.load(f)
self.assertTrue(self.registry.get_by_version(version).validate(data))
path = path_to_sample_data(get_schema_root(), version, "alert.json")
schema = self.registry.get_by_version(version) # noqa: F841
if version in no_data:
self.assertFalse(os.path.exists(path))
else:
with open(path, "r") as f:
data = json.load(f)
self.assertTrue(self.registry.get_by_version(version).validate(data))

def test_example_avro(self):
"""Test that example data in Avro format can be loaded by the schema.
Expand All @@ -62,28 +61,27 @@ def test_example_avro(self):
bad_versions = ("2.0",) # This data is known not to parse.

for version in self.registry.known_versions:
with get_schema_root() as schema_root:
path = path_to_sample_data(schema_root, version,
"fakeAlert.avro")
schema = self.registry.get_by_version(version)
path = path_to_sample_data(get_schema_root(), version,
"fakeAlert.avro")
schema = self.registry.get_by_version(version)

if version in no_data:
self.assertFalse(os.path.exists(path))
else:
with open(path, "rb") as f:
if version in bad_versions:
with self.assertRaises(RuntimeError):
schema.retrieve_alerts(f)
else:
retrieved_schema, alerts = schema.retrieve_alerts(f)
if version in no_data:
self.assertFalse(os.path.exists(path))
else:
with open(path, "rb") as f:
if version in bad_versions:
with self.assertRaises(RuntimeError):
schema.retrieve_alerts(f)
else:
retrieved_schema, alerts = schema.retrieve_alerts(f)

fastavro_keys = list(schema.definition.keys())
for key in fastavro_keys:
if '__' in key and '__len__' not in key:
schema.definition.pop(key)
fastavro_keys = list(schema.definition.keys())
for key in fastavro_keys:
if '__' in key and '__len__' not in key:
schema.definition.pop(key)

self.assertEqual(retrieved_schema, schema,
f"schema not equal on version={version}")
for idx, alert in enumerate(alerts):
self.assertTrue(schema.validate(alert),
f"failed to validate version={version}, alert idx={idx}")
self.assertEqual(retrieved_schema, schema,
f"schema not equal on version={version}")
for idx, alert in enumerate(alerts):
self.assertTrue(schema.validate(alert),
f"failed to validate version={version}, alert idx={idx}")

0 comments on commit c1f99d0

Please sign in to comment.