Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up unintest #55

Merged
merged 2 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 71 additions & 34 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import shutil
from pathlib import Path

import pytest

from pyencrypt.encrypt import encrypt_file, encrypt_key, generate_so_file
Expand All @@ -19,21 +22,59 @@ def pytest_configure(config):
)


@pytest.fixture(scope="function")
def file_and_loader(request, tmp_path_factory):
tmp_path = tmp_path_factory.mktemp("file")
def _common_loader(path, license=False):
open("tmp", "a+").write(f"common:path: {path}" + "\n")

key = generate_aes_key()
cipher_key, d, n = encrypt_key(key)
loader_path = generate_so_file(cipher_key, d, n, path, license=license)

work_dir = loader_path.parent
work_dir.joinpath("loader.py").unlink()
work_dir.joinpath("loader.c").unlink()
work_dir.joinpath("loader_origin.py").unlink()
return key, loader_path.absolute()

file_marker = request.node.get_closest_marker("file")
file_name = file_marker.kwargs.get("name")
function_name = file_marker.kwargs.get("function")
code = file_marker.kwargs.get("code")

@pytest.fixture(scope="session")
def common_loader(tmp_path_factory):
open("tmp", "a+").write("make loader" + "\n")
tmp_path = tmp_path_factory.mktemp("loader")
return _common_loader(tmp_path, False)


@pytest.fixture(scope="session")
def common_loader_with_license(tmp_path_factory):
tmp_path = tmp_path_factory.mktemp("loader_with_license")
return _common_loader(tmp_path, True)


@pytest.fixture(scope="function")
def file_and_loader(request, common_loader, common_loader_with_license, tmp_path):
license_marker = request.node.get_closest_marker("license")
license, kwargs = False, {}
if license_marker is not None:
kwargs = license_marker.kwargs
license = kwargs.pop("enable", True)

if license:
key, loader_path = common_loader_with_license
else:
key, loader_path = common_loader

# copy loader -> tmp_path
loader_path = (
Path(shutil.copytree(loader_path.parent, tmp_path / "encrypted"))
/ loader_path.name
)
if license:
generate_license_file(key.decode(), loader_path.parent, **kwargs)

file_marker = request.node.get_closest_marker("file")
file_name = file_marker.kwargs.get("name")
function_name = file_marker.kwargs.get("function")
code = file_marker.kwargs.get("code")

file_path = tmp_path / f"{file_name}.py"
file_path.touch()
file_path.write_text(
Expand All @@ -45,38 +86,43 @@ def {function_name}():
),
encoding="utf-8",
)
# generate loader.so
key = generate_aes_key()

new_path = file_path.with_suffix(".pye")
encrypt_file(file_path, key.decode(), new_path=new_path)
file_path.unlink()
cipher_key, d, n = encrypt_key(key)
loader_path = generate_so_file(cipher_key, d, n, file_path.parent, license=license)
work_dir = loader_path.parent
work_dir.joinpath("loader.py").unlink()
work_dir.joinpath("loader.c").unlink()
work_dir.joinpath("loader_origin.py").unlink()

# License
license and generate_license_file(key.decode(), work_dir, **kwargs)
return (new_path, loader_path)


@pytest.fixture(scope="function")
def package_and_loader(request, tmp_path_factory):
pkg_path = tmp_path_factory.mktemp("package")

file_marker = request.node.get_closest_marker("package")
package_name = file_marker.kwargs.get("name")
function_name = file_marker.kwargs.get("function")
code = file_marker.kwargs.get("code")
def package_and_loader(request, common_loader, common_loader_with_license, tmp_path):
pkg_path = tmp_path

license_marker = request.node.get_closest_marker("license")
license, kwargs = False, {}
if license_marker is not None:
kwargs = license_marker.kwargs
license = kwargs.pop("enable", True)

if license:
key, loader_path = common_loader_with_license
else:
key, loader_path = common_loader

# copy loader -> tmp_path
loader_path = (
Path(shutil.copytree(loader_path.parent, tmp_path / "encrypted"))
/ loader_path.name
)

if license:
generate_license_file(key.decode(), loader_path.parent, **kwargs)

file_marker = request.node.get_closest_marker("package")
package_name = file_marker.kwargs.get("name")
function_name = file_marker.kwargs.get("function")
code = file_marker.kwargs.get("code")

current = pkg_path
for dir_name in package_name.split(".")[:-1]:
current = current.joinpath(dir_name)
Expand All @@ -95,16 +141,7 @@ def {function_name}():
)

new_path = file_path.with_suffix(".pye")
key = generate_aes_key()
encrypt_file(file_path, key, new_path=new_path)
encrypt_file(file_path, key.decode(), new_path=new_path)
file_path.unlink()

cipher_key, d, n = encrypt_key(key)
loader_path = generate_so_file(cipher_key, d, n, pkg_path, license)
work_dir = loader_path.parent
work_dir.joinpath("loader.py").unlink()
work_dir.joinpath("loader.c").unlink()
work_dir.joinpath("loader_origin.py").unlink()
# License
license and generate_license_file(key.decode(), work_dir, **kwargs)
return pkg_path, loader_path
6 changes: 3 additions & 3 deletions tests/test_encrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_can_encrypt(path, expected):
assert can_encrypt(path) == expected


class TestGenarateSoFile:
class TestGenerateSoFile:
def setup_method(self, method):
if method.__name__ == "test_generate_so_file_default_path":
shutil.rmtree(
Expand All @@ -53,7 +53,7 @@ def setup_method(self, method):
)
def test_generate_so_file(self, key, tmp_path):
cipher_key, d, n = encrypt_key(key)
assert generate_so_file(cipher_key, d, n, tmp_path)
assert generate_so_file(cipher_key, d, n, tmp_path).exists()
assert (tmp_path / "encrypted" / "loader.py").exists() is True
assert (tmp_path / "encrypted" / "loader_origin.py").exists() is True
if sys.platform.startswith("win"):
Expand All @@ -76,7 +76,7 @@ def test_generate_so_file(self, key, tmp_path):
)
def test_generate_so_file_default_path(self, key):
cipher_key, d, n = encrypt_key(key)
assert generate_so_file(cipher_key, d, n)
assert generate_so_file(cipher_key, d, n).exists()
assert (Path(os.getcwd()) / "encrypted" / "loader.py").exists() is True
assert (Path(os.getcwd()) / "encrypted" / "loader_origin.py").exists() is True
if sys.platform.startswith("win"):
Expand Down