diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 173926f..eb4902d 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -14,6 +14,6 @@ jobs: uses: actions/checkout@v4 - name: Run Labeler - uses: crazy-max/ghaction-github-labeler@v5.0.0 + uses: crazy-max/ghaction-github-labeler@v5.1.0 with: skip-delete: true diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml new file mode 100644 index 0000000..42c88bc --- /dev/null +++ b/.github/workflows/publish-docker.yml @@ -0,0 +1,64 @@ +name: Create and publish a Docker image + +# Configures this workflow to run every time a release is published +on: + release: + types: [published] + +# Defines two custom environment variables for the workflow. +# These are used for the Container registry domain, and a name for the Docker image that this workflow builds. +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu. +jobs: + build-and-push-image: + runs-on: ubuntu-latest + + # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. + permissions: + contents: read + packages: write + attestations: write + id-token: write + + steps: + # Necessary for buildx + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup QEMU + uses: docker/setup-qemu-action@v3 + + # Set up BuildKit Docker container builder to be able to build + # multi-platform images and export cache + # https://github.com/docker/setup-buildx-action + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + + # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. + # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository. + # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. + - name: Build and push Docker image + id: build-and-push + uses: docker/build-push-action@v6 + with: + context: . + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + platforms: linux/amd64,linux/arm64 diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 04c527b..49e77a5 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -18,22 +18,32 @@ jobs: python-version: ["3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v4 + - name: Check out the repository + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + + - name: Install Poetry + run: | + pipx install poetry + poetry --version + - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + poetry install --with development + + - name: Build package + run: poetry build --ansi + - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --ignore W503 + poetry run flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --ignore W503,D212 - name: Test with pytest run: | - pytest + poetry run pytest diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 8b75506..f894f2f 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -32,7 +32,7 @@ jobs: - name: Build package run: python -m build - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + uses: pypa/gh-action-pypi-publish@15c56dba361d8335944d31a2ecd17d700fc7bcbc with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index b4da1bc..65a296f 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -3,8 +3,6 @@ name: Run drevalpy Tests on: push: branches: - - development - - main - "release/*" pull_request: branches: @@ -23,7 +21,6 @@ jobs: - { python-version: "3.10", os: ubuntu-latest, session: "mypy" } - { python-version: "3.10", os: ubuntu-latest, session: "tests" } - { python-version: "3.10", os: windows-latest, session: "tests" } - - { python-version: "3.10", os: macos-latest, session: "tests" } - { python-version: "3.10", os: ubuntu-latest, session: "typeguard" } - { python-version: "3.10", os: ubuntu-latest, session: "xdoctest" } - { python-version: "3.10", os: ubuntu-latest, session: "docs-build" } @@ -130,4 +127,6 @@ jobs: run: nox --force-color --session=coverage -- xml -i - name: Upload coverage report - uses: codecov/codecov-action@v4.6.0 + uses: codecov/codecov-action@v5.0.2 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ec20a08..231ac24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,7 @@ repos: types: [python] require_serial: true args: - - --ignore=W503 + - --ignore=D212,W503,C901 - id: pyupgrade name: pyupgrade description: Automatically upgrade syntax for newer versions. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..8bf0fe7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,53 @@ +# I followed this article's recommendations +# https://medium.com/@albertazzir/blazing-fast-python-docker-builds-with-poetry-a78a66f5aed0 + +# The builder image, used to build the virtual environment +FROM python:3.10-buster as builder + +RUN pip install poetry==1.8.4 + +# POETRY_VIRTUALENVS_CREATE=1: Makes sure that environment will be as isolated as possible and above all that +# installation will not mess up with the system Python or, even worse, with Poetry itself. +# POETRY_CACHE_DIR: When removing the cache folder, make sure this is done in the same RUN command. If it’s done in a +# separate RUN command, the cache will still be part of the previous Docker layer (the one containing poetry install ) +# effectively rendering your optimization useless. + +ENV POETRY_NO_INTERACTION=1 \ + POETRY_VIRTUALENVS_IN_PROJECT=1 \ + POETRY_VIRTUALENVS_CREATE=1 \ + POETRY_CACHE_DIR=/tmp/poetry_cache + +WORKDIR /root + +COPY pyproject.toml poetry.lock ./ + +# First, we install only the dependencies. This way, we can cache this layer and avoid re-installing dependencies +# every time we change our application code. +# Because poetry will complain if a README.md is not found, we create a dummy one. +RUN touch README.md + +RUN poetry install --without dev --no-root && rm -rf $POETRY_CACHE_DIR + +# The runtime image, used to run the code +FROM python:3.10-slim-buster as runtime + +LABEL image.author.name="Judith Bernett" +LABEL image.author.email="judith.bernett@tum.de" + +ENV VIRTUAL_ENV=/root/.venv \ + PATH="/root/.venv/bin:$PATH" + +COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} + +# Copy all relevant code + +COPY drevalpy ./drevalpy +COPY create_report.py ./ +COPY README.md ./ +COPY run_suite.py ./ +COPY setup.py ./ +COPY pyproject.toml ./ +COPY poetry.lock ./ + +# Install drevalpy +RUN pip install . diff --git a/README.md b/README.md index 69cf3ad..fe56a23 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ pip install drevalpy From Source: ```bash -conda env create -f models/simple_neural_network/env.yml +conda env create -f models/SimpleNeuralNetwork/env.yml pip install . ``` diff --git a/README.rst b/README.rst index 7671a00..1812fca 100644 --- a/README.rst +++ b/README.rst @@ -15,10 +15,10 @@ DrEvalPy: Python Cancer Cell Line Drug Response Prediction Suite .. |Read the Docs| image:: https://img.shields.io/readthedocs/drevalpy/latest.svg?label=Read%20the%20Docs :target: https://drevalpy.readthedocs.io/ :alt: Read the documentation at https://drevalpy.readthedocs.io/ -.. |Build| image:: https://github.com/daisybio/drevalpy/workflows/Build%20drevalpy%20Package/badge.svg +.. |Build| image:: https://github.com/daisybio/drevalpy/actions/workflows/build_package.yml/badge.svg :target: https://github.com/daisybio/drevalpy/actions?workflow=Package :alt: Build Package Status -.. |Tests| image:: https://github.com/daisybio/drevalpy/workflows/Run%20drevalpy%20Tests/badge.svg +.. |Tests| image:: https://github.com/daisybio/drevalpy/actions/workflows/run_tests.yml/badge.svg :target: https://github.com/daisybio/drevalpy/actions?workflow=Tests :alt: Run Tests Status .. |Codecov| image:: https://codecov.io/gh/daisybio/drevalpy/branch/main/graph/badge.svg diff --git a/configs/CCLE.yaml b/configs/CCLE.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/configs/GDSC.yaml b/configs/GDSC.yaml deleted file mode 100644 index f7536af..0000000 --- a/configs/GDSC.yaml +++ /dev/null @@ -1,37 +0,0 @@ ---- -moli: - epochs: [5, 10, 15, 20] - batch_size: [8, 16, 32] - h_dim1: 512 - h_dim2: 128 - h_dim3: 64 - lr_exp: 0.001 - lr_mut: 0.0001 - lr_cnv: 0.0005 - lr_class: 0.0005 - dropout_rate_exp: 0.5 - dropout_rate_mut: 0.5 - dropout_rate_cnv: 0.5 - dropout_rate_class: 0.5 - weight_decay: 0.0001 - gamma: 0.5 - margin: 1.5 - loss: triplet - -superFELT: - epochs: [5, 10, 15, 20] - batch_size: [8, 16, 32] - h_dim1: 512 - h_dim2: 128 - h_dim3: 64 - lr_exp: 0.001 - lr_mut: 0.0001 - lr_cnv: 0.0005 - lr_class: 0.0005 - dropout_rate_exp: 0.5 - dropout_rate_mut: 0.5 - dropout_rate_cnv: 0.5 - dropout_rate_class: 0.5 - weight_decay: 0.0001 - gamma: 0.5 - margin: 1.5 diff --git a/create_report.py b/create_report.py index b0c5f05..a38ce22 100644 --- a/create_report.py +++ b/create_report.py @@ -3,6 +3,8 @@ import argparse import os +import pandas as pd + from drevalpy.visualization import ( CorrelationComparisonScatter, CriticalDifferencePlot, @@ -14,8 +16,10 @@ from drevalpy.visualization.utils import create_html, create_index_html, parse_results, prep_results, write_results -def create_output_directories(custom_id): - """If they do not exist yet, make directories for the visualization files. +def create_output_directories(custom_id: str) -> None: + """ + If they do not exist yet, make directories for the visualization files. + :param custom_id: run id passed via command line """ os.makedirs(f"results/{custom_id}/violin_plots", exist_ok=True) @@ -26,13 +30,22 @@ def create_output_directories(custom_id): os.makedirs(f"results/{custom_id}/critical_difference_plots", exist_ok=True) -def draw_setting_plots(lpo_lco_ldo, ev_res, ev_res_per_drug, ev_res_per_cell_line, custom_id): - """Draw all plots for a specific setting (LPO, LCO, LDO). +def draw_setting_plots( + lpo_lco_ldo: str, + ev_res: pd.DataFrame, + ev_res_per_drug: pd.DataFrame, + ev_res_per_cell_line: pd.DataFrame, + custom_id: str, +) -> list[str]: + """ + Draw all plots for a specific setting (LPO, LCO, LDO). + :param lpo_lco_ldo: setting :param ev_res: overall evaluation results :param ev_res_per_drug: evaluation results per drug :param ev_res_per_cell_line: evaluation results per cell line :param custom_id: run id passed via command line + :returns: list of unique algorithms """ ev_res_subset = ev_res[ev_res["LPO_LCO_LDO"] == lpo_lco_ldo] # PIPELINE: SAVE_TABLES @@ -99,8 +112,12 @@ def draw_setting_plots(lpo_lco_ldo, ev_res, ev_res_per_drug, ev_res_per_cell_lin return eval_results_preds["algorithm"].unique() -def draw_per_grouping_setting_plots(grouping, ev_res_per_group, lpo_lco_ldo, custom_id): - """Draw plots for a specific grouping (drug or cell line) for a specific setting (LPO, LCO, LDO). +def draw_per_grouping_setting_plots( + grouping: str, ev_res_per_group: pd.DataFrame, lpo_lco_ldo: str, custom_id: str +) -> None: + """ + Draw plots for a specific grouping (drug or cell line) for a specific setting (LPO, LCO, LDO). + :param grouping: drug or cell_line :param ev_res_per_group: evaluation results per drug or per cell line :param lpo_lco_ldo: setting @@ -132,15 +149,17 @@ def draw_per_grouping_setting_plots(grouping, ev_res_per_group, lpo_lco_ldo, cus def draw_algorithm_plots( - model, - ev_res, - ev_res_per_drug, - ev_res_per_cell_line, - t_vs_p, - lpo_lco_ldo, - custom_id, -): - """Draw all plots for a specific algorithm. + model: str, + ev_res: pd.DataFrame, + ev_res_per_drug: pd.DataFrame, + ev_res_per_cell_line: pd.DataFrame, + t_vs_p: pd.DataFrame, + lpo_lco_ldo: str, + custom_id: str, +) -> None: + """ + Draw all plots for a specific algorithm. + :param model: name of the model/algorithm :param ev_res: overall evaluation results :param ev_res_per_drug: evaluation results per drug @@ -194,15 +213,17 @@ def draw_algorithm_plots( def draw_per_grouping_algorithm_plots( - grouping_slider, - grouping_scatter_table, - model, - ev_res_per_group, - t_v_p, - lpo_lco_ldo, - custom_id, + grouping_slider: str, + grouping_scatter_table: str, + model: str, + ev_res_per_group: pd.DataFrame, + t_v_p: pd.DataFrame, + lpo_lco_ldo: str, + custom_id: str, ): - """Draw plots for a specific grouping (drug or cell line) for a specific algorithm. + """ + Draw plots for a specific grouping (drug or cell line) for a specific algorithm. + :param grouping_slider: the grouping variable for the regression plots :param grouping_scatter_table: the grouping variable for the scatter plots. If grouping_slider is drug, this should be cell_line and vice versa @@ -320,8 +341,8 @@ def draw_per_grouping_algorithm_plots( custom_id=run_id, ) # get all html files from results/{run_id} - all_files = [] - for _, _, files in os.walk(f"results/{run_id}"): + all_files: list[str] = [] + for _, _, files in os.walk(f"results/{run_id}"): # type: ignore[assignment] for file in files: if file.endswith(".html") and file not in ["index.html", "LPO.html", "LCO.html", "LDO.html"]: all_files.append(file) diff --git a/docs/conf.py b/docs/conf.py index 7c31e3b..491b879 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +"""Configuration file for the Sphinx documentation builder.""" # mypy: ignore-errors # drevalpy documentation build configuration file # @@ -15,6 +16,8 @@ from jinja2.defaults import DEFAULT_FILTERS +import drevalpy + sys.path.insert(0, os.path.abspath("../")) @@ -55,9 +58,9 @@ # the built documents. # # The short X.Y version. -version = "0.0.1" +version = drevalpy.__version__ # The full version, including alpha/beta/rc tags. -release = "0.0.1" +release = drevalpy.__version__ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -205,7 +208,12 @@ def get_obj_module(qualname): - """Get a module/class/attribute and its original module by qualname.""" + """ + Get a module/class/attribute and its original module by qualname. + + :param qualname: The qualified name of the object. + :returns: The object and its original module. + """ modname = qualname classname = None attrname = None @@ -225,7 +233,12 @@ def get_obj_module(qualname): def get_linenos(obj): - """Get an object’s line numbers.""" + """ + Get an object’s line numbers. + + :param obj: The object. + :returns: The start and end line numbers. + """ try: lines, start = inspect.getsourcelines(obj) except TypeError: # obj is an attribute or None @@ -239,7 +252,12 @@ def get_linenos(obj): def modurl(qualname): - """Get the full GitHub URL for some object’s qualname.""" + """ + Get the full GitHub URL for some object’s qualname. + + :param qualname: The qualified name of the object. + :returns: The full GitHub URL. + """ obj, module = get_obj_module(qualname) path = Path(module.__file__).relative_to(project_dir) start, end = get_linenos(obj) diff --git a/docs/requirements.txt b/docs/requirements.txt index 8a033ff..521aa7c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ -sphinx-autobuild==2024.10.3 ; python_version >= "3.9" and python_version < "3.12" -sphinx-autodoc-typehints<3.0 ; python_version >= "3.9" and python_version < "3.12" -sphinx-click==6.0.0 ; python_version >= "3.9" and python_version < "3.12" -sphinx-rtd-theme==3.0.1 ; python_version >= "3.9" and python_version < "3.12" +sphinx-autobuild==2024.10.3 ; python_version >= "3.9" and python_full_version <= "3.13.0" +sphinx-autodoc-typehints==2.3.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +sphinx-click==6.0.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +sphinx-rtd-theme==3.0.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" +-e . diff --git a/drevalpy/__init__.py b/drevalpy/__init__.py index e69de29..233d8a5 100644 --- a/drevalpy/__init__.py +++ b/drevalpy/__init__.py @@ -0,0 +1,5 @@ +"""Module containing the drevalpy suite.""" + +from importlib.metadata import version + +__version__ = version("drevalpy") diff --git a/drevalpy/datasets/dataset.py b/drevalpy/datasets/dataset.py index 9049b57..394c438 100644 --- a/drevalpy/datasets/dataset.py +++ b/drevalpy/datasets/dataset.py @@ -16,27 +16,33 @@ import copy import os from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union +from pathlib import Path +from typing import Any, Callable import networkx as nx import numpy as np import pandas as pd -from numpy.typing import ArrayLike from sklearn.base import TransformerMixin from sklearn.model_selection import GroupKFold, train_test_split +from ..pipeline_function import pipeline_function from .utils import permute_features, randomize_graph +np.set_printoptions(threshold=6) + class Dataset(ABC): """Abstract wrapper class for datasets.""" + @classmethod @abstractmethod - def load(self, path: str): + def from_csv(cls: type["Dataset"], input_file: str | Path, dataset_name: str = "unknown") -> "Dataset": """ Loads the dataset from data. - :param path: path to the dataset + :param input_file: Path to the csv file containing the data to be loaded + :param dataset_name: Optional name to associate the dataset with, default = "unknown" + :returns: Dataset object containing data from provided csv file. """ @abstractmethod @@ -51,14 +57,108 @@ def save(self, path: str): class DrugResponseDataset(Dataset): """Drug response dataset.""" + _response: np.ndarray + _cell_line_ids: np.ndarray + _drug_ids: np.ndarray + _predictions: np.ndarray | None = None + _cv_splits: list[dict[str, "DrugResponseDataset"]] = [] + _name: str + + @classmethod + def from_csv( + cls: type["DrugResponseDataset"], input_file: str | Path, dataset_name: str = "unknown" + ) -> "DrugResponseDataset": + """ + Load a dataset from a csv file. + + This function creates a DrugResponseDataset from a provided input file in csv format. + The following columns are required: + - response: the drug response values as floating point values + - cell_line_ids: a string identifier for cell lines + - drug_ids: a string identifier for drugs + - predictions: an optional column containing a predicted value TODO what exactly? + + :param input_file: Path to the csv file containing the data to be loaded + :param dataset_name: Optional name to associate the dataset with, default = "unknown" + + :returns: DrugResponseDataset object containing data from provided csv file. + """ + data = pd.read_csv(input_file) + if "predictions" in data.columns: + predictions = data["predictions"].values + else: + predictions = None + return cls( + response=data["response"].values, + cell_line_ids=data["cell_line_id"].values, + drug_ids=data["drug_id"].values, + predictions=predictions, + dataset_name=dataset_name, + ) + + @property + def response(self) -> np.ndarray: + """ + Returns the response values. + + :returns: numpy array containing response values. + """ + return self._response + + @property + def cell_line_ids(self) -> np.ndarray: + """ + Returns the cell_line_ids. + + :returns: numpy array containing cell_line_ids values. + """ + return self._cell_line_ids + + @property + def drug_ids(self) -> np.ndarray: + """ + Returns the drug_ids. + + :returns: numpy array containing drug_ids values. + """ + return self._drug_ids + + @property + def predictions(self) -> np.ndarray | None: + """ + Returns the predictions if they exist. + + :returns: numpy array containing prediction values or None. + """ + return self._predictions + + @property + def cv_splits(self) -> list[dict[str, "DrugResponseDataset"]]: + """ + Returns the cv_splits. + + :returns: DrugResponseDatasets containing the CV_splits. + """ + return self._cv_splits + + @property + def dataset_name(self) -> str: + """ + Returns the name of this DrugResponseDataset. + + :returns: dataset name. + """ + return self._name + + @pipeline_function def __init__( self, - response: Optional[np.ndarray] = None, - cell_line_ids: Optional[np.ndarray] = None, - drug_ids: Optional[np.ndarray] = None, - predictions: Optional[np.ndarray] = None, - dataset_name: Optional[str] = None, - ): + response: np.ndarray, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + predictions: np.ndarray | None = None, + dataset_name: str = "unnamed", + ) -> None: """ Initializes the drug response dataset. @@ -66,59 +166,53 @@ def __init__( :param cell_line_ids: cell line IDs :param drug_ids: drug IDs :param predictions: optional. Predicted drug response values per cell line and drug - :param dataset_name: optional. Name of the dataset + :param dataset_name: optional. Name of the dataset, default: "unnamed" + :raises AssertionError: If response, cell_line_ids, drug_ids, (and the optional predictions) do not all have + the same length. """ super().__init__() - if response is not None: - self.response = np.array(response) - self.cell_line_ids = np.array(cell_line_ids) - self.drug_ids = np.array(drug_ids) - if len(self.response) != len(self.cell_line_ids): - raise AssertionError("response and cell_line_ids have different lengths") - if len(self.response) != len(self.drug_ids): - raise AssertionError("response and drug_ids/cell_line_ids have different lengths") - self.dataset_name = dataset_name - else: - self.response = response - self.cell_line_ids = cell_line_ids - self.drug_ids = drug_ids - self.dataset_name = dataset_name - - if predictions is not None: - self.predictions = np.array(predictions) - if len(self.predictions) != len(self.response): - raise AssertionError("predictions and response have different lengths") - else: - self.predictions = None - self.cv_splits = None + if len(response) != len(cell_line_ids): + raise AssertionError("Response and cell_line_ids have different lengths.") + if len(response) != len(drug_ids): + raise AssertionError("Response and drug_ids have different lengths.") + if predictions is not None and len(response) != len(predictions): + raise AssertionError("Response and predictions have different lengths.") + self._response = response + self._cell_line_ids = cell_line_ids + self._drug_ids = drug_ids + self._predictions = predictions + self._name = dataset_name - def __len__(self): - """Overwrites the default length method.""" + def __len__(self) -> int: + """ + Overwrites the default length method. + + :returns: Number of samples in the dataset + """ return len(self.response) - def __str__(self): - """Overwrite the default str method.""" - if len(self.response) > 3: - string = ( - f"DrugResponseDataset: CLs {self.cell_line_ids[:3]}...; " - f"Drugs {self.drug_ids[:3]}...; " - f"Response {self.response[:3]}..." - ) - else: - string = ( - f"DrugResponseDataset: CLs {self.cell_line_ids}; " - f"Drugs {self.drug_ids}; " - f"Response {self.response}" - ) + def __str__(self) -> str: + """ + Overwrite the default str method. + + :return: Text summary of the dataset + """ + string = ( + f"{self.dataset_name} DrugResponseDataset with {len(self)} entries:\n" + f"CLs {self.cell_line_ids}\n" + f"Drugs {self.drug_ids}\n" + f"Response {self.response}\n" + ) if self.predictions is not None: - if len(self.predictions) > 3: - string += f"; Predictions {self.predictions[:3]}..." - else: - string += f"; Predictions {self.predictions}" + string += f"Predictions {self.predictions}\n" return string - def to_dataframe(self): - """Convert the dataset into a pandas DataFrame.""" + def to_dataframe(self) -> pd.DataFrame: + """ + Convert the dataset into a pandas DataFrame. + + :returns: pandas DataFrame of the dataset with columns 'cell_line_id', 'drug_id', 'response'(, 'predictions') + """ data = { "cell_line_id": self.cell_line_ids, "drug_id": self.drug_ids, @@ -128,74 +222,55 @@ def to_dataframe(self): data["predictions"] = self.predictions return pd.DataFrame(data) - def load(self, path: str): + def save(self, path: str | Path): """ - Loads the drug response dataset from data. + Stores the drug response dataset on disk. - :param path: path to the dataset - """ - data = pd.read_csv(path) - self.response = data["response"].values - self.cell_line_ids = data["cell_line_ids"].values - self.drug_ids = data["drug_ids"].values - if "predictions" in data.columns: - self.predictions = data["predictions"].values - - def save(self, path: str): + :param path: path to desired storage location """ - Saves the drug response dataset to data. - - :param path: path to the dataset - """ - out = pd.DataFrame( - { - "cell_line_ids": self.cell_line_ids, - "drug_ids": self.drug_ids, - "response": self.response, - } - ) - if self.predictions is not None: - out["predictions"] = self.predictions - out.to_csv(path, index=False) + self.to_dataframe().to_csv(path, index=False) + @pipeline_function def add_rows(self, other: "DrugResponseDataset") -> None: """ Adds rows from another dataset. :param other: other dataset """ - self.response = np.concatenate([self.response, other.response]) - self.cell_line_ids = np.concatenate([self.cell_line_ids, other.cell_line_ids]) - self.drug_ids = np.concatenate([self.drug_ids, other.drug_ids]) + self._response = np.concatenate([self._response, other.response]) + self._cell_line_ids = np.concatenate([self._cell_line_ids, other.cell_line_ids]) + self._drug_ids = np.concatenate([self._drug_ids, other.drug_ids]) if self.predictions is not None and other.predictions is not None: - self.predictions = np.concatenate([self.predictions, other.predictions]) + self._predictions = np.concatenate([self._predictions, other.predictions]) + @pipeline_function def remove_nan_responses(self) -> None: """Removes rows with NaN values in the response.""" mask = np.isnan(self.response) - self.response = self.response[~mask] - self.cell_line_ids = self.cell_line_ids[~mask] - self.drug_ids = self.drug_ids[~mask] + self._response = self.response[~mask] + self._cell_line_ids = self.cell_line_ids[~mask] + self._drug_ids = self.drug_ids[~mask] if self.predictions is not None: - self.predictions = self.predictions[~mask] + self._predictions = self.predictions[~mask] + @pipeline_function def shuffle(self, random_state: int = 42) -> None: """ Shuffles the dataset. :param random_state: random state """ - indices = np.arange(len(self.response)) + indices = np.arange(len(self)) np.random.seed(random_state) np.random.shuffle(indices) - self.response = self.response[indices] - self.cell_line_ids = self.cell_line_ids[indices] - self.drug_ids = self.drug_ids[indices] + self._response = self.response[indices] + self._cell_line_ids = self.cell_line_ids[indices] + self._drug_ids = self.drug_ids[indices] if self.predictions is not None: - self.predictions = self.predictions[indices] + self._predictions = self.predictions[indices] - def remove_drugs(self, drugs_to_remove: Union[str, list]) -> None: + def _remove_drugs(self, drugs_to_remove: str | list[str | int]) -> None: """ Removes drugs from the dataset. @@ -205,11 +280,11 @@ def remove_drugs(self, drugs_to_remove: Union[str, list]) -> None: drugs_to_remove = [drugs_to_remove] mask = [drug not in drugs_to_remove for drug in self.drug_ids] - self.drug_ids = self.drug_ids[mask] - self.cell_line_ids = self.cell_line_ids[mask] - self.response = self.response[mask] + self._drug_ids = self.drug_ids[mask] + self._cell_line_ids = self.cell_line_ids[mask] + self._response = self.response[mask] - def remove_cell_lines(self, cell_lines_to_remove: Union[str, list]) -> None: + def _remove_cell_lines(self, cell_lines_to_remove: str | list[str | int]) -> None: """ Removes cell lines from the dataset. @@ -219,23 +294,24 @@ def remove_cell_lines(self, cell_lines_to_remove: Union[str, list]) -> None: cell_lines_to_remove = [cell_lines_to_remove] mask = [cell_line not in cell_lines_to_remove for cell_line in self.cell_line_ids] - self.drug_ids = self.drug_ids[mask] - self.cell_line_ids = self.cell_line_ids[mask] - self.response = self.response[mask] + self._drug_ids = self.drug_ids[mask] + self._cell_line_ids = self.cell_line_ids[mask] + self._response = self.response[mask] - def remove_rows(self, indices: ArrayLike) -> None: + def remove_rows(self, indices: np.ndarray) -> None: """ Removes rows from the dataset. :param indices: indices of rows to remove """ - self.drug_ids = np.delete(self.drug_ids, indices) - self.cell_line_ids = np.delete(self.cell_line_ids, indices) - self.response = np.delete(self.response, indices) + indices = np.array(indices, dtype=int) + self._drug_ids = np.delete(self.drug_ids, indices) + self._cell_line_ids = np.delete(self.cell_line_ids, indices) + self._response = np.delete(self.response, indices) if self.predictions is not None: - self.predictions = np.delete(self.predictions, indices) + self._predictions = np.delete(self.predictions, indices) - def reduce_to(self, cell_line_ids: Optional[ArrayLike], drug_ids: Optional[ArrayLike]) -> None: + def reduce_to(self, cell_line_ids: np.ndarray | None = None, drug_ids: np.ndarray | None = None) -> None: """ Removes all rows which contain a cell_line not in cell_line_ids or a drug not in drug_ids. @@ -243,11 +319,12 @@ def reduce_to(self, cell_line_ids: Optional[ArrayLike], drug_ids: Optional[Array :param drug_ids: drug IDs or None to keep all cell lines """ if drug_ids is not None: - self.remove_drugs(list(set(self.drug_ids) - set(drug_ids))) + self._remove_drugs(list(set(self.drug_ids) - set(drug_ids))) if cell_line_ids is not None: - self.remove_cell_lines(list(set(self.cell_line_ids) - set(cell_line_ids))) + self._remove_cell_lines(list(set(self.cell_line_ids) - set(cell_line_ids))) + @pipeline_function def split_dataset( self, n_cv_splits: int, @@ -266,8 +343,9 @@ def split_dataset( :param split_early_stopping: if True, an early stopping set is generated :param validation_ratio: ratio of validation set size to training set size :param random_state: random state - :return: list of dictionaries containing the cross-validation datasets. + :returns: list of dictionaries containing the cross-validation datasets. Each fold is a dictionary with keys 'train', 'validation', 'test', 'validation_es', 'early_stopping'. + :raises ValueError: if mode is not 'LPO', 'LCO', or 'LDO' """ cell_line_ids = self.cell_line_ids drug_ids = self.drug_ids @@ -306,16 +384,17 @@ def split_dataset( validation_es, early_stopping = _split_early_stopping_data(split["validation"], test_mode=mode) split["validation_es"] = validation_es split["early_stopping"] = early_stopping - self.cv_splits = cv_splits + self._cv_splits = cv_splits return cv_splits - def save_splits(self, path: str) -> None: + def save_splits(self, path: str): """ Save cross validation splits to path/cv_split_0_train.csv and path/cv_split_0_test.csv. :param path: path to the directory where the cv split files are saved + :raises AssertionError: if DrugResponseDataset was not split """ - if self.cv_splits is None: + if not self.cv_splits: raise AssertionError("Trying to save splits, but DrugResponseDataset was not split.") os.makedirs(path, exist_ok=True) for i, split in enumerate(self.cv_splits): @@ -336,6 +415,7 @@ def load_splits(self, path: str) -> None: Load cross validation splits from path/cv_split_0_train.csv and path/cv_split_0_test.csv. :param path: path to the directory containing the cv split files + :raises AssertionError: if no cv split files are found in path """ files = os.listdir(path) files = [file for file in files if (file.endswith(".csv") and file.startswith("cv_split"))] @@ -363,25 +443,24 @@ def load_splits(self, path: str) -> None: "validation_es": validation_es_splits, "early_stopping": early_stopping_splits, } - self.cv_splits = [] + self._cv_splits.clear() # TODO do we need this? for split_train, split_test in zip(train_splits, test_splits, strict=True): - tr_split = DrugResponseDataset(dataset_name=self.dataset_name) - tr_split.load(os.path.join(path, split_train)) - - te_split = DrugResponseDataset(dataset_name=self.dataset_name) - te_split.load(os.path.join(path, split_test)) - self.cv_splits.append({"train": tr_split, "test": te_split}) + tr_split = DrugResponseDataset.from_csv(os.path.join(path, split_train), dataset_name=self.dataset_name) + te_split = DrugResponseDataset.from_csv(os.path.join(path, split_test), dataset_name=self.dataset_name) + self._cv_splits.append({"train": tr_split, "test": te_split}) for mode in ["validation", "validation_es", "early_stopping"]: if len(optional_splits[mode]) > 0: for i, v_split in enumerate(optional_splits[mode]): - split = DrugResponseDataset(dataset_name=self.dataset_name) - split.load(os.path.join(path, v_split)) - self.cv_splits[i][mode] = split + split = DrugResponseDataset.from_csv(os.path.join(path, v_split), dataset_name=self.dataset_name) + self._cv_splits[i][mode] = split def copy(self): - """Returns a copy of the drug response dataset.""" + """Returns a copy of the drug response dataset. + + :returns: copy of the dataset + """ return DrugResponseDataset( response=copy.deepcopy(self.response), cell_line_ids=copy.deepcopy(self.cell_line_ids), @@ -390,8 +469,11 @@ def copy(self): dataset_name=self.dataset_name, ) - def __hash__(self): - """Overwrites default hash method.""" + def __hash__(self) -> int: + """Overwrites default hash method. + + :returns: hash value of the dataset + """ return hash( ( self.dataset_name, @@ -402,17 +484,17 @@ def __hash__(self): ) ) - def mask(self, mask: list[bool]) -> None: + def mask(self, mask: np.ndarray) -> None: """ Removes rows from the dataset based on a boolean mask. :param mask: boolean mask """ - self.response = self.response[mask] - self.cell_line_ids = self.cell_line_ids[mask] - self.drug_ids = self.drug_ids[mask] + self._response = self.response[mask] + self._cell_line_ids = self.cell_line_ids[mask] + self._drug_ids = self.drug_ids[mask] if self.predictions is not None: - self.predictions = self.predictions[mask] + self._predictions = self.predictions[mask] def transform(self, response_transformation: TransformerMixin) -> None: """ @@ -420,9 +502,9 @@ def transform(self, response_transformation: TransformerMixin) -> None: :param response_transformation: e.g., StandardScaler, MinMaxScaler, RobustScaler """ - self.response = response_transformation.transform(self.response.reshape(-1, 1)).squeeze() + self._response = response_transformation.transform(self.response.reshape(-1, 1)).squeeze() if self.predictions is not None: - self.predictions = response_transformation.transform(self.predictions.reshape(-1, 1)).squeeze() + self._predictions = response_transformation.transform(self.predictions.reshape(-1, 1)).squeeze() def fit_transform(self, response_transformation: TransformerMixin) -> None: """ @@ -439,9 +521,9 @@ def inverse_transform(self, response_transformation: TransformerMixin) -> None: :param response_transformation: e.g., StandardScaler, MinMaxScaler, RobustScaler """ - self.response = response_transformation.inverse_transform(self.response.reshape(-1, 1)).squeeze() + self._response = response_transformation.inverse_transform(self.response.reshape(-1, 1)).squeeze() if self.predictions is not None: - self.predictions = response_transformation.inverse_transform(self.predictions.reshape(-1, 1)).squeeze() + self._predictions = response_transformation.inverse_transform(self.predictions.reshape(-1, 1)).squeeze() def _split_early_stopping_data( @@ -452,7 +534,7 @@ def _split_early_stopping_data( :param validation_dataset: input validation dataset :param test_mode: LCO, LDO, LPO - :return: the resulting validation and early stopping datasets + :returns: the resulting validation and early stopping datasets """ validation_dataset.shuffle(random_state=42) cv_v = validation_dataset.split_dataset( @@ -470,14 +552,14 @@ def _split_early_stopping_data( def _leave_pair_out_cv( n_cv_splits: int, - response: ArrayLike, - cell_line_ids: ArrayLike, - drug_ids: ArrayLike, - split_validation=True, - validation_ratio=0.1, - random_state=42, - dataset_name: Optional[str] = None, -) -> list[dict]: + response: np.ndarray, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + split_validation: bool = True, + validation_ratio: float = 0.1, + random_state: int = 42, + dataset_name: str = "unknown", +) -> list[dict[str, DrugResponseDataset]]: """ Leave pair out cross validation. Splits data into n_cv_splits number of cross validation splits. @@ -489,7 +571,8 @@ def _leave_pair_out_cv( :param validation_ratio: ratio of validation set (of the training set) :param random_state: random state :param dataset_name: name of the dataset - :return: list of dicts of the cross validation sets + :returns: list of dicts of the cross validation sets + :raises AssertionError: if response, cell_line_ids and drug_ids have different lengths """ if not (len(response) == len(cell_line_ids) == len(drug_ids)): raise AssertionError("response, cell_line_ids and drug_ids must have the same length") @@ -546,21 +629,28 @@ def _leave_pair_out_cv( def _leave_group_out_cv( group: str, n_cv_splits: int, - response: ArrayLike, - cell_line_ids: ArrayLike, - drug_ids: ArrayLike, - split_validation=True, - validation_ratio=0.1, - random_state=42, - dataset_name: Optional[str] = None, + response: np.ndarray, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + split_validation: bool = True, + validation_ratio: float = 0.1, + random_state: int = 42, + dataset_name: str = "unknown", ): """ Leave group out cross validation: Splits data into n_cv_splits number of cross validation splits. :param group: group to leave out (cell_line or drug) :param n_cv_splits: number of cross validation splits + :param response: response (e.g. ic50 values) + :param cell_line_ids: cell line IDs + :param drug_ids: drug IDs + :param split_validation: whether to split the training set into training and validation set + :param validation_ratio: ratio of validation set (of the training set) :param random_state: random state - :return: list of dicts of the cross validation sets + :param dataset_name: name of the dataset + :returns: list of dicts of the cross validation sets + :raises AssertionError: if group is not 'cell_line' or 'drug' """ if group not in {"cell_line", "drug"}: raise AssertionError(f"group must be 'cell_line' or 'drug', but is {group}") @@ -629,10 +719,63 @@ def _leave_group_out_cv( class FeatureDataset(Dataset): """Class for feature datasets.""" + _meta_info: dict[str, Any] = {} + _features: dict[str, dict[str, Any]] = {} + + @classmethod + def from_csv( + cls: type["FeatureDataset"], input_file: str | Path, dataset_name: str = "unknown" + ) -> "FeatureDataset": + """ + Load a feature dataset from a csv file. + + This function creates a FeatureDataset from a provided input file in csv format. + :param input_file: Path to the csv file containing the data to be loaded + :param dataset_name: Optional name to associate the dataset with, default = "unknown" + :raises NotImplementedError: This method is currently not implemented. + """ + raise NotImplementedError + + @property + def meta_info(self) -> dict[str, Any]: + """ + Returns the meta information. + + :returns: Meta information of this FeatureDataset + """ + return self._meta_info + + @property + def features(self) -> dict[str, dict[str, Any]]: + """ + Returns the features. + + :returns: features of this FeatureDataset + """ + return self._features + + @property + def identifiers(self) -> np.ndarray: + """ + Returns the identifiers of the features. + + :returns: feature identifiers of this FeatureDataset + """ + return np.array(list(self.features.keys())) + + @property + def view_names(self) -> list[str]: + """ + Returns the view_names. + + :returns: view_names of this FeatureDataset + """ + return list(self.features[list(self.features.keys())[0]].keys()) # TODO whut?! + def __init__( self, features: dict[str, dict[str, Any]], - meta_info: Optional[dict[str, Any]] = None, + meta_info: dict[str, Any] | None = None, ): """ Initializes the feature dataset. @@ -641,46 +784,39 @@ def __init__( key: drug ID/cell line ID, value: Dict of feature views, key: feature name, value: feature vector :param meta_info: additional information for the views, e.g. gene names for gene expression + :raises AssertionError: if meta_info keys are not in view names """ super().__init__() - self.features = features - self.view_names = self.get_view_names() + self._features = features if meta_info is not None: # assert that str of meta Dict[str, Any] is in view_names if not all(meta_key in self.view_names for meta_key in meta_info.keys()): raise AssertionError(f"Meta keys {meta_info.keys()} not in view names {self.view_names}") - self.meta_info = meta_info - else: - self.meta_info = None - self.identifiers = self.get_ids() + self._meta_info = meta_info def save(self, path: str): """ Saves the feature dataset to data. :param path: path to the dataset + :raises NotImplementedError: if method is not implemented """ raise NotImplementedError("save method not implemented") - def load(self, path: str): - """ - Loads the feature dataset from data. - - :param path: path to the dataset - """ - raise NotImplementedError("load method not implemented") - - def randomize_features(self, views_to_randomize: Union[str, list], randomization_type: str) -> None: + def randomize_features(self, views_to_randomize: str | list[str], randomization_type: str) -> None: """ Randomizes the feature vectors. + Permutation permutes the feature vectors. + Invariant means that the randomization is done in a way that a key characteristic of the feature is + preserved. In case of matrices, this is the mean and standard deviation of the feature view for this + instance, for networks it is the degree distribution. + :param views_to_randomize: name of feature view or list of names of multiple feature views to randomize. The other views are not randomized. :param randomization_type: randomization type ('permutation', 'invariant'). - :return: Permutation permutes the feature vectors. - Invariant means that the randomization is done in a way that a key characteristic of the - feature is preserved. In case of matrices, this is the mean and standard deviation of the - feature view for this instance, for networks it is the degree distribution. + :raises AssertionError: if randomization_type is not 'permutation' or 'invariant' + :raises ValueError: if no invariant randomization is available for the feature view type """ if randomization_type not in ["permutation", "invariant"]: raise AssertionError( @@ -695,7 +831,7 @@ def randomize_features(self, views_to_randomize: Union[str, list], randomization # E.g. each cell line gets the feature vector/graph/image... # of another cell line. # Drawn without replacement. - self.features = permute_features( + self._features = permute_features( features=self.features, views_to_permute=views_to_randomize, identifiers=self.identifiers, @@ -726,23 +862,19 @@ def randomize_features(self, views_to_randomize: Union[str, list], randomization ) self.features[identifier][view] = new_features - def get_ids(self): - """Returns drug ids of the dataset.""" - return np.array(list(self.features.keys())) - - def get_view_names(self): - """Returns feature view names.""" - return list(self.features[list(self.features.keys())[0]].keys()) - - def get_feature_matrix(self, view: str, identifiers: ArrayLike, stack: bool = True) -> Union[np.ndarray, list]: + def get_feature_matrix(self, view: str, identifiers: np.ndarray) -> np.ndarray: """ Returns the feature matrix for the given view. The feature view must be a vector or matrix. :param view: view name :param identifiers: list of identifiers (cell lines oder drugs) - :param stack: if True, stacks the feature vectors to a matrix. If False, returns a list of features. - :return: feature matrix + :returns: feature matrix + :raises AssertionError: if no identifiers are given + :raises AssertionError: if view is not in the FeatureDataset + :raises AssertionError: if identifiers are not in the FeatureDataset + :raises AssertionError: if feature vectors of view have different lengths + :raises AssertionError: if view is not a numpy array, i.e. not a vector or matrix """ if len(identifiers) == 0: raise AssertionError("get_feature_matrix: No identifiers given.") @@ -761,36 +893,39 @@ def get_feature_matrix(self, view: str, identifiers: ArrayLike, stack: bool = Tr if not all(isinstance(self.features[id_][view], np.ndarray) for id_ in identifiers): raise AssertionError(f"get_feature_matrix only works for vectors or matrices. {view} is not a numpy array.") - out = [self.features[id_][view] for id_ in identifiers] - return np.stack(out, axis=0) if stack else out + out = np.array([self.features[id_][view] for id_ in identifiers]) + return out def copy(self): - """Returns a copy of the feature dataset.""" + """Returns a copy of the feature dataset. + + :returns: copy of the dataset + """ return FeatureDataset(features=copy.deepcopy(self.features)) - def _add_features(self, other: "FeatureDataset") -> None: + def add_features(self, other: "FeatureDataset") -> None: """ Adds features views from another dataset. Inner join (only common identifiers are kept). :param other: other dataset + :raises AssertionError: if feature views overlap """ if len(set(self.view_names) & set(other.view_names)) != 0: raise AssertionError( "Trying to add features but feature views overlap. FeatureDatasets should be distinct." ) - if other.meta_info is not None: + if other.meta_info: self.add_meta_info(other) common_identifiers = set(self.identifiers).intersection(other.identifiers) new_features = {} for id_ in common_identifiers: + id_ = str(id_) new_features[id_] = {view: self.features[id_][view] for view in self.view_names} for view in other.view_names: new_features[id_][view] = other.features[id_][view] - self.features = new_features - self.view_names = self.get_view_names() - self.identifiers = self.get_ids() + self._features = new_features def add_meta_info(self, other: "FeatureDataset") -> None: """ @@ -799,15 +934,22 @@ def add_meta_info(self, other: "FeatureDataset") -> None: :param other: other dataset """ other_meta = other.meta_info - self.meta_info.update(other_meta) + if self.meta_info is None: + self.meta_info = other_meta + else: + if other_meta is not None: + self.meta_info.update(other_meta) - def transform_features(self, ids: ArrayLike, transformer: TransformerMixin, view: str): + def transform_features(self, ids: np.ndarray, transformer: TransformerMixin, view: str): """ Applies a transformation like standard scaling to features. :param ids: The IDs to transform :param transformer: fitted sklearn transformer :param view: the view to transform + :raises AssertionError: if view is not in the FeatureDataset + :raises AssertionError: if a cell line is missing + :raises AssertionError: if IDs are not unique """ if view not in self.view_names: raise AssertionError(f"Transform view {view!r} not in in the FeatureDataset.") @@ -822,14 +964,16 @@ def transform_features(self, ids: ArrayLike, transformer: TransformerMixin, view scaled_feature_vector = transformer.transform([feature_vector])[0] self.features[identifier][view] = scaled_feature_vector - def fit_transform_features(self, train_ids: ArrayLike, transformer: TransformerMixin, view: str): + def fit_transform_features(self, train_ids: np.ndarray, transformer: TransformerMixin, view: str): """ Fits and applies a transformation. Fitting is done only on the train_ids. :param train_ids: The IDs corresponding to the training dataset. :param transformer: sklearn transformer :param view: the view to transform - :return: The modified FeatureDataset with transformed gene expression features. + :returns: The modified FeatureDataset with transformed gene expression features. + :raises AssertionError: if view is not in the FeatureDataset + :raises AssertionError: if train IDs are not unique """ if view not in self.view_names: raise AssertionError(f"Transform view {view!r} not in in the FeatureDataset.") @@ -837,15 +981,8 @@ def fit_transform_features(self, train_ids: ArrayLike, transformer: TransformerM if len(np.unique(train_ids)) != len(train_ids): raise AssertionError("Train IDs should be unique.") - train_features = [] - # Collect all features of the view for fitting the scaler - for identifier in train_ids: - feature_vector = self.features[identifier][view] - train_features.append(feature_vector) - - # Fit the scaler on the collected feature data - train_features = np.vstack(train_features) + train_features = np.vstack([self.features[identifier][view] for identifier in train_ids]) transformer.fit(train_features) # Apply transformation and scaling to each feature vector @@ -855,7 +992,11 @@ def fit_transform_features(self, train_ids: ArrayLike, transformer: TransformerM self.features[identifier][view] = scaled_gene_expression return transformer - def _apply(self, function: Callable, view: str): - """Applies a function to the features of a view.""" + def apply(self, function: Callable, view: str): + """Applies a function to the features of a view. + + :param function: function to apply + :param view: view to apply the function to + """ for identifier in self.features: self.features[identifier][view] = function(self.features[identifier][view]) diff --git a/drevalpy/datasets/loader.py b/drevalpy/datasets/loader.py index 239c984..396d32a 100644 --- a/drevalpy/datasets/loader.py +++ b/drevalpy/datasets/loader.py @@ -1,7 +1,11 @@ +"""Contains functions to load the GDSC1, GDSC2, CCLE, and Toy datasets.""" + import os +from typing import Callable import pandas as pd +from ..pipeline_function import pipeline_function from .dataset import DrugResponseDataset from .utils import download_dataset @@ -15,7 +19,7 @@ def load_gdsc1( :param path_data: Path to the dataset. :param file_name: File name of the dataset. :param dataset_name: Name of the dataset. - :return: Dictionary containing response, cell line IDs, and drug IDs. + :return: DrugResponseDataset containing response, cell line IDs, and drug IDs. """ path = os.path.join(path_data, dataset_name, file_name) if not os.path.exists(path): @@ -38,7 +42,7 @@ def load_gdsc2(path_data: str = "data", file_name: str = "response_GDSC2.csv"): :param path_data: Path to the dataset. :param file_name: File name of the dataset. - :return: Dictionary containing response, cell line IDs, and drug IDs. + :return: DrugResponseDataset containing response, cell line IDs, and drug IDs. """ return load_gdsc1(path_data=path_data, file_name=file_name, dataset_name="GDSC2") @@ -49,7 +53,7 @@ def load_ccle(path_data: str = "data", file_name: str = "response_CCLE.csv") -> :param path_data: Path to the dataset. :param file_name: File name of the dataset. - :return: Dictionary containing response, cell line IDs, and drug IDs. + :return: DrugResponseDataset containing response, cell line IDs, and drug IDs. """ dataset_name = "CCLE" path = os.path.join(path_data, dataset_name, file_name) @@ -72,8 +76,7 @@ def load_toy(path_data: str = "data") -> DrugResponseDataset: Loads small Toy dataset, subsampled from GDSC1. :param path_data: Path to the dataset. - :param file_name: File name of the dataset. - :return: Dictionary containing response, cell line IDs, and drug IDs. + :return: DrugResponseDataset containing response, cell line IDs, and drug IDs. """ dataset_name = "Toy_Data" path = os.path.join(path_data, dataset_name, "toy_data.csv") @@ -89,18 +92,25 @@ def load_toy(path_data: str = "data") -> DrugResponseDataset: ) -AVAILABLE_DATASETS = {"GDSC1": load_gdsc1, "GDSC2": load_gdsc2, "CCLE": load_ccle, "Toy_Data": load_toy} +AVAILABLE_DATASETS: dict[str, Callable] = { + "GDSC1": load_gdsc1, + "GDSC2": load_gdsc2, + "CCLE": load_ccle, + "Toy_Data": load_toy, +} +@pipeline_function def load_dataset(dataset_name: str, path_data: str = "data") -> DrugResponseDataset: """ Load a dataset based on the dataset name. :param dataset_name: The name of the dataset to load ('GDSC1', 'GDSC2', 'CCLE', or 'Toy_Data'). :param path_data: The path to the dataset. - :return: A dictionary containing response, cell line IDs, drug IDs, and dataset name. + :return: A DrugResponseDataset containing response, cell line IDs, drug IDs, and dataset name. + :raises ValueError: If the dataset name is unknown. """ if dataset_name in AVAILABLE_DATASETS: - return AVAILABLE_DATASETS[dataset_name](path_data) + return AVAILABLE_DATASETS[dataset_name](path_data) # type: ignore else: raise ValueError(f"Unknown dataset name: {dataset_name}") diff --git a/drevalpy/datasets/utils.py b/drevalpy/datasets/utils.py index e5e070f..50c2bb2 100644 --- a/drevalpy/datasets/utils.py +++ b/drevalpy/datasets/utils.py @@ -1,33 +1,33 @@ """Utility functions for datasets.""" -import os import zipfile +from pathlib import Path +from typing import Any import networkx as nx import numpy as np import requests -from numpy.typing import ArrayLike def download_dataset( dataset_name: str, - data_path: str = "data", + data_path: str | Path = "data", redownload: bool = False, ): """ Download the latets dataset from Zenodo. - :param dataset: dataset name, e.g., "GDSC1", "GDSC2", "CCLE" or "Toy_Data" + :param dataset_name: dataset name, e.g., "GDSC1", "GDSC2", "CCLE" or "Toy_Data" :param data_path: where to save the data :param redownload: whether to redownload the data - :return: + :raises HTTPError: if the download fails """ file_name = f"{dataset_name}.zip" - file_path = os.path.join(data_path, file_name) - extracted_folder_path = os.path.join(data_path, dataset_name) + file_path = Path(data_path) / file_name + extracted_folder_path = file_path.with_suffix("") # Check if the extracted data exists and skip download if not redownloading - if os.path.exists(extracted_folder_path) and not redownload: + if extracted_folder_path.exists() and not redownload: print(f"{dataset_name} is already extracted, skipping download.") else: url = "https://zenodo.org/doi/10.5281/zenodo.12633909" @@ -42,7 +42,7 @@ def download_dataset( data = response.json() # Ensure the save path exists - os.makedirs(data_path, exist_ok=True) + extracted_folder_path.parent.mkdir(exist_ok=True, parents=True) # Download each file name_to_url = {file["key"]: file["links"]["self"] for file in data["files"]} @@ -61,7 +61,7 @@ def download_dataset( for member in z.infolist(): if not member.filename.startswith("__MACOSX/"): z.extract(member, data_path) - os.remove(file_path) # Remove zip file after extraction + file_path.unlink() # Remove zip file after extraction print(f"{dataset_name} data downloaded and extracted to {data_path}") @@ -100,10 +100,10 @@ def randomize_graph(original_graph: nx.Graph) -> nx.Graph: def permute_features( - features: dict, - identifiers: ArrayLike, - views_to_permute: list, - all_views: list, + features: dict[str, dict[str, Any]], + identifiers: np.ndarray, + views_to_permute: list[str], + all_views: list[str], ) -> dict: """ Permute the specified views for each entity (= cell line or drug). diff --git a/drevalpy/evaluation.py b/drevalpy/evaluation.py index 4c06af1..24dadc8 100644 --- a/drevalpy/evaluation.py +++ b/drevalpy/evaluation.py @@ -1,7 +1,6 @@ """Functions for evaluating model performance.""" import warnings -from typing import Union import numpy as np import pandas as pd @@ -10,6 +9,7 @@ from sklearn import metrics from .datasets.dataset import DrugResponseDataset +from .pipeline_function import pipeline_function warning_shown = False constant_prediction_warning_shown = False @@ -31,7 +31,9 @@ def partial_correlation( :param cell_line_ids: cell line IDs :param drug_ids: drug IDs :param method: method to compute the partial correlation (pearson, spearman) - :return: partial correlation float + :param return_pvalue: whether to return the p-value + :returns: partial correlation float + :raises AssertionError: if predictions, response, drug_ids, and cell_line_ids do not have the same length """ if len(y_true) < 3: return np.nan if not return_pvalue else (np.nan, np.nan) @@ -94,28 +96,28 @@ def partial_correlation( return r -def check_constant_prediction(y_pred: np.ndarray) -> bool: +def _check_constant_prediction(y_pred: np.ndarray) -> bool: """ Check if predictions are constant. - :param y_pred: - :return: + :param y_pred: predictions + :return: bool whether predictions are constant """ tol = 1e-6 # no variation in predictions - return np.all(np.isclose(y_pred, y_pred[0], atol=tol)) + return bool(np.all(np.isclose(y_pred, y_pred[0], atol=tol))) -def check_constant_target_or_small_sample(y_true: np.ndarray) -> bool: +def _check_constant_target_or_small_sample(y_true: np.ndarray) -> bool: """ Check if target is constant or sample size is too small. - :param y_true: - :return: + :param y_true: true response + :returns: bool whether target is constant or sample size is too small """ tol = 1e-6 # Check for insufficient sample size or no variation in target - return len(y_true) < 2 or np.all(np.isclose(y_true, y_true[0], atol=tol)) + return len(y_true) < 2 or bool(np.all(np.isclose(y_true, y_true[0], atol=tol))) def pearson(y_pred: np.ndarray, y_true: np.ndarray) -> float: @@ -125,14 +127,14 @@ def pearson(y_pred: np.ndarray, y_true: np.ndarray) -> float: :param y_pred: predictions :param y_true: response :return: pearson correlation float + :raises AssertionError: if predictions and response do not have the same length """ - if len(y_pred) != len(y_true): raise AssertionError("predictions, response must have the same length") - if check_constant_prediction(y_pred): + if _check_constant_prediction(y_pred): return 0.0 - if check_constant_target_or_small_sample(y_true): + if _check_constant_target_or_small_sample(y_true): return np.nan return pearsonr(y_pred, y_true)[0] @@ -145,13 +147,14 @@ def spearman(y_pred: np.ndarray, y_true: np.ndarray) -> float: :param y_pred: predictions :param y_true: response :return: spearman correlation float + :raises AssertionError: if predictions and response do not have the same length """ # we can use scipy.stats.spearmanr if len(y_pred) != len(y_true): raise AssertionError("predictions, response must have the same length") - if check_constant_prediction(y_pred): + if _check_constant_prediction(y_pred): return 0.0 - if check_constant_target_or_small_sample(y_true): + if _check_constant_target_or_small_sample(y_true): return np.nan return spearmanr(y_pred, y_true)[0] @@ -164,13 +167,14 @@ def kendall(y_pred: np.ndarray, y_true: np.ndarray) -> float: :param y_pred: predictions :param y_true: response :return: kendall tau correlation float + :raises AssertionError: if predictions and response do not have the same length """ # we can use scipy.stats.spearmanr if len(y_pred) != len(y_true): raise AssertionError("predictions, response must have the same length") - if check_constant_prediction(y_pred): + if _check_constant_prediction(y_pred): return 0.0 - if check_constant_target_or_small_sample(y_true): + if _check_constant_target_or_small_sample(y_true): return np.nan return kendalltau(y_pred, y_true)[0] @@ -186,6 +190,7 @@ def kendall(y_pred: np.ndarray, y_true: np.ndarray) -> float: "Kendall": kendall, "Partial_Correlation": partial_correlation, } +# both used by pipeline! MINIMIZATION_METRICS = ["MSE", "RMSE", "MAE"] MAXIMIZATION_METRICS = [ "R^2", @@ -200,8 +205,9 @@ def get_mode(metric: str): """ Get whether the optimum value of the metric is the minimum or maximum. - :param metric: - :return: + :param metric: metric, e.g., RMSE + :returns: whether the optimum value of the metric is the minimum or maximum + :raises ValueError: if the metric is not in MINIMIZATION_METRICS or MAXIMIZATION_METRICS """ if metric in MINIMIZATION_METRICS: mode = "min" @@ -214,7 +220,8 @@ def get_mode(metric: str): return mode -def evaluate(dataset: DrugResponseDataset, metric: Union[list[str], str]): +@pipeline_function +def evaluate(dataset: DrugResponseDataset, metric: list[str] | str): """ Evaluates the model on the given dataset. @@ -222,17 +229,20 @@ def evaluate(dataset: DrugResponseDataset, metric: Union[list[str], str]): :param metric: evaluation metric(s) (one or a list of "MSE", "RMSE", "MAE", "r2", "Pearson", "spearman", "kendall", "partial_correlation") :return: evaluation metric + :raises AssertionError: if metric is not in AVAILABLE """ if isinstance(metric, str): metric = [metric] predictions = dataset.predictions + if predictions is None: + raise AssertionError("No predictions found in the dataset") response = dataset.response results = {} for m in metric: if m not in AVAILABLE_METRICS: raise AssertionError(f"invalid metric {m}. Available: {list(AVAILABLE_METRICS.keys())}") - if len(response) < 2: + if len(response) < 2 or np.all(np.isnan(response)) or np.all(np.isnan(predictions)): results[m] = float(np.nan) else: if m == "Partial_Correlation": @@ -248,14 +258,3 @@ def evaluate(dataset: DrugResponseDataset, metric: Union[list[str], str]): results[m] = float(AVAILABLE_METRICS[m](y_pred=predictions, y_true=response)) return results - - -def visualize_results(results: pd.DataFrame, mode: Union[list[str], str]): - """ - Visualizes the model on the given dataset. - - :param dataset: dataset to evaluate on - :param mode: - :return: evaluation metric - """ - raise NotImplementedError("visualize not implemented yet") diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index f760df2..98a2e0f 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -4,7 +4,7 @@ import os import shutil import warnings -from typing import Optional +from typing import Any, Optional import numpy as np import pandas as pd @@ -16,7 +16,8 @@ from .datasets.dataset import DrugResponseDataset, FeatureDataset from .evaluation import evaluate, get_mode from .models import MODEL_FACTORY, MULTI_DRUG_MODEL_FACTORY, SINGLE_DRUG_MODEL_FACTORY -from .models.drp_model import DRPModel, SingleDrugModel +from .models.drp_model import DRPModel +from .pipeline_function import pipeline_function def drug_response_experiment( @@ -41,47 +42,47 @@ def drug_response_experiment( Run the drug response prediction experiment. Save results to disc. :param models: list of model classes to compare - :param baselines: list of baseline models. No randomization or robustness tests are run for the - baseline models. + :param baselines: list of baseline models. No randomization or robustness tests are run for the baseline models. :param response_data: drug response dataset :param response_transformation: normalizer to use for the response data :param metric: metric to use for hyperparameter optimization :param n_cv_splits: number of cross-validation splits :param multiprocessing: whether to use multiprocessing - :param randomization_mode: list of randomization modes to do. - Modes: SVCC, SVRC, SVCD, SVRD - Can be a list of randomization tests e.g. 'SVCC SVCD'. Default is None, which means no - randomization tests are run. - SVCC: Single View Constant for Cell Lines: in this mode, one experiment is done for every - cell line view the model uses (e.g. gene expression, mutation, ..). - For each experiment one cell line view is held constant while the others are randomized. - SVRC Single View Random for Cell Lines: in this mode, one experiment is done for every - cell line view the model uses (e.g. gene expression, mutation, ..). - For each experiment one cell line view is randomized while the others are held constant. - SVCD: Single View Constant for Drugs: in this mode, one experiment is done for every drug - view the model uses (e.g. fingerprints, target_information, ..). - For each experiment one drug view is held constant while the others are randomized. - SVRD: Single View Random for Drugs: in this mode, one experiment is done for every drug - view the model uses (e.g. gene expression, target_information, ..). - For each experiment one drug view is randomized while the others are held constant. - :param randomization_type: type of randomization to use. Choose from "gaussian", "zeroing", - "permutation". Default is "permutation" - "gaussian": replace the features with random values sampled from a gaussian distribution - with the same mean and standard deviation - "zeroing": replace the features with zeros - "permutation": permute the features over the instances, keeping the distribution of the - features the same but dissolving the relationship to the target - :param n_trials_robustness: number of trials to run for the robustness test. - The robustness test is a test where models are - retrained multiple tiems with varying seeds. Default is 0, which - means no robustness test is run. + :param randomization_mode: list of randomization modes to do. Modes: SVCC, SVRC, SVCD, SVRD Can be a list of + randomization tests e.g. 'SVCC SVCD'. Default is None, which means no randomization tests are run. + + * SVCC: Single View Constant for Cell Lines: in this mode, one experiment is done for every cell line view + the model uses (e.g. gene expression, mutation, ...). For each experiment one cell line view is held + constant while the others are randomized. + * SVRC Single View Random for Cell Lines: in this mode, one experiment is done for every cell line view the + model uses (e.g. gene expression, mutation, ...). For each experiment one cell line view is randomized while + the others are held constant. + * SVCD: Single View Constant for Drugs: in this mode, one experiment is done for every drug view the model + uses (e.g. fingerprints, target_information, ...). For each experiment one drug view is held constant + while the others are randomized. + * SVRD: Single View Random for Drugs: in this mode, one experiment is done for every drug view the model uses + (e.g. gene expression, target_information, ...). For each experiment one drug view is randomized while + the others are held constant. + + :param randomization_type: type of randomization to use. Choose from "permutation" and "invariant". + Default is "permutation". + + * "permutation": permute the features over the instances, keeping the distribution of the features the same + but dissolving the relationship to the target + * "invariant": the features are permuted in a way that a key characteristic of the feature is kept. In case of + matrices, this is the mean and standard deviation of the feature view for this instance, for networks it + is the degree distribution. + + :param cross_study_datasets: list of datasets for the cross-study prediction. The trained model is assessed for + its generalization to these datasets. Default is None, which means no cross-study prediction is run. + :param n_trials_robustness: number of trials to run for the robustness test. The robustness test is a test where + models are retrained multiple times with varying seeds. Default is 0, which means no robustness test is run. :param path_out: path to the output directory :param run_id: identifier to save the results - :param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, - leave-drug-out) + :param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, leave-drug-out) :param overwrite: whether to overwrite existing results :param path_data: path to the data directory, usually data/ - :return: None + :raises ValueError: if no cv splits are found """ if baselines is None: baselines = [] @@ -89,7 +90,6 @@ def drug_response_experiment( result_path = os.path.join(path_out, run_id, test_mode) split_path = os.path.join(result_path, "splits") result_folder_exists = os.path.exists(result_path) - randomization_test_views = [] if result_folder_exists and overwrite: # if results exists, delete them if overwrite is True print(f"Overwriting existing results at {result_path}") @@ -146,6 +146,9 @@ def drug_response_experiment( model_hpam_set = model_class.get_hyperparameter_set() + if response_data.cv_splits is None: + raise ValueError("No cv splits found.") + for split_index, split in enumerate(response_data.cv_splits): print(f"################# FOLD {split_index+1}/{len(response_data.cv_splits)} " f"#################") @@ -233,7 +236,7 @@ def drug_response_experiment( best_hpams = json.load(f) if not is_baseline: if randomization_mode is not None: - print(f"Randomization tests for {model_class.model_name}") + print(f"Randomization tests for {model_class.get_model_name()}") # if this line changes, it also needs to be changed in pipeline: # randomization_split.py randomization_test_views = get_randomization_test_views( @@ -253,7 +256,7 @@ def drug_response_experiment( response_transformation=response_transformation, ) if n_trials_robustness > 0: - print(f"Robustness test for {model_class.model_name}") + print(f"Robustness test for {model_class.get_model_name()}") robustness_test( n_trials=n_trials_robustness, model=model, @@ -278,6 +281,7 @@ def drug_response_experiment( print("Done!") +@pipeline_function def consolidate_single_drug_model_predictions( models: list[type[DRPModel]], n_cv_splits: int, @@ -287,13 +291,24 @@ def consolidate_single_drug_model_predictions( n_trials_robustness: int = 0, out_path: str = "", ) -> None: - """Consolidate SingleDrugModel predictions into a single file.""" - + """ + Consolidate single drug model predictions into a single file. + + :param models: list of model classes to compare, e.g., [SimpleNeuralNetwork, RandomForest] + :param n_cv_splits: number of cross-validation splits, e.g., 5 + :param results_path: path to the results directory, e.g., results/ + :param cross_study_datasets: list of cross-study datasets, e.g., [CCLE, GDSC1] + :param randomization_mode: list of randomization modes, e.g., ["SVCC", "SVRC"] + :param n_trials_robustness: number of robustness trials, e.g., 10 + :param out_path: for the package, this is the same as results_path. For the pipeline, this is empty because it + will be stored in the work directory. + """ for model in models: - if model.model_name in SINGLE_DRUG_MODEL_FACTORY: - model_instance = MODEL_FACTORY[model.model_name]() - model_path = os.path.join(results_path, str(model.model_name)) - out_path = os.path.join(out_path, str(model.model_name)) + if model.get_model_name() in SINGLE_DRUG_MODEL_FACTORY: + + model_instance = MODEL_FACTORY[model.get_model_name()]() + model_path = os.path.join(results_path, model.get_model_name()) + out_path = os.path.join(out_path, model.get_model_name()) os.makedirs(os.path.join(out_path, "predictions"), exist_ok=True) if cross_study_datasets: os.makedirs(os.path.join(out_path, "cross_study"), exist_ok=True) @@ -305,7 +320,7 @@ def consolidate_single_drug_model_predictions( for split in range(n_cv_splits): # Collect predictions for drugs across all scenarios (main, cross_study, robustness, randomization) - predictions = { + predictions: Any = { "main": [], "cross_study": {}, "robustness": {}, @@ -410,22 +425,23 @@ def consolidate_single_drug_model_predictions( ) -def handle_overwrite(path: str, overwrite: bool) -> None: - """Handle overwrite logic for a given path.""" - if os.path.exists(path) and overwrite: - shutil.rmtree(path) - os.makedirs(path, exist_ok=True) - - def load_features( model: DRPModel, path_data: str, dataset: DrugResponseDataset -) -> tuple[FeatureDataset, FeatureDataset]: - """Load and reduce cell line and drug features for a given dataset.""" +) -> tuple[FeatureDataset, Optional[FeatureDataset]]: + """ + Load and reduce cell line and drug features for a given dataset. + + :param model: model to use, e.g., SimpleNeuralNetwork + :param path_data: path to the data directory, e.g., data/ + :param dataset: dataset to load features for, e.g., GDSC2 + :returns: tuple of cell line and, potentially, drug features + """ cl_features = model.load_cell_line_features(data_path=path_data, dataset_name=dataset.dataset_name) drug_features = model.load_drug_features(data_path=path_data, dataset_name=dataset.dataset_name) return cl_features, drug_features +@pipeline_function def cross_study_prediction( dataset: DrugResponseDataset, model: DRPModel, @@ -439,15 +455,20 @@ def cross_study_prediction( single_drug_id: Optional[str] = None, ) -> None: """ - Run the drug response prediction experiment on a cross-study dataset. Save results to disc. + Run the drug response prediction experiment on a cross-study dataset to assess the generalizability of the model. - :param dataset: cross-study dataset - :param model: model to use + :param dataset: cross-study dataset, e.g., GDSC1 if trained on GDSC2 + :param model: model to use, e.g, SimpleNeuralNetwork :param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, leave-drug-out) - :param train_dataset: training dataset + :param train_dataset: training dataset, e.g., GDSC2 + :param path_data: path to the data directory, e.g., data/ :param early_stopping_dataset: early stopping dataset + :param response_transformation: normalizer to use for the response data, e.g., StandardScaler + :param path_out: path to the output directory, e.g., results/ + :param split_index: index of the split :param single_drug_id: drug id to use for single drug models None for global models + :raises ValueError: if feature loading fails or if the test mode is invalid """ dataset = dataset.copy() os.makedirs(os.path.join(path_out, "cross_study"), exist_ok=True) @@ -458,15 +479,16 @@ def cross_study_prediction( try: cl_features, drug_features = load_features(model, path_data, dataset) except ValueError as e: - warnings.warn(e, stacklevel=2) + warnings.warn(str(e), stacklevel=2) return cell_lines_to_keep = cl_features.identifiers if cl_features is not None else None + drugs_to_keep: Optional[np.ndarray] = None if single_drug_id is not None: - drugs_to_keep = [single_drug_id] - else: - drugs_to_keep = drug_features.identifiers if drug_features is not None else None + drugs_to_keep = np.array([single_drug_id]) + elif drug_features is not None: + drugs_to_keep = drug_features.identifiers print( f"Reducing cross study dataset ... feature data available for " @@ -486,33 +508,33 @@ def cross_study_prediction( } dataset_pairs = [f"{cl}_{drug}" for cl, drug in zip(dataset.cell_line_ids, dataset.drug_ids, strict=True)] - dataset.remove_rows([i for i, pair in enumerate(dataset_pairs) if pair in train_pairs]) + dataset.remove_rows(np.array([i for i, pair in enumerate(dataset_pairs) if pair in train_pairs])) elif test_mode == "LCO": - train_cell_lines = set(train_dataset.cell_line_ids) + train_cell_lines = train_dataset.cell_line_ids dataset.reduce_to( - cell_line_ids=[cl for cl in dataset.cell_line_ids if cl not in train_cell_lines], + cell_line_ids=np.setdiff1d(dataset.cell_line_ids, train_cell_lines), drug_ids=None, ) elif test_mode == "LDO": - train_drugs = set(train_dataset.drug_ids) + train_drugs = train_dataset.drug_ids dataset.reduce_to( cell_line_ids=None, - drug_ids=[drug for drug in dataset.drug_ids if drug not in train_drugs], + drug_ids=np.setdiff1d(dataset.drug_ids, train_drugs), ) else: raise ValueError(f"Invalid test mode: {test_mode}. Choose from LPO, LCO, LDO") if len(dataset) > 0: dataset.shuffle(random_state=42) - dataset.predictions = model.predict( + dataset._predictions = model.predict( cell_line_ids=dataset.cell_line_ids, drug_ids=dataset.drug_ids, cell_line_input=cl_features, drug_input=drug_features, ) if response_transformation: - dataset.response = response_transformation.inverse_transform(dataset.response) + dataset._response = response_transformation.inverse_transform(dataset.response) else: - dataset.predictions = np.array([]) + dataset._predictions = np.array([]) dataset.save( os.path.join( path_out, @@ -522,13 +544,19 @@ def cross_study_prediction( ) +@pipeline_function def get_randomization_test_views(model: DRPModel, randomization_mode: list[str]) -> dict[str, list[str]]: """ Get the views to use for the randomization tests. - :param model: - :param randomization_mode: - :return: + * For SVCC, a single cell line view (e.g., gene expression) is held constant while the others are randomized. + * For SVCD, a single drug view (e.g., fingerprints) is held constant while the others are randomized. + * For SVRC, a single cell line view is randomized while the others are held constant. + * For SVRD, a single drug view is randomized while the others are held constant. + + :param model: model to use, e.g., SimpleNeuralNetwork + :param randomization_mode: list of randomization modes to do, e.g., ["SVCC", "SVRC"] + :returns: dictionary of randomization test views """ cell_line_views = model.cell_line_views drug_views = model.drug_views @@ -536,12 +564,12 @@ def get_randomization_test_views(model: DRPModel, randomization_mode: list[str]) if "SVCC" in randomization_mode: for view in cell_line_views: randomization_test_views[f"SVCC_{view}"] = [v for v in cell_line_views if v != view] - if "SVRC" in randomization_mode: - for view in cell_line_views: - randomization_test_views[f"SVRC_{view}"] = [view] if "SVCD" in randomization_mode: for view in drug_views: randomization_test_views[f"SVCD_{view}"] = [v for v in drug_views if v != view] + if "SVRC" in randomization_mode: + for view in cell_line_views: + randomization_test_views[f"SVRC_{view}"] = [view] if "SVRD" in randomization_mode: for view in drug_views: randomization_test_views[f"SVRD_{view}"] = [view] @@ -565,19 +593,18 @@ def robustness_test( Run robustness tests for the given model and dataset. This will run the model n times with different random seeds to get a distribution of the results. + :param n_trials: number of trials to run :param model: model to evaluate :param hpam_set: hyperparameters to use + :param path_data: path to the data directory :param train_dataset: training dataset :param test_dataset: test dataset :param early_stopping_dataset: early stopping dataset :param path_out: path to the output directory :param split_index: index of the split - :param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, - leave-drug-out) - :param response_transformation: sklearn.preprocessing scaler like StandardScaler or - MinMaxScaler to use to scale the target - :return: None (save results to disk) + :param response_transformation: sklearn.preprocessing scaler like StandardScaler or MinMaxScaler to use to scale + the target """ robustness_test_path = os.path.join(path_out, "robustness") os.makedirs(robustness_test_path, exist_ok=True) @@ -601,6 +628,7 @@ def robustness_test( ) +@pipeline_function def robustness_train_predict( trial: int, trial_file: str, @@ -615,16 +643,15 @@ def robustness_train_predict( """ Train and predict for the robustness test. - :param trial: - :param trial_file: - :param train_dataset: - :param test_dataset: - :param early_stopping_dataset: - :param model: - :param hpam_set: - :param path_data: - :param response_transformation: - :return: + :param trial: trial number + :param trial_file: file to save the results to + :param train_dataset: training dataset + :param test_dataset: test dataset + :param early_stopping_dataset: early stopping dataset + :param model: model to evaluate + :param hpam_set: hyperparameters to use + :param path_data: path to the data directory, e.g., data/ + :param response_transformation: sklearn.preprocessing scaler like StandardScaler or MinMaxScaler to use to scale """ train_dataset.shuffle(random_state=trial) test_dataset.shuffle(random_state=trial) @@ -664,12 +691,12 @@ def randomization_test( "methylation_only": ["gene_expression", "copy_number_var", "mutation"]}" :param model: model to evaluate :param hpam_set: hyperparameters to use + :param path_data: path to the data directory :param train_dataset: training dataset :param test_dataset: test dataset :param early_stopping_dataset: early stopping dataset :param path_out: path to the output directory :param split_index: index of the split - :param test_mode: test mode one of "LPO", "LCO", "LDO" :param randomization_type: type of randomization to use. Choose from "permutation", "invariant". Default is "permutation" which permutes the features over the instances, keeping the distribution of the features the same but dissolving the relationship to the target. @@ -678,7 +705,6 @@ def randomization_test( instance, for networks it is the degree distribution. :param response_transformation: sklearn.preprocessing scaler like StandardScaler or MinMaxScaler to use to scale the target - :return: None (save results to disk) """ for test_name, views in randomization_test_views.items(): randomization_test_path = os.path.join(path_out, "randomization") @@ -708,6 +734,7 @@ def randomization_test( print(f"Randomization test {test_name} already exists. Skipping.") +@pipeline_function def randomize_train_predict( view: str, test_name: str, @@ -720,38 +747,51 @@ def randomize_train_predict( test_dataset: DrugResponseDataset, early_stopping_dataset: Optional[DrugResponseDataset], response_transformation: Optional[TransformerMixin], -): +) -> None: """ Randomize the features for a given view and run the model. - :param view: - :param test_name: - :param randomization_type: - :param randomization_test_file: - :param model: - :param hpam_set: - :param path_data: - :param train_dataset: - :param test_dataset: - :param early_stopping_dataset: - :param response_transformation: - :return: + :param view: view to randomize, e.g., gene_expression + :param test_name: name of the randomization test, e.g., SVRC_gene_expression + :param randomization_type: type of randomization to use, e.g., permutation + :param randomization_test_file: file to save the results to + :param model: model to evaluate + :param hpam_set: hyperparameters to use + :param path_data: path to the data directory + :param train_dataset: training dataset + :param test_dataset: test dataset + :param early_stopping_dataset: early stopping dataset + :param response_transformation: sklearn.preprocessing scaler like StandardScaler or MinMaxScaler to use to scale """ cl_features, drug_features = load_features(model, path_data, train_dataset) - if (view not in cl_features.get_view_names()) and (view not in drug_features.get_view_names()): + # Handle case where both features are None early on + if cl_features is None and drug_features is None: warnings.warn( - f"View {view} not found in features. Skipping randomization test {test_name} " f"which includes this view.", + "Both cl_features and drug_features are None. Skipping randomization test.", stacklevel=2, ) return - cl_features_rand = cl_features.copy() if cl_features is not None else None - drug_features_rand = drug_features.copy() if drug_features is not None else None - if view in cl_features.get_view_names(): - cl_features_rand.randomize_features(view, randomization_type=randomization_type) - elif view in drug_features.get_view_names(): - drug_features_rand.randomize_features(view, randomization_type=randomization_type) + # Check if view is in either feature set, if not, warn and skip + if (cl_features is not None and view not in cl_features.view_names) and ( + drug_features is not None and view not in drug_features.view_names + ): + warnings.warn( + f"View {view} not found in features. Skipping randomization test {test_name} which includes this view.", + stacklevel=2, + ) + return + + cl_features_rand: Optional[FeatureDataset] = None + if cl_features is not None: + cl_features_rand = cl_features.copy() + cl_features_rand.randomize_features(view, randomization_type=randomization_type) # type: ignore[union-attr] + + drug_features_rand: Optional[FeatureDataset] = None + if drug_features is not None: + drug_features_rand = drug_features.copy() + drug_features_rand.randomize_features(view, randomization_type=randomization_type) # type: ignore[union-attr] test_dataset_rand = train_and_predict( model=model, @@ -773,9 +813,9 @@ def split_early_stopping( """ Split the validation dataset into a validation and early stopping dataset. - :param validation_dataset: - :param test_mode: - :return: + :param validation_dataset: validation dataset + :param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, leave-drug-out) + :returns: tuple of validation and early stopping datasets """ validation_dataset.shuffle(random_state=42) cv_v = validation_dataset.split_dataset( @@ -790,6 +830,7 @@ def split_early_stopping( return validation_dataset, early_stopping_dataset +@pipeline_function def train_and_predict( model: DRPModel, hpams: dict, @@ -804,19 +845,21 @@ def train_and_predict( """ Train the model and predict the response for the prediction dataset. - :param model: - :param hpams: - :param path_data: - :param train_dataset: - :param prediction_dataset: - :param early_stopping_dataset: - :param response_transformation: - :param cl_features: - :param drug_features: - :return: + :param model: model to use, e.g., SimpleNeuralNetwork + :param hpams: hyperparameters to use + :param path_data: path to the data directory, e.g., data/ + :param train_dataset: training dataset + :param prediction_dataset: prediction dataset + :param early_stopping_dataset: early stopping dataset, optional + :param response_transformation: normalizer to use for the response data, e.g., StandardScaler + :param cl_features: cell line features + :param drug_features: drug features + :returns: prediction dataset with predictions + :raises ValueError: if train_dataset does not have a dataset_name """ model.build_model(hyperparameters=hpams) - + if train_dataset.dataset_name is None: + raise ValueError("train_dataset must have a dataset_name") if cl_features is None: print("Loading cell line features ...") cl_features = model.load_cell_line_features(data_path=path_data, dataset_name=train_dataset.dataset_name) @@ -830,10 +873,18 @@ def train_and_predict( # making sure there are no missing features: len_train_before = len(train_dataset) len_pred_before = len(prediction_dataset) + print(f"Number of cell lines in features: {len(cell_lines_to_keep)}") + if drugs_to_keep is not None: + print(f"Number of drugs in features: {len(drugs_to_keep)}") + print(f"Number of cell lines in train dataset: {len(np.unique(train_dataset.cell_line_ids))}") + print(f"Number of drugs in train dataset: {len(np.unique(train_dataset.drug_ids))}") + train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) prediction_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) - print(f"Reduced training dataset from {len_train_before} to {len(train_dataset)}") - print(f"Reduced prediction dataset from {len_pred_before} to {len(prediction_dataset)}") + print(f"Reduced training dataset from {len_train_before} to {len(train_dataset)}, because of missing features") + print( + f"Reduced prediction dataset from {len_pred_before} to {len(prediction_dataset)}, because of missing features" + ) if early_stopping_dataset is not None: len_es_before = len(early_stopping_dataset) @@ -842,7 +893,8 @@ def train_and_predict( if response_transformation: train_dataset.fit_transform(response_transformation) - early_stopping_dataset.transform(response_transformation) + if early_stopping_dataset is not None: + early_stopping_dataset.transform(response_transformation) prediction_dataset.transform(response_transformation) print("Training model ...") @@ -852,15 +904,18 @@ def train_and_predict( drug_input=drug_features, output_earlystopping=early_stopping_dataset, ) - prediction_dataset.predictions = model.predict( - cell_line_ids=prediction_dataset.cell_line_ids, - drug_ids=prediction_dataset.drug_ids, - cell_line_input=cl_features, - drug_input=drug_features, - ) + if len(prediction_dataset) > 0: + prediction_dataset._predictions = model.predict( + cell_line_ids=prediction_dataset.cell_line_ids, + drug_ids=prediction_dataset.drug_ids, + cell_line_input=cl_features, + drug_input=drug_features, + ) - if response_transformation: - prediction_dataset.inverse_transform(response_transformation) + if response_transformation: + prediction_dataset.inverse_transform(response_transformation) + else: + prediction_dataset._predictions = np.array([]) return prediction_dataset @@ -876,17 +931,17 @@ def train_and_evaluate( metric: str = "rmse", ) -> dict[str, float]: """ - Train and evaluate the model. - - :param model: - :param hpams: - :param path_data: - :param train_dataset: - :param validation_dataset: - :param early_stopping_dataset: - :param response_transformation: - :param metric: - :return: + Train and evaluate the model, i.e., call train_and_predict() and then evaluate(). + + :param model: model to use + :param hpams: hyperparameters to use + :param path_data: path to the data directory + :param train_dataset: training dataset + :param validation_dataset: validation dataset + :param early_stopping_dataset: early stopping dataset + :param response_transformation: normalizer to use for the response data + :param metric: metric to evaluate the model on + :returns: dictionary of the evaluation results, e.g., {"RMSE": 0.1} """ validation_dataset = train_and_predict( model=model, @@ -911,17 +966,18 @@ def hpam_tune( path_data: str = "data", ) -> dict: """ - Tune the hyperparameters for the given model. - - :param model: - :param train_dataset: - :param validation_dataset: - :param hpam_set: - :param early_stopping_dataset: - :param response_transformation: - :param metric: - :param path_data: - :return: + Tune the hyperparameters for the given model in an iterative manner. + + :param model: model to use + :param train_dataset: training dataset + :param validation_dataset: validation dataset + :param hpam_set: hyperparameters to tune + :param early_stopping_dataset: early stopping dataset + :param response_transformation: normalizer to use for the response data + :param metric: metric to evaluate which model is the best + :param path_data: path to the data directory, e.g., data/ + :returns: best hyperparameters + :raises AssertionError: if hpam_set is empty """ if len(hpam_set) == 0: raise AssertionError("hpam_set must contain at least one hyperparameter configuration") @@ -973,16 +1029,16 @@ def hpam_tune_raytune( """ Tune the hyperparameters for the given model using raytune. - :param model: - :param train_dataset: - :param validation_dataset: - :param early_stopping_dataset: - :param hpam_set: - :param response_transformation: - :param metric: - :param ray_path: - :param path_data: - :return: + :param model: model to use + :param train_dataset: training dataset + :param validation_dataset: validation dataset + :param early_stopping_dataset: early stopping dataset + :param hpam_set: hyperparameters to tune + :param response_transformation: normalizer to use for the response data + :param metric: metric to evaluate which model is the best + :param ray_path: path to the raytune directory + :param path_data: path to the data directory, e.g., data/ + :returns: best hyperparameters """ if len(hpam_set) == 1: return hpam_set[0] @@ -1016,32 +1072,35 @@ def hpam_tune_raytune( return best_config +@pipeline_function def make_model_list(models: list[type[DRPModel]], response_data: DrugResponseDataset) -> dict[str, str]: """ - Make a list of models to evaluate. + Make a list of models to evaluate: if it is a single drug model, add the drug id to the model name. - :param models: - :param baselines: - :param response_data: - :return: + :param models: list of models to evaluate + :param response_data: response data, needed to get the unique drugs for single drug models + :returns: dictionary of model names: model class, e.g., {"SimpleNeuralNetwork": "SimpleNeuralNetwork", + "MOLIR.Afatinib": "MOLIR"} """ model_list = {} unique_drugs = np.unique(response_data.drug_ids) for model in models: - if issubclass(model, SingleDrugModel): + if model.is_single_drug_model: for drug in unique_drugs: - model_list[f"{model.model_name}.{drug}"] = str(model.model_name) + model_list[f"{model.get_model_name()}.{drug}"] = model.get_model_name() else: - model_list[str(model.model_name)] = str(model.model_name) + model_list[model.get_model_name()] = model.get_model_name() return model_list -def get_model_name_and_drug_id(model_name: str): +@pipeline_function +def get_model_name_and_drug_id(model_name: str) -> tuple[str, Optional[str]]: """ Get the model name and drug id from the model name. - :param model_name: - :return: + :param model_name: model name, e.g., SimpleNeuralNetwork or MOLIR.Afatinib + :returns: tuple of model name and, potentially drug id if it is a single drug model + :raises AssertionError: if the model name is not found in the model factory """ if model_name in MULTI_DRUG_MODEL_FACTORY: return model_name, None @@ -1058,13 +1117,18 @@ def get_model_name_and_drug_id(model_name: str): return model_name, drug_id -def get_datasets_from_cv_split(split, model_class, model_name, drug_id): +@pipeline_function +def get_datasets_from_cv_split( + split: dict[str, DrugResponseDataset], model_class: type[DRPModel], model_name: str, drug_id: Optional[str] = None +) -> tuple[DrugResponseDataset, DrugResponseDataset, Optional[DrugResponseDataset], DrugResponseDataset]: """ - Get dataset from cross validation split. + Get train, validation, (early stopping), and test datasets from the CV split. - :param model_class: - :param model_name: - :param drug_id: + :param split: dictionary of the CV split + :param model_class: model class + :param model_name: model name + :param drug_id: drug id for single drug models + :returns: tuple of train, validation, (early stopping), and test datasets """ train_dataset = split["train"] validation_dataset = split["validation"] @@ -1101,14 +1165,18 @@ def get_datasets_from_cv_split(split, model_class, model_name, drug_id): ) -def generate_data_saving_path(model_name, drug_id, result_path, suffix): +@pipeline_function +def generate_data_saving_path(model_name, drug_id, result_path, suffix) -> str: """ Generate a path to save data to. - :param model_name: - :param drug_id: - :param result_path: - :param suffix: + For single drug models, the path is result_path/model_name/drugs/drug_id/suffix. + For all others, it is result_path/model_name/suffix. + :param model_name: model name + :param drug_id: drug id + :param result_path: path to the results directory + :param suffix: suffix to add to the path, e.g., "predictions", "best_hpams", "randomization", "robustness" + :returns: path to save data to """ is_single_drug_model = model_name in SINGLE_DRUG_MODEL_FACTORY if is_single_drug_model: diff --git a/drevalpy/models/DIPK/__init__.py b/drevalpy/models/DIPK/__init__.py new file mode 100644 index 0000000..54ad098 --- /dev/null +++ b/drevalpy/models/DIPK/__init__.py @@ -0,0 +1 @@ +"""Module for the DIPK model.""" diff --git a/drevalpy/models/DIPK/attention_utils.py b/drevalpy/models/DIPK/attention_utils.py new file mode 100644 index 0000000..e7cd88b --- /dev/null +++ b/drevalpy/models/DIPK/attention_utils.py @@ -0,0 +1,78 @@ +"""Contains a custom MultiHeadAttentionLayer for the DIPK model.""" + +import torch +from torch import nn + + +class MultiHeadAttentionLayer(nn.Module): + """Custom multi-head attention layer for the DIPK model.""" + + def __init__(self, hid_dim: int, n_heads: int, dropout: float, device: str | torch.device | int | None): + """ + Initialize the multi-head attention layer. + + :param hid_dim: dimension of hidden layer + :param n_heads: number of heads + :param dropout: dropout rate + :param device: which device to use, e.g. "cuda" or "cpu" + :raises ValueError: if hidden dimension is not divisible by the number of heads + """ + super().__init__() + + # Ensure head dimension divides evenly + if hid_dim % n_heads != 0: + raise ValueError("Hidden dimension must be divisible by the number of heads.") + + # Define dimensions + self.hid_dim = hid_dim + self.n_heads = n_heads + self.head_dim = hid_dim // n_heads + + # Define fully connected layers for Q, K, V, and output + self.fc_q = nn.Linear(hid_dim, hid_dim) + self.fc_k = nn.Linear(hid_dim, hid_dim) + self.fc_v = nn.Linear(hid_dim, hid_dim) + self.fc_o = nn.Linear(hid_dim, hid_dim) + + # Dropout and scaling factor + self.dropout = nn.Dropout(dropout) + self.scale = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32, device=device)) + + def forward( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the multi-head attention layer. + + :param query: query tensor + :param key: key tensor + :param value: value tensor + :param mask: mask tensor + :returns: output tensor and attention tensor + """ + batch_size = query.size(0) + + # Transform inputs + transformed_query = self.fc_q(query) + transformed_key = self.fc_k(key) + transformed_value = self.fc_v(value) + + # Split into heads and transpose for multi-head processing + transformed_query = transformed_query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + transformed_key = transformed_key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + transformed_value = transformed_value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + + # Scaled dot-product attention + energy = torch.matmul(transformed_query, transformed_key.transpose(-2, -1)) / self.scale + if mask is not None: + energy = energy.masked_fill(mask == 0, float("-inf")) + attention = torch.softmax(energy, dim=-1) + + # Apply attention weights + x = torch.matmul(self.dropout(attention), transformed_value) + + # Concatenate heads and pass through output layer + x = x.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.hid_dim) + x = self.fc_o(x) + + return x, attention diff --git a/drevalpy/models/DIPK/data_utils.py b/drevalpy/models/DIPK/data_utils.py new file mode 100644 index 0000000..11926e3 --- /dev/null +++ b/drevalpy/models/DIPK/data_utils.py @@ -0,0 +1,207 @@ +""" +Includes functions to load and process the DIPK dataset. + +get_data: Creates a list of dictionaries with drug and cell line features. +CollateFn: Class to collate the DataLoader batches. +DIPKDataset: Dataset class for the DIPK model. + +""" + +import os +from abc import ABC + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from drevalpy.datasets.dataset import FeatureDataset + + +def load_bionic_features(data_path: str, dataset_name: str, gene_add_num: int = 512) -> FeatureDataset: + """ + Load biological network (BIONIC) features for DIPK. + + :param data_path: Path to the data, e.g., "data/" + :param dataset_name: Name of the dataset, e.g., GDSC2 + :param gene_add_num: Number of genes to add to the feature set + :returns: FeatureDataset with gene expression and biological network features + """ + # Load gene expression dataset + gene_expression_path = os.path.join(data_path, dataset_name, "gene_expression.csv") + gene_expression = pd.read_csv(gene_expression_path) + expression_dict = gene_expression.set_index("cell_line_name").drop("cellosaurus_id", axis=1).T.to_dict() + + # Load gene list and PPI features + gene_list_path = os.path.join(data_path, dataset_name, "DIPK_features", "gene_list_sel.txt") + with open(gene_list_path, encoding="gbk") as f: + gene_list = {line.strip() for line in f} + + ppi_path = os.path.join(data_path, dataset_name, "DIPK_features", "human_ppi_features.tsv") + dataset = pd.read_csv(ppi_path, index_col=0, sep="\t") + + # Ensure BIONIC dictionary uses gene names directly + bionic_gene_dict = {gene: dataset.loc[gene].values for gene in gene_list if gene in dataset.index} + + # Compute BIONIC features + bionic_feature_dict = {} + for cell_line, expressions in expression_dict.items(): + # Sort genes based on descending expression values + sorted_genes = sorted(expressions.items(), key=lambda x: -x[1]) + top_genes = [gene for gene, _ in sorted_genes[:gene_add_num]] + + # Aggregate BIONIC features for selected genes + selected_features = [bionic_gene_dict[gene] for gene in top_genes if gene in bionic_gene_dict] + if selected_features: + aggregated_feature = np.mean(selected_features, axis=0) + else: + # Handle case where no features are found (padding with zeros) + aggregated_feature = np.zeros(next(iter(bionic_gene_dict.values())).shape) + + bionic_feature_dict[cell_line] = aggregated_feature + + feature_data = {cell_line: {"bionic_features": features} for cell_line, features in bionic_feature_dict.items()} + return FeatureDataset(features=feature_data) + + +def get_data( + cell_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_features: FeatureDataset, + drug_features: FeatureDataset, + ic50: np.ndarray | None = None, +) -> list: + """ + Prepare data samples for training or prediction. + + Each sample includes: + - Drug features (e.g., molecular embeddings). + - Cell line features (gene expression and bionic_features). + - Optional IC50 response values for supervised tasks. + + :param cell_ids: IDs of the cell lines from the dataset. + :param drug_ids: IDs of the drugs from the dataset. + :param cell_line_features: Input features associated with the cell lines. + :param drug_features: Input features associated with the drugs. + :param ic50: (Optional) Response values (e.g., IC50) to associate with samples. + :return: List of dictionaries, each containing drug and cell line features, with optional IC50. + """ + data_list = [] + for i in range(len(cell_ids)): + drug_id = str(drug_ids[i]) + cell_id = str(cell_ids[i]) + drug_tensor = torch.tensor(drug_features.features[drug_id]["molgnet_features"], dtype=torch.float32) + gene_expression = torch.tensor(cell_line_features.features[cell_id]["gene_expression"], dtype=torch.float32) + bionic_features = torch.tensor(cell_line_features.features[cell_id]["bionic_features"], dtype=torch.float32) + + sample = { + "molgnet_features": drug_tensor, + "gene_expression": gene_expression, + "bionic_features": bionic_features, + } + if ic50 is not None: + sample["ic50"] = torch.tensor([ic50[i]], dtype=torch.float32) + + data_list.append(sample) + + return data_list + + +class CollateFn: + """Collate function for the DataLoader, either for training or testing.""" + + def __init__(self, train=True): + """ + Initialize the CollateFn. + + :param train: indicates whether the DataLoader is used for training + """ + self.train = train + + def __call__(self, batch): + """ + Collate the batch. + + :param batch: batch of feature dictionaries + :returns: collated node features, gene features, bionic features, and (optional) IC50 values + """ + # Find the max number of atoms (nodes) in the batch for molgnet_features padding + max_atoms_molgnet = max([sample["molgnet_features"].size(0) for sample in batch]) + + # Pad molgnet_features to match the maximum number of atoms + padded_molgnet_features = [] + molgnet_mask = [] + + for sample in batch: + num_atoms = sample["molgnet_features"].size(0) + padding_size = max_atoms_molgnet - num_atoms + + # Pad molgnet_features + padded_features = torch.cat( + [sample["molgnet_features"], torch.zeros(padding_size, sample["molgnet_features"].size(1))], dim=0 + ) + padded_molgnet_features.append(padded_features) + + # Create a mask where valid atom features are True and padded ones are False + mask = torch.cat( + [torch.ones(num_atoms, dtype=torch.bool), torch.zeros(padding_size, dtype=torch.bool)], dim=0 + ) + molgnet_mask.append(mask) + + # Stack the padded molgnet features into a single tensor + molgnet_features = torch.stack(padded_molgnet_features) + molgnet_mask = torch.stack(molgnet_mask) + + # Collate other features + gene_features = torch.stack([sample["gene_expression"] for sample in batch]) + bionic_features = torch.stack([sample["bionic_features"] for sample in batch]) + + if self.train: + ic50_values = torch.stack([sample["ic50"] for sample in batch]) + # Return a dictionary with all features + return { + "molgnet_features": molgnet_features, + "gene_features": gene_features, + "bionic_features": bionic_features, + "ic50_values": ic50_values, + "molgnet_mask": molgnet_mask, + } + else: + # Return a dictionary without ic50_values for inference + return { + "molgnet_features": molgnet_features, + "gene_features": gene_features, + "bionic_features": bionic_features, + "molgnet_mask": molgnet_mask, + } + + +class DIPKDataset(Dataset, ABC): + """Dataset of graphs from get_data.""" + + def __init__(self, samples): + """ + Initialize the GraphDataset. + + :param samples: list + """ + super().__init__() + self._samples = samples + + def __getitem__(self, idx): + """ + Get the sample at index idx. + + :param idx: index + :returns: sample + """ + sample = self._samples[idx] + return sample + + def __len__(self) -> int: + """ + Get the number of graphs in the dataset. + + :return: number of samples + """ + return len(self._samples) diff --git a/drevalpy/models/DIPK/dipk.py b/drevalpy/models/DIPK/dipk.py new file mode 100644 index 0000000..5b1c685 --- /dev/null +++ b/drevalpy/models/DIPK/dipk.py @@ -0,0 +1,333 @@ +""" +DIPK model. Adapted from https://github.com/user15632/DIPK. + +Original publication: +Improving drug response prediction via integrating gene relationships with deep learning +Pengyong Li, Zhengxiang Jiang, Tianxiao Liu, Xinyu Liu, Hui Qiao, Xiaojun Yao +Briefings in Bioinformatics, Volume 25, Issue 3, May 2024, bbae153, https://doi.org/10.1093/bib/bbae153 +""" + +import os +from typing import Any + +import numpy as np +import pandas as pd +import torch +import torch.optim as optim +from torch import nn +from torch.utils.data import DataLoader + +from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset +from drevalpy.models.drp_model import DRPModel +from drevalpy.models.utils import load_and_reduce_gene_features + +from .data_utils import CollateFn, DIPKDataset, get_data, load_bionic_features +from .gene_expression_encoder import GeneExpressionEncoder, encode_gene_expression, train_gene_expession_autoencoder +from .model_utils import Predictor + + +class DIPKModel(DRPModel): + """DIPK model. Adapted from https://github.com/user15632/DIPK.""" + + cell_line_views = ["gene_expression", "bionic_features"] + drug_views = ["molgnet_features"] + early_stopping = True + + def __init__(self) -> None: + """Initialize the DIPK model.""" + super().__init__() + self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # all of this gets initialized in build_model + self.model: Predictor | None = None + self.epochs: int = 0 + self.batch_size: int = 0 + self.lr: float = 0.0 + self.gene_expression_encoder: GeneExpressionEncoder | None = None + self.epochs_autoencoder: int = 100 + + @classmethod + def get_model_name(cls) -> str: + """ + Get the model name. + + :returns: DIPK + """ + return "DIPK" + + def build_model(self, hyperparameters: dict[str, Any]) -> None: + """ + Builds the DIPK model with the specified hyperparameters. + + :param hyperparameters: embedding_dim, heads, fc_layer_num, fc_layer_dim, dropout_rate, EPOCHS, batch_size, lr + + Details of hyperparameters: + + - embedding_dim: int, embedding dimension used for the graph encoder which is not used in the final model + - heads: int, number of heads for the multi-head attention layer, defaults to 1 + - fc_layer_num: int, number of fully connected layers for the dense layers + - fc_layer_dim: list[int], number of neurons for each fully connected layer + - dropout_rate: float, dropout rate for all fully connected layers + - EPOCHS: int, number of epochs to train the model + - batch_size: int, batch size for training + - lr: float, learning rate for training + """ + self.model = Predictor( + hyperparameters["heads"], + hyperparameters["fc_layer_num"], + hyperparameters["fc_layer_dim"], + hyperparameters["dropout_rate"], + ).to(self.DEVICE) + self.epochs = hyperparameters["epochs"] + self.batch_size = hyperparameters["batch_size"] + self.lr = hyperparameters["lr"] + self.epochs_autoencoder = hyperparameters["epochs_autoencoder"] + self.patience = hyperparameters["patience"] + + def train( + self, + output: DrugResponseDataset, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, + ) -> None: + """ + Trains the model. + + :param output: training data associated with the response output + :param cell_line_input: input data associated with the cell line + :param drug_input: input data associated with the drug + :param output_earlystopping: early stopping data associated with the response output, not used + :raises ValueError: if drug_input is None or if the model is not initialized + """ + if drug_input is None: + raise ValueError("DIPK model requires drug features.") + if not isinstance(self.model, Predictor): + raise ValueError("DIPK model not initialized.") + if output_earlystopping is None: + raise ValueError("DIPK model requires early stopping data.") + + loss_func = nn.MSELoss() + params = [{"params": self.model.parameters()}] + optimizer = optim.Adam(params, lr=self.lr) + + self.gene_expression_encoder = train_gene_expession_autoencoder( + cell_line_input.get_feature_matrix(view="gene_expression", identifiers=output.cell_line_ids), + cell_line_input.get_feature_matrix(view="gene_expression", identifiers=output_earlystopping.cell_line_ids), + epochs_autoencoder=self.epochs_autoencoder, + ) + + cell_line_input.apply( + lambda x: encode_gene_expression(x, self.gene_expression_encoder), # type: ignore[arg-type] + view="gene_expression", + ) # type: ignore[arg-type] + + # Load data + collate = CollateFn(train=True) + train_samples = get_data( + cell_ids=output.cell_line_ids, + drug_ids=output.drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ic50=output.response, + ) + early_stopping_samples = get_data( + cell_ids=output_earlystopping.cell_line_ids, + drug_ids=output_earlystopping.drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ic50=output_earlystopping.response, + ) + + train_loader: DataLoader = DataLoader( + DIPKDataset(train_samples), batch_size=self.batch_size, shuffle=True, collate_fn=collate + ) + early_stopping_loader: DataLoader = DataLoader( + DIPKDataset(early_stopping_samples), batch_size=self.batch_size, shuffle=True, collate_fn=collate + ) + + # Early stopping parameters + best_val_loss = float("inf") + epochs_without_improvement = 0 + + # Train model + print("Training DIPK model") + for epoch in range(self.epochs): + self.model.train() + epoch_loss = 0.0 + batch_count = 0 + + # Training phase + for batch in train_loader: + drug_features = batch["molgnet_features"].to(self.DEVICE) + gene_features = batch["gene_features"].to(self.DEVICE) + bionic_features = batch["bionic_features"].to(self.DEVICE) + molgnet_mask = batch["molgnet_mask"].to(self.DEVICE) + ic50_values = batch["ic50_values"].to(self.DEVICE) + + # Forward pass + prediction = self.model( + molgnet_drug_features=drug_features, + gene_expression=gene_features, + bionic=bionic_features, + molgnet_mask=molgnet_mask, + ) + + # Compute the loss + loss = loss_func(torch.squeeze(prediction), torch.squeeze(ic50_values)) + + # Backpropagation + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Update loss and batch count + epoch_loss += loss.detach().item() + batch_count += 1 + + epoch_loss /= batch_count + print(f"Epoch [{epoch + 1}] Training Loss: {epoch_loss:.4f}") + + # Validation phase for early stopping + self.model.eval() + val_loss = 0.0 + val_batch_count = 0 + with torch.no_grad(): + for batch in early_stopping_loader: + drug_features = batch["molgnet_features"].to(self.DEVICE) + gene_features = batch["gene_features"].to(self.DEVICE) + bionic_features = batch["bionic_features"].to(self.DEVICE) + molgnet_mask = batch["molgnet_mask"].to(self.DEVICE) + ic50_values = batch["ic50_values"].to(self.DEVICE) + + # Forward pass + prediction = self.model( + molgnet_drug_features=drug_features, + gene_expression=gene_features, + bionic=bionic_features, + molgnet_mask=molgnet_mask, + ) + + # Compute the loss + loss = loss_func(torch.squeeze(prediction), torch.squeeze(ic50_values)) + + # Update validation loss + val_loss += loss.item() + val_batch_count += 1 + + val_loss /= val_batch_count + print(f"Epoch [{epoch + 1}] Validation Loss: {val_loss:.4f}") + + # Early stopping check + if val_loss < best_val_loss: + best_val_loss = val_loss + epochs_without_improvement = 0 + else: + epochs_without_improvement += 1 + if epochs_without_improvement >= self.patience: + print(f"Early stopping triggered at epoch {epoch + 1}") + break + + def predict( + self, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + ) -> np.ndarray: + """ + Predicts the response values for the given cell lines and drugs. + + :param cell_line_ids: list of cell line IDs + :param drug_ids: list of drug IDs + :param cell_line_input: input data associated with the cell line + :param drug_input: input data associated with the drug + :return: predicted response values + :raises ValueError: if drug_input is None or if the model is not initialized + """ + if drug_input is None: + raise ValueError("DIPK model requires drug features.") + if not isinstance(self.model, Predictor): + raise ValueError("DIPK model not initialized.") + + # Load data + collate = CollateFn(train=False) + test_samples = get_data( + cell_ids=cell_line_ids, + drug_ids=drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ) + test_loader: DataLoader = DataLoader( + DIPKDataset(test_samples), batch_size=self.batch_size, shuffle=False, collate_fn=collate + ) + + # Run prediction + self.model.eval() + predictions = [] + with torch.no_grad(): + for batch in test_loader: + drug_features = batch["molgnet_features"].to(self.DEVICE) + gene_features = batch["gene_features"].to(self.DEVICE) + bionic_features = batch["bionic_features"].to(self.DEVICE) + molgnet_mask = batch["molgnet_mask"].to(self.DEVICE) + + prediction = self.model( + molgnet_drug_features=drug_features, + gene_expression=gene_features, + bionic=bionic_features, + molgnet_mask=molgnet_mask, + ) + predictions += torch.squeeze(prediction).cpu().tolist() + + return np.array(predictions) + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Load cell line features. + + :param data_path: path to the data + :param dataset_name: path to the dataset + :returns: cell line features + """ + gene_expression = load_and_reduce_gene_features( + feature_type="gene_expression", + gene_list=None, + data_path=data_path, + dataset_name=dataset_name, + ) + bionic_features = load_bionic_features( + data_path=data_path, + dataset_name=dataset_name, + ) + bionic_features.add_features(gene_expression) + + return bionic_features + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Load drug features. + + :param data_path: path to the data + :param dataset_name: path to the dataset + :returns: drug features + """ + + def load_feature(file_path, sep="\t"): + return np.array(pd.read_csv(file_path, index_col=0, sep=sep)) + + drug_path = os.path.join(data_path, dataset_name, "DIPK_features", "Drugs") + files_in_drug_path = os.listdir(drug_path) + drug_list = [ + file.split("_")[1].split(".csv")[0] + for file in files_in_drug_path + if file.endswith(".csv") and file.startswith("MolGNet") + ] + + return FeatureDataset( + features={ + drug: { + "molgnet_features": load_feature(os.path.join(drug_path, f"MolGNet_{drug}.csv")), + } + for drug in drug_list + } + ) diff --git a/drevalpy/models/DIPK/gene_expression_encoder.py b/drevalpy/models/DIPK/gene_expression_encoder.py new file mode 100644 index 0000000..21e8c2a --- /dev/null +++ b/drevalpy/models/DIPK/gene_expression_encoder.py @@ -0,0 +1,258 @@ +"""Gene expression Autoencoder for DIPK model.""" + +from abc import ABC +from copy import deepcopy + +import numpy as np +import torch +import torch.nn +import torch.nn as nn +import torch.optim as optim +from torch.nn import functional +from torch.utils.data import DataLoader, Dataset + +ldim = 512 +hdim = [2048, 1024] + + +class GeneExpressionEncoder(nn.Module): + """Gene expression encoder. + + Code adapted from the + DIPK model https://github.com/user15632/DIPK. + """ + + def __init__(self, input_dim, latent_dim=ldim, h_dims=None, drop_out_rate=0.3): + """Initialize the gene expression encoder. + + :param input_dim: input dimension + :param latent_dim: latent dimension + :param h_dims: hidden dimensions + :param drop_out_rate: dropout rate + """ + super().__init__() + if h_dims is None: + h_dims = hdim + hidden_dims = deepcopy(h_dims) + hidden_dims.insert(0, input_dim) + modules = [] + for i in range(1, len(hidden_dims)): + modules.append( + nn.Sequential( + nn.Linear(hidden_dims[i - 1], hidden_dims[i]), + nn.BatchNorm1d(hidden_dims[i]), + nn.ReLU(), + nn.Dropout(drop_out_rate), + ) + ) + self.encoder = nn.Sequential(*modules) + self.bottleneck = nn.Linear(hidden_dims[-1], latent_dim) + + def forward(self, input): + """Forward pass of the gene expression encoder. + + :param input: input data + :return: encoded data + """ + result = self.encoder(input) + embedding = functional.relu(self.bottleneck(result)) + return embedding + + +class GeneExpressionDecoder(nn.Module): + """Gene expression decoder.""" + + def __init__(self, input_dim, latent_dim=ldim, h_dims=None, drop_out_rate=0.3): + """Initialize the gene expression decoder. + + :param input_dim: input dimension + :param latent_dim: latent dimension + :param h_dims: hidden dimensions + :param drop_out_rate: dropout rate + """ + super().__init__() + if h_dims is None: + h_dims = hdim + hidden_dims = deepcopy(h_dims) + hidden_dims.insert(0, input_dim) + self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]) + hidden_dims.reverse() + modules = [] + for i in range(len(hidden_dims) - 2): + modules.append( + nn.Sequential( + nn.Linear(hidden_dims[i], hidden_dims[i + 1]), + nn.BatchNorm1d(hidden_dims[i + 1]), + nn.ReLU(), + nn.Dropout(drop_out_rate), + ) + ) + self.decoder = nn.Sequential(*modules) + self.decoder_output = nn.Linear(hidden_dims[-2], hidden_dims[-1]) + + def forward(self, embedding): + """ + Forward pass of the gene expression decoder. + + :param embedding: input data + :return: decoded data + """ + result = self.decoder_input(embedding) + result = self.decoder(result) + output = self.decoder_output(result) + return output + + +class CollateFn: + """Collate function for the DataLoader, either for training or testing.""" + + def __call__(self, batch): + """Collate the batch. + + :param batch: batch of PyG Data objects + :returns: PyG Batch, gene features, and bionic features + """ + batch_data = torch.stack(batch) + return batch_data + + +class DataSet(Dataset, ABC): + """Dataset class for gene expression data.""" + + def __init__(self, data): + """Initialize the dataset. + + :param data: data + """ + self._data = data + + def __getitem__(self, idx): + """Return the data at the given index. + + :param idx: index + :return: data + """ + data = self._data[idx] + return data + + def __len__(self): + """Return the length of the dataset. + + :return: length of the dataset + """ + return len(self._data) + + +def train_gene_expession_autoencoder( + gene_expression_input: np.ndarray, gene_expression_input_early_stopping: np.ndarray, epochs_autoencoder: int = 100 +) -> GeneExpressionEncoder: + """Train the autoencoder model for gene expression data with early stopping. + + :param gene_expression_input: gene expression data + :param gene_expression_input_early_stopping: validation data for early stopping + :param epochs_autoencoder: number of epochs for training the autoencoder + :return: trained encoder model + """ + lr = 1e-4 + batch_size = 1024 + noising = True + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # create model + encoder = GeneExpressionEncoder(len(gene_expression_input[0])).to(device) + decoder = GeneExpressionDecoder(len(gene_expression_input[0])).to(device) + loss_func = nn.MSELoss() + params = [{"params": encoder.parameters()}, {"params": decoder.parameters()}] + optimizer = optim.Adam(params, lr=lr) + + # load data + my_collate = CollateFn() + gene_expression_tensor = torch.tensor(gene_expression_input, dtype=torch.float32).to(device) + train_loader = DataLoader( + DataSet(gene_expression_tensor), batch_size=batch_size, shuffle=True, collate_fn=my_collate + ) + + # prepare early stopping validation data + gene_expression_val_tensor = torch.tensor(gene_expression_input_early_stopping, dtype=torch.float32).to(device) + + # early stopping parameters + patience = 5 + best_val_loss = float("inf") + epochs_without_improvement = 0 + + print("Training DIPK autoencoder for gene expression data") + for epoch_index in range(epochs_autoencoder): + # training + encoder.train() + decoder.train() + epoch_loss = 0.0 + batch_count = 0 + for gene_expression_batch in train_loader: + gene_expression_batch = gene_expression_batch.to(device) + if noising: + z = gene_expression_batch.clone() + y = np.random.binomial(1, 0.2, (z.shape[0], z.shape[1])) + z[np.array(y, dtype=bool)] = 0 + gene_expression_batch.requires_grad_(True) + output = decoder(encoder(z)) + else: + output = decoder(encoder(gene_expression_batch)) + loss = loss_func(output, gene_expression_batch) + optimizer.zero_grad() + loss.backward() + optimizer.step() + epoch_loss += loss.detach().item() + batch_count += 1 + epoch_loss /= batch_count + + # validation + encoder.eval() + decoder.eval() + with torch.no_grad(): + val_output = decoder(encoder(gene_expression_val_tensor)) + val_loss = loss_func(val_output, gene_expression_val_tensor).item() + + print(f"DIPK Autoenc. Epoch: {epoch_index}, Train Loss: {epoch_loss}, Val Loss: {val_loss}") + + # early stopping check + if val_loss < best_val_loss: + best_val_loss = val_loss + epochs_without_improvement = 0 + else: + epochs_without_improvement += 1 + if epochs_without_improvement >= patience: + print(f"DIPK Autoenc. Early stopping triggered at epoch {epoch_index}") + break + + encoder.eval() + return encoder + + +def encode_gene_expression(gene_expression_input: np.ndarray, encoder: GeneExpressionEncoder) -> np.ndarray: + """Encode gene expression data. + + :param gene_expression_input: gene expression data + :param encoder: trained encoder model + :return: encoded gene expression data + """ + encoder.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + encoder.to(device) + + # Check the original input shape, because we have to unsqueeze the input if it is a single vector + original_shape = gene_expression_input.shape + + gene_expression_tensor = torch.tensor(gene_expression_input, dtype=torch.float32).to(device) + + # Add batch dimension if input is a single vector + if gene_expression_tensor.ndim == 1: + gene_expression_tensor = gene_expression_tensor.unsqueeze(0) + + with torch.no_grad(): + encoded_data = encoder(gene_expression_tensor).cpu().numpy() + + # Match the output shape to the input shape + if len(original_shape) == 1: + encoded_data = encoded_data.squeeze(0) + + return encoded_data diff --git a/drevalpy/models/DIPK/hyperparameters.yaml b/drevalpy/models/DIPK/hyperparameters.yaml new file mode 100644 index 0000000..2107f5e --- /dev/null +++ b/drevalpy/models/DIPK/hyperparameters.yaml @@ -0,0 +1,25 @@ +--- +DIPK: + batch_size: + - 64 + lr: + - 0.00001 + heads: + - 2 + fc_layer_num: + - 3 + fc_layer_dim: + - - 256 + - 128 + - 64 + - 32 + - 16 + - 1 + dropout_rate: + - 0.3 + epochs: + - 100 + epochs_autoencoder: + - 100 + patience: + - 10 diff --git a/drevalpy/models/DIPK/model_utils.py b/drevalpy/models/DIPK/model_utils.py new file mode 100644 index 0000000..329c550 --- /dev/null +++ b/drevalpy/models/DIPK/model_utils.py @@ -0,0 +1,144 @@ +"""Includes custom torch.nn.Modules for the DIPK model: AttentionLayer, DenseLayer, Predictor.""" + +import torch +import torch.nn as nn + +from .attention_utils import MultiHeadAttentionLayer + +features_dim_gene = 512 +features_dim_bionic = 512 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class AttentionLayer(nn.Module): + """Custom attention layer for the DIPK model.""" + + def __init__(self, heads: int = 1): + """ + Initialize the attention layer with a multi-head attention layer with a specified number of heads. + + :param heads: number of heads for the multi-head attention layer + """ + super().__init__() + self.fc_layer_0 = nn.Linear(features_dim_gene, 768) + self.fc_layer_1 = nn.Linear(features_dim_bionic, 768) + self.attention_0 = MultiHeadAttentionLayer(hid_dim=768, n_heads=heads, dropout=0.3, device=DEVICE) + self.attention_1 = MultiHeadAttentionLayer(hid_dim=768, n_heads=heads, dropout=0.3, device=DEVICE) + + def forward( + self, molgnet_features: torch.Tensor, mask: torch.Tensor, gene_expression: torch.Tensor, bionic: torch.Tensor + ) -> torch.Tensor: + """ + Forward pass of the attention layer. + + :param molgnet_features: MolGNet features + :param mask: mask for the MolGNet features, as molecules have varying sizes (valid atom features are True) + :param gene_expression: gene expression features of the graph data + :param bionic: bionic network features of the graph data + :returns: tensor of MolGNet features after attention layer + """ + gene_expression = nn.functional.relu(self.fc_layer_0(gene_expression)) # Shape: [batch_size, feature_dim_gene] + bionic = nn.functional.relu(self.fc_layer_1(bionic)) # Shape: [batch_size, feature_dim_bionic] + + # Preparing query, key, value for attention layers + query_0 = torch.unsqueeze(gene_expression, 1) # Shape: [batch_size, 1, 768] for gene + query_1 = torch.unsqueeze(bionic, 1) # Shape: [batch_size, 1, 768] for bionic + key = molgnet_features # Shape: [batch_size, seq_len, 768] (features from MolGNet) + value = molgnet_features # Shape: [batch_size, seq_len, 768] (same as key) + + mask = torch.unsqueeze(mask, 1).unsqueeze(2) + + # Apply the first attention layer + x_att = self.attention_0(query_0, key, value, mask) # Output: [batch_size, seq_len, hid_dim] + x = torch.squeeze(x_att[0]) # Squeeze to remove the extra dimension (1) + + # Apply the second attention layer + x_att = self.attention_1(query_1, key, value, mask) # Output: [batch_size, seq_len, hid_dim] + x += torch.squeeze(x_att[0]) # Add the result of the second attention to the first + + return x + + +class DenseLayers(nn.Module): + """Custom dense layers for the DIPK model.""" + + def __init__(self, fc_layer_num: int, fc_layer_dim: list[int], dropout_rate: float): + """ + Initialize the dense layers of the DIPK model which follow the attention layer. + + :param fc_layer_num: number of fully connected layers + :param fc_layer_dim: list of dimensions for each fully connected layer + :param dropout_rate: dropout rate for all fully connected layers + """ + super().__init__() + self.fc_layer_num = fc_layer_num + self.fc_layer_0 = nn.Linear(features_dim_gene, 512) + self.fc_layer_1 = nn.Linear(features_dim_bionic, 512) + self.fc_input = nn.Linear(768 + 512, 768 + 512) + self.fc_layers = torch.nn.Sequential( + nn.Linear(768 + 512, 512), + nn.Linear(512, fc_layer_dim[0]), + nn.Linear(fc_layer_dim[0], fc_layer_dim[1]), + nn.Linear(fc_layer_dim[1], fc_layer_dim[2]), + nn.Linear(fc_layer_dim[2], fc_layer_dim[3]), + nn.Linear(fc_layer_dim[3], fc_layer_dim[4]), + nn.Linear(fc_layer_dim[4], fc_layer_dim[5]), + ) + self.dropout_layers = torch.nn.ModuleList([nn.Dropout(p=dropout_rate) for _ in range(fc_layer_num)]) + self.fc_output = nn.Linear(fc_layer_dim[fc_layer_num - 2], 1) + + def forward(self, x: torch.Tensor, gene: torch.Tensor, bionic: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the dense layers. + + :param x: output tensor from the attention layer + :param gene: gene expression features (GEF) of the graph data + :param bionic: biological network features (BNF) of the graph data + :returns: output tensor after the dense layers + """ + gene = torch.nn.functional.relu(self.fc_layer_0(gene)) + bionic = torch.nn.functional.relu(self.fc_layer_1(bionic)) + f = torch.cat((x, gene + bionic), 1) + f = torch.nn.functional.relu(self.fc_input(f)) + for layer_index in range(self.fc_layer_num): + f = torch.nn.functional.relu(self.fc_layers[layer_index](f)) + f = self.dropout_layers[layer_index](f) + f = self.fc_output(f) + return f + + +class Predictor(nn.Module): + """Whole DIPK model.""" + + def __init__(self, heads: int, fc_layer_num: int, fc_layer_dim: list[int], dropout_rate: float): + """ + Initialize the DIPK model with the specified hyperparameters. + + :param heads: number of heads for the multi-head attention layer + :param fc_layer_num: number of fully connected layers for the dense layers + :param fc_layer_dim: number of neurons for each fully connected layer + :param dropout_rate: dropout rate for all fully connected layers + """ + super().__init__() + self.attention_layer = AttentionLayer(heads=heads) + self.dense_layers = DenseLayers(fc_layer_num=fc_layer_num, fc_layer_dim=fc_layer_dim, dropout_rate=dropout_rate) + + def forward( + self, + molgnet_drug_features: torch.Tensor, + gene_expression: torch.Tensor, + bionic: torch.Tensor, + molgnet_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the DIPK model. + + :param molgnet_drug_features: tensor of MolGNet features from graph data + :param gene_expression: gene expression features (GEF) of the graph data + :param bionic: biological network features (BNF) of the graph data + :param molgnet_mask: mask for the MolGNet features, as molecules have varying sizes + :returns: output tensor of the DIPK model + """ + molgnet_drug_features = self.attention_layer(molgnet_drug_features, molgnet_mask, gene_expression, bionic) + f = self.dense_layers(molgnet_drug_features, gene_expression, bionic) + return f diff --git a/drevalpy/models/DrugRegNet/DrugRegNetModel.py b/drevalpy/models/DrugRegNet/DrugRegNetModel.py deleted file mode 100644 index 3938a58..0000000 --- a/drevalpy/models/DrugRegNet/DrugRegNetModel.py +++ /dev/null @@ -1,141 +0,0 @@ -import numpy as np -import pandas as pd -from scipy import stats -from sklearn.linear_model import Lasso - - -class DrugRegNetModel: - """Class for DrugRegNetModel.""" - - def __init__(self, path_drug_response, path_dysregnet_scores, features): - """ - Initialization method for DrugRegNet model. - - :param path_drug_response: - :param path_dysregnet_scores: - :param features: - """ - self.drug_response = pd.read_csv(path_drug_response, index_col=0).T - self.dysregnet_scores = pd.read_feather(path_dysregnet_scores) - self.dysregnet_scores = self.dysregnet_scores.set_index("patient id") - self.features = features - - def create_train_data(self): - """Creates a training data set.""" - all_data = dict() - for drug in self.drug_response.columns: - print("Creating train data for drug:", drug) - drp = self.drug_response[drug] - drp = drp[~drp.isna()] - drp = drp[~drp.index.duplicated(keep="first")] - drp = drp[drp.index.isin(self.dysregnet_scores.index)] - x = self.dysregnet_scores.loc[drp.index] - x = self.feature_selection(x) - all_data[drug] = DrugRegNetDataset(drug, x, drp) - self.all_data = all_data - - def feature_selection(self, x, n_features=300): - """ - Selects features. - - :param x: - :param n_features: - """ - if self.features == "topN": - # get the n_features columns with the highest variance - x = x.loc[:, x.var().nlargest(n_features).index] - return x - - def train_model(self): - """Train the model.""" - for drug in self.all_data.keys(): - print("Training model for drug:", drug) - x = self.all_data[drug].x - y = self.all_data[drug].y - # TODO: cross validation? - model = Lasso(alpha=0.1) - model.fit(x, y) - # get p-values for coefficients - p_values = self.calculate_pvalues(model, x, y) - # do Bonferroni correction by getting minimum of p-value * number of features and 1 - p_adj = np.minimum(p_values * x.shape[1], 1) - result_df = pd.DataFrame( - { - "edge": x.columns, - "coef": model.coef_, - "p_val": p_values, - "p_adj": p_adj, - } - ) - model.results = result_df - setattr(self, drug, model) - - @staticmethod - def calculate_pvalues(model, x, y): - """ - Calculate p-values. - - :param model: - :param x: - :param y: - """ - params = np.append(model.intercept_, model.coef_) - predictions = model.predict(x) - new_x = pd.DataFrame({"Constant": np.ones(len(x))}, index=x.index).join(x) - mse = (sum((y - predictions) ** 2)) / (len(new_x) - len(new_x.columns)) - var_b = mse * (np.linalg.inv(np.dot(new_x.T, new_x)).diagonal()) - sd_b = np.sqrt(var_b) - ts_b = params / sd_b - p_values = [2 * (1 - stats.t.cdf(np.abs(i), (len(new_x) - 1))) for i in ts_b] - p_values = np.round(p_values, 3) - p_values = p_values[1:] - return p_values - - def export_results(self, path): - """ - Export the results to a provdied path. - - :param path: The path to export the results to - """ - for drug in self.all_data.keys(): - result_df = getattr(self, drug).results - # order by p-value - result_df = result_df.sort_values("p_val") - result_df.to_csv(path + "/results_" + drug + ".csv") - drug_specific_network = result_df[result_df["p_val"] < 0.5] - # only get edge column - if not drug_specific_network.empty: - drug_specific_network = drug_specific_network["edge"] - # split column such that (g1, g2) becomes g1 and g2 - drug_specific_network = drug_specific_network.str.replace("(", "").str.replace(")", "") - drug_specific_network = drug_specific_network.str.replace("'", "") - drug_specific_network = drug_specific_network.str.split(", ", expand=True) - drug_specific_network.columns = ["intA", "intB"] - drug_specific_network.to_csv(path + "/network_" + drug + ".csv", index=False) - - -class DrugRegNetDataset: - """Class for DrugRegNetDataset.""" - - def __init__(self, drug, x, y): - """ - Initialization method for the DrugRegNet dataset. - - :param drug: - :param x: - :param y: - """ - self.drug = drug - self.x = x - self.y = y - - -if __name__ == "__main__": - model = DrugRegNetModel( - "../../data/response_output/CCLE/curve_curator_pEC50_CCLE.csv", - "../../data/cell_line_input/DysRegNet/ccle_fake.fea", - features="topN", - ) - model.create_train_data() - model.train_model() - model.export_results("results") diff --git a/drevalpy/models/DrugRegNet/__init__.py b/drevalpy/models/DrugRegNet/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/drevalpy/models/MOLIR/__init__.py b/drevalpy/models/MOLIR/__init__.py index e69de29..5f90f05 100644 --- a/drevalpy/models/MOLIR/__init__.py +++ b/drevalpy/models/MOLIR/__init__.py @@ -0,0 +1 @@ +"""Module for the regression adaption of the MOLI model: MOLIR.""" diff --git a/drevalpy/models/MOLIR/molir.py b/drevalpy/models/MOLIR/molir.py index 82142cb..7a3627e 100644 --- a/drevalpy/models/MOLIR/molir.py +++ b/drevalpy/models/MOLIR/molir.py @@ -1,46 +1,64 @@ """ -Contains the MOLIR model. +Contains the MOLIR model, a regression adaptation of the MOLI model. + Original authors: Sharifi-Noghabi et al. (2019, 10.1093/bioinformatics/btz318) Code adapted from their Github: https://github.com/hosseinshn/MOLI and Hauptmann et al. (2023, 10.1186/s12859-023-05166-7) https://github.com/kramerlab/Multi-Omics_analysis """ -from typing import Any, Optional +from typing import Any import numpy as np from sklearn.feature_selection import VarianceThreshold from sklearn.preprocessing import StandardScaler from ...datasets.dataset import DrugResponseDataset, FeatureDataset -from ..drp_model import SingleDrugModel -from ..utils import load_and_reduce_gene_features +from ..drp_model import DRPModel +from ..utils import get_multiomics_feature_dataset from .utils import MOLIModel, get_dimensions_of_omics_data -class MOLIR(SingleDrugModel): +class MOLIR(DRPModel): """ - Regression extension of - MOLI: multi-omics late integration deep neural network. - Takes somatic mutation, copy number variation and gene expression data as input. - MOLI uses type-specific encoding subnetworks to learn features for each omics type, - concatenates them into one representation and optimizes this representation via a combined cost - function consisting of a triplet loss and a binary cross-entropy loss. - We use a regression adaption with MSE loss and an mechanism to find positive and negative samples. + Regression extension of MOLI: multi-omics late integration deep neural network. + + Takes somatic mutation, copy number variation and gene expression data as input. MOLI uses type-specific encoding + subnetworks to learn features for each omics type, concatenates them into one representation and optimizes this + representation via a combined cost function consisting of a triplet loss and a binary cross-entropy loss. + We use a regression adaption with MSE loss and a mechanism to find positive and negative samples. """ + is_single_drug_model = True cell_line_views = ["gene_expression", "mutations", "copy_number_variation_gistic"] drug_views = [] early_stopping = True - model_name = "MOLIR" def __init__(self) -> None: + """ + Initializes the MOLIR model. + + The hyperparameters are set in build_model, the model is set in train when we know the dimensionality of the + gene expression, mutation and copy number variation data. + """ super().__init__() - self.model = None - self.hyperparameters = None + self.model: MOLIModel | None = None + self.hyperparameters: dict[str, Any] = dict() + + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: MOLIR + """ + return "MOLIR" def build_model(self, hyperparameters: dict[str, Any]) -> None: """ Builds the model from hyperparameters. + + :param hyperparameters: Custom hyperparameters for the model, includes mini_batch, layer dimensions (h_dim1, + h_dim2, h_dim3), learning_rate, dropout_rate, weight_decay, gamma, epochs, and margin. """ self.hyperparameters = hyperparameters @@ -48,83 +66,115 @@ def train( self, output: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input: Optional[FeatureDataset] = None, - output_earlystopping: Optional[DrugResponseDataset] = None, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: - selector_gex = VarianceThreshold(0.05) - cell_line_input.fit_transform_features( - train_ids=np.unique(output.cell_line_ids), - transformer=selector_gex, - view="gene_expression", - ) - scaler_gex = StandardScaler() - cell_line_input.fit_transform_features( - train_ids=np.unique(output.cell_line_ids), - transformer=scaler_gex, - view="gene_expression", - ) - if self.early_stopping and len(output_earlystopping) < 2: - output_earlystopping = None - dim_gex, dim_mut, dim_cnv = get_dimensions_of_omics_data(cell_line_input) - self.model = MOLIModel( - hpams=self.hyperparameters, - input_dim_expr=dim_gex, - input_dim_mut=dim_mut, - input_dim_cnv=dim_cnv, - ) - self.model.fit( - output_train=output, - cell_line_input=cell_line_input, - output_earlystopping=output_earlystopping, - ) + """ + Initializes and trains the model. + + First, the gene expression data is reduced using a variance threshold (0.05) and standardized. Then, + the model is initialized with the hyperparameters and the dimensions of the gene expression, mutation and + copy number variation data. If there is no training data, the model is set to None (and predictions will be + skipped as well). If there is not enough training data, the predictions will be made on the randomly + initialized model. + + :param output: drug response data + :param cell_line_input: cell line omics features, i.e., gene expression, mutations and copy number variation + :param drug_input: drug features, not needed + :param output_earlystopping: early stopping data, not used when there is not enough data + """ + if len(output) > 0: + selector_gex = VarianceThreshold(0.05) + cell_line_input.fit_transform_features( + train_ids=np.unique(output.cell_line_ids), + transformer=selector_gex, + view="gene_expression", + ) + scaler_gex = StandardScaler() + cell_line_input.fit_transform_features( + train_ids=np.unique(output.cell_line_ids), + transformer=scaler_gex, + view="gene_expression", + ) + if output_earlystopping is not None and self.early_stopping and len(output_earlystopping) < 2: + output_earlystopping = None + dim_gex, dim_mut, dim_cnv = get_dimensions_of_omics_data(cell_line_input) + self.model = MOLIModel( + hpams=self.hyperparameters, + input_dim_expr=dim_gex, + input_dim_mut=dim_mut, + input_dim_cnv=dim_cnv, + ) + if len(output) >= self.hyperparameters["mini_batch"]: + self.model.fit( + output_train=output, + cell_line_input=cell_line_input, + output_earlystopping=output_earlystopping, + ) + else: + print(f"Not enough training data provided ({len(output)}), will predict on randomly initialized model.") + else: + print("No training data provided, skipping model") + self.model = None def predict( self, - drug_ids: np.ndarray, cell_line_ids: np.ndarray, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: + """ + Predicts the drug response. + + If there was no training data, only nans will be returned. + + :param cell_line_ids: Cell lines to predict + :param drug_ids: Drugs to predict + :param cell_line_input: cell line omics features + :param drug_input: drug features, not needed + :returns: Predicted drug response + """ input_data = self.get_feature_matrices( cell_line_ids=cell_line_ids, drug_ids=drug_ids, cell_line_input=cell_line_input, drug_input=drug_input, ) - gene_expression = input_data["gene_expression"] - mutations = input_data["mutations"] - cnvs = input_data["copy_number_variation_gistic"] + (gene_expression, mutations, cnvs) = ( + input_data["gene_expression"], + input_data["mutations"], + input_data["copy_number_variation_gistic"], + ) + if self.model is None: + print("No model trained, will predict NA.") + return np.array([np.nan] * len(cell_line_ids)) return self.model.predict(gene_expression, mutations, cnvs) def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - all_data = load_and_reduce_gene_features( - feature_type="gene_expression", - gene_list=None, - data_path=data_path, - dataset_name=dataset_name, - ) - # log transformation - all_data._apply(function=np.log, view="gene_expression") - # in Toy_Data, everything is already in the dataset - # TODO: implement this in models/utils.py - mut_data = load_and_reduce_gene_features( - feature_type="mutations", - gene_list=None, + """ + Loads the cell line features: gene expression, mutations and copy number variation. + + :param data_path: path to the data + :param dataset_name: name of the dataset + :returns: FeatureDataset with gene expression, mutations and copy number variation + """ + feature_dataset = get_multiomics_feature_dataset( data_path=data_path, dataset_name=dataset_name, - ) - cnv_data = load_and_reduce_gene_features( - feature_type="copy_number_variation_gistic", gene_list=None, - data_path=data_path, - dataset_name=dataset_name, + omics=self.cell_line_views, ) - for fd in [mut_data, cnv_data]: - all_data._add_features(fd) - return all_data + # log transformation + feature_dataset.apply(function=np.log, view="gene_expression") + return feature_dataset - def load(self, path): - raise NotImplementedError + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset | None: + """ + Returns None, as drug features are not needed for MOLIR. - def save(self, path): - raise NotImplementedError + :param data_path: path to the data + :param dataset_name: name of the dataset + :returns: None + """ + return None diff --git a/drevalpy/models/MOLIR/utils.py b/drevalpy/models/MOLIR/utils.py index 53c24b4..e1d5d37 100644 --- a/drevalpy/models/MOLIR/utils.py +++ b/drevalpy/models/MOLIR/utils.py @@ -1,13 +1,15 @@ """ -Code for the MOLI model. -Original authors: Sharifi-Noghabi et al. (2019, 10.1093/bioinformatics/btz318) +Utility functions for the MOLIR model. + +Original authors of MOLI: Sharifi-Noghabi et al. (2019, 10.1093/bioinformatics/btz318) Code adapted from: Hauptmann et al. (2023, 10.1186/s12859-023-05166-7), https://github.com/kramerlab/Multi-Omics_analysis """ import os import random -from typing import Optional, Union +import secrets +from typing import Optional import numpy as np import pytorch_lightning as pl @@ -20,31 +22,45 @@ class RegressionDataset(Dataset): - """ - Dataset for regression tasks for the data loader. - """ + """Dataset for regression tasks for the data loader.""" def __init__( self, output: DrugResponseDataset, - cell_line_input: FeatureDataset = None, + cell_line_input: FeatureDataset, ) -> None: + """ + Initializes the dataset by setting the output and the cell line input. + + :param output: drug response dataset + :param cell_line_input: omics features of the cell lines + """ self.output = output self.cell_line_input = cell_line_input - def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - response = self.output.response[idx].astype(np.float32) + def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.float32]: + """ + Overwrites the getitem method. + + :param idx: index of the sample + :returns: gene expression, mutations, copy number variation, and response of the sample as numpy arrays + """ + response: np.float32 = np.float32(self.output.response[idx]) cell_line_id = str(self.output.cell_line_ids[idx]) - gene_expression = self.cell_line_input.features[cell_line_id]["gene_expression"].astype(np.float32) - mutations = self.cell_line_input.features[cell_line_id]["mutations"].astype(np.float32) - copy_number = self.cell_line_input.features[cell_line_id]["copy_number_variation_gistic"].astype(np.float32) + gene_expression: np.ndarray = self.cell_line_input.features[cell_line_id]["gene_expression"].astype(np.float32) + mutations: np.ndarray = self.cell_line_input.features[cell_line_id]["mutations"].astype(np.float32) + copy_number: np.ndarray = self.cell_line_input.features[cell_line_id]["copy_number_variation_gistic"].astype( + np.float32 + ) return gene_expression, mutations, copy_number, response def __len__(self) -> int: """ Overwrites the len method. + + :returns: number of samples in the dataset """ return len(self.output.response) @@ -56,7 +72,17 @@ def generate_triplets_indices( random_seed: Optional[int] = None, ) -> tuple[np.ndarray, np.ndarray]: """ - Generates triplets for the MOLI model. + Generates triplets for the MOLIR model. + + The positive and negative range are determined by the standard deviation of the response values. A sample is + considered positive if its response value is within the positive range of the label. The positive range is ±10% + of the standard deviation of all response values. A sample is considered negative if its response value is at + least one standard deviation away from the response value of the sample. + :param y: response values + :param positive_range: positive range for the triplet loss + :param negative_range: negative range for the triplet loss + :param random_seed: random seed for reproducibility + :returns: positive and negative sample indices for each sample """ if random_seed is not None: random.seed(random_seed) @@ -65,17 +91,26 @@ def generate_triplets_indices( negative_sample_indices = [] # Iterate over each label in the dataset for idx_current_label, current_label in enumerate(y): - positive_class_indices = get_positive_class_indices(current_label, idx_current_label, y, positive_range) + positive_class_indices = _get_positive_class_indices(current_label, idx_current_label, y, positive_range) positive_sample_idx = np.random.choice(positive_class_indices, 1)[0] - negative_class_indices = get_negative_class_indices(current_label, y, negative_range) + negative_class_indices = _get_negative_class_indices(current_label, y, negative_range) negative_sample_idx = np.random.choice(negative_class_indices, 1)[0] positive_sample_indices.append(positive_sample_idx) negative_sample_indices.append(negative_sample_idx) return np.array(positive_sample_indices), np.array(negative_sample_indices) -def get_positive_class_indices(label: float, idx_label: int, y: np.ndarray, positive_range: float) -> np.ndarray: - # find the samples that are within the positive range of the label except the label itself +def _get_positive_class_indices(label: np.float32, idx_label: int, y: np.ndarray, positive_range: float) -> np.ndarray: + """ + Find the samples that are within the positive range of the label except the label itself. + + If there is no similar sample within the positive range, the method returns the closest sample to the label. + :param label: response of interest + :param idx_label: index of the response of interest + :param y: all responses + :param positive_range: 0.1 * the standard deviation of all training responses + :returns: indices of the samples that can be considered positive examples (=similar to the response of interest) + """ indices_similar_samples = np.where(np.logical_and(label - positive_range <= y, y <= label + positive_range))[0] indices_similar_samples = np.delete(indices_similar_samples, np.where(indices_similar_samples == idx_label)) if len(indices_similar_samples) == 0: @@ -84,7 +119,16 @@ def get_positive_class_indices(label: float, idx_label: int, y: np.ndarray, posi return indices_similar_samples -def get_negative_class_indices(label: float, y: np.ndarray, negative_range: float) -> np.ndarray: +def _get_negative_class_indices(label: np.float32, y: np.ndarray, negative_range: float) -> np.ndarray: + """ + Finds dissimilar samples to the label. + + If there is no dissimilar sample within the negative range, the method returns the sample that is the furthest away. + :param label: reponse of interest + :param y: all responses + :param negative_range: 1 * the standard deviation of all training responses + :returns: indices of the samples that can be considered negative examples (=dissimilar to the response of interest) + """ dissimilar_samples = np.where(np.logical_or(label - negative_range >= y, y >= label + negative_range))[0] if len(dissimilar_samples) == 0: # return the sample that is the furthest away from the label @@ -95,9 +139,12 @@ def get_negative_class_indices(label: float, y: np.ndarray, negative_range: floa def make_ranges(output: DrugResponseDataset) -> tuple[float, float]: """ Compute the positive and negative range for the triplet loss. + + :param output: drug response dataset + :returns: positive and negative range for the triplet loss """ - positive_range = np.std(output.response) * 0.1 - negative_range = np.std(output.response) + positive_range = float(np.std(output.response) * 0.1) + negative_range = float(np.std(output.response)) return positive_range, negative_range @@ -107,7 +154,15 @@ def create_dataset_and_loaders( cell_line_input: FeatureDataset, output_earlystopping: Optional[DrugResponseDataset] = None, ) -> tuple[DataLoader, Optional[DataLoader]]: - # Create datasets and dataloaders + """ + Creates the RegressionDataset (torch Dataset) and the DataLoader for the training and validation data. + + :param batch_size: specified batch size + :param output_train: response values for the training data + :param cell_line_input: omic input features of the cell lines + :param output_earlystopping: early stopping dataset + :returns: training and validation data loaders + """ train_dataset = RegressionDataset(output_train, cell_line_input) train_loader = DataLoader( train_dataset, @@ -135,6 +190,12 @@ def create_dataset_and_loaders( def get_dimensions_of_omics_data(cell_line_input: FeatureDataset) -> tuple[int, int, int]: + """ + Determines the dimensions of the omics data for the creation of the input layers. + + :param cell_line_input: omic input features of the cell lines + :returns: dimensions of the gene expression, mutations, and copy number variation data + """ first_item = next(iter(cell_line_input.features.values())) dim_gex = first_item["gene_expression"].shape[0] dim_mut = first_item["mutations"].shape[0] @@ -143,7 +204,21 @@ def get_dimensions_of_omics_data(cell_line_input: FeatureDataset) -> tuple[int, class MOLIEncoder(nn.Module): + """ + Encoders of the MOLIR model, which is identical to the encoders of the original MOLI model. + + The MOLIR model has three encoders for the gene expression, mutations, and copy number variation data which are + trained together. + """ + def __init__(self, input_size: int, output_size: int, dropout_rate: float) -> None: + """ + Initializes the encoder for the MOLIR model. + + :param input_size: input size determined by feature selection. + :param output_size: output size of the encoder, set as hyperparameter. + :param dropout_rate: dropout rate for regularization, set as hyperparameter. + """ super().__init__() self.encode = nn.Sequential( nn.Linear(input_size, output_size), @@ -153,40 +228,87 @@ def __init__(self, input_size: int, output_size: int, dropout_rate: float) -> No ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the encoder. + + :param x: omic input features + :returns: encoded omic features + """ return self.encode(x) class MOLIRegressor(nn.Module): - def __init__(self, input_size: int, dropout_rate: int) -> None: + """ + Regressor of the MOLIR model. + + It is identical to the regressor of the original MOLI model, except for the omission of the final sigmoid + activation function. After the three encoders, the encoded features are concatenated and fed into the regressor. + """ + + def __init__(self, input_size: int, dropout_rate: float) -> None: + """ + Initializes the regressor for the MOLIR model. + + :param input_size: determined by the output sizes of the encoders. + :param dropout_rate: set as hyperparameter. + """ super().__init__() self.regressor = nn.Sequential(nn.Linear(input_size, 1), nn.Dropout(dropout_rate)) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the regressor. + + :param x: concatenated encoded features + :returns: predicted drug response + """ return self.regressor(x) class MOLIModel(pl.LightningModule): + """ + PyTorch Lightning module for the MOLIR model. + + The architecture of the MOLIR model is identical to the MOLI model, except for the omission of the final sigmoid + layer and the usage of a regression MSE loss instead of a binary cross-entropy loss. Additionally, early stopping is + added instead of tuning the number of epochs as hyperparameter. + """ + def __init__( - self, hpams: dict[str, Union[int, float]], input_dim_expr: int, input_dim_mut: int, input_dim_cnv: int + self, hpams: dict[str, int | float], input_dim_expr: int, input_dim_mut: int, input_dim_cnv: int ) -> None: + """ + Initializes the MOLIR model. + + The MOLIR model uses a combined loss function of a triplet margin loss for the concatenated representation + and an MSE loss for the regression loss. + + :param hpams: includes mini_batch, layer dimensions (h_dim1, h_dim2, h_dim3), learning_rate, dropout_rate, + weight decay, gamma, epochs, and margin. + :param input_dim_expr: determined by the feature selection of the gene expression data. + :param input_dim_mut: determined by dataset size + :param input_dim_cnv: determined by dataset size + """ super().__init__() self.save_hyperparameters() - self.mini_batch = hpams["mini_batch"] - self.h_dim1 = hpams["h_dim1"] - self.h_dim2 = hpams["h_dim2"] - self.h_dim3 = hpams["h_dim3"] + self.mini_batch = int(hpams["mini_batch"]) + self.h_dim1 = int(hpams["h_dim1"]) + self.h_dim2 = int(hpams["h_dim2"]) + self.h_dim3 = int(hpams["h_dim3"]) self.lr = hpams["learning_rate"] self.dropout_rate = hpams["dropout_rate"] self.weight_decay = hpams["weight_decay"] self.gamma = hpams["gamma"] - self.epochs = hpams["epochs"] + self.epochs = int(hpams["epochs"]) self.triplet_loss = nn.TripletMarginLoss(margin=hpams["margin"], p=2) self.regression_loss = nn.MSELoss() - # Positive and Negative range for triplet loss - self.positive_range = None - self.negative_range = None - self.checkpoint_callback = None + # Positive and Negative range for triplet loss, determined by the standard deviation of the training responses, + # set in fit method + self.positive_range = 1.0 + self.negative_range = 1.0 + # Checkpoint callback for early stopping, set in fit method + self.checkpoint_callback: pl.callbacks.ModelCheckpoint | None = None self.expression_encoder = MOLIEncoder(input_dim_expr, self.h_dim1, self.dropout_rate) self.mutation_encoder = MOLIEncoder(input_dim_mut, self.h_dim2, self.dropout_rate) @@ -200,6 +322,17 @@ def fit( output_earlystopping: Optional[DrugResponseDataset] = None, patience: int = 5, ) -> None: + """ + Trains the MOLIR model. + + First, the ranges for the triplet loss are determined using the standard deviation of the training responses. + Then, the training and validation data loaders are created. The model is trained using the Lightning Trainer + with an early stopping callback and patience of 5. + :param output_train: training dataset containing the response output + :param cell_line_input: feature dataset containing the omics data of the cell lines + :param output_earlystopping: early stopping dataset + :param patience: for early stopping + """ self.positive_range, self.negative_range = make_ranges(output_train) train_loader, val_loader = create_dataset_and_loaders( @@ -214,7 +347,7 @@ def fit( early_stop_callback = EarlyStopping(monitor=monitor, mode="min", patience=patience) name = "version-" + "".join( - [random.choice("0123456789abcdef") for _ in range(20)] + [secrets.choice("0123456789abcdef") for _ in range(20)] ) # preventing conflicts of filenames self.checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath=None, @@ -247,28 +380,39 @@ def predict( ) -> np.ndarray: """ Perform prediction on given input data. + + If there was enough training data to train the model, the model from the best epoch was saved in the checkpoint + callback and is loaded now. If there was not enough training data, the model is only randomly initialized. + :param gene_expression: gene expression data + :param mutations: mutation data + :param copy_number: copy number variation data + :returns: predicted drug response """ # load best model - if self.checkpoint_callback.best_model_path: + if hasattr(self, "checkpoint_callback") and self.checkpoint_callback is not None: best_model = MOLIModel.load_from_checkpoint(self.checkpoint_callback.best_model_path) else: best_model = self # convert to torch tensors - gene_expression = torch.from_numpy(gene_expression).float().to(best_model.device) - mutations = torch.from_numpy(mutations).float().to(best_model.device) - copy_number = torch.from_numpy(copy_number).float().to(best_model.device) + gene_expression_tensor = torch.from_numpy(gene_expression).float().to(best_model.device) + mutations_tensor = torch.from_numpy(mutations).float().to(best_model.device) + copy_number_tensor = torch.from_numpy(copy_number).float().to(best_model.device) best_model.eval() with torch.no_grad(): - z = best_model.encode_and_concatenate(gene_expression, mutations, copy_number) + z = best_model._encode_and_concatenate(gene_expression_tensor, mutations_tensor, copy_number_tensor) preds = best_model.regressor(z) return preds.squeeze().cpu().detach().numpy() - def encode_and_concatenate( + def _encode_and_concatenate( self, gene_expression: torch.Tensor, mutations: torch.Tensor, copy_number: torch.Tensor ) -> torch.Tensor: """ - Encodes the input modalities (gene expression, mutations, and copy number) - and concatenates the resulting embeddings. + Encodes the input modalities, concatenates, and normalizes the resulting embeddings. + + :param gene_expression: gene expression data + :param mutations: mutation data + :param copy_number: copy number variation data + :returns: concatenated, normalized embeddings """ z_ex = self.expression_encoder(gene_expression) z_mu = self.mutation_encoder(mutations) @@ -279,13 +423,26 @@ def encode_and_concatenate( return z def forward(self, x_gene: torch.Tensor, x_mutation: torch.Tensor, x_cna: torch.Tensor) -> torch.Tensor: - z = self.encode_and_concatenate(x_gene, x_mutation, x_cna) + """ + Forward pass of the MOLIR model. + + :param x_gene: gene expression input + :param x_mutation: mutation input + :param x_cna: copy number variation input + :returns: predicted drug response + """ + z = self._encode_and_concatenate(x_gene, x_mutation, x_cna) preds = self.regressor(z) return preds - def compute_loss(self, z: torch.Tensor, preds: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def _compute_loss(self, z: torch.Tensor, preds: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Computes the combined triplet loss and regression loss. + + :param z: concatenated, normalized embeddings on which the triplet loss is calculated + :param preds: predicted drug response on which the regression loss is calculated + :param y: true drug response + :returns: combined loss """ positive_indices, negative_indices = generate_triplets_indices( y.cpu().detach().numpy(), self.positive_range, self.negative_range @@ -295,39 +452,54 @@ def compute_loss(self, z: torch.Tensor, preds: torch.Tensor, y: torch.Tensor) -> regression_loss = self.regression_loss(preds.squeeze(), y) return triplet_loss + regression_loss - def training_step( - self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int - ) -> torch.Tensor: + def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Training step of the MOLIR model. + + :param batch: batch of gene expression, mutations, copy number variation, and response + :param batch_idx: index of the batch + :returns: combined loss + """ gene_expression, mutations, copy_number, response = batch # Encode and concatenate - z = self.encode_and_concatenate(gene_expression, mutations, copy_number) + z = self._encode_and_concatenate(gene_expression, mutations, copy_number) # Get predictions preds = self.regressor(z) # Compute loss - loss = self.compute_loss(z, preds, response) + loss = self._compute_loss(z, preds, response) self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True) return loss - def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int - ) -> torch.Tensor: + def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Validation step of the MOLIR model. + + :param batch: batch of gene expression, mutations, copy number variation, and response + :param batch_idx: index of the batch + :returns: combined loss + """ gene_expression, mutations, copy_number, response = batch # Encode and concatenate - z = self.encode_and_concatenate(gene_expression, mutations, copy_number) + z = self._encode_and_concatenate(gene_expression, mutations, copy_number) # Get predictions preds = self.regressor(z) # Compute loss - val_loss = self.compute_loss(z, preds, response) + val_loss = self._compute_loss(z, preds, response) self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True) return val_loss def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Overwrites the configure_optimizers method from PyTorch Lightning. + + :returns: optimizers for the MOLIR expression, mutation, copy number variation encoders, and regressor + """ optimizer = torch.optim.Adagrad( [ {"params": self.expression_encoder.parameters(), "lr": self.lr}, diff --git a/drevalpy/models/SRMF/__init__.py b/drevalpy/models/SRMF/__init__.py index e69de29..4691984 100644 --- a/drevalpy/models/SRMF/__init__.py +++ b/drevalpy/models/SRMF/__init__.py @@ -0,0 +1 @@ +"""Module for the SRMF (Similarity Regularization Matrix Factorization) model.""" diff --git a/drevalpy/models/SRMF/env.yml b/drevalpy/models/SRMF/env.yml deleted file mode 100644 index e69de29..0000000 diff --git a/drevalpy/models/SRMF/srmf.py b/drevalpy/models/SRMF/srmf.py index 7144b67..53590e9 100644 --- a/drevalpy/models/SRMF/srmf.py +++ b/drevalpy/models/SRMF/srmf.py @@ -1,6 +1,14 @@ +""" +Contains the SRMF (Similarity Regularization Matrix Factorization) model. + +Original publication: Wang, L., Li, X., Zhang, L. et al. Improved anticancer drug response prediction in cell lines +using matrix factorization with similarity regularization. BMC Cancer 17, 513 (2017). +https://doi.org/10.1186/s12885-017-3500-5. +Matlab code adapted from https://github.com/linwang1982/SRMF. +""" + import numpy as np import pandas as pd -from numpy.typing import ArrayLike from scipy.spatial.distance import jaccard from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset @@ -9,23 +17,52 @@ class SRMF(DRPModel): - """SRMF model: Similarity Regularization Matrix Factorization.""" + """ + SRMF model: Similarity Regularization Matrix Factorization. + + The primary idea is to map m drugs and n cell lines into a shared latent space, with a low dimensionality K, + where K << min (m, n). The properties of a drug $d_i$ and a cell line $c_j$ are described by two latent coordinates + $u_i$ and $v_j$ (K dimensional row vectors), respectively. The drug response matrix Y is approximated by: + $min_{U,V} || W * (Y - U * V^T) ||^2_F + lambda_l * (||U||^2_F + ||V||^2_F) + lambda_d * ||S_d - U * U^T||^2_F + + lambda_c * ||S_c - V * V^T||^2_F$ + where W is a weight matrix ($W_{ij} = 1 if Y_{ij}$ is a known response value, else 0). U, V contain $u_i$ , + $v_j$ as row vectors, respectively, $||.||_F$ is the Frobenius norm. To avoid overfitting, L2 regularization is + used. S_d, S_c are drug/cell line similarity matrices. Differences between two drugs/cell lines are minimized in + latent space. + """ - model_name = "SRMF" cell_line_views = ["gene_expression"] drug_views = ["fingerprints"] - def __init__(self): + def __init__(self) -> None: """Initalization method for SRMF Model.""" super().__init__() - self.best_u = None - self.best_v = None - self.w = None + self.best_u: pd.DataFrame = pd.DataFrame() + self.best_v: pd.DataFrame = pd.DataFrame() + self.w: pd.DataFrame = pd.DataFrame() + self.k: int = 45 + self.lambda_l: float = 0.01 + self.lambda_d: float = 0.0 + self.lambda_c: float = 0.01 + self.max_iter: int = 50 + self.seed: int = 1 + + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: SRMF + """ + return "SRMF" - def build_model(self, hyperparameters: dict): + def build_model(self, hyperparameters: dict) -> None: """ Initializes hyperparameters for SRMF model. + K is the latent dimensionality, lambda_l, lambda_d, lambda_c are regularization parameters, max_iter is the + number of iterations, seed is the random seed. + :param hyperparameters: dictionary containing the hyperparameters """ self.k = hyperparameters.get("K", 45) @@ -38,9 +75,9 @@ def build_model(self, hyperparameters: dict): def train( self, output: DrugResponseDataset, - cell_line_input: FeatureDataset = None, - drug_input: FeatureDataset = None, - output_earlystopping=None, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ Prepares data and trains the SRMF model. @@ -48,7 +85,12 @@ def train( :param output: response data :param cell_line_input: feature data for cell lines :param drug_input: feature data for drugs + :param output_earlystopping: optional early stopping dataset + :raises ValueError: if drug_input is None """ + if drug_input is None: + raise ValueError("SRMF requires drug features.") + drugs = np.unique(drug_input.identifiers) # transductive approach - all drug features are used cell_lines = np.unique(cell_line_input.identifiers) # transductive approach - all cell line features are used @@ -85,7 +127,7 @@ def train( drug_response_matrix[np.isnan(drug_response_matrix)] = 0 # Train the model - best_u, best_v = self.cmf( + best_u, best_v = self._cmf( w=self.w.T.values, int_mat=drug_response_matrix.values.T, drug_mat=drug_similarity.values, @@ -96,17 +138,19 @@ def train( def predict( self, - drug_ids: ArrayLike, - cell_line_ids: ArrayLike, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ Predicts the drug response based on the trained latent factors. :param drug_ids: drug identifiers :param cell_line_ids: cell line identifiers - :return: predicted response matrix + :param cell_line_input: not needed for prediction in SRMF + :param drug_input: not needed for prediction in SRMF + :returns: predicted response matrix """ best_u = self.best_u.loc[drug_ids].values best_v = self.best_v.loc[cell_line_ids].values @@ -116,14 +160,15 @@ def predict( return diagonal_predictions - def cmf(self, w, int_mat, drug_mat, cell_mat): + def _cmf(self, w, int_mat, drug_mat, cell_mat) -> tuple[np.ndarray, np.ndarray]: """ Implements the SRMF model with specific update rules and regularization. - :param w: - :param int_mat: - :param drug_mat: - :param cell_mat: + :param w: weight matrix + :param int_mat: interaction matrix + :param drug_mat: drug similarity matrix + :param cell_mat: cell line similarity matrix + :returns: best drug and cell line latent factors """ np.random.seed(self.seed) m, n = w.shape @@ -132,14 +177,14 @@ def cmf(self, w, int_mat, drug_mat, cell_mat): best_u, best_v = u0, v0 - last_loss = self.compute_loss(u0, v0, w, int_mat, drug_mat, cell_mat) + last_loss = self._compute_loss(u0, v0, w, int_mat, drug_mat, cell_mat) best_loss = last_loss wr = w * int_mat for _ in range(self.max_iter): - u = self.alg_update(u0, v0, w, wr, drug_mat, self.lambda_l, self.lambda_d) - v = self.alg_update(v0, u, w.T, wr.T, cell_mat, self.lambda_l, self.lambda_c) - curr_loss = self.compute_loss(u, v, w, int_mat, drug_mat, cell_mat) + u = self._alg_update(u0, v0, w, wr, drug_mat, self.lambda_l, self.lambda_d) + v = self._alg_update(v0, u, w.T, wr.T, cell_mat, self.lambda_l, self.lambda_c) + curr_loss = self._compute_loss(u, v, w, int_mat, drug_mat, cell_mat) if curr_loss < best_loss: best_u, best_v = u, v @@ -154,16 +199,17 @@ def cmf(self, w, int_mat, drug_mat, cell_mat): return best_u, best_v - def compute_loss(self, u, v, w, int_mat, drug_mat, cell_mat): + def _compute_loss(self, u, v, w, int_mat, drug_mat, cell_mat) -> np.float64: """ Computes the loss for SRMF, including similarity regularization. - :param u: - :param v: - :param w: - :param int_mat: - :param drug_mat: - :param cell_mat: + :param u: drug latent factors + :param v: cell line latent factors + :param w: weight matrix + :param int_mat: interaction matrix + :param drug_mat: drug similarity matrix + :param cell_mat: cell line similarity matrix + :returns: loss value """ loss = np.sum((w * (int_mat - np.dot(u, v.T))) ** 2) loss += self.lambda_l * (np.sum(u**2) + np.sum(v**2)) @@ -171,17 +217,18 @@ def compute_loss(self, u, v, w, int_mat, drug_mat, cell_mat): loss += self.lambda_c * np.sum((cell_mat - np.dot(v, v.T)) ** 2) return loss - def alg_update(self, u, v, w, r, s, lambda_l, lambda_d): + def _alg_update(self, u, v, w, r, s, lambda_l, lambda_d) -> np.ndarray: """ Algorithm update rule for u or v in the SRMF model. - :param u: - :param v: - :param w: - :param r: - :param s: - :param lambda_l: - :param lambda_d: + :param u: drug latent factors + :param v: cell line latent factors + :param w: weight matrix + :param r: weight * interaction matrix + :param s: drug/cell line similarity matrix + :param lambda_l: regularization parameter + :param lambda_d: drug/cell line similarity regularization parameter + :returns: updated u or v """ x = np.dot(r, v) + 2 * lambda_d * np.dot(s, u) y = 2 * lambda_d * np.dot(u.T, u) @@ -205,10 +252,11 @@ def alg_update(self, u, v, w, r, s, lambda_l, lambda_d): def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ - Loads the cell line features. + Loads the cell line features, in this case the gene expression features. - :param path: Path to the gene expression and landmark genes - :return: FeatureDataset containing the cell line gene expression features, filtered + :param data_path: Path to the gene expression and landmark genes, e.g., data/ + :param dataset_name: Name of the dataset, e.g., GDSC2 + :returns: FeatureDataset containing the cell line gene expression features, filtered through the landmark genes """ return load_and_reduce_gene_features( @@ -220,25 +268,10 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ - Loads the drug features. + Loads the drug features, in this case the drug fingerprints. - :param data_path: - :param dataset_name: + :param data_path: Path to the drug features, in this case the drug fingerprints, e.g., data/ + :param dataset_name: Name of the dataset, e.g., GDSC2 + :returns: FeatureDataset containing the drug fingerprint features """ return load_drug_fingerprint_features(data_path, dataset_name) - - def load(self, path): - """ - Loads the model from a given path. - - :param path: Path to the model - """ - raise NotImplementedError("SRMF does not support loading yet ...") - - def save(self, path): - """ - Saves the model to a given path. - - :param path: Path to save the model - """ - raise NotImplementedError("SRMF does not support saving yet ...") diff --git a/drevalpy/models/SimpleNeuralNetwork/__init__.py b/drevalpy/models/SimpleNeuralNetwork/__init__.py new file mode 100644 index 0000000..4db64a7 --- /dev/null +++ b/drevalpy/models/SimpleNeuralNetwork/__init__.py @@ -0,0 +1 @@ +"""Module for the baseline neural network models SimpleNeuralNetwork and MultiOmicsNeuralNetwork.""" diff --git a/drevalpy/models/simple_neural_network/hyperparameters.yaml b/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml similarity index 100% rename from drevalpy/models/simple_neural_network/hyperparameters.yaml rename to drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml diff --git a/drevalpy/models/simple_neural_network/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py similarity index 61% rename from drevalpy/models/simple_neural_network/multiomics_neural_network.py rename to drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index 4738a74..ef35964 100644 --- a/drevalpy/models/simple_neural_network/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -1,12 +1,8 @@ -""" -Contains the MultiOmicsNeuralNetwork model. -""" +"""Contains the baseline MultiOmicsNeuralNetwork model.""" import warnings -from typing import Optional import numpy as np -from numpy.typing import ArrayLike from sklearn.decomposition import PCA from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset @@ -17,14 +13,7 @@ class MultiOmicsNeuralNetwork(DRPModel): - """ - Simple Feedforward Neural Network model with dropout. - - hyperparameters: - units_per_layer: number of units per layer e.g. [100, 50] means 2 layers with 100 and 50 - units respectively and the output layer with one unit. - dropout_prob: dropout probability for layers 1, 2, ..., n-1 - """ + """Simple Feedforward Neural Network model with dropout using multiple omics data.""" cell_line_views = [ "gene_expression", @@ -34,37 +23,59 @@ class MultiOmicsNeuralNetwork(DRPModel): ] drug_views = ["fingerprints"] early_stopping = True - model_name = "MultiOmicsNeuralNetwork" def __init__(self): + """ + Initalization method for MultiOmicsNeuralNetwork Model. + + The model and the PCA are initialized to None because they are built later in the build_model method. + """ super().__init__() self.model = None + self.hyperparameters = None self.pca = None + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: MultiOmicsNeuralNetwork + """ + return "MultiOmicsNeuralNetwork" + def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. + + The model is a simple feedforward neural network with dropout. The PCA is used to reduce the dimensionality of + the methylation data. + + :param hyperparameters: dictionary containing the hyperparameters units_per_layer, dropout_prob, and + methylation_pca_components. """ - self.model = FeedForwardNetwork( - n_units_per_layer=hyperparameters["units_per_layer"], - dropout_prob=hyperparameters["dropout_prob"], - ) + self.hyperparameters = hyperparameters self.pca = PCA(n_components=hyperparameters["methylation_pca_components"]) def train( self, output: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input: FeatureDataset = None, - output_earlystopping: Optional[DrugResponseDataset] = None, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ): """ - Trains the model. + Fits the PCA and trains the model. + :param output: training data associated with the response output :param cell_line_input: cell line omics features :param drug_input: drug omics features :param output_earlystopping: optional early stopping dataset + :raises ValueError: if drug_input (fingerprints) is missing """ + if drug_input is None: + raise ValueError("Drug input (fingerprints) is needed for the MultiOmicsNeuralNetwork model.") + unique_methylation = np.stack( [cell_line_input.features[id_]["methylation"] for id_ in np.unique(output.cell_line_ids)], axis=0, @@ -73,6 +84,18 @@ def train( self.pca.n_components = min(self.pca.n_components, len(unique_methylation)) self.pca = self.pca.fit(unique_methylation) + first_feature = next(iter(cell_line_input.features.values())) + dim_gex = first_feature["gene_expression"].shape[0] + dim_met = self.pca.n_components + dim_mut = first_feature["mutations"].shape[0] + dim_cnv = first_feature["copy_number_variation_gistic"].shape[0] + dim_fingerprint = next(iter(drug_input.features.values()))["fingerprints"].shape[0] + + self.model = FeedForwardNetwork( + hyperparameters=self.hyperparameters, + input_dim=dim_gex + dim_met + dim_mut + dim_cnv + dim_fingerprint, + ) + with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -91,25 +114,21 @@ def train( met_transform=self.pca, ) - def save(self, path: str): - """ - Saves the model. - :param path: path to save the model - """ - raise NotImplementedError("save method not implemented") - - def load(self, path: str): - raise NotImplementedError("load method not implemented") - def predict( self, - drug_ids: ArrayLike, - cell_line_ids: ArrayLike, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ - Predicts the response for the given input. + Transforms the methylation data using the fitted PCA and then predicts the response for the given input. + + :param drug_ids: drug identifiers + :param cell_line_ids: cell line identifiers + :param drug_input: drug omics features + :param cell_line_input: cell line omics features + :returns: predicted response """ inputs = self.get_feature_matrices( cell_line_ids=cell_line_ids, @@ -152,14 +171,14 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD :return: FeatureDataset containing the cell line omics features, filtered through the drug target genes """ - return get_multiomics_feature_dataset(data_path=data_path, dataset_name=dataset_name) def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Load the drug features. - :param data_path: - :param dataset_name: + :param data_path: path to the drug features, in this case the drug fingerprints, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC1 + :returns: FeatureDataset containing the drug fingerprint features """ return load_drug_fingerprint_features(data_path, dataset_name) diff --git a/drevalpy/models/simple_neural_network/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py similarity index 55% rename from drevalpy/models/simple_neural_network/simple_neural_network.py rename to drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 438665b..ffd3c7b 100644 --- a/drevalpy/models/simple_neural_network/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -1,12 +1,8 @@ -""" -Contains the SimpleNeuralNetwork model. -""" +"""Contains the SimpleNeuralNetwork model.""" import warnings -from typing import Optional import numpy as np -from numpy.typing import ArrayLike from sklearn.preprocessing import StandardScaler from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset @@ -17,58 +13,79 @@ class SimpleNeuralNetwork(DRPModel): - """ - Simple Feedforward Neural Network model with dropout. - - hyperparameters: - units_per_layer: number of units per layer e.g. [100, 50] means 2 layers with 100 and 50 - units respectively and the output layer with one unit. - dropout_prob: dropout probability for layers 1, 2, ..., n-1 - """ + """Simple Feedforward Neural Network model with dropout using only gene expression data.""" cell_line_views = ["gene_expression"] drug_views = ["fingerprints"] early_stopping = True - model_name = "SimpleNeuralNetwork" def __init__(self): + """Initializes the SimpleNeuralNetwork. + + The model is build in train(). The gene_expression_scalar is set to the StandardScaler() and later fitted + using the training data only. + """ super().__init__() self.model = None + self.hyperparameters = None self.gene_expression_scaler = StandardScaler() + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: SimpleNeuralNetwork + """ + return "SimpleNeuralNetwork" + def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. + + :param hyperparameters: includes units_per_layer and dropout_prob. """ - self.model = FeedForwardNetwork( - n_units_per_layer=hyperparameters["units_per_layer"], - dropout_prob=hyperparameters["dropout_prob"], - ) + self.hyperparameters = hyperparameters def train( self, output: DrugResponseDataset, - cell_line_input: FeatureDataset = None, - drug_input: FeatureDataset = None, - output_earlystopping: Optional[DrugResponseDataset] = None, - ): + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, + ) -> None: """ - Trains the model. + First scales the gene expression data and trains the model. + + The gene expression data is first arcsinh transformed. Afterward, the StandardScaler() is fitted on the + training gene expression data only. Then, it transforms all gene expression data. :param output: training data associated with the response output :param cell_line_input: cell line omics features :param drug_input: drug omics features :param output_earlystopping: optional early stopping dataset + :raises ValueError: if drug_input (fingerprints) is missing """ + if drug_input is None: + raise ValueError("drug_input (fingerprints) are required for SimpleNeuralNetwork.") + # Apply arcsinh transformation and scaling to gene expression features if "gene_expression" in self.cell_line_views: - cell_line_input._apply(function=np.arcsinh, view="gene_expression") + cell_line_input.apply(function=np.arcsinh, view="gene_expression") self.gene_expression_scaler = cell_line_input.fit_transform_features( train_ids=np.unique(output.cell_line_ids), transformer=self.gene_expression_scaler, view="gene_expression", ) + dim_gex = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] + dim_fingerprint = next(iter(drug_input.features.values()))["fingerprints"].shape[0] + + self.model = FeedForwardNetwork( + hyperparameters=self.hyperparameters, + input_dim=dim_gex + dim_fingerprint, + ) + with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -86,32 +103,22 @@ def train( num_workers=1, ) - def save(self, path: str): - raise NotImplementedError("save method not implemented") - - def load(self, path: str): - raise NotImplementedError("load method not implemented") - def predict( self, - drug_ids: ArrayLike, - cell_line_ids: ArrayLike, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ Predicts the response for the given input. - """ - # Apply transformation to gene expression features before prediction - if "gene_expression" in self.cell_line_views: - cell_line_input = cell_line_input.copy() - cell_line_input._apply(function=np.arcsinh, view="gene_expression") - cell_line_input.transform_features( - ids=np.unique(cell_line_ids), - transformer=self.gene_expression_scaler, - view="gene_expression", - ) + :param cell_line_ids: IDs of the cell lines to be predicted + :param drug_ids: IDs of the drugs to be predicted + :param cell_line_input: gene expression of the test data + :param drug_input: fingerprints of the test data + :returns: the predicted drug responses + """ x = self.get_concatenated_features( cell_line_view="gene_expression", drug_view="fingerprints", @@ -126,9 +133,10 @@ def predict( def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Loads the cell line features. - :param path: Path to the gene expression and landmark genes - :return: FeatureDataset containing the cell line gene expression features, filtered - through the landmark genes + + :param data_path: Path to the gene expression and landmark genes + :param dataset_name: name of the dataset + :return: FeatureDataset containing the cell line gene expression features, filtered through the landmark genes """ return load_and_reduce_gene_features( feature_type="gene_expression", @@ -138,4 +146,11 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD ) def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the fingerprint data. + + :param data_path: Path to the fingerprints, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC1 + :returns: FeatureDataset containing the fingerprints + """ return load_drug_fingerprint_features(data_path, dataset_name) diff --git a/drevalpy/models/simple_neural_network/utils.py b/drevalpy/models/SimpleNeuralNetwork/utils.py similarity index 60% rename from drevalpy/models/simple_neural_network/utils.py rename to drevalpy/models/SimpleNeuralNetwork/utils.py index b2aa322..bb28c61 100644 --- a/drevalpy/models/simple_neural_network/utils.py +++ b/drevalpy/models/SimpleNeuralNetwork/utils.py @@ -1,10 +1,8 @@ -""" -Utility functions for the simple neural network models. -""" +"""Utility functions for the simple neural network models.""" import os -import random -from typing import Optional +import secrets +from typing import Any import numpy as np import pytorch_lightning as pl @@ -17,19 +15,29 @@ class RegressionDataset(Dataset): - """ - Dataset for regression tasks for the data loader. - """ + """Dataset for regression tasks for the data loader.""" def __init__( self, output: DrugResponseDataset, - cell_line_input: FeatureDataset = None, - drug_input: FeatureDataset = None, - cell_line_views: list[str] = None, - drug_views: list[str] = None, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset, + cell_line_views: list[str], + drug_views: list[str], met_transform=None, ): + """ + Initializes the regression dataset. + + :param output: response values + :param cell_line_input: input omics data + :param drug_input: input fingerprint data + :param cell_line_views: either gene expression for the SimpleNeuralNetwork or all omics data for the + MultiOMICSNeuralNetwork + :param drug_views: fingerprints + :param met_transform: How to transform the methylation data for the MultiOMICSNeuralNetwork (the fitted PCA) + :raises AssertionError: if the views are not found in the input data + """ self.cell_line_views = cell_line_views self.drug_views = drug_views self.output = output @@ -44,6 +52,15 @@ def __init__( self.met_transform = met_transform def __getitem__(self, idx): + """ + Overwrites the getitem method from the Dataset class. + + Retrieves the cell line and drug features and the response for the given index. If methylation data is + present, the data is transformed using the fitted PCA. + :param idx: index of the sample of interest + :returns: the cell line feature(s) and the response + :raises TypeError: if the features are not numpy arrays + """ cell_line_id = self.output.cell_line_ids[idx] drug_id = self.output.drug_ids[idx] response = self.output.response[idx] @@ -77,45 +94,78 @@ def __getitem__(self, idx): return data, response def __len__(self): - "Overwrites the len method." + """ + Overwrites the len method from the Dataset class. + + :returns: the length of the output + """ return len(self.output.response) class FeedForwardNetwork(pl.LightningModule): - """ - Feed forward neural network for regression tasks with basic architecture. - """ + """Feed forward neural network for regression tasks with basic architecture.""" + + def __init__(self, hyperparameters: dict[str, int | float | list[int]], input_dim: int) -> None: + """ + Initializes the feed forward network. + + The model uses a simple architecture with fully connected layers, batch normalization, and dropout. An MSE + loss is used. - def __init__(self, n_units_per_layer=None, dropout_prob=None) -> None: + :param hyperparameters: hyperparameters + :param input_dim: input dimension, for SimpleNeuralNetwork it is the sum of the gene expression and + fingerprint, for MultiOMICSNeuralNetwork it is the sum of all omics data and fingerprints + :raises TypeError: if the hyperparameters are not of the correct type + """ super().__init__() - if n_units_per_layer is None: - n_units_per_layer = [256, 64] + self.save_hyperparameters() + + if not isinstance(hyperparameters["units_per_layer"], list): + raise TypeError("units_per_layer must be a list of integers") + if not isinstance(hyperparameters["dropout_prob"], float): + raise TypeError("dropout_prob must be a float") + + n_units_per_layer: list[int] = hyperparameters["units_per_layer"] + dropout_prob: float = hyperparameters["dropout_prob"] self.n_units_per_layer = n_units_per_layer self.dropout_prob = dropout_prob - self.model_initialized = False self.loss = nn.MSELoss() - self.checkpoint_callback = None + # self.checkpoint_callback is initialized in the fit method + self.checkpoint_callback: pl.callbacks.ModelCheckpoint | None = None self.fully_connected_layers = nn.ModuleList() self.batch_norm_layers = nn.ModuleList() self.dropout_layer = None + self.fully_connected_layers.append(nn.Linear(input_dim, self.n_units_per_layer[0])) + self.batch_norm_layers.append(nn.BatchNorm1d(self.n_units_per_layer[0])) + + for i in range(1, len(self.n_units_per_layer)): + self.fully_connected_layers.append(nn.Linear(self.n_units_per_layer[i - 1], self.n_units_per_layer[i])) + self.batch_norm_layers.append(nn.BatchNorm1d(self.n_units_per_layer[i])) + + self.fully_connected_layers.append(nn.Linear(self.n_units_per_layer[-1], 1)) + if self.dropout_prob is not None: + self.dropout_layer = nn.Dropout(p=self.dropout_prob) + def fit( self, output_train: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input: FeatureDataset = None, - cell_line_views: list[str] = None, - drug_views: list[str] = None, - output_earlystopping: Optional[DrugResponseDataset] = None, - trainer_params: Optional[dict] = None, + drug_input: FeatureDataset | None, + cell_line_views: list[str], + drug_views: list[str], + output_earlystopping: DrugResponseDataset | None = None, + trainer_params: dict | None = None, batch_size=32, patience=5, - checkpoint_path: Optional[str] = None, + checkpoint_path: str | None = None, num_workers: int = 2, - met_transform=None, + met_transform: Any = None, ) -> None: """ Fits the model. + + First, the data is loaded using a DataLoader. Then, the model is trained using the Lightning Trainer. :param output_train: Response values for training :param cell_line_input: Cell line features :param drug_input: Drug features @@ -123,13 +173,18 @@ def fit( :param drug_views: Drug info needed for this model :param output_earlystopping: Response values for early stopping :param trainer_params: custom parameters for the trainer - :param batch_size: - :param patience: - :param checkpoint_path: - :param num_workers: - :param met_transform: - :return: + :param batch_size: batch size for the DataLoader, default is 32 + :param patience: patience for early stopping, default is 5 + :param checkpoint_path: path to save the checkpoints + :param num_workers: number of workers for the DataLoader, default is 2 + :param met_transform: transformation for methylation data, default is None, PCA is used for the MultiOMICSNN. + :raises ValueError: if drug_input is missing """ + if drug_input is None: + raise ValueError( + "Drug input (fingerprints) are required for SimpleNeuralNetwork and " "MultiOMICsNeuralNetwork." + ) + if trainer_params is None: trainer_params = { "progress_bar_refresh_rate": 300, @@ -176,7 +231,7 @@ def fit( early_stop_callback = EarlyStopping(monitor=monitor, mode="min", patience=patience) name = "version-" + "".join( - [random.choice("0123456789abcdef") for i in range(20)] + [secrets.choice("0123456789abcdef") for i in range(20)] ) # preventing conflicts of filenames self.checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath=checkpoint_path, @@ -190,9 +245,6 @@ def fit( trainer_params_copy = trainer_params.copy() del trainer_params_copy["progress_bar_refresh_rate"] - # Force initialize model with dummy data - self.force_initialize(train_loader) - # Initialize the Lightning trainer trainer = pl.Trainer( callbacks=[ @@ -204,22 +256,17 @@ def fit( **trainer_params_copy, ) if val_loader is None: - trainer.fit(self, train_loader) else: trainer.fit(self, train_loader, val_loader) - # TODO use best model from history self.load_from_checkpoint( - # self.checkpoint_callback.best_model_path) - def forward(self, x): + def forward(self, x) -> torch.Tensor: """ Forward pass of the model. - :param x: - :return: - """ - if not self.model_initialized: - self.initialize_model(x) + :param x: input data + :returns: predicted response + """ for i in range(len(self.fully_connected_layers) - 2): x = self.fully_connected_layers[i](x) x = self.batch_norm_layers[i](x) @@ -232,59 +279,64 @@ def forward(self, x): return x.squeeze() - def initialize_model(self, x): - """ - Initializes the model. - :param x: - :return: + def _forward_loss_and_log(self, x, y, log_as: str): """ - n_features = x.size(1) - self.fully_connected_layers.append(nn.Linear(n_features, self.n_units_per_layer[0])) - self.batch_norm_layers.append(nn.BatchNorm1d(self.n_units_per_layer[0])) - - for i in range(1, len(self.n_units_per_layer)): - self.fully_connected_layers.append(nn.Linear(self.n_units_per_layer[i - 1], self.n_units_per_layer[i])) - self.batch_norm_layers.append(nn.BatchNorm1d(self.n_units_per_layer[i])) - - self.fully_connected_layers.append(nn.Linear(self.n_units_per_layer[-1], 1)) - if self.dropout_prob is not None: - self.dropout_layer = nn.Dropout(p=self.dropout_prob) - self.model_initialized = True - - def force_initialize(self, dataloader): - """Force initialize the model by running a dummy forward pass.""" - for batch in dataloader: - x, _ = batch - self.forward(x) - break + Forward pass, calculates the loss, and logs the loss. - def _forward_loss_and_log(self, x, y, log_as: str): + :param x: input data + :param y: response + :param log_as: either train_loss or val_loss + :returns: loss + """ y_pred = self.forward(x) result = self.loss(y_pred, y) self.log(log_as, result, on_step=True, on_epoch=True, prog_bar=True) return result def training_step(self, batch): + """ + Overwrites the training step from the LightningModule. + + Does a forward pass, calculates the loss and logs the loss. + :param batch: batch of data + :returns: loss + """ x, y = batch return self._forward_loss_and_log(x, y, "train_loss") def validation_step(self, batch): + """ + Overwrites the validation step from the LightningModule. + + Does a forward pass, calculates the loss and logs the loss. + :param batch: batch of data + :returns: loss + """ x, y = batch return self._forward_loss_and_log(x, y, "val_loss") def predict(self, x: np.ndarray) -> np.ndarray: """ Predicts the response for the given input. - :param x: - :return: + + :param x: input data + :returns: predicted response """ - is_training = self.training - self.eval() + if hasattr(self, "checkpoint_callback") and self.checkpoint_callback is not None: + best_model = FeedForwardNetwork.load_from_checkpoint(self.checkpoint_callback.best_model_path) + else: + best_model = self + is_training = best_model.training + best_model.eval() with torch.no_grad(): - y_pred = self.forward(torch.from_numpy(x).float()) - self.train(is_training) + y_pred = best_model.forward(torch.from_numpy(x).float().to(best_model.device)) + best_model.train(is_training) return y_pred.cpu().detach().numpy() - def configure_optimizers(self): + def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Overwrites the configure_optimizers from the LightningModule. + :returns: Adam optimizer + """ return torch.optim.Adam(self.parameters()) diff --git a/drevalpy/models/SuperFELTR/__init__.py b/drevalpy/models/SuperFELTR/__init__.py index e69de29..7451545 100644 --- a/drevalpy/models/SuperFELTR/__init__.py +++ b/drevalpy/models/SuperFELTR/__init__.py @@ -0,0 +1 @@ +"""Module for the SuperFELTR model, a regression adaptation of SuperFELT.""" diff --git a/drevalpy/models/SuperFELTR/hyperparameters.yaml b/drevalpy/models/SuperFELTR/hyperparameters.yaml index ddc0c27..cb2d2d1 100644 --- a/drevalpy/models/SuperFELTR/hyperparameters.yaml +++ b/drevalpy/models/SuperFELTR/hyperparameters.yaml @@ -21,5 +21,5 @@ SuperFELTR: GDSC1: 0.7 GDSC2: 0.7 Toy_Data: 0.6 - margin: 1 + margin: 1.0 learning_rate: 0.01 diff --git a/drevalpy/models/SuperFELTR/superfeltr.py b/drevalpy/models/SuperFELTR/superfeltr.py index fa14cd8..4bfffe5 100644 --- a/drevalpy/models/SuperFELTR/superfeltr.py +++ b/drevalpy/models/SuperFELTR/superfeltr.py @@ -1,51 +1,80 @@ """ Contains the SuperFELTR model. + +Regression extension of Super.FELT: supervised feature extraction learning using triplet loss for drug response +prediction with multi-omics data. +Very similar to MOLI. Differences: + + * In MOLI, encoders and the classifier were trained jointly. Super.FELT trains them independently + * MOLI was trained without feature selection (except for the Variance Threshold on the gene expression). + Super.FELT uses feature selection for all omics data. + +The input remains the same: somatic mutation, copy number variation and gene expression data. Original authors of SuperFELT: Park, Soh & Lee. (2021, 10.1186/s12859-021-04146-z) Code adapted from their Github: https://github.com/DMCB-GIST/Super.FELT and Hauptmann et al. (2023, 10.1186/s12859-023-05166-7) https://github.com/kramerlab/Multi-Omics_analysis """ -from typing import Optional +from typing import Any import numpy as np +import pytorch_lightning as pl from sklearn.feature_selection import VarianceThreshold from ...datasets.dataset import DrugResponseDataset, FeatureDataset -from ..drp_model import SingleDrugModel +from ..drp_model import DRPModel from ..MOLIR.utils import get_dimensions_of_omics_data, make_ranges -from ..utils import load_and_reduce_gene_features +from ..utils import get_multiomics_feature_dataset from .utils import SuperFELTEncoder, SuperFELTRegressor, train_superfeltr_model -class SuperFELTR(SingleDrugModel): - """ - Regression extension of Super.FELT: supervised feature extraction learning using triplet loss for drug response - prediction with multi-omics data. - Very similar to MOLI. Differences: - - In MOLI, encoders and the classifier were trained jointly. Super.FELT trains them independently - - MOLI was trained without feature selection (except for the Variance Threshold on the gene expression). - Super.FELT uses feature selection for all omics data. - The input remains the same: somatic mutation, copy number variation and gene expression data. - """ +class SuperFELTR(DRPModel): + """Regression extension of Super.FELT.""" + is_single_drug_model = True cell_line_views = ["gene_expression", "mutations", "copy_number_variation_gistic"] drug_views = [] early_stopping = True - model_name = "SuperFELTR" - def __init__(self): + def __init__(self) -> None: + """ + Initialization method for SuperFELTR Model. + + The encoders and the regressor are initialized to None because they are built later in the first training pass. + The hyperparameters are also initialized to an empty dict because they are initialized in build_model. The + ranges are initialized during training which is why here, they get dummy values. The best checkpoint is + determined after training. + """ super().__init__() - self.expr_encoder = None - self.mut_encoder = None - self.cnv_encoder = None - self.regressor = None - self.hyperparameters = None - self.ranges = None - self.best_checkpoint = None + # encoders and regressor are initialized to None because they are built later in the first training pass + self.expr_encoder: SuperFELTEncoder | None = None + self.mut_encoder: SuperFELTEncoder | None = None + self.cnv_encoder: SuperFELTEncoder | None = None + self.regressor: SuperFELTRegressor | None = None + # hyperparameters are initialized to None because they are initialized in build_model + self.hyperparameters: dict[str, Any] = dict() + # ranges are initialized later because they are initialized using the standard variation of the train + # response data which is only available when entering the training + self.ranges: tuple[float, float] = (0.0, 1.0) + # best checkpoint is determined after training + self.best_checkpoint: pl.callbacks.ModelCheckpoint | None = None - def build_model(self, hyperparameters): + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: SuperFELTR + """ + return "SuperFELTR" + + def build_model(self, hyperparameters) -> None: """ Builds the model from hyperparameters. + + :param hyperparameters: dictionary containing the hyperparameters for the model. Contain mini_batch, + dropout_rate, weight_decay, out_dim_expr_encoder, out_dim_mutation_encoder, out_dim_cnv_encoder, epochs, + variance thresholds for gene expression, mutation, and copy number variation, margin, and learning rate. """ self.hyperparameters = hyperparameters @@ -53,70 +82,111 @@ def train( self, output: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input: Optional[FeatureDataset] = None, - output_earlystopping: Optional[DrugResponseDataset] = None, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ - Trains the model. - """ - cell_line_input = self.feature_selection(output, cell_line_input) - if self.early_stopping and len(output_earlystopping) < 2: - output_earlystopping = None - dim_gex, dim_mut, dim_cnv = get_dimensions_of_omics_data(cell_line_input) - self.ranges = make_ranges(output) - - # difference to MOLI: encoders and regressor are trained independently - # Create and train encoders - encoders = {} - encoder_dims = {"expression": dim_gex, "mutation": dim_mut, "copy_number_variation_gistic": dim_cnv} - for omic_type, dim in encoder_dims.items(): - encoder = SuperFELTEncoder( - input_size=dim, hpams=self.hyperparameters, omic_type=omic_type, ranges=self.ranges + Does feature selection, trains the encoders sequentially, and then trains the regressor. + + If there is not enough training data, the model is trained with random initialization, if there is no + training data at all, the model is skipped and later on, NA is predicted. + + :param output: training data associated with the response output + :param cell_line_input: cell line omics features + :param drug_input: not needed, as it is a single drug model + :param output_earlystopping: optional early stopping dataset + :raises ValueError: if drug_input is not None + """ + if drug_input is not None: + raise ValueError("SuperFELTR is a single drug model and does not require drug input.") + + if len(output) > 0: + cell_line_input = self._feature_selection(output, cell_line_input) + if output_earlystopping is not None and self.early_stopping and len(output_earlystopping) < 2: + output_earlystopping = None + dim_gex, dim_mut, dim_cnv = get_dimensions_of_omics_data(cell_line_input) + self.ranges = make_ranges(output) + + # difference to MOLI: encoders and regressor are trained independently + # Create and train encoders + encoders = {} + encoder_dims = {"expression": dim_gex, "mutation": dim_mut, "copy_number_variation_gistic": dim_cnv} + for omic_type, dim in encoder_dims.items(): + encoder = SuperFELTEncoder( + input_size=dim, hpams=self.hyperparameters, omic_type=omic_type, ranges=self.ranges + ) + if len(output) >= self.hyperparameters["mini_batch"]: + print(f"Training SuperFELTR Encoder for {omic_type} ... ") + best_checkpoint = train_superfeltr_model( + model=encoder, + hpams=self.hyperparameters, + output_train=output, + cell_line_input=cell_line_input, + output_earlystopping=output_earlystopping, + patience=5, + ) + encoders[omic_type] = SuperFELTEncoder.load_from_checkpoint(best_checkpoint.best_model_path) + else: + print( + f"Not enough training data provided for SuperFELTR Encoder for {omic_type}. Using random " + f"initialization." + ) + encoders[omic_type] = encoder + + self.expr_encoder, self.mut_encoder, self.cnv_encoder = ( + encoders["expression"], + encoders["mutation"], + encoders["copy_number_variation_gistic"], ) - print(f"Training SuperFELTR Encoder for {omic_type} ... ") - best_checkpoint = train_superfeltr_model( - model=encoder, + + self.regressor = SuperFELTRegressor( + input_size=self.hyperparameters["out_dim_expr_encoder"] + + self.hyperparameters["out_dim_mutation_encoder"] + + self.hyperparameters["out_dim_cnv_encoder"], hpams=self.hyperparameters, - output_train=output, - cell_line_input=cell_line_input, - output_earlystopping=output_earlystopping, - patience=5, + encoders=(self.expr_encoder, self.mut_encoder, self.cnv_encoder), ) - encoders[omic_type] = SuperFELTEncoder.load_from_checkpoint(best_checkpoint.best_model_path) - - self.expr_encoder, self.mut_encoder, self.cnv_encoder = ( - encoders["expression"], - encoders["mutation"], - encoders["copy_number_variation_gistic"], - ) - - self.regressor = SuperFELTRegressor( - input_size=self.hyperparameters["out_dim_expr_encoder"] - + self.hyperparameters["out_dim_mutation_encoder"] - + self.hyperparameters["out_dim_cnv_encoder"], - hpams=self.hyperparameters, - encoders=(self.expr_encoder, self.mut_encoder, self.cnv_encoder), - ranges=self.ranges, - ) - self.best_checkpoint = train_superfeltr_model( - model=self.regressor, - hpams=self.hyperparameters, - output_train=output, - cell_line_input=cell_line_input, - output_earlystopping=output_earlystopping, - patience=5, - ) + if len(output) >= self.hyperparameters["mini_batch"]: + print("Training SuperFELTR Regressor ... ") + self.best_checkpoint = train_superfeltr_model( + model=self.regressor, + hpams=self.hyperparameters, + output_train=output, + cell_line_input=cell_line_input, + output_earlystopping=output_earlystopping, + patience=5, + ) + else: + print("Not enough training data provided for SuperFELTR Regressor. Using random initialization.") + self.best_checkpoint = None + else: + print("No training data provided, skipping model") + self.best_checkpoint = None + self.expr_encoder, self.mut_encoder, self.cnv_encoder, self.regressor = None, None, None, None def predict( self, - drug_ids: np.ndarray, cell_line_ids: np.ndarray, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ Predicts the drug response. + + If there is no training data, NA is predicted. If there was not enough training data, predictions are made + with the randomly initialized model. + + :param cell_line_ids: cell line ids + :param drug_ids: drug ids + :param cell_line_input: cell line omics features + :param drug_input: drug omics features, not needed + :returns: predicted drug response + :raises ValueError: if drug_input is not None """ + if drug_input is not None: + raise ValueError("SuperFELTR is a single drug model and does not require drug input.") + input_data = self.get_feature_matrices( cell_line_ids=cell_line_ids, drug_ids=drug_ids, @@ -128,6 +198,12 @@ def predict( input_data["mutations"], input_data["copy_number_variation_gistic"], ) + if self.expr_encoder is None or self.mut_encoder is None or self.cnv_encoder is None or self.regressor is None: + print("No training data was available, predicting NA") + return np.array([np.nan] * len(cell_line_ids)) + if self.best_checkpoint is None: + print("Not enough training data provided for SuperFELTR Regressor. Predicting with random initialization.") + return self.regressor.predict(gene_expression, mutations, cnvs) best_regressor = SuperFELTRegressor.load_from_checkpoint( self.best_checkpoint.best_model_path, input_size=self.hyperparameters["out_dim_expr_encoder"] @@ -135,13 +211,16 @@ def predict( + self.hyperparameters["out_dim_cnv_encoder"], hpams=self.hyperparameters, encoders=(self.expr_encoder, self.mut_encoder, self.cnv_encoder), - ranges=self.ranges, ) return best_regressor.predict(gene_expression, mutations, cnvs) - def feature_selection(self, output: DrugResponseDataset, cell_line_input: FeatureDataset) -> FeatureDataset: + def _feature_selection(self, output: DrugResponseDataset, cell_line_input: FeatureDataset) -> FeatureDataset: """ - Feature selection for all omics data. + Feature selection for all omics data using the predefined variance thresholds. + + :param output: training data associated with the response output + :param cell_line_input: cell line omics features + :returns: cell line omics features with selected features """ thresholds = { "gene_expression": self.hyperparameters["expression_var_threshold"][output.dataset_name], @@ -156,26 +235,26 @@ def feature_selection(self, output: DrugResponseDataset, cell_line_input: Featur return cell_line_input def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - all_data = load_and_reduce_gene_features( - feature_type="gene_expression", - gene_list=None, - data_path=data_path, - dataset_name=dataset_name, + """ + Loads the cell line features: gene expression, mutations, and copy number variation. + + :param data_path: path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC2 + :returns: FeatureDataset containing the cell line gene expression features, mutations, and copy number variation + """ + feature_dataset = get_multiomics_feature_dataset( + data_path=data_path, dataset_name=dataset_name, gene_list=None, omics=self.cell_line_views ) # log transformation - all_data._apply(function=np.log, view="gene_expression") - feature_types = ["mutations", "copy_number_variation_gistic"] - # in Toy_Data, everything is already in the dataset - # TODO: implement this in models/utils.py - for feature in feature_types: - fd = load_and_reduce_gene_features( - feature_type=feature, gene_list=None, data_path=data_path, dataset_name=dataset_name - ) - all_data._add_features(fd) - return all_data + feature_dataset.apply(function=np.log, view="gene_expression") + return feature_dataset - def load(self, path): - raise NotImplementedError + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset | None: + """ + Returns None, as drug features are not needed for SuperFELTR. - def save(self, path): - raise NotImplementedError + :param data_path: Path to the fingerprints, e.g., data/ + :param dataset_name: Name of the dataset + :returns: None + """ + return None diff --git a/drevalpy/models/SuperFELTR/utils.py b/drevalpy/models/SuperFELTR/utils.py index 10a2312..b5c23cf 100644 --- a/drevalpy/models/SuperFELTR/utils.py +++ b/drevalpy/models/SuperFELTR/utils.py @@ -1,6 +1,7 @@ +"""Utility functions for the SuperFELTR model.""" + import os -import random -from typing import Union +import secrets import numpy as np import pytorch_lightning as pl @@ -13,14 +14,37 @@ class SuperFELTEncoder(pl.LightningModule): + """ + SuperFELT encoder definition for a single omic type, i.e., gene expression, mutation, or copy number variation. + + Very similar to MOLIEncoder, but with BatchNorm1d before ReLU. + """ + def __init__( - self, input_size: int, hpams: dict[str, Union[int, float]], omic_type: str, ranges: tuple[float, float] + self, input_size: int, hpams: dict[str, int | float | dict], omic_type: str, ranges: tuple[float, float] ) -> None: + """ + Initializes the SuperFELTEncoder. + + Save_hyperparameters is turned on to facilitate loading the model from a checkpoint. + :param input_size: determined by the variance threshold feature selection + :param hpams: hyperparameters for the model + :param omic_type: gene expression, mutation, or copy number variation + :param ranges: positive and negative ranges for the triplet loss + :raises ValueError: if the hyperparameters are not of the correct type + """ super().__init__() self.save_hyperparameters() + if ( + not isinstance(hpams["dropout_rate"], float) + or not isinstance(hpams["margin"], float) + or not isinstance(hpams["learning_rate"], float) + or not isinstance(hpams["weight_decay"], float) + ): + raise ValueError("dropout_rate, margin, learning_rate, and weight_decay must be floats!") self.omic_type = omic_type - output_size = self.get_output_size(hpams) + output_size = self._get_output_size(hpams) # only change vs MOLIEncoder: BatchNorm1d before ReLU self.encode = nn.Sequential( @@ -35,20 +59,56 @@ def __init__( self.positive_range, self.negative_range = ranges def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the SuperFELTEncoder. + + :param x: input tensor + :returns: encoded tensor + """ return self.encode(x) def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Override the configure_optimizers method to use the Adam optimizer. + + :returns: Adam optimizer + """ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) return optimizer - def get_output_size(self, hpams: dict[str, Union[int, float]]) -> int: - return { + def _get_output_size(self, hpams: dict[str, int | float | dict]) -> int: + """ + Get the output size of the encoder based on the omic type from the hyperparameters. + + :param hpams: hyperparameters for the model + :returns: output size of the encoder + :raises ValueError: if the output sizes are not of the correct type + """ + if ( + not isinstance(hpams["out_dim_expr_encoder"], int) + or not isinstance(hpams["out_dim_mutation_encoder"], int) + or not isinstance(hpams["out_dim_cnv_encoder"], int) + ): + raise ValueError("out_dim_expr_encoder, out_dim_mutation_encoder, and out_dim_cnv_encoder must be ints!") + + output_sizes = { "expression": hpams["out_dim_expr_encoder"], "mutation": hpams["out_dim_mutation_encoder"], "copy_number_variation_gistic": hpams["out_dim_cnv_encoder"], - }[self.omic_type] + } + output_size = output_sizes[self.omic_type] + return output_size - def get_omic_data(self, data_expr: torch.Tensor, data_mut: torch.Tensor, data_cnv: torch.Tensor) -> torch.Tensor: + def _get_omic_data(self, data_expr: torch.Tensor, data_mut: torch.Tensor, data_cnv: torch.Tensor) -> torch.Tensor: + """ + Get the omic data based on the omic type. + + :param data_expr: expression data + :param data_mut: mutation data + :param data_cnv: copy number variation data + :returns: the omic data + :raises ValueError: if the omic type is not recognized + """ if self.omic_type == "expression": data = data_expr elif self.omic_type == "mutation": @@ -59,93 +119,167 @@ def get_omic_data(self, data_expr: torch.Tensor, data_mut: torch.Tensor, data_cn raise ValueError(f"omic_type {self.omic_type} not recognized.") return data - def compute_loss(self, encoded: torch.Tensor, response: torch.Tensor) -> torch.Tensor: + def _compute_loss(self, encoded: torch.Tensor, response: torch.Tensor) -> torch.Tensor: + """ + Computes the triplet loss. + + :param encoded: encoded data + :param response: response data + :returns: triplet loss + """ positive_indices, negative_indices = generate_triplets_indices( response.cpu().detach().numpy(), self.positive_range, self.negative_range ) triplet_loss = self.triplet_loss(encoded, encoded[positive_indices], encoded[negative_indices]) return triplet_loss - def training_step( - self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int - ) -> torch.Tensor: + def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Override the training_step method to compute the triplet loss. + + :param batch: batch containing the omic data and response + :param batch_idx: index of the batch + :returns: triplet loss + """ data_expr, data_mut, data_cnv, response = batch - data = self.get_omic_data(data_expr, data_mut, data_cnv) + data = self._get_omic_data(data_expr, data_mut, data_cnv) encoded = self.encode(data) - triplet_loss = self.compute_loss(encoded, response) + triplet_loss = self._compute_loss(encoded, response) self.log("train_loss", triplet_loss, on_step=False, on_epoch=True, prog_bar=True) return triplet_loss - def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int - ) -> torch.Tensor: + def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Override the validation_step method to compute the triplet loss. + + :param batch: batch containing the omic data and response + :param batch_idx: index of the batch + :returns: triplet loss + """ data_expr, data_mut, data_cnv, response = batch - data = self.get_omic_data(data_expr, data_mut, data_cnv) + data = self._get_omic_data(data_expr, data_mut, data_cnv) encoded = self.encode(data) - triplet_loss = self.compute_loss(encoded, response) + triplet_loss = self._compute_loss(encoded, response) self.log("val_loss", triplet_loss, on_step=False, on_epoch=True, prog_bar=True) return triplet_loss class SuperFELTRegressor(pl.LightningModule): + """ + SuperFELT regressor definition. + + Very similar to SuperFELT classifier, but with a regression loss and without the last sigmoid layer. + """ + def __init__( self, input_size: int, - hpams: dict[str, Union[int, float]], + hpams: dict[str, int | float | dict], encoders: tuple[SuperFELTEncoder, SuperFELTEncoder, SuperFELTEncoder], - ranges: tuple[float, float], ) -> None: + """ + Initializes the SuperFELTRegressor. + + The encoders are put in eval mode because they were fitted before. + + :param input_size: depends on the output of the encoders + :param hpams: hyperparameters for the model + :param encoders: the fitted encoders for the gene expression, mutation, and copy number variation data + :raises ValueError: if the hyperparameters are not of the correct type + """ super().__init__() + if ( + not isinstance(hpams["learning_rate"], float) + or not isinstance(hpams["weight_decay"], float) + or not isinstance(hpams["dropout_rate"], float) + ): + raise ValueError("learning_rate, weight_decay and dropout_rate must be floats!") self.regressor = nn.Sequential(nn.Linear(input_size, 1), nn.Dropout(hpams["dropout_rate"])) - self.lr = hpams["learning_rate"] - self.weight_decay = hpams["weight_decay"] + self.lr = float(hpams["learning_rate"]) + self.weight_decay = float(hpams["weight_decay"]) self.encoders = encoders - self.positive_ranges, self.negative_ranges = ranges # put the encoders in eval mode for encoder in self.encoders: encoder.eval() self.regression_loss = nn.MSELoss() def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the SuperFELTRegressor. + + :param x: input tensor + :returns: predicted response + """ return self.regressor(x) def predict(self, data_expr: np.ndarray, data_mut: np.ndarray, data_cnv: np.ndarray) -> np.ndarray: - data_expr, data_mut, data_cnv = map( + """ + Predicts the response for the given input. + + :param data_expr: expression data + :param data_mut: mutation data + :param data_cnv: copy number variation data + :returns: predicted response + """ + data_expr_tensor, data_mut_tensor, data_cnv_tensor = map( lambda data: torch.from_numpy(data).float().to(self.device), [data_expr, data_mut, data_cnv] ) self.eval() with torch.no_grad(): - encoded = self.encode_and_concatenate(data_expr, data_mut, data_cnv) + encoded = self._encode_and_concatenate(data_expr_tensor, data_mut_tensor, data_cnv_tensor) preds = self.regressor(encoded) return preds.squeeze().cpu().detach().numpy() def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Override the configure_optimizers method to use the Adagrad optimizer. + + :returns: Adagrad optimizer + """ return torch.optim.Adagrad(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) - def encode_and_concatenate( + def _encode_and_concatenate( self, data_expr: torch.Tensor, data_mut: torch.Tensor, data_cnv: torch.Tensor ) -> torch.Tensor: + """ + Encodes the omic data and concatenates the encoded tensors. + + :param data_expr: expression data + :param data_mut: mutation data + :param data_cnv: copy number variation data + :returns: concatenated encoded tensor + """ encoded_expr = self.encoders[0].encode(data_expr) encoded_mut = self.encoders[1].encode(data_mut) encoded_cnv = self.encoders[2].encode(data_cnv) return torch.cat((encoded_expr, encoded_mut, encoded_cnv), dim=1) - def training_step( - self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int - ) -> torch.Tensor: + def training_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Override the training_step method to compute the regression loss. + + :param batch: batch containing the omic data and response + :param batch_idx: index of the batch + :returns: regression loss + """ data_expr, data_mut, data_cnv, response = batch - encoded = self.encode_and_concatenate(data_expr, data_mut, data_cnv) + encoded = self._encode_and_concatenate(data_expr, data_mut, data_cnv) pred = self.regressor(encoded) loss = self.regression_loss(pred.squeeze(), response) self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True) return loss - def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int - ) -> torch.Tensor: + def validation_step(self, batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Override the validation_step method to compute the regression loss. + + :param batch: batch containing the omic data and response + :param batch_idx: index of the batch + :returns: regression loss + """ data_expr, data_mut, data_cnv, response = batch - encoded = self.encode_and_concatenate(data_expr, data_mut, data_cnv) + encoded = self._encode_and_concatenate(data_expr, data_mut, data_cnv) pred = self.regressor(encoded) loss = self.regression_loss(pred.squeeze(), response) self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True) @@ -153,13 +287,29 @@ def validation_step( def train_superfeltr_model( - model: Union[SuperFELTEncoder, SuperFELTRegressor], - hpams: dict[str, Union[int, float]], + model: SuperFELTEncoder | SuperFELTRegressor, + hpams: dict[str, int | float | dict], output_train: DrugResponseDataset, cell_line_input: FeatureDataset, - output_earlystopping: DrugResponseDataset, + output_earlystopping: DrugResponseDataset | None = None, patience: int = 5, ) -> pl.callbacks.ModelCheckpoint: + """ + Trains one encoder or the regressor. + + First, the dataset and loaders are created. Then, the model is trained with the Lightning trainer. + :param model: either one of the encoders or the regressor + :param hpams: hyperparameters for the model + :param output_train: response data for training + :param cell_line_input: cell line omics features + :param output_earlystopping: response data for early stopping + :param patience: for early stopping, defaults to 5 + :returns: checkpoint callback with the best model + :raises ValueError: if the epochs and mini_batch are not integers + """ + if not isinstance(hpams["epochs"], int) or not isinstance(hpams["mini_batch"], int): + raise ValueError("epochs and mini_batch must be integers!") + train_loader, val_loader = create_dataset_and_loaders( batch_size=hpams["mini_batch"], output_train=output_train, @@ -169,7 +319,7 @@ def train_superfeltr_model( monitor = "train_loss" if (val_loader is None) else "val_loss" early_stop_callback = EarlyStopping(monitor=monitor, mode="min", patience=patience) name = "version-" + "".join( - [random.choice("0123456789abcdef") for _ in range(20)] + [secrets.choice("0123456789abcdef") for _ in range(20)] ) # preventing conflicts of filenames checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath=None, diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 3f97046..1f3bf96 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -1,6 +1,4 @@ -""" -Module containing all drug response prediction models. -""" +"""Module containing all drug response prediction models.""" __all__ = [ "NaivePredictor", @@ -20,25 +18,30 @@ "MULTI_DRUG_MODEL_FACTORY", "SINGLE_DRUG_MODEL_FACTORY", "MODEL_FACTORY", + "DIPKModel", ] from .baselines.multi_omics_random_forest import MultiOmicsRandomForest from .baselines.naive_pred import NaiveCellLineMeanPredictor, NaiveDrugMeanPredictor, NaivePredictor from .baselines.singledrug_random_forest import SingleDrugRandomForest from .baselines.sklearn_models import ElasticNetModel, GradientBoosting, RandomForest, SVMRegressor +from .DIPK.dipk import DIPKModel +from .drp_model import DRPModel from .MOLIR.molir import MOLIR -from .simple_neural_network.multiomics_neural_network import MultiOmicsNeuralNetwork -from .simple_neural_network.simple_neural_network import SimpleNeuralNetwork +from .SimpleNeuralNetwork.multiomics_neural_network import MultiOmicsNeuralNetwork +from .SimpleNeuralNetwork.simple_neural_network import SimpleNeuralNetwork from .SRMF.srmf import SRMF from .SuperFELTR.superfeltr import SuperFELTR -SINGLE_DRUG_MODEL_FACTORY = { +# SINGLE_DRUG_MODEL_FACTORY is used in the pipeline! +SINGLE_DRUG_MODEL_FACTORY: dict[str, type[DRPModel]] = { "SingleDrugRandomForest": SingleDrugRandomForest, "MOLIR": MOLIR, "SuperFELTR": SuperFELTR, } -MULTI_DRUG_MODEL_FACTORY = { +# MULTI_DRUG_MODEL_FACTORY is used in the pipeline! +MULTI_DRUG_MODEL_FACTORY: dict[str, type[DRPModel]] = { "NaivePredictor": NaivePredictor, "NaiveDrugMeanPredictor": NaiveDrugMeanPredictor, "NaiveCellLineMeanPredictor": NaiveCellLineMeanPredictor, @@ -50,7 +53,9 @@ "MultiOmicsRandomForest": MultiOmicsRandomForest, "GradientBoosting": GradientBoosting, "SRMF": SRMF, + "DIPK": DIPKModel, } +# MODEL_FACTORY is used in the pipeline! MODEL_FACTORY = MULTI_DRUG_MODEL_FACTORY.copy() MODEL_FACTORY.update(SINGLE_DRUG_MODEL_FACTORY) diff --git a/drevalpy/models/baselines/__init__.py b/drevalpy/models/baselines/__init__.py index 774320f..6d8f7ca 100644 --- a/drevalpy/models/baselines/__init__.py +++ b/drevalpy/models/baselines/__init__.py @@ -1,3 +1 @@ -""" -Module containing the baseline models. -""" +"""Module containing the baseline models.""" diff --git a/drevalpy/models/baselines/env.yml b/drevalpy/models/baselines/env.yml deleted file mode 100644 index e69de29..0000000 diff --git a/drevalpy/models/baselines/multi_omics_random_forest.py b/drevalpy/models/baselines/multi_omics_random_forest.py index a64b4da..fd8ed60 100644 --- a/drevalpy/models/baselines/multi_omics_random_forest.py +++ b/drevalpy/models/baselines/multi_omics_random_forest.py @@ -1,9 +1,6 @@ -""" -Contains the Multi-OMICS Random Forest model. -""" +"""Contains the Multi-OMICS Random Forest model.""" import numpy as np -from numpy.typing import ArrayLike from sklearn.decomposition import PCA from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset @@ -13,9 +10,7 @@ class MultiOmicsRandomForest(RandomForest): - """ - Multi-OMICS Random Forest model. - """ + """Multi-OMICS Random Forest model.""" cell_line_views = [ "gene_expression", @@ -23,15 +18,29 @@ class MultiOmicsRandomForest(RandomForest): "mutations", "copy_number_variation_gistic", ] - model_name = "MultiOmicsRandomForest" def __init__(self): + """ + Initializes the model. + + Sets the PCA to None, which is initialized in the build_model method. + """ super().__init__() self.pca = None + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: MultiOmicsRandomForest + """ + return "MultiOmicsRandomForest" + def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. + :param hyperparameters: Hyperparameters for the model. """ super().build_model(hyperparameters) @@ -43,22 +52,21 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD :param data_path: data path e.g. data/ :param dataset_name: dataset name e.g. GDSC1 - :return: FeatureDataset containing the cell line omics features, filtered through the + :returns: FeatureDataset containing the cell line omics features, filtered through the drug target genes """ - return get_multiomics_feature_dataset(data_path=data_path, dataset_name=dataset_name) def train( self, output: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input: FeatureDataset = None, - output_earlystopping=None, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ - Trains the model: the number of features is the number of genes + the number of - fingerprints. + Trains the model: the number of features is the number of genes + the number of fingerprints. + :param output: training dataset containing the response output :param cell_line_input: training dataset containing the OMICs :param drug_input: training dataset containing fingerprints data @@ -99,13 +107,19 @@ def train( def predict( self, - drug_ids: ArrayLike, - cell_line_ids: ArrayLike, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ Predicts the response for the given input. + + :param cell_line_ids: cell line ids + :param drug_ids: drug ids + :param cell_line_input: cell line input + :param drug_input: drug input + :returns: predicted response """ inputs = self.get_feature_matrices( cell_line_ids=cell_line_ids, diff --git a/drevalpy/models/baselines/naive_pred.py b/drevalpy/models/baselines/naive_pred.py index 04dd7a0..f43213b 100644 --- a/drevalpy/models/baselines/naive_pred.py +++ b/drevalpy/models/baselines/naive_pred.py @@ -1,12 +1,12 @@ """ -Describes the naive predictor models. The naive predictor models are simple models that predict -the mean of the response values. The NaivePredictor predicts the overall mean of the response, -the NaiveCellLineMeanPredictor predicts the mean of the response per cell line, and the -NaiveDrugMeanPredictor predicts the mean of the response per drug. +Implements the naive predictor models. + +The naive predictor models are simple models that predict the mean of the response values. The NaivePredictor +predicts the overall mean of the response, the NaiveCellLineMeanPredictor predicts the mean of the response per cell +line, and the NaiveDrugMeanPredictor predicts the mean of the response per drug. """ import numpy as np -from numpy.typing import ArrayLike from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset from drevalpy.models.drp_model import DRPModel @@ -14,30 +14,47 @@ class NaivePredictor(DRPModel): - """ - Naive predictor model that predicts the overall mean of the response. - """ + """Naive predictor model that predicts the overall mean of the response.""" - model_name = "NaivePredictor" cell_line_views = ["cell_line_id"] drug_views = ["drug_id"] def __init__(self): + """ + Initializes the model. + + Sets the dataset mean to None, which is initialized in the train method. + """ super().__init__() self.dataset_mean = None + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: NaivePredictor + """ + return "NaivePredictor" + def build_model(self, hyperparameters: dict): + """ + Builds the model from hyperparameters. Not needed for the NaivePredictor. + + :param hyperparameters: Hyperparameters for the model, not needed + """ pass def train( self, output: DrugResponseDataset, - cell_line_input=None, - drug_input=None, - output_earlystopping=None, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ Computes the overall mean of the output response values and saves them. + :param output: training dataset containing the response output :param cell_line_input: not needed :param drug_input: not needed @@ -47,13 +64,14 @@ def train( def predict( self, - drug_ids=None, - cell_line_ids: ArrayLike = None, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ - Predicts the dataset mean for each drug-cell line combination + Predicts the dataset mean for each drug-cell line combination. + :param cell_line_ids: cell line ids :param drug_ids: not needed :param cell_line_input: not needed @@ -62,51 +80,78 @@ def predict( """ return np.full(cell_line_ids.shape[0], self.dataset_mean) - def save(self, path): - raise NotImplementedError("Naive predictor does not support saving yet ...") - - def load(self, path): - raise NotImplementedError("Naive predictor does not support loading yet ...") - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the cell line features, in this case the cell line ids. + + :param data_path: path to the data + :param dataset_name: name of the dataset + :returns: FeatureDataset containing the cell line ids + """ return load_cl_ids_from_csv(data_path, dataset_name) def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the drug features, in this case the drug ids. + + :param data_path: path to the data + :param dataset_name: name of the dataset + :returns: FeatureDataset containing the drug ids + """ return load_drug_ids_from_csv(data_path, dataset_name) class NaiveDrugMeanPredictor(DRPModel): - """ - Naive predictor model that predicts the mean of the response per drug. - """ + """Naive predictor model that predicts the mean of the response per drug.""" - model_name = "NaiveDrugMeanPredictor" cell_line_views = ["cell_line_id"] drug_views = ["drug_id"] def __init__(self): + """ + Initializes the model. + + Drug means and dataset mean are set to None, which are initialized in the train method. + """ super().__init__() self.drug_means = None self.dataset_mean = None + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: NaiveDrugMeanPredictor + """ + return "NaiveDrugMeanPredictor" + def build_model(self, hyperparameters: dict): + """ + Builds the model from hyperparameters. Not needed for the NaiveDrugMeanPredictor. + + :param hyperparameters: Hyperparameters for the model, not needed + """ pass def train( self, output: DrugResponseDataset, - cell_line_input=None, - drug_input: FeatureDataset = None, - output_earlystopping=None, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ - Computes the mean per drug. If - later on - the drug is not in the training set, - the overall mean is used. + Computes the mean per drug. If - later on - the drug is not in the training set, the overall mean is used. + :param output: training dataset containing the response output - :param drug_input: drug id :param cell_line_input: not needed + :param drug_input: drug id :param output_earlystopping: not needed + :raises ValueError: If drug_input is None """ + if drug_input is None: + raise ValueError("drug_input (drug_id) is required for the NaiveDrugMeanPredictor.") drug_ids = drug_input.get_feature_matrix(view="drug_id", identifiers=output.drug_ids) self.dataset_mean = np.mean(output.response) self.drug_means = {} @@ -119,26 +164,30 @@ def train( def predict( self, - drug_ids: ArrayLike, - cell_line_ids=None, - drug_input=None, - cell_line_input=None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ - Predicts the drug mean for each drug-cell line combination. If the drug is not in the - training set, the dataset mean is used. - :param drug_ids: drug ids + Predicts the drug mean for each drug-cell line combination. + + If the drug is not in the training set, the dataset mean is used. + :param cell_line_ids: not needed - :param drug_input: not needed + :param drug_ids: drug ids :param cell_line_input: not needed + :param drug_input: not needed :return: array of the same length as the input drug_id containing the drug mean """ return np.array([self.predict_drug(drug) for drug in drug_ids]) def predict_drug(self, drug_id: str): """ - Predicts the mean of the response for a given drug. If the drug is not in the training set, - the dataset mean is used. + Predicts the mean of the response for a given drug. + + If the drug is not in the training set, the dataset mean is used. + :param drug_id: ID of the drug :return: predicted response """ @@ -146,46 +195,71 @@ def predict_drug(self, drug_id: str): return self.drug_means[drug_id] return self.dataset_mean - def save(self, path): - raise NotImplementedError("Naive predictor does not support saving yet ...") - - def load(self, path): - raise NotImplementedError("Naive predictor does not support loading yet ...") - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the cell line features, in this case the cell line ids. + + :param data_path: path to the data + :param dataset_name: name of the dataset + :returns: FeatureDataset containing the cell line ids + """ return load_cl_ids_from_csv(data_path, dataset_name) def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the drug features, in this case the drug ids. + + :param data_path: path to the data + :param dataset_name: name of the dataset + :returns: FeatureDataset containing the drug ids + """ return load_drug_ids_from_csv(data_path, dataset_name) class NaiveCellLineMeanPredictor(DRPModel): - """ - Naive predictor model that predicts the mean of the response per cell line. - """ + """Naive predictor model that predicts the mean of the response per cell line.""" - model_name = "NaiveCellLineMeanPredictor" cell_line_views = ["cell_line_id"] drug_views = ["drug_id"] def __init__(self): + """ + Initializes the model. + + Cell line means and dataset mean are set to None, which are initialized in the train method. + """ super().__init__() self.cell_line_means = None self.dataset_mean = None + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: NaiveCellLineMeanPredictor + """ + return "NaiveCellLineMeanPredictor" + def build_model(self, hyperparameters: dict): + """ + Builds the model from hyperparameters. Not needed for the NaiveCellLineMeanPredictor. + + :param hyperparameters: not needed + """ pass def train( self, output: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input=None, - output_earlystopping=None, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ - Computes the mean per cell line. If - later on - the cell line is not in the training - set, the overall mean is used. + Computes the mean per cell line. + + If - later on - the cell line is not in the training set, the overall mean is used. :param output: training dataset containing the response output :param cell_line_input: cell line inputs :param drug_input: not needed @@ -205,26 +279,29 @@ def train( def predict( self, - drug_ids=None, - cell_line_ids: ArrayLike = None, - drug_input=None, - cell_line_input=None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ - Predicts the cell line mean for each drug-cell line combination. If the cell line is not - in the training set, the dataset mean is used. + Predicts the cell line mean for each drug-cell line combination. + + If the cell line is not in the training set, the dataset mean is used. + :param cell_line_ids: cell line ids :param drug_ids: not needed - :param drug_input: not needed :param cell_line_input: not needed + :param drug_input: not needed :return: array of the same length as the input cell_line_id containing the cell line mean """ return np.array([self.predict_cl(cl) for cl in cell_line_ids]) - def predict_cl(self, cl_id: str): + def predict_cl(self, cl_id: str) -> float: """ - Predicts the mean of the response for a given cell line. If the cell line is not in the - training set, the dataset mean is used. + Predicts the mean of the response for a given cell line. + + If the cell line is not in the training set, the dataset mean is used. :param cl_id: Cell line ID :return: predicted response """ @@ -232,14 +309,22 @@ def predict_cl(self, cl_id: str): return self.cell_line_means[cl_id] return self.dataset_mean - def save(self, path): - raise NotImplementedError("Naive predictor does not support saving yet ...") - - def load(self, path): - raise NotImplementedError("Naive predictor does not support loading yet ...") - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the cell line features, in this case the cell line ids. + + :param data_path: path to the data + :param dataset_name: name of the dataset + :returns: FeatureDataset containing the cell line ids + """ return load_cl_ids_from_csv(data_path, dataset_name) def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the drug features, in this case the drug ids. + + :param data_path: path to the data + :param dataset_name: name of the dataset + :returns: FeatureDataset containing the drug ids + """ return load_drug_ids_from_csv(data_path, dataset_name) diff --git a/drevalpy/models/baselines/singledrug_random_forest.py b/drevalpy/models/baselines/singledrug_random_forest.py index f1b8387..6849a84 100644 --- a/drevalpy/models/baselines/singledrug_random_forest.py +++ b/drevalpy/models/baselines/singledrug_random_forest.py @@ -1,34 +1,38 @@ """ -Contains the SingleDrugRandomForest class, which is a RandomForest model that uses only gene -expression dataset for drug response prediction and trains one model per drug. -""" +Contains the SingleDrugRandomForest class. -from typing import Optional +It is a RandomForest model that uses only gene expression dataset for drug response prediction and trains one model +per drug. +""" import numpy as np -from numpy.typing import ArrayLike - -from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset -from ..drp_model import SingleDrugModel +from ...datasets.dataset import DrugResponseDataset, FeatureDataset from .sklearn_models import RandomForest -class SingleDrugRandomForest(SingleDrugModel, RandomForest): - """ - SingleDrugRandomForest class. - """ +class SingleDrugRandomForest(RandomForest): + """SingleDrugRandomForest class.""" + is_single_drug_model = True drug_views = [] - model_name = "SingleDrugRandomForest" early_stopping = False + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: SingleDrugRandomForest + """ + return "SingleDrugRandomForest" + def train( self, output: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input=None, - output_earlystopping=None, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ Trains the model; the number of features is the number of fingerprints. @@ -37,27 +41,48 @@ def train( :param cell_line_input: training dataset containing gene expression data :param drug_input: not needed :param output_earlystopping: not needed + :raises ValueError: if drug_input is not None """ - if drug_input is not None or output_earlystopping is not None: - raise ValueError("SingleDrugRandomForest does not support drug_input or " "output_earlystopping!") + if drug_input is not None: + raise ValueError("SingleDrugRandomForest does not support drug_input!") - x = self.get_concatenated_features( - cell_line_view="gene_expression", - drug_view=None, - cell_line_ids_output=output.cell_line_ids, - drug_ids_output=output.drug_ids, - cell_line_input=cell_line_input, - drug_input=None, - ) - self.model.fit(x, output.response) + if len(output) > 0: + x = self.get_concatenated_features( + cell_line_view="gene_expression", + drug_view=None, + cell_line_ids_output=output.cell_line_ids, + drug_ids_output=output.drug_ids, + cell_line_input=cell_line_input, + drug_input=None, + ) + self.model.fit(x, output.response) + else: + print("No training data provided, will predict NA.") + self.model = None def predict( self, - drug_ids: ArrayLike, - cell_line_ids: ArrayLike, - drug_input: Optional[FeatureDataset] = None, - cell_line_input: FeatureDataset = None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: + """ + Predicts the drug response for the given cell lines. + + :param cell_line_ids: cell line ids + :param drug_ids: drug ids, not needed here + :param cell_line_input: cell line input + :param drug_input: drug input, not needed here + :returns: predicted drug response + :raises ValueError: if drug_input is not None + """ + if drug_input is not None: + raise ValueError("drug_input is not needed.") + + if self.model is None: + print("No training data was available, predicting NA.") + return np.array([np.nan] * len(cell_line_ids)) x = self.get_concatenated_features( cell_line_view="gene_expression", drug_view=None, diff --git a/drevalpy/models/baselines/sklearn_models.py b/drevalpy/models/baselines/sklearn_models.py index e6b3bd6..84a67e0 100644 --- a/drevalpy/models/baselines/sklearn_models.py +++ b/drevalpy/models/baselines/sklearn_models.py @@ -1,7 +1,8 @@ """Contains sklearn baseline models: ElasticNet, RandomForest, SVM.""" +from typing import Optional + import numpy as np -from numpy.typing import ArrayLike from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor from sklearn.linear_model import ElasticNet, Lasso, Ridge from sklearn.svm import SVR @@ -13,22 +14,35 @@ class SklearnModel(DRPModel): - """ - Parent class that contains the common methods for the sklearn models. - """ + """Parent class that contains the common methods for the sklearn models.""" cell_line_views = ["gene_expression"] drug_views = ["fingerprints"] def __init__(self): + """ + Initializes the model. + + Sets the model to None, which is initialized in the build_model method to the respective sklearn model. + """ super().__init__() self.model = None + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :raises NotImplementedError: If the method is not implemented in the child class. + """ + raise NotImplementedError("get_model_name method has to be implemented in the child class.") + def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. - :param hyperparameters: Custom hyperparameters for the model, have to be defined in the - child class. + + :param hyperparameters: Custom hyperparameters for the model, have to be defined in the child class. + :raises NotImplementedError: If the method is not implemented in the child class. """ raise NotImplementedError("build_model method has to be implemented in the child class.") @@ -36,17 +50,21 @@ def train( self, output: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input: FeatureDataset = None, - output_earlystopping=None, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ - Trains the model: the number of features is the number of genes + the number of - fingerprints. + Trains the model. + + The number of features is the number of genes + the number of fingerprints. :param output: training dataset containing the response output :param cell_line_input: training dataset containing gene expression data :param drug_input: training dataset containing fingerprints data :param output_earlystopping: not needed + :raises ValueError: If drug_input is None. """ + if drug_input is None: + raise ValueError("drug_input (fingerprints) is required for the sklearn models.") x = self.get_concatenated_features( cell_line_view="gene_expression", @@ -60,17 +78,24 @@ def train( def predict( self, - drug_ids: ArrayLike, - cell_line_ids: ArrayLike, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ Predicts the response for the given input. - :param gene_expression: gene expression data - :param fingerprints: fingerprints data - :return: predicted response + + :param drug_ids: drug ids + :param cell_line_ids: cell line ids + :param drug_input: drug input + :param cell_line_input: cell line input + :returns: predicted drug response + :raises ValueError: If drug_input is not None. """ + if drug_input is None: + raise ValueError("drug_input (fingerprints) is required.") + x = self.get_concatenated_features( cell_line_view="gene_expression", drug_view="fingerprints", @@ -81,18 +106,13 @@ def predict( ) return self.model.predict(x) - def save(self, path): - raise NotImplementedError("ElasticNetModel does not support saving yet ...") - - def load(self, path): - raise NotImplementedError("ElasticNetModel does not support loading yet ...") - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Loads the cell line features. - :param path: Path to the gene expression and landmark genes - :return: FeatureDataset containing the cell line gene expression features, filtered - through the landmark genes + + :param data_path: Path to the gene expression and landmark genes + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the cell line gene expression features, filtered through the landmark genes """ return load_and_reduce_gene_features( feature_type="gene_expression", @@ -101,20 +121,33 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD dataset_name=dataset_name, ) - def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + def load_drug_features(self, data_path: str, dataset_name: str) -> Optional[FeatureDataset]: + """ + Load the drug features, in this case the fingerprints. + + :param data_path: Path to the data + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the drug fingerprints + """ return load_drug_fingerprint_features(data_path, dataset_name) class ElasticNetModel(SklearnModel): - """ - ElasticNet model for drug response prediction. - """ + """ElasticNet model for drug response prediction.""" - model_name = "ElasticNet" + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: ElasticNet + """ + return "ElasticNet" def build_model(self, hyperparameters: dict): """ Builds the ElasticNet model from hyperparameters. + :param hyperparameters: Contains L1 ratio and alpha. """ if hyperparameters["l1_ratio"] == 0.0: @@ -129,16 +162,23 @@ def build_model(self, hyperparameters: dict): class RandomForest(SklearnModel): - """ - RandomForest model for drug response prediction. - """ + """RandomForest model for drug response prediction.""" - model_name = "RandomForest" + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: RandomForest + """ + return "RandomForest" def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. - :param hyperparameters: Hyperparameters for the model. + + :param hyperparameters: Hyperparameters for the model. Contains n_estimators, criterion, max_samples, + and n_jobs. """ if hyperparameters["max_depth"] == "None": hyperparameters["max_depth"] = None @@ -151,16 +191,22 @@ def build_model(self, hyperparameters: dict): class SVMRegressor(SklearnModel): - """ - SVM model for drug response prediction. - """ + """SVM model for drug response prediction.""" - model_name = "SVR" + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: SVR (Support Vector Regressor) + """ + return "SVR" def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. - :param hyperparameters: Hyperparameters for the model. + + :param hyperparameters: Hyperparameters for the model. Contains kernel, C, epsilon, and max_iter. """ self.model = SVR( kernel=hyperparameters["kernel"], @@ -171,16 +217,23 @@ def build_model(self, hyperparameters: dict): class GradientBoosting(SklearnModel): - """ - Gradient Boosting model for drug response prediction. - """ + """Gradient Boosting model for drug response prediction.""" - model_name = "GradientBoosting" + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: GradientBoosting + """ + return "GradientBoosting" def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. - :param hyperparameters: Hyperparameters for the model. + + :param hyperparameters: Hyperparameters for the model. Contains n_estimators, learning_rate, max_depth, + and subsample """ if hyperparameters["max_depth"] == "None": hyperparameters["max_depth"] = None diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index bf7efdc..ad64238 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -1,57 +1,67 @@ """ -Contains the DRPModel class, which is an abstract wrapper class for drug response prediction -models, the SingleDrugModel class, which is an abstract wrapper class for single drug models and -CompositeDrugModel class, which transforms multiple separate single drug response prediction models -into a global model by applying a separate model for each drug. +Contains the DRPModel class. + +The DRPModel class is an abstract wrapper class for drug response prediction models. + + """ import inspect import os -import warnings from abc import ABC, abstractmethod from typing import Any, Optional import numpy as np import yaml -from numpy.typing import ArrayLike from sklearn.model_selection import ParameterGrid from ..datasets.dataset import DrugResponseDataset, FeatureDataset +from ..pipeline_function import pipeline_function class DRPModel(ABC): """ Abstract wrapper class for drug response prediction models. + + The DRPModel class is an abstract wrapper class for drug response prediction models. + It has a boolean attribute is_single_drug_model indicating whether it is a single drug model and a boolean + attribute early_stopping indicating whether early stopping is used. """ + # Used in the pipeline! early_stopping = False + # Then, the model is trained per drug + is_single_drug_model = False + @classmethod @abstractmethod - def __init__(self, *args, **kwargs): + @pipeline_function + def get_model_name(cls) -> str: """ - Creates an instance of a drug response prediction model. - :param model_name: model name for displaying results - :param args: optional arguments - :param kwargs: optional keyword arguments + Returns the name of the model. + + :return: model name """ @classmethod - def get_hyperparameter_set(cls, hyperparameter_file: Optional[str] = None): + @pipeline_function + def get_hyperparameter_set(cls) -> list[dict[str, Any]]: """ - Loads the hyperparameters from a yaml file. - :param hyperparameter_file: yaml file containing the hyperparameters - :return: + Loads the hyperparameters from a yaml file which is located in the same directory as the model. + + :returns: list of hyperparameter sets + :raises ValueError: if the hyperparameters are not in the correct format + :raises KeyError: if the model is not found in the hyperparameters file """ - if hyperparameter_file is None: - hyperparameter_file = os.path.join(os.path.dirname(inspect.getfile(cls)), "hyperparameters.yaml") + hyperparameter_file = os.path.join(os.path.dirname(inspect.getfile(cls)), "hyperparameters.yaml") with open(hyperparameter_file, encoding="utf-8") as f: try: - hpams = yaml.safe_load(f)[cls.model_name] + hpams = yaml.safe_load(f)[cls.get_model_name()] except yaml.YAMLError as exc: raise ValueError(f"Error in hyperparameters.yaml: {exc}") from exc except KeyError as key_exc: - raise KeyError(f"Model {cls.model_name} not found in hyperparameters.yaml") from key_exc + raise KeyError(f"Model {cls.get_model_name()} not found in hyperparameters.yaml") from key_exc if hpams is None: return [{}] @@ -64,33 +74,34 @@ def get_hyperparameter_set(cls, hyperparameter_file: Optional[str] = None): @property @abstractmethod - def model_name(self): - """ - Returns the model name. - :return: model name - """ - - @property - @abstractmethod - def cell_line_views(self): + def cell_line_views(self) -> list[str]: """ Returns the sources the model needs as input for describing the cell line. + :return: cell line views, e.g., ["methylation", "gene_expression", "mirna_expression", - "mutation"] + "mutation"]. If the model does not use cell line features, return an empty list. """ @property @abstractmethod - def drug_views(self): + def drug_views(self) -> list[str]: """ Returns the sources the model needs as input for describing the drug. - :return: drug views, e.g., ["descriptors", "fingerprints", "targets"] + + :return: drug views, e.g., ["descriptors", "fingerprints", "targets"]. If the model does not use drug features, + return an empty list. """ @abstractmethod - def build_model(self, hyperparameters: dict[str, Any]): + def build_model(self, hyperparameters: dict[str, Any]) -> None: """ Builds the model, for models that use hyperparameters. + + :param hyperparameters: hyperparameters for the model + + Example:: + + self.model = ElasticNet(alpha=hyperparameters["alpha"], l1_ratio=hyperparameters["l1_ratio"]) """ @abstractmethod @@ -98,78 +109,101 @@ def train( self, output: DrugResponseDataset, cell_line_input: FeatureDataset, - drug_input: Optional[FeatureDataset] = None, - output_earlystopping: Optional[DrugResponseDataset] = None, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, ) -> None: """ Trains the model. + :param output: training data associated with the response output - :param cell_line_input: input associated with the cell line - :param drug_input: input associated with the drug + :param cell_line_input: input associated with the cell line, required for all models + :param drug_input: input associated with the drug, optional because single drug models do not use drug features :param output_earlystopping: optional early stopping dataset """ @abstractmethod def predict( self, - drug_ids: ArrayLike, - cell_line_ids: ArrayLike, - drug_input: FeatureDataset = None, - cell_line_input: FeatureDataset = None, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, ) -> np.ndarray: """ Predicts the response for the given input. + :param drug_ids: list of drug ids, also used for single drug models, there it is just an array containing the + same drug id + :param cell_line_ids: list of cell line ids + :param cell_line_input: input associated with the cell line, required for all models + :param drug_input: input associated with the drug, optional because single drug models do not use drug features + :returns: predicted response """ @abstractmethod - def save(self, path): - """ - Saves the model. - - :param path: path to save the model + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ + Load the cell line features before the train/predict method is called. - @abstractmethod - def load(self, path): - """ - Loads the model. + Required to implement for all models. Could, e.g., call get_multiomics_feature_dataset() or + load_and_reduce_gene_features() from models/utils.py. - :param path: path to load the model + :param data_path: path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., "GDSC2" + :returns: FeatureDataset with the cell line features """ @abstractmethod - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - """ - :return: FeatureDataset + def load_drug_features(self, data_path: str, dataset_name: str) -> Optional[FeatureDataset]: """ + Load the drug features before the train/predict method is called. - @abstractmethod - def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - """ - Load the drug features. + Required to implement for all models that use drug features. Could, e.g., + call load_drug_fingerprint_features() or load_drug_ids_from_csv() from models/utils.py. + + For single drug models, this method can return None. - :return: FeatureDataset + :param data_path: path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., "GDSC2" + :returns: FeatureDataset or None """ def get_concatenated_features( self, - cell_line_view: str, + cell_line_view: Optional[str], drug_view: Optional[str], - cell_line_ids_output: ArrayLike, - drug_ids_output: ArrayLike, + cell_line_ids_output: np.ndarray, + drug_ids_output: np.ndarray, cell_line_input: Optional[FeatureDataset], drug_input: Optional[FeatureDataset], - ): + ) -> np.ndarray: """ - Concatenates the features for the given cell line and drug view. - :param cell_line_view: - :param drug_view: - :param cell_line_ids_output: - :param drug_ids_output: - :param cell_line_input: - :param drug_input: - :return: X, the feature matrix needed for, e.g., sklearn models + Concatenates the features to an input matrix X for the given cell line and drug views. + + :param cell_line_view: gene expression, methylation, etc. + :param drug_view: ids, fingerprints, etc. + :param cell_line_ids_output: cell line ids + :param drug_ids_output: drug ids + :param cell_line_input: input associated with the cell line + :param drug_input: input associated with the drug + :returns: X, the feature matrix needed for, e.g., sklearn models + :raises ValueError: if no features are provided + + This can, e.g., be done in the training method to produce a large input feature matrix for the model where + the rows are the samples and the columns are the cell line and drug features concatenated. This method is an + alternative to using DataLoaders. It is used for models operating on the whole input matrix at once. + + Example:: + + x = self.get_concatenated_features( + cell_line_view="gene_expression", + drug_view="fingerprints", + cell_line_ids_output=output.cell_line_ids, + drug_ids_output=output.drug_ids, + cell_line_input=cell_line_input, + drug_input=drug_input, + ) + self.model.fit(x, output.response) """ inputs = self.get_feature_matrices( cell_line_ids=cell_line_ids_output, @@ -177,8 +211,8 @@ def get_concatenated_features( cell_line_input=cell_line_input, drug_input=drug_input, ) - cell_line_features = inputs.get(cell_line_view) - drug_features = inputs.get(drug_view) + cell_line_features = None if cell_line_view is None else inputs.get(cell_line_view) + drug_features = None if drug_view is None else inputs.get(drug_view) if cell_line_features is not None and drug_features is not None: x = np.concatenate((cell_line_features, drug_features), axis=1) @@ -192,24 +226,71 @@ def get_concatenated_features( def get_feature_matrices( self, - cell_line_ids: ArrayLike, - drug_ids: ArrayLike, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, cell_line_input: Optional[FeatureDataset], drug_input: Optional[FeatureDataset], - ): + ) -> dict[str, np.ndarray]: """ - Returns the feature matrices for the given cell line and drug ids by retrieving the - correct views. - :param cell_line_ids: - :param drug_ids: - :param cell_line_input: - :param drug_input: - :return: + Returns the feature matrices for the given cell line and drug ids by retrieving the correct views. + + :param cell_line_ids: cell line identifiers + :param drug_ids: drug identifiers + :param cell_line_input: cell line omics features + :param drug_input: drug omics features + :returns: dictionary with the feature matrices + :raises ValueError: if the input does not contain the correct views + + This can e.g., done to produce the input for the predict() method for deep learning models: + Example:: + + input_data = self.get_feature_matrices( + cell_line_ids=cell_line_ids, + drug_ids=drug_ids, + cell_line_input=cell_line_input, + drug_input=drug_input, + ) + ( + gene_expression, + mutations, + cnvs + ) = ( + input_data["gene_expression"], + input_data["mutations"], + input_data["copy_number_variation_gistic"] + ) + return self.model.predict(gene_expression, mutations, cnvs) + + Or to produce separate inputs for the train()/predict() method for other models if the model does not operate + on the concatenated input matrix:: + + inputs = self.get_feature_matrices( + cell_line_ids=output.cell_line_ids, + drug_ids=output.drug_ids, + cell_line_input=cell_line_input, + drug_input=drug_input, + ) + ( + gene_expression, + methylation, + mutations, + copy_number_variation_gistic, + fingerprints, + ) = ( + inputs["gene_expression"], + inputs["methylation"], + inputs["mutations"], + inputs["copy_number_variation_gistic"], + inputs["fingerprints"], + ) + self.model.fit( + gene_expression, methylation, mutations, copy_number_variation_gistic, fingerprints, output.response + ) """ cell_line_feature_matrices = {} if cell_line_input is not None: for cell_line_view in self.cell_line_views: - if cell_line_view not in cell_line_input.get_view_names(): + if cell_line_view not in cell_line_input.view_names: raise ValueError(f"Cell line input does not contain view {cell_line_view}") cell_line_feature_matrices[cell_line_view] = cell_line_input.get_feature_matrix( view=cell_line_view, identifiers=cell_line_ids @@ -217,130 +298,8 @@ def get_feature_matrices( drug_feature_matrices = {} if drug_input is not None: for drug_view in self.drug_views: - if drug_view not in drug_input.get_view_names(): + if drug_view not in drug_input.view_names: raise ValueError(f"Drug input does not contain view {drug_view}") drug_feature_matrices[drug_view] = drug_input.get_feature_matrix(view=drug_view, identifiers=drug_ids) return {**cell_line_feature_matrices, **drug_feature_matrices} - - -class SingleDrugModel(DRPModel, ABC): - """ - Abstract wrapper class for single drug response prediction models. - """ - - early_stopping = False - drug_views = [] - - def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - return None - - -class CompositeDrugModel(DRPModel): - """ - Transforms multiple separate single drug response prediction models into a global model by - applying a seperate model for each drug. - """ - - cell_line_views = None - drug_views = [] - model_name = "CompositeDrugModel" - - def __init__(self, base_model: type[DRPModel], *args, **kwargs): - """ - Creates an instance of a single drug response prediction model. - :param model_name: model name for displaying results - """ - super().__init__(*args, **kwargs) - self.models = {} - self.base_model = base_model - self.cell_line_views = base_model.cell_line_views - self.model_name = base_model.model_name - self.early_stopping = base_model.early_stopping - - def build_model(self, hyperparameters: dict[str, Any]): - """ - Builds the model. - """ - for drug in hyperparameters: - self.models[drug] = self.base_model() - self.models[drug].drug_views = self.drug_views - self.models[drug].build_model(hyperparameters[drug]) - - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - return list(self.models.values())[0].load_cell_line_features(data_path=data_path, dataset_name=dataset_name) - - def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - return None - - def train( - self, - output: DrugResponseDataset, - cell_line_input: FeatureDataset, - drug_input=None, - output_earlystopping: Optional[DrugResponseDataset] = None, - ) -> None: - """ - Trains the model. - :param output: Training data associated with the response output - :param cell_line_input: Input associated with the cell line - :param drug_input: Not needed for the single drug models - :param output_earlystopping: Optional. Training data associated with the early stopping - output - """ - drugs = np.unique(output.drug_ids) - for i, drug in enumerate(drugs): - if drug not in self.models: - raise AssertionError( - f"Drug {drug} not in models. Maybe the CompositeDrugModel was not built or drug " - f"missing from train data." - ) - print(f"Training model for drug {drug} ({i+1}/{len(drugs)})") - output_mask = output.drug_ids == drug - output_drug = output.copy() - output_drug.mask(output_mask) - output_earlystopping_drug = None - if output_earlystopping is not None: - output_earlystopping_mask = output_earlystopping.drug_ids == drug - output_earlystopping_drug = output_earlystopping.copy() - output_earlystopping_drug.mask(output_earlystopping_mask) - - self.models[drug].train( - output=output_drug, - cell_line_input=cell_line_input, - output_earlystopping=output_earlystopping_drug, - ) - - def predict( - self, - drug_ids: list[str], - cell_line_ids: list[str], - drug_input=None, - cell_line_input: FeatureDataset = None, - ) -> np.ndarray: - """ - Predicts the response for the given input. - :param drug_ids: list of drug ids - :param cell_line_ids: list of cell line ids - :param cell_line_input: input associated with the cell line - :param drug_input: not needed for the single drug models - :return: predicted response - """ - prediction = np.zeros_like(drug_ids, dtype=float) - for drug in np.unique(drug_ids): - mask = drug_ids == drug - if drug not in self.models: - prediction[mask] = np.nan - else: - prediction[mask] = self.models[drug].predict( - drug_ids=drug, - cell_line_ids=cell_line_ids[mask], - cell_line_input=cell_line_input, - ) - if np.any(np.isnan(prediction)): - warnings.warn( - "SingleDRPModel Warning: Some drugs were not in the training set. Prediction is " - "NaN. Maybe a SingleDRPModel was used in an LDO setting.", - stacklevel=2, - ) - return prediction diff --git a/drevalpy/models/simple_neural_network/__init__.py b/drevalpy/models/simple_neural_network/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/drevalpy/models/simple_neural_network/env.yml b/drevalpy/models/simple_neural_network/env.yml deleted file mode 100644 index 7185bf5..0000000 --- a/drevalpy/models/simple_neural_network/env.yml +++ /dev/null @@ -1,185 +0,0 @@ -name: drp -channels: - - conda-forge -dependencies: - - appnope=0.1.4 - - asttokens=2.4.1 - - brotli=1.1.0 - - brotli-bin=1.1.0 - - brotli-python=1.1.0 - - bzip2=1.0.8 - - ca-certificates=2024.2.2 - - certifi=2024.2.2 - - charset-normalizer=3.3.2 - - colorama=0.4.6 - - comm=0.2.1 - - contourpy=1.2.0 - - cycler=0.12.1 - - debugpy=1.8.1 - - decorator=5.1.1 - - exceptiongroup=1.2.0 - - executing=2.0.1 - - filelock=3.13.1 - - flaky=3.8.1 - - fonttools=4.50.0 - - freetype=2.12.1 - - fsspec=2024.2.0 - - gmp=6.3.0 - - gmpy2=2.1.2 - - idna=3.6 - - importlib-metadata=7.0.1 - - importlib_metadata=7.0.1 - - iniconfig=2.0.0 - - ipykernel=6.29.3 - - ipython=8.22.1 - - jedi=0.19.1 - - jinja2=3.1.3 - - joblib=1.3.2 - - jupyter_client=8.6.0 - - jupyter_core=5.7.1 - - kiwisolver=1.4.5 - - lcms2=2.16 - - lerc=4.0.0 - - libabseil=20230802.1 - - libblas=3.9.0 - - libbrotlicommon=1.1.0 - - libbrotlidec=1.1.0 - - libbrotlienc=1.1.0 - - libcblas=3.9.0 - - libcxx=16.0.6 - - libdeflate=1.20 - - libffi=3.4.2 - - libgfortran=5.0.0 - - libgfortran5=13.2.0 - - libjpeg-turbo=3.0.0 - - liblapack=3.9.0 - - libopenblas=0.3.26 - - libpng=1.6.43 - - libprotobuf=4.25.1 - - libsodium=1.0.18 - - libsqlite=3.45.1 - - libtiff=4.6.0 - - libtorch=2.1.2 - - libuv=1.47.0 - - libwebp-base=1.3.2 - - libxcb=1.15 - - libzlib=1.2.13 - - lightning=2.2.0.post0 - - lightning-utilities=0.10.1 - - littleutils=0.2.2 - - llvm-openmp=17.0.6 - - markupsafe=2.1.5 - - matplotlib-base=3.8.3 - - matplotlib-inline=0.1.6 - - mpc=1.3.1 - - mpfr=4.2.1 - - mpmath=1.3.0 - - munkres=1.1.4 - - ncurses=6.4 - - nest-asyncio=1.6.0 - - networkx=3.2.1 - - nomkl=1.0 - - numpy=1.26.4 - - openjpeg=2.5.2 - - openssl=3.2.1 - - outdated=0.2.2 - - packaging=23.2 - - pandas=2.2.1 - - pandas-flavor=0.6.0 - - parso=0.8.3 - - patsy=0.5.6 - - pexpect=4.9.0 - - pickleshare=0.7.5 - - pillow=10.2.0 - - pingouin=0.5.4 - - pip=24.0 - - platformdirs=4.2.0 - - pluggy=1.4.0 - - prompt-toolkit=3.0.42 - - psutil=5.9.8 - - pthread-stubs=0.4 - - ptyprocess=0.7.0 - - pure_eval=0.2.2 - - pygments=2.17.2 - - pyparsing=3.1.2 - - pysocks=1.7.1 - - pytest=8.1.1 - - python=3.10.13 - - python-dateutil=2.8.2 - - python-tzdata=2024.1 - - python_abi=3.10 - - pytorch=2.1.2 - - pytorch-lightning=2.1.3 - - pytz=2024.1 - - pyyaml=6.0.1 - - pyzmq=25.1.2 - - readline=8.2 - - requests=2.31.0 - - scikit-learn=1.4.1.post1 - - scipy=1.12.0 - - seaborn=0.13.2 - - seaborn-base=0.13.2 - - setuptools=69.1.1 - - six=1.16.0 - - sleef=3.5.1 - - stack_data=0.6.2 - - statsmodels=0.14.1 - - sympy=1.12 - - tabulate=0.9.0 - - threadpoolctl=3.3.0 - - tk=8.6.13 - - tomli=2.0.1 - - torchmetrics=1.2.1 - - tornado=6.4 - - tqdm=4.66.2 - - traitlets=5.14.1 - - typing-extensions=4.10.0 - - typing_extensions=4.10.0 - - tzdata=2024a - - unicodedata2=15.1.0 - - urllib3=2.2.1 - - wcwidth=0.2.13 - - wheel=0.42.0 - - xarray=2024.2.0 - - xorg-libxau=1.0.11 - - xorg-libxdmcp=1.1.3 - - xz=5.2.6 - - yaml=0.2.5 - - zeromq=4.3.5 - - zipp=3.17.0 - - zstd=1.5.5 - - pip: - - aiosignal==1.3.1 - - attrs==23.2.0 - - black==24.2.0 - - click==8.1.7 - - dill==0.3.8 - - diskcache==5.6.3 - - fastprogress==1.0.3 - - frozenlist==1.4.1 - - gensim==4.3.2 - - importlib-resources==6.4.0 - - jsonschema==4.21.1 - - jsonschema-specifications==2023.12.1 - - msgpack==1.0.7 - - mypy-extensions==1.0.0 - - pathspec==0.12.1 - - plotly==5.20.0 - - protobuf==4.25.3 - - pubchempy==1.0.4 - - pyarrow==15.0.0 - - pyfaidx==0.8.1.1 - - pytoda==1.1.4 - - ray==2.9.3 - - rdkit-pypi==2022.9.5 - - referencing==0.33.0 - - rpds-py==0.18.0 - - selfies==2.1.1 - - smart-open==7.0.4 - - smilespe==0.0.3 - - tenacity==8.2.3 - - tensorboardx==2.6.2.2 - - torch==2.2.1 - - unidecode==1.3.8 - - upfp==0.0.5 - - wrapt==1.16.0 diff --git a/drevalpy/models/simple_neural_network/env_linux.yml b/drevalpy/models/simple_neural_network/env_linux.yml deleted file mode 100644 index 81b037a..0000000 --- a/drevalpy/models/simple_neural_network/env_linux.yml +++ /dev/null @@ -1,211 +0,0 @@ -name: drp -channels: - - conda-forge - - bioconda - - defaults -dependencies: - - _libgcc_mutex=0.1 - - _openmp_mutex=4.5 - - appnope=0.1.4 - - asttokens=2.4.1 - - brotli=1.1.0 - - brotli-bin=1.1.0 - - brotli-python=1.1.0 - - bzip2=1.0.8 - - ca-certificates=2024.2.2 - - certifi=2024.2.2 - - charset-normalizer=3.3.2 - - colorama=0.4.6 - - comm=0.2.1 - - contourpy=1.2.0 - - cycler=0.12.1 - - debugpy=1.8.1 - - decorator=5.1.1 - - exceptiongroup=1.2.0 - - executing=2.0.1 - - filelock=3.13.1 - - flaky=3.8.1 - - fonttools=4.50.0 - - freetype=2.12.1 - - fsspec=2024.2.0 - - gmp=6.3.0 - - gmpy2=2.1.2 - - idna=3.6 - - importlib-metadata=7.0.1 - - importlib_metadata=7.0.1 - - iniconfig=2.0.0 - - ipykernel=6.29.3 - - ipython=8.22.1 - - jedi=0.19.1 - - jinja2=3.1.3 - - joblib=1.3.2 - - jupyter_client=8.6.0 - - jupyter_core=5.7.1 - - keyutils=1.6.1 - - kiwisolver=1.4.5 - - krb5=1.21.2 - - lcms2=2.16 - - ld_impl_linux-64=2.40 - - lerc=4.0.0 - - libabseil=20230802.1 - - libblas=3.9.0 - - libbrotlicommon=1.1.0 - - libbrotlidec=1.1.0 - - libbrotlienc=1.1.0 - - libcblas=3.9.0 - - libcxx=16.0.6 - - libcxxabi=16.0.6 - - libdeflate=1.20 - - libedit=3.1.20191231 - - libffi=3.4.2 - - libgcc-ng=13.2.0 - - libgfortran-ng=13.2.0 - - libgfortran5=13.2.0 - - libgomp=13.2.0 - - libjpeg-turbo=3.0.0 - - liblapack=3.9.0 - - libnsl=2.0.1 - - libopenblas=0.3.26 - - libpng=1.6.43 - - libprotobuf=4.25.1 - - libsodium=1.0.18 - - libsqlite=3.45.1 - - libstdcxx-ng=13.2.0 - - libtiff=4.6.0 - - libtorch=2.1.2 - - libuuid=2.38.1 - - libuv=1.47.0 - - libwebp-base=1.3.2 - - libxcb=1.15 - - libxcrypt=4.4.36 - - libzlib=1.2.13 - - lightning=2.2.0.post0 - - lightning-utilities=0.10.1 - - littleutils=0.2.2 - - llvm-openmp=17.0.6 - - markupsafe=2.1.5 - - matplotlib-base=3.8.3 - - matplotlib-inline=0.1.6 - - mpc=1.3.1 - - mpfr=4.2.1 - - mpmath=1.3.0 - - munkres=1.1.4 - - ncurses=6.4.20240210 - - nest-asyncio=1.6.0 - - networkx=3.2.1 - - nomkl=1.0 - - numpy=1.26.4 - - openjpeg=2.5.2 - - openssl=3.2.1 - - outdated=0.2.2 - - packaging=23.2 - - pandas=2.2.1 - - pandas-flavor=0.6.0 - - parso=0.8.3 - - patsy=0.5.6 - - pexpect=4.9.0 - - pickleshare=0.7.5 - - pillow=10.2.0 - - pingouin=0.5.4 - - pip=24.0 - - platformdirs=4.2.0 - - pluggy=1.4.0 - - prompt-toolkit=3.0.42 - - psutil=5.9.8 - - pthread-stubs=0.4 - - ptyprocess=0.7.0 - - pure_eval=0.2.2 - - pygments=2.17.2 - - pyparsing=3.1.2 - - pysocks=1.7.1 - - pytest=8.1.1 - - python=3.10.13 - - python-dateutil=2.8.2 - - python-tzdata=2024.1 - - python_abi=3.10 - - pytorch-lightning=2.1.3 - - pytz=2024.1 - - pyyaml=6.0.1 - - pyzmq=25.1.2 - - readline=8.2 - - requests=2.31.0 - - scikit-learn=1.4.1.post1 - - scipy=1.12.0 - - seaborn=0.13.2 - - seaborn-base=0.13.2 - - setuptools=69.1.1 - - six=1.16.0 - - sleef=3.5.1 - - stack_data=0.6.2 - - statsmodels=0.14.1 - - sympy=1.12 - - tabulate=0.9.0 - - threadpoolctl=3.3.0 - - tk=8.6.13 - - tomli=2.0.1 - - torchmetrics=1.2.1 - - tornado=6.4 - - tqdm=4.66.2 - - traitlets=5.14.1 - - typing-extensions=4.10.0 - - typing_extensions=4.10.0 - - tzdata=2024a - - unicodedata2=15.1.0 - - urllib3=2.2.1 - - wcwidth=0.2.13 - - wheel=0.42.0 - - xarray=2024.2.0 - - xorg-libxau=1.0.11 - - xorg-libxdmcp=1.1.3 - - xz=5.2.6 - - yaml=0.2.5 - - zeromq=4.3.5 - - zipp=3.17.0 - - zstd=1.5.5 - - pip: - - aiosignal==1.3.1 - - attrs==23.2.0 - - black==24.2.0 - - click==8.1.7 - - dill==0.3.8 - - diskcache==5.6.3 - - fastprogress==1.0.3 - - frozenlist==1.4.1 - - gensim==4.3.2 - - importlib-resources==6.4.0 - - jsonschema==4.21.1 - - jsonschema-specifications==2023.12.1 - - msgpack==1.0.7 - - mypy-extensions==1.0.0 - - nvidia-cublas-cu12==12.1.3.1 - - nvidia-cuda-cupti-cu12==12.1.105 - - nvidia-cuda-nvrtc-cu12==12.1.105 - - nvidia-cuda-runtime-cu12==12.1.105 - - nvidia-cudnn-cu12==8.9.2.26 - - nvidia-cufft-cu12==11.0.2.54 - - nvidia-curand-cu12==10.3.2.106 - - nvidia-cusolver-cu12==11.4.5.107 - - nvidia-cusparse-cu12==12.1.0.106 - - nvidia-nccl-cu12==2.19.3 - - nvidia-nvjitlink-cu12==12.4.127 - - nvidia-nvtx-cu12==12.1.105 - - pathspec==0.12.1 - - plotly==5.20.0 - - protobuf==4.25.3 - - pubchempy==1.0.4 - - pyarrow==15.0.0 - - pyfaidx==0.8.1.1 - - pytoda==1.1.4 - - ray==2.9.3 - - rdkit-pypi==2022.9.5 - - referencing==0.33.0 - - rpds-py==0.18.0 - - selfies==2.1.1 - - smart-open==7.0.4 - - smilespe==0.0.3 - - tensorboardx==2.6.2.2 - - torch==2.2.1 - - triton==2.2.0 - - unidecode==1.3.8 - - upfp==0.0.5 - - wrapt==1.16.0 diff --git a/drevalpy/models/utils.py b/drevalpy/models/utils.py index 535e46e..27d1e54 100644 --- a/drevalpy/models/utils.py +++ b/drevalpy/models/utils.py @@ -1,6 +1,4 @@ -""" -Utility functions for loading and processing data. -""" +"""Utility functions for loading and processing data.""" import os.path import warnings @@ -15,9 +13,10 @@ def load_cl_ids_from_csv(path: str, dataset_name: str) -> FeatureDataset: """ Load cell line ids from csv file. - :param path: - :param dataset_name: - :return: + + :param path: path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC2 + :returns: FeatureDataset with the cell line ids """ cl_names = pd.read_csv(f"{path}/{dataset_name}/cell_line_names.csv", index_col=1) return FeatureDataset(features={cl: {"cell_line_id": np.array([cl])} for cl in cl_names.index}) @@ -30,12 +29,14 @@ def load_and_reduce_gene_features( dataset_name: str, ) -> FeatureDataset: """ - Load and reduce gene features. - :param feature_type: - :param gene_list: - :param data_path: - :param dataset_name: - :return: + Load and reduce features of a single feature type. + + :param feature_type: type of feature, e.g., gene_expression, methylation, etc. + :param gene_list: list of genes to include, e.g., landmark_genes + :param data_path: path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC2 + :returns: FeatureDataset with the reduced features + :raises ValueError: if genes from gene_list are missing in the dataset """ ge = pd.read_csv(f"{data_path}/{dataset_name}/{feature_type}.csv", index_col=1) # remove column @@ -53,6 +54,9 @@ def load_and_reduce_gene_features( ) genes_in_list = set(gene_info["Symbol"]) + if cl_features.meta_info is None: + raise ValueError("No meta information available in the dataset.") + genes_in_features = set(cl_features.meta_info[feature_type]) # Ensure that all genes from gene_list are in the dataset missing_genes = genes_in_list - genes_in_features @@ -77,12 +81,13 @@ def load_and_reduce_gene_features( return cl_features -def iterate_features(df: pd.DataFrame, feature_type: str): +def iterate_features(df: pd.DataFrame, feature_type: str) -> dict[str, dict[str, np.ndarray]]: """ Iterate over features. - :param df: - :param feature_type: - :return: + + :param df: DataFrame with the features + :param feature_type: type of feature, e.g., gene_expression, methylation, etc. + :returns: dictionary with the features """ features = {} for cl in df.index: @@ -101,9 +106,10 @@ def iterate_features(df: pd.DataFrame, feature_type: str): def load_drug_ids_from_csv(data_path: str, dataset_name: str) -> FeatureDataset: """ Load drug ids from csv file. - :param data_path: - :param dataset_name: - :return: + + :param data_path: path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC2 + :returns: FeatureDataset with the drug ids """ drug_names = pd.read_csv(f"{data_path}/{dataset_name}/drug_names.csv", index_col=0) return FeatureDataset(features={drug: {"drug_id": np.array([drug])} for drug in drug_names.index}) @@ -112,9 +118,10 @@ def load_drug_ids_from_csv(data_path: str, dataset_name: str) -> FeatureDataset: def load_drug_fingerprint_features(data_path: str, dataset_name: str) -> FeatureDataset: """ Load drug features from fingerprints. - :param data_path: - :param dataset_name: - :return: + + :param data_path: path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC2 + :returns: FeatureDataset with the drug fingerprints """ if dataset_name == "Toy_Data": fingerprints = pd.read_csv(os.path.join(data_path, dataset_name, "fingerprints.csv"), index_col=0) @@ -132,48 +139,49 @@ def get_multiomics_feature_dataset( data_path: str, dataset_name: str, gene_list: Optional[str] = "drug_target_genes_all_drugs", + omics: Optional[list[str]] = None, ) -> FeatureDataset: """ - Get multiomics feature dataset. - :param data_path: - :param dataset_name: - :param gene_list: - :return: + Get multiomics feature dataset for the given list of OMICs. + + :param data_path: path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC2 + :param gene_list: list of genes to include, e.g., landmark_genes + :param omics: list of omics to include, e.g., ["gene_expression", "methylation"] + :returns: FeatureDataset with the multiomics features + :raises ValueError: if no omics features are found """ - ge_dataset = load_and_reduce_gene_features( - feature_type="gene_expression", - gene_list=gene_list, - data_path=data_path, - dataset_name=dataset_name, - ) - me_dataset = load_and_reduce_gene_features( - feature_type="methylation", - gene_list=None, - data_path=data_path, - dataset_name=dataset_name, - ) - mu_dataset = load_and_reduce_gene_features( - feature_type="mutations", - gene_list=gene_list, - data_path=data_path, - dataset_name=dataset_name, - ) - cnv_dataset = load_and_reduce_gene_features( - feature_type="copy_number_variation_gistic", - gene_list=gene_list, - data_path=data_path, - dataset_name=dataset_name, - ) - for fd in [me_dataset, mu_dataset, cnv_dataset]: - ge_dataset._add_features(fd) - return ge_dataset + if omics is None: + omics = ["gene_expression", "methylation", "mutations", "copy_number_variation_gistic"] + feature_dataset = None + for omic in omics: + if feature_dataset is None: + feature_dataset = load_and_reduce_gene_features( + feature_type=omic, + gene_list=None if omic == "methylation" else gene_list, + data_path=data_path, + dataset_name=dataset_name, + ) + else: + feature_dataset.add_features( + load_and_reduce_gene_features( + feature_type=omic, + gene_list=None if omic == "methylation" else gene_list, + data_path=data_path, + dataset_name=dataset_name, + ) + ) + if feature_dataset is None: + raise ValueError("No omics features found.") + return feature_dataset def unique(array): """ - Get unique values ordered by first occurence. - :param array: - :return: + Get unique values ordered by first occurrence. + + :param array: array of values + :returns: unique values ordered by first occurrence """ uniq, index = np.unique(array, return_index=True) return uniq[index.argsort()] diff --git a/drevalpy/pipeline_function.py b/drevalpy/pipeline_function.py new file mode 100644 index 0000000..531ae4e --- /dev/null +++ b/drevalpy/pipeline_function.py @@ -0,0 +1,12 @@ +"""Decorator to mark a function as a pipeline function.""" + + +def pipeline_function(func): + """ + Decorator to mark a function as a pipeline function. + + :param func: function to decorate + :return: function with custom attribute + """ + func.is_pipeline_function = True # Adds a custom attribute to the function + return func diff --git a/drevalpy/utils.py b/drevalpy/utils.py index 863bf0c..a4c7df7 100644 --- a/drevalpy/utils.py +++ b/drevalpy/utils.py @@ -1,20 +1,26 @@ """Utility functions for the evaluation pipeline.""" import argparse +import os +from typing import Optional +from sklearn.base import TransformerMixin from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler -from drevalpy.datasets import AVAILABLE_DATASETS -from drevalpy.datasets.loader import load_dataset -from drevalpy.evaluation import AVAILABLE_METRICS -from drevalpy.experiment import drug_response_experiment -from drevalpy.models import MODEL_FACTORY +from .datasets import AVAILABLE_DATASETS +from .datasets.dataset import DrugResponseDataset +from .datasets.loader import load_dataset +from .evaluation import AVAILABLE_METRICS +from .experiment import drug_response_experiment, pipeline_function +from .models import MODEL_FACTORY -def get_parser(): +@pipeline_function +def get_parser() -> argparse.ArgumentParser: """ Get the parser for the evaluation pipeline. - :return: + + :returns: parser """ parser = argparse.ArgumentParser(description="Run the drug response prediction model test suite.") parser.add_argument( @@ -157,12 +163,15 @@ def get_parser(): return parser -def check_arguments(args): +@pipeline_function +def check_arguments(args) -> None: """ Check the validity of the arguments for the evaluation pipeline. - :param args: - :return: + :param args: arguments passed from the command line + :raises AssertionError: if any of the arguments is invalid + :raises NotImplementedError: because CurveCurator is not implemented yet + :raises ValueError: if the number of cross-validation splits is less than 1 """ if not args.models: raise AssertionError("At least one model must be specified") @@ -199,6 +208,9 @@ def check_arguments(args): f"AVAILABLE_DATASETS in the response_datasets init." ) + # if the path to args.path_data does not exist, create the directory + os.makedirs(args.path_data, exist_ok=True) + if args.n_cv_splits <= 1: raise ValueError("Number of cross-validation splits must be greater than 1") @@ -218,11 +230,11 @@ def check_arguments(args): ) -def main(args): +def main(args) -> None: """ Main function to run the drug response evaluation pipeline. + :param args: passed from command line - :return: """ check_arguments(args) @@ -266,13 +278,16 @@ def main(args): ) -def get_datasets(dataset_name: str, cross_study_datasets: list, path_data: str = "data"): +def get_datasets( + dataset_name: str, cross_study_datasets: list, path_data: str = "data" +) -> tuple[DrugResponseDataset, Optional[list[DrugResponseDataset]]]: """ Load the response data and cross-study datasets. - :param dataset_name: - :param cross_study_datasets: - :param path_data: - :return: + + :param dataset_name: name of the dataset + :param cross_study_datasets: list of cross-study datasets + :param path_data: path to the data directory, default is "data" + :returns: response data and, potentially, cross-study datasets """ # PIPELINE: LOAD_RESPONSE response_data = load_dataset(dataset_name=dataset_name, path_data=path_data) @@ -281,11 +296,15 @@ def get_datasets(dataset_name: str, cross_study_datasets: list, path_data: str = return response_data, cross_study_datasets -def get_response_transformation(response_transformation: str): +@pipeline_function +def get_response_transformation(response_transformation: str) -> Optional[TransformerMixin]: """ - Get the response transformation object. - :param response_transformation: - :return: + Get the skelarn response transformation object of choice. + + Users can choose from "None", "standard", "minmax", "robust". + :param response_transformation: response transformation to apply + :returns: response transformation object + :raises ValueError: if the response transformation is not recognized """ if response_transformation == "None": return None diff --git a/drevalpy/visualization/__init__.py b/drevalpy/visualization/__init__.py index c64f7b9..f13249b 100644 --- a/drevalpy/visualization/__init__.py +++ b/drevalpy/visualization/__init__.py @@ -1,3 +1,5 @@ +"""Module containing the drevalpy plotly visualizations.""" + __all__ = [ "CorrelationComparisonScatter", "CriticalDifferencePlot", diff --git a/drevalpy/visualization/corr_comp_scatter.py b/drevalpy/visualization/corr_comp_scatter.py index bb8a669..5eb222b 100644 --- a/drevalpy/visualization/corr_comp_scatter.py +++ b/drevalpy/visualization/corr_comp_scatter.py @@ -1,4 +1,6 @@ -from typing import TextIO +"""Contains the code needed to draw the correlation comparison scatter plot.""" + +from io import TextIOWrapper import numpy as np import pandas as pd @@ -7,19 +9,40 @@ from plotly.subplots import make_subplots from scipy import stats -from drevalpy.models import SINGLE_DRUG_MODEL_FACTORY -from drevalpy.visualization.outplot import OutPlot +from ..models import SINGLE_DRUG_MODEL_FACTORY +from ..pipeline_function import pipeline_function +from .outplot import OutPlot class CorrelationComparisonScatter(OutPlot): + """ + Class to draw scatter plots for comparison of correlation metrics between models. + + Produces two types of plots: an overall comparison plot and a dropdown plot for comparison between all models. + If one model is consistently better than the other, the points deviate from the identity line (higher if the + model is on the y-axis, lower if it is on the x-axis. + The dropdown plot allows to select two models for comparison of their per-drug/per-cell-line pearson correlation. + The overall plot facets all models and visualizes the density of the points. + """ + + @pipeline_function def __init__( self, df: pd.DataFrame, color_by: str, lpo_lco_ldo: str, - metric="Pearson", - algorithm="all", + metric: str = "Pearson", + algorithm: str = "all", ): + """ + Initialize the CorrelationComparisonScatter object. + + :param df: evaluation results per group, either drug or cell line + :param color_by: group variable, i.e., drug or cell line + :param lpo_lco_ldo: evaluation setting, e.g., LCO (leave-cell-line-out) + :param metric: correlation metric to be compared. Default is Pearson. + :param algorithm: used to distinguish between per-algorithm plots and per-setting plots (all models then). + """ exclude_models = ( {"NaiveDrugMeanPredictor"}.union({model for model in SINGLE_DRUG_MODEL_FACTORY.keys()}) if color_by == "drug" @@ -28,24 +51,22 @@ def __init__( exclude_models.add("NaivePredictor") self.df = df.sort_values("model") + self.name: str | None = None if algorithm == "all": # draw plots for comparison between all models self.df = self.df[ (self.df["LPO_LCO_LDO"] == lpo_lco_ldo) & (self.df["rand_setting"] == "predictions") & (~self.df["algorithm"].isin(exclude_models)) - & # and exclude all lines for which algorithm starts with any element from # exclude_models - (~self.df["algorithm"].str.startswith(tuple(exclude_models))) + & (~self.df["algorithm"].str.startswith(tuple(exclude_models))) ] self.name = f"{color_by}_{lpo_lco_ldo}" elif algorithm not in exclude_models: # draw plots for comparison between all test settings of one model self.df = self.df[(self.df["LPO_LCO_LDO"] == lpo_lco_ldo) & (self.df["algorithm"] == algorithm)] self.name = f"{color_by} {algorithm} {lpo_lco_ldo}" - else: - self.name = None if self.df.empty: print(f"No data found for {self.name}. Skipping ...") return @@ -78,13 +99,21 @@ def __init__( for i in range(len(self.models)): self.fig_overall["layout"]["annotations"][i]["font"]["size"] = 12 self.dropdown_fig = go.Figure() - self.dropdown_buttons_x = list() - self.dropdown_buttons_y = list() + self.dropdown_buttons_x: list[dict] = list() + self.dropdown_buttons_y: list[dict] = list() + @pipeline_function def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: + """ + Draws and saves the scatter plots. + + :param out_prefix: e.g., results/my_run/corr_comp_scatter/ + :param out_suffix: should be self.name + :raises AssertionError: if out_suffix does not match self.name + """ if self.df.empty: return - self.__draw__() + self._draw() if self.name != out_suffix: raise AssertionError(f"Name mismatch: {self.name} != {out_suffix}") path_out = f"{out_prefix}corr_comp_scatter_{out_suffix}.html" @@ -92,9 +121,11 @@ def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: path_out = f"{out_prefix}corr_comp_scatter_overall_{out_suffix}.html" self.fig_overall.write_html(path_out) - def __draw__(self) -> None: + def _draw(self) -> None: + """Draws the scatter plots.""" print("Drawing scatterplots ...") - self.__generate_corr_comp_scatterplots__() + self._generate_corr_comp_scatterplots() + # Set titles self.fig_overall.update_layout( title=f'{str(self.color_by).replace("_", " ").capitalize()}-wise scatter plot of {self.metric} ' f"for each model", @@ -105,6 +136,7 @@ def __draw__(self) -> None: f"for each model", showlegend=False, ) + # Set dropdown menu self.dropdown_fig.update_layout( updatemenus=[ { @@ -131,8 +163,17 @@ def __draw__(self) -> None: self.dropdown_fig.update_yaxes(range=[-1, 1]) @staticmethod - def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO: - files = kwargs.get("files") + def write_to_html(lpo_lco_ldo: str, f: TextIOWrapper, *args, **kwargs) -> TextIOWrapper: + """ + Inserts the generated files into the result HTML file. + + :param lpo_lco_ldo: setting, e.g., LCO + :param f: file to write to + :param args: unused + :param kwargs: used to get all files generated by create_report.py / the pipeline + :returns: the file f + """ + files: list[str] = kwargs.get("files", []) f.write('

Comparison of correlation metrics

\n') for group_by in ["drug", "cell_line"]: plot_list = [f for f in files if f.startswith("corr_comp_scatter") and f.endswith(f"{lpo_lco_ldo}.html")] @@ -167,9 +208,10 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO: f.write("\n") return f - def __generate_corr_comp_scatterplots__(self): + def _generate_corr_comp_scatterplots(self) -> None: + """Generates the scatter plots.""" # render first scatterplot that is shown in the dropdown plot - first_df = self.__subset_df__(run_id=self.models[0]) + first_df = self._subset_df(run_id=self.models[0]) scatterplot = go.Scatter( x=first_df[self.metric], y=first_df[self.metric], @@ -193,7 +235,7 @@ def __generate_corr_comp_scatterplots__(self): for run_idx in range(len(self.models)): run = self.models[run_idx] - x_df = self.__subset_df__(run_id=run) + x_df = self._subset_df(run_id=run) self.dropdown_buttons_x.append( dict( label=run, @@ -206,9 +248,9 @@ def __generate_corr_comp_scatterplots__(self): ) for run2_idx in range(len(self.models)): run2 = self.models[run2_idx] - y_df = self.__subset_df__(run_id=run2) + y_df = self._subset_df(run_id=run2) - scatterplot = self.__draw_subplot__(x_df, y_df, run, run2) + scatterplot = self._draw_subplot(x_df, y_df, run, run2) self.fig_overall.add_trace(scatterplot, col=run_idx + 1, row=run2_idx + 1) self.fig_overall.add_trace(line_corr, col=run_idx + 1, row=run2_idx + 1) @@ -233,14 +275,29 @@ def __generate_corr_comp_scatterplots__(self): self.fig_overall["layout"][f"yaxis{y_axis_idx}"]["title"] = str(run2).replace("_", "
", 2) self.fig_overall["layout"][f"yaxis{y_axis_idx}"]["title"]["font"]["size"] = 6 - def __subset_df__(self, run_id: str): + def _subset_df(self, run_id: str) -> pd.DataFrame: + """ + Subsets the dataframe for a given run_id to the relevant columns and sets the index to the color_by variable. + + :param run_id: user-defined ID of the whole run + :returns: subsetted dataframe + """ s_df = self.df[self.df["setting"] == run_id][[self.metric, self.color_by, "model"]] s_df.set_index(self.color_by, inplace=True) s_df.sort_index(inplace=True) s_df[self.metric] = s_df[self.metric].fillna(0) return s_df - def __draw_subplot__(self, x_df, y_df, run, run2): + def _draw_subplot(self, x_df, y_df, run, run2) -> go.Scatter: + """ + A subplot of the faceted overall plot. + + :param x_df: dataframe for the x-axis + :param y_df: dataframe for the y-axis + :param run: title for the x-axis + :param run2: title for the y-axis + :returns: scatterplot for the subplot + """ # only retain the common indices common_indices = x_df.index.intersection(y_df.index) x_df_inter = x_df.loc[common_indices] @@ -258,7 +315,7 @@ def __draw_subplot__(self, x_df, y_df, run, run2): "setting_y", ] - density = self.__get_density__(joint_df[f"{self.metric}_x"], joint_df[f"{self.metric}_y"]) + density = self._get_density(joint_df[f"{self.metric}_x"], joint_df[f"{self.metric}_y"]) joint_df["color"] = density custom_text = joint_df.apply( @@ -283,8 +340,14 @@ def __draw_subplot__(self, x_df, y_df, run, run2): return scatterplot @staticmethod - def __get_density__(x: pd.Series, y: pd.Series): - """Get kernal density estimate for each (x, y) point.""" + def _get_density(x: pd.Series, y: pd.Series) -> np.ndarray: + """ + Get kernel density estimate for each (x, y) point. + + :param x: values on the x-axis + :param y: values on the y-axis + :returns: density of the points + """ try: values = np.vstack([x, y]) kernel = stats.gaussian_kde(values) diff --git a/drevalpy/visualization/critical_difference_plot.py b/drevalpy/visualization/critical_difference_plot.py index b5d2e95..1b5fd61 100644 --- a/drevalpy/visualization/critical_difference_plot.py +++ b/drevalpy/visualization/critical_difference_plot.py @@ -1,6 +1,17 @@ +""" +Draw critical difference plot which shows whether a model is significantly better than another model. + +Most code is a modified version of the code available at https://github.com/hfawaz/cd-diagram +Author: Hassan Ismail Fawaz , Germain Forestier , +Jonathan Weber , Lhassane Idoumghar , Pierre-Alain Muller + +License: GPL3 +""" + import math import operator -from typing import TextIO +from io import TextIOWrapper +from typing import Any import matplotlib import matplotlib.pyplot as plt @@ -9,8 +20,9 @@ import pandas as pd from scipy.stats import friedmanchisquare, wilcoxon -from drevalpy.evaluation import MINIMIZATION_METRICS -from drevalpy.visualization.outplot import OutPlot +from ..evaluation import MINIMIZATION_METRICS +from ..pipeline_function import pipeline_function +from .outplot import OutPlot matplotlib.use("agg") matplotlib.rcParams["font.family"] = "sans-serif" @@ -18,7 +30,25 @@ class CriticalDifferencePlot(OutPlot): + """ + Draws the critical difference diagram. + + Used by the pipeline! + + The critical difference diagram is used to compare the performance of multiple classifiers and show whether a + model is significantly better than another model. This is calculated over the average ranks of the classifiers + which is why there need to be at least 3 classifiers to draw the diagram. Because the ranks are calculated over + the cross-validation splits and the significance threshold is set to 0.05, e.g., 10 CV folds are advisable. + """ + + @pipeline_function def __init__(self, eval_results_preds: pd.DataFrame, metric="MSE"): + """ + Initializes the critical difference plot. + + :param eval_results_preds: evaluation results subsetted to predictions only (no randomizations etc) + :param metric: to be used to assess the critical difference + """ eval_results_preds = eval_results_preds[["algorithm", "CV_split", metric]].rename( columns={ "algorithm": "classifier_name", @@ -32,31 +62,52 @@ def __init__(self, eval_results_preds: pd.DataFrame, metric="MSE"): self.eval_results_preds = eval_results_preds self.metric = metric + @pipeline_function def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: + """ + Draws the critical difference plot and saves it to a file. + + :param out_prefix: e.g., results/my_run/critical_difference_plots/ + :param out_suffix: e.g., LPO + """ try: - self.__draw__() + self._draw() path_out = f"{out_prefix}critical_difference_algorithms_{out_suffix}.svg" self.fig.savefig(path_out, bbox_inches="tight") except Exception as e: print(f"Error in drawing critical difference plot: {e}") - def __draw__(self) -> None: - self.fig = self.__draw_cd_diagram__( + def _draw(self) -> None: + """Draws the critical difference plot.""" + self.fig = self._draw_cd_diagram( alpha=0.05, title=f"Critical Difference: {self.metric}", labels=True, ) @staticmethod - def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO: + def write_to_html(lpo_lco_ldo: str, f: TextIOWrapper, *args, **kwargs) -> TextIOWrapper: + """ + Inserts the critical difference plot into the HTML report file. + + :param lpo_lco_ldo: setting, e.g., LPO + :param f: HTML report file + :param args: not needed + :param kwargs: not needed + :returns: HTML report file + """ path_out_cd = f"critical_difference_plots/critical_difference_algorithms_{lpo_lco_ldo}.svg" f.write(f" ") return f - def __draw_cd_diagram__(self, alpha=0.05, title=None, labels=False) -> plt.Figure: + def _draw_cd_diagram(self, alpha=0.05, title=None, labels=False) -> plt.Figure: """ - Draws the critical difference diagram given the list of pairwise classifiers that are - significant or not + Draws the critical difference diagram given the list of pairwise classifiers. + + :param alpha: significance level + :param title: title of the plot + :param labels: whether to display the average ranks + :returns: the figure """ # Standard Plotly colors plotly_colors = [ @@ -72,22 +123,17 @@ def __draw_cd_diagram__(self, alpha=0.05, title=None, labels=False) -> plt.Figur "#17becf", ] - p_values, average_ranks, _ = wilcoxon_holm(alpha=alpha, df_perf=self.eval_results_preds) - - print(average_ranks) + p_values, average_ranks, _ = _wilcoxon_holm(df_perf=self.eval_results_preds, alpha=alpha) - for p in p_values: - print(p) - - graph_ranks( - avranks=average_ranks.values, - names=average_ranks.keys(), + _graph_ranks( + avranks=average_ranks.values.tolist(), + names=list(average_ranks.keys()), p_values=p_values, + colors=plotly_colors, reverse=True, - width=9, + width=9.0, textspace=1.5, labels=labels, - colors=plotly_colors, ) font = { @@ -101,39 +147,27 @@ def __draw_cd_diagram__(self, alpha=0.05, title=None, labels=False) -> plt.Figur return plt.gcf() -# The code below is a modified version of the code available at https://github.com/hfawaz/cd-diagram -# Author: Hassan Ismail Fawaz -# Germain Forestier -# Jonathan Weber -# Lhassane Idoumghar -# Pierre-Alain Muller -# License: GPL3 - - # inspired from orange3 https://docs.orange.biolab.si/3/data-mining-library/reference/evaluation.cd.html -def graph_ranks( - avranks, - names, - p_values, - lowv=None, - highv=None, - width=6, - textspace=1, - reverse=False, - labels=False, - colors=None, -): +def _graph_ranks( + avranks: list[float], + names: list[str], + p_values: list[tuple[str, str, float, bool]], + colors: list[str], + lowv: int | None = None, + highv: int | None = None, + width: float = 9.0, + textspace: float = 1.0, + reverse: bool = False, + labels: bool = False, +) -> None: """ - Draws a CD graph, which is used to display the differences in methods' - performance. See Janez Demsar, Statistical Comparisons of Classifiers over - Multiple Data Sets, 7(Jan):1--30, 2006. + Draws a CD graph, which is used to display the differences in methods' performance. - Needs matplotlib to work. + See Janez Demsar, Statistical Comparisons of Classifiers over Multiple Data Sets, 7(Jan):1--30, 2006. - The image is ploted on `plt` imported using + Needs matplotlib to work. The image is ploted on `plt` imported using `import matplotlib.pyplot as plt`. - Args: :param avranks: list of float, average ranks of methods. :param names: list of str, names of methods. :param p_values: list of tuples, p-values of the methods. @@ -146,51 +180,33 @@ def graph_ranks( :param colors: list of str, optional, list of colors for the methods """ - width = float(width) - textspace = float(textspace) - - def nth(data, position): + def nth(data: list[tuple[float, float]], position: int) -> list[float]: """ - Returns only nth elemnt in a list. + Returns only nth element in a list. + + :param data: list (text_space, cline), (width - text_space, cline) + :param position: position to return + :returns: nth element in the list """ position = lloc(data, position) return [a[position] for a in data] - def lloc(data, position): + def lloc(data: list[tuple[float, float]], position: int) -> int: """ List location in list of list structure. + Enable the use of negative locations: -1 is the last element, -2 second last... + + :param data: list (text_space, cline), (width - text_space, cline) + :param position: position to return + :returns: location in the list """ if position < 0: return len(data[0]) + position else: return position - def mxrange(lr): - """ - Multiple xranges. Can be used to traverse matrices. - This function is very slow due to unknown number of - parameters. - - >>> mxrange([3,5]) - [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)] - - >>> mxrange([[3,5,1],[9,0,-3]]) - [(3, 9), (3, 6), (3, 3), (4, 9), (4, 6), (4, 3)] - - """ - if not len(lr): - yield () - else: - # it can work with single numbers - index = lr[0] - if isinstance(index, int): - index = [index] - for a in range(*index): - for b in mxrange(lr[1:]): - yield tuple([a] + list(b)) - sums = avranks nnames = names @@ -208,7 +224,13 @@ def mxrange(lr): linesblank = 0 scalewidth = width - 2 * textspace - def rankpos(rank): + def rankpos(rank: float) -> float: + """ + Calculate the position of the rank. + + :param rank: rank of the method + :returns: textspace + scalewidth / (highv - lowv) * a + """ if not reverse: a = rank - lowv else: @@ -225,16 +247,28 @@ def rankpos(rank): fig = plt.figure(figsize=(width, height)) fig.set_facecolor("white") - ax = fig.add_axes([0, 0, 1, 1]) # reverse y axis + ax = fig.add_axes(rect=(0.0, 0.0, 1.0, 1.0)) # reverse y axis ax.set_axis_off() hf = 1.0 / height # height factor wf = 1.0 / width def hfl(list_input): + """ + List input multiplied by height factor. + + :param list_input: list of floats (cline) + :returns: list of floats + """ return [a * hf for a in list_input] - def wfl(list_input): + def wfl(list_input: list[float]) -> list[float]: + """ + List input multiplied by width factor. + + :param list_input: list of floats (text_space) + :returns: list of floats + """ return [a * wf for a in list_input] # Upper left corner is (0,0). @@ -242,13 +276,26 @@ def wfl(list_input): ax.set_xlim(0, 1) ax.set_ylim(1, 0) - def line(list_input, color="k", **kwargs): + def line(list_input: list[tuple[float, float]], color: str = "k", **kwargs) -> None: """ Input is a list of pairs of points. + + :param list_input: (text_space, cline), (width - text_space, cline) + :param color: color of the line + :param kwargs: additional arguments for plotting """ ax.plot(wfl(nth(list_input, 0)), hfl(nth(list_input, 1)), color=color, **kwargs) - def text(x, y, s, *args, **kwargs): + def text(x: float, y: float, s: str, *args, **kwargs): + """ + Add text to the plot. + + :param x: x position + :param y: y position + :param s: text to display + :param args: additional arguments + :param kwargs: additional keyword arguments + """ ax.text(wf * x, hf * y, s, *args, **kwargs) line( @@ -262,7 +309,6 @@ def text(x, y, s, *args, **kwargs): linewidth = 2.0 linewidth_sign = 4.0 - tick = None for a in list(np.arange(lowv, highv, 0.5)) + [highv]: tick = smalltick if a == int(a): @@ -276,7 +322,7 @@ def text(x, y, s, *args, **kwargs): for a in range(lowv, highv + 1): text( rankpos(a), - cline - tick / 2 - 0.05, + cline - bigtick / 2 - 0.05, str(a), ha="center", va="bottom", @@ -285,7 +331,13 @@ def text(x, y, s, *args, **kwargs): k = len(ssums) - def filter_names(name): + def filter_names(name: str) -> str: + """ + Filter the names. + + :param name: name of the method + :returns: name of the method + """ return name space_between_names = 0.24 @@ -354,7 +406,7 @@ def filter_names(name): # draw no significant lines # get the cliques - cliques = form_cliques(p_values, nnames) + cliques = _form_cliques(p_values, nnames) i = 1 achieved_half = False print(nnames) @@ -378,17 +430,21 @@ def filter_names(name): start += height -def form_cliques(p_values, nnames): +def _form_cliques(p_values: list[tuple[str, str, float, bool]], nnames: list[str]) -> Any: """ - This method forms the cliques + This method forms the cliques. + + :param p_values: list of tuples, p-values of the methods strucutred as (Method1, Method2, p-value, is_significant) + :param nnames: list of str, names of the methods + :returns: cliques """ # first form the numpy matrix data m = len(nnames) g_data = np.zeros((m, m), dtype=np.int64) for p in p_values: if p[3] is False: - i = np.where(nnames == p[0])[0][0] - j = np.where(nnames == p[1])[0][0] + i = int(np.where(np.array(nnames) == p[0])[0][0]) + j = int(np.where(np.array(nnames) == p[1])[0][0]) min_i = min(i, j) max_j = max(i, j) g_data[min_i, max_j] = 1 @@ -397,10 +453,18 @@ def form_cliques(p_values, nnames): return networkx.find_cliques(g) -def wilcoxon_holm(alpha=0.05, df_perf=None): +def _wilcoxon_holm( + df_perf: pd.DataFrame, alpha: float = 0.05 +) -> tuple[list[tuple[str, str, float, bool]], pd.Series, int]: """ - Applies the wilcoxon signed rank test between each pair of algorithm and then use Holm - to reject the null's hypothesis + Applies the Wilcoxon signed rank test between algorithm pair and then use Holm to reject the null hypothesis. + + Returns the p-values in a format of (Method1, Method2, p-value, is_significant), the average ranks in a format of + pd.Series(Method: avg_rank), and the maximum number of datasets tested (=n_cv_folds). + + :param alpha: significance level + :param df_perf: the dataframe containing the performance of the algorithms + :returns: the p-values, the average ranks, and the maximum number of datasets tested """ print(pd.unique(df_perf["classifier_name"])) # count the number of tested datasets per classifier diff --git a/drevalpy/visualization/heatmap.py b/drevalpy/visualization/heatmap.py index c3ae8b2..c4e3d37 100644 --- a/drevalpy/visualization/heatmap.py +++ b/drevalpy/visualization/heatmap.py @@ -1,13 +1,27 @@ +"""Plots a heatmap of the evaluation metrics.""" + import numpy as np import pandas as pd import plotly.graph_objects as go from plotly.subplots import make_subplots -from drevalpy.visualization.vioheat import VioHeat +from ..pipeline_function import pipeline_function +from .vioheat import VioHeat class Heatmap(VioHeat): + """Plots a heatmap of the evaluation metrics.""" + + @pipeline_function def __init__(self, df: pd.DataFrame, normalized_metrics=False, whole_name=False): + """ + Initialize the Heatmap class. + + :param df: either containing all predictions for all algorithms or all tests for one algorithm (including + robustness, randomization, … tests then) + :param normalized_metrics: whether the metrics are normalized + :param whole_name: whether the whole name should be displayed + """ super().__init__(df, normalized_metrics, whole_name) self.df = self.df[[col for col in self.df.columns if col in self.all_metrics]] if self.normalized_metrics: @@ -45,15 +59,23 @@ def __init__(self, df: pd.DataFrame, normalized_metrics=False, whole_name=False) vertical_spacing=0.1, ) + @pipeline_function def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: - self.__draw__() + """ + Draw the heatmap and save it to a file. + + :param out_prefix: e.g., results/my_run/heatmaps/ + :param out_suffix: e.g., algorithms_normalized + """ + self._draw() path_out = f"{out_prefix}heatmap_{out_suffix}.html" self.fig.write_html(path_out) - def __draw__(self) -> None: + def _draw(self) -> None: + """Draw the heatmap.""" print("Drawing heatmaps ...") for plot_setting in self.plot_settings: - self.__draw_subplots__(plot_setting) + self._draw_subplots(plot_setting) self.fig.update_layout( height=1000, width=1100, @@ -61,17 +83,23 @@ def __draw__(self) -> None: ) self.fig.update_traces(showscale=False) - def __draw_subplots__(self, plot_setting): + def _draw_subplots(self, plot_setting: str) -> None: + """ + Draw the subplots of the heatmap. + + :param plot_setting: Either "standard_errors", "r2", "correlations", or "errors" + :raises ValueError: If an unknown plot setting is given + """ idx_split = self.df.index.to_series().str.split("_") setting = idx_split.str[0:3].str.join("_") if plot_setting == "standard_errors": - dt = self.df.groupby(setting).apply(lambda x: self.calc_summary_metric(x=x, std_error=True)) + dt = self.df.groupby(setting).apply(lambda x: self._calc_summary_metric(x=x, std_error=True)) row_idx = 1 colorscale = "Pinkyl" elif plot_setting == "r2": r2_columns = [col for col in self.df.columns if "R^2" in col] dt = self.df[r2_columns] - dt = dt.groupby(setting).apply(lambda x: self.calc_summary_metric(x=x, std_error=False)) + dt = dt.groupby(setting).apply(lambda x: self._calc_summary_metric(x=x, std_error=False)) dt = dt.sort_values(by=r2_columns[0], ascending=True) row_idx = 2 colorscale = "Blues" @@ -83,13 +111,13 @@ def __draw_subplots__(self, plot_setting): ] corr_columns.sort() dt = self.df[corr_columns] - dt = dt.groupby(setting).apply(lambda x: self.calc_summary_metric(x=x, std_error=False)) + dt = dt.groupby(setting).apply(lambda x: self._calc_summary_metric(x=x, std_error=False)) dt = dt.sort_values(by=corr_columns[0], ascending=True) row_idx = 3 colorscale = "Viridis" elif plot_setting == "errors": dt = self.df[["MSE", "RMSE", "MAE"]] - dt = dt.groupby(setting).apply(lambda x: self.calc_summary_metric(x=x, std_error=False)) + dt = dt.groupby(setting).apply(lambda x: self._calc_summary_metric(x=x, std_error=False)) dt = dt.sort_values(by="MSE", ascending=False) row_idx = 4 colorscale = "hot" @@ -112,7 +140,14 @@ def __draw_subplots__(self, plot_setting): ) @staticmethod - def calc_summary_metric(x, std_error=False): + def _calc_summary_metric(x: pd.DataFrame, std_error: bool = False): + """ + Calculate the mean or standard error of the metrics. + + :param x: DataFrame containing the metrics + :param std_error: whether to calculate the standard error or the mean + :returns: Series containing the mean or standard error of the metrics + """ # make empty results series results = pd.Series(index=x.columns) # iterate over columns diff --git a/drevalpy/visualization/html_tables.py b/drevalpy/visualization/html_tables.py index 0bfeb2c..e150cd9 100644 --- a/drevalpy/visualization/html_tables.py +++ b/drevalpy/visualization/html_tables.py @@ -1,22 +1,42 @@ +"""Renders the evaluation results as HTML tables.""" + import os -from typing import TextIO +from io import TextIOWrapper import pandas as pd -from drevalpy.visualization.outplot import OutPlot +from ..pipeline_function import pipeline_function +from .outplot import OutPlot class HTMLTable(OutPlot): + """Renders the evaluation results as HTML tables.""" + + @pipeline_function def __init__(self, df: pd.DataFrame, group_by: str): + """ + Initialize the HTMLTable class. + + :param df: either all results of a setting or results evaluated by group (cell line, drug) for a setting + :param group_by: all or the group by which the results are evaluated + """ self.df = df self.group_by = group_by + @pipeline_function def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: - self.__draw__() + """ + Draw the table and save it to a file. + + :param out_prefix: e.g., results/my_run/html_tables/ + :param out_suffix: e.g., LPO, LPO_drug + """ + self._draw() path_out = f"{out_prefix}table_{out_suffix}.html" self.df.to_html(path_out, index=False) - def __draw__(self) -> None: + def _draw(self) -> None: + """Draw the table.""" selected_columns = [ "algorithm", "rand_setting", @@ -66,8 +86,18 @@ def __draw__(self) -> None: self.df = self.df[selected_columns] @staticmethod - def write_to_html(lpo_lco_ldo: str, f: TextIO, prefix: str = "", *args, **kwargs) -> TextIO: - files = kwargs.get("files") + def write_to_html(lpo_lco_ldo: str, f: TextIOWrapper, prefix: str = "", *args, **kwargs) -> TextIOWrapper: + """ + Write the evaluation results into the report HTML file. + + :param lpo_lco_ldo: setting, e.g., LPO + :param f: report file + :param prefix: e.g., results/my_run + :param args: additional arguments + :param kwargs: additional keyword arguments + :return: the report file + """ + files: list[str] = kwargs.get("files", []) if prefix != "": prefix = os.path.join(prefix, "html_tables") f.write('

Evaluation Results Table

\n') @@ -85,7 +115,7 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, prefix: str = "", *args, **kwargs return f -def _write_table(f: TextIO, table: str, prefix: str = ""): +def _write_table(f: TextIOWrapper, table: str, prefix: str = ""): with open(os.path.join(prefix, table)) as eval_f: eval_results = eval_f.readlines() eval_results[0] = eval_results[0].replace( diff --git a/drevalpy/visualization/outplot.py b/drevalpy/visualization/outplot.py index 953cd1c..cbe858a 100644 --- a/drevalpy/visualization/outplot.py +++ b/drevalpy/visualization/outplot.py @@ -1,39 +1,37 @@ +"""Abstract wrapper class for all visualizations.""" + from abc import ABC, abstractmethod -from typing import TextIO +from io import TextIOWrapper class OutPlot(ABC): - """ - Abstract wrapper class for all visualizations - """ + """Abstract wrapper class for all visualizations.""" @abstractmethod def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: """ - Draw and save the plot + Draw and save the plot. + :param out_prefix: path to output directory for python package :param out_suffix: custom suffix for output file - :return: """ pass @abstractmethod - def __draw__(self) -> None: - """ - Draw the plot - :return: - """ + def _draw(self) -> None: + """Draw the plot.""" pass @staticmethod @abstractmethod - def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO: + def write_to_html(lpo_lco_ldo: str, f: TextIOWrapper, *args, **kwargs) -> TextIOWrapper: """ - Write the plot to html - :param lpo_lco_ldo: - :param f: - :param args: - :param kwargs: - :return: + Write the plot to the final report file. + + :param lpo_lco_ldo: LPO, LCO, LDO + :param f: the file to write to + :param args: additional arguments + :param kwargs: additional keyword arguments + :return: the file to write to """ pass diff --git a/drevalpy/visualization/regression_slider_plot.py b/drevalpy/visualization/regression_slider_plot.py index d63d888..349f74f 100644 --- a/drevalpy/visualization/regression_slider_plot.py +++ b/drevalpy/visualization/regression_slider_plot.py @@ -1,14 +1,21 @@ -from typing import TextIO +"""Module for generating regression plots with a slider for Pearson correlation coefficient.""" + +from io import TextIOWrapper import numpy as np import pandas as pd import plotly.express as px +import plotly.graph_objects as go from scipy.stats import pearsonr -from drevalpy.visualization.outplot import OutPlot +from ..pipeline_function import pipeline_function +from .outplot import OutPlot class RegressionSliderPlot(OutPlot): + """Generates regression plots with a slider for the Pearson correlation coefficient.""" + + @pipeline_function def __init__( self, df: pd.DataFrame, @@ -17,11 +24,20 @@ def __init__( group_by: str = "drug", normalize=False, ): + """ + Initialize the RegressionSliderPlot class. + + :param df: true vs. predicted values + :param lpo_lco_ldo: setting, e.g., LPO + :param model: model name + :param group_by: either "drug" or "cell_line" + :param normalize: whether to normalize the true and predicted values by the mean of the group + """ self.df = df[(df["LPO_LCO_LDO"] == lpo_lco_ldo) & (df["rand_setting"] == "predictions")] self.df = self.df[(self.df["algorithm"] == model)] self.group_by = group_by self.normalize = normalize - self.fig = None + self.fig = go.Figure() self.model = model if self.normalize: @@ -32,22 +48,39 @@ def __init__( self.df.loc[:, "y_true"] = self.df["y_true"] - self.df["mean_y_true_per_cell_line"] self.df.loc[:, "y_pred"] = self.df["y_pred"] - self.df["mean_y_true_per_cell_line"] + @pipeline_function def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: - self.__draw__() + """ + Draw the regression plot and save it to a file. + + :param out_prefix: e.g., results/my_run/regression_plots/ + :param out_suffix: e.g., LPO_drug_SimpleNeuralNetwork + """ + self._draw() self.fig.write_html(f"{out_prefix}regression_lines_{out_suffix}.html") - def __draw__(self): + def _draw(self): + """Draw the regression plot.""" print(f"Generating regression plots for {self.group_by}, normalize={self.normalize}...") self.df = self.df.groupby(self.group_by).filter(lambda x: len(x) > 1) pccs = self.df.groupby(self.group_by).apply(lambda x: pearsonr(x["y_true"], x["y_pred"])[0]) pccs = pccs.reset_index() pccs.columns = [self.group_by, "pcc"] self.df = self.df.merge(pccs, on=self.group_by) - self.__render_plot__() + self._render_plot() @staticmethod - def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO: - files = kwargs.get("files") + def write_to_html(lpo_lco_ldo: str, f: TextIOWrapper, *args, **kwargs) -> TextIOWrapper: + """ + Write the plot to the final report file. + + :param lpo_lco_ldo: setting, e.g., LPO + :param f: final report file + :param args: additional arguments + :param kwargs: additional keyword arguments, in this case all files + :return: the final report file + """ + files: list[str] = kwargs.get("files", []) f.write('

Regression plots

\n') f.write("
    \n") regr_files = [f for f in files if lpo_lco_ldo in f and f.startswith("regression_lines")] @@ -57,7 +90,8 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO: f.write("
\n") return f - def __render_plot__(self): + def _render_plot(self): + """Render the regression plot.""" # sort df by group name df = self.df.sort_values(self.group_by) setting_title = self.model + " " + df["LPO_LCO_LDO"].unique()[0] @@ -98,9 +132,14 @@ def __render_plot__(self): max_val = np.max([np.max(df["y_true"]), np.max(df["y_pred"])]) self.fig.update_xaxes(range=[min_val, max_val]) self.fig.update_yaxes(range=[min_val, max_val]) - self.__make_slider__(setting_title) + self._make_slider(setting_title) + + def _make_slider(self, setting_title: str) -> None: + """ + Make a slider for the Pearson correlation coefficient. - def __make_slider__(self, setting_title): + :param setting_title: title of the plot + """ n_ticks = 21 steps = [] # take the range from pcc (-1 - 1) and divide it into n_ticks-1 equal parts diff --git a/drevalpy/visualization/utils.py b/drevalpy/visualization/utils.py index b4ff49e..bdd3cfa 100644 --- a/drevalpy/visualization/utils.py +++ b/drevalpy/visualization/utils.py @@ -1,30 +1,30 @@ -""" -Utility functions for the visualization part of the package. -""" +"""Utility functions for the visualization part of the package.""" import os import pathlib import re import shutil +from typing import Optional, TextIO import importlib_resources import pandas as pd -from drevalpy.datasets.dataset import DrugResponseDataset -from drevalpy.evaluation import AVAILABLE_METRICS, evaluate -from drevalpy.visualization import HTMLTable -from drevalpy.visualization.corr_comp_scatter import CorrelationComparisonScatter -from drevalpy.visualization.critical_difference_plot import CriticalDifferencePlot -from drevalpy.visualization.regression_slider_plot import RegressionSliderPlot -from drevalpy.visualization.vioheat import VioHeat +from ..datasets.dataset import DrugResponseDataset +from ..evaluation import AVAILABLE_METRICS, evaluate +from ..pipeline_function import pipeline_function +from .corr_comp_scatter import CorrelationComparisonScatter +from .critical_difference_plot import CriticalDifferencePlot +from .html_tables import HTMLTable +from .regression_slider_plot import RegressionSliderPlot +from .vioheat import VioHeat -def parse_layout(f, path_to_layout): +def _parse_layout(f: TextIO, path_to_layout: str) -> None: """ Parse the layout file and write it to the output file. - :param f: - :param path_to_layout: - :return: + + :param f: file to write to + :param path_to_layout: path to the layout file """ with open(path_to_layout, encoding="utf-8") as layout_f: layout = layout_f.readlines() @@ -37,11 +37,13 @@ def parse_layout(f, path_to_layout): f.write("".join(layout)) -def parse_results(path_to_results: str): +def parse_results(path_to_results: str) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Parse the results from the given directory. - :param path_to_results: - :return: + + :param path_to_results: path to the results directory + :returns: evaluation results, evaluation results per drug, evaluation results per cell line, and true vs. predicted + values """ print("Generating result tables ...") # generate list of all result files @@ -49,10 +51,12 @@ def parse_results(path_to_results: str): result_files = list(result_dir.rglob("*.csv")) # filter for all files that follow this pattern: # result_dir/*/{predictions|cross_study|randomization|robustness}/*.csv + # Convert the path to a forward-slash version for the regex (for Windows) + result_dir_str = str(result_dir).replace("\\", "/") pattern = re.compile( - rf"{result_dir}/(LPO|LCO|LDO)/[^/]+/(predictions|cross_study|randomization|robustness)/.*\.csv$" + rf"{result_dir_str}/(LPO|LCO|LDO)/[^/]+/(predictions|cross_study|randomization|robustness)/.*\.csv$" ) - result_files = [file for file in result_files if pattern.match(str(file))] + result_files = [file for file in result_files if pattern.match(str(file).replace("\\", "/"))] # inititalize dictionaries to store the evaluation results evaluation_results = None @@ -62,9 +66,11 @@ def parse_results(path_to_results: str): # read every result file and compute the evaluation metrics for file in result_files: - file_parts = os.path.normpath(file).split("/") - lpo_lco_ldo = file_parts[2] - algorithm = file_parts[3] + rel_file = str(os.path.normpath(file.relative_to(result_dir))).replace("\\", "/") + print(f'Evaluating file: "{rel_file}" ...') + file_parts = rel_file.split("/") + lpo_lco_ldo = file_parts[0] + algorithm = file_parts[1] ( overall_eval, eval_results_per_drug, @@ -100,26 +106,26 @@ def parse_results(path_to_results: str): ) -def evaluate_file(pred_file: pathlib.Path, test_mode: str, model_name: str): +@pipeline_function +def evaluate_file( + pred_file: pathlib.Path, test_mode: str, model_name: str +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, str]: """ Evaluate the predictions from the final models. - :param pred_file: - :param test_mode: - :param model_name: - :return: + + :param pred_file: path to the prediction file + :param test_mode: test mode, e.g., LPO + :param model_name: model name, e.g., SimpleNeuralNetwork + :return: evaluation results, evaluation results per drug, evaluation results per cell line, true vs. predicted + values, and model name """ print("Parsing file:", os.path.normpath(pred_file)) - result = pd.read_csv(pred_file) - dataset = DrugResponseDataset( - response=result["response"], - cell_line_ids=result["cell_line_ids"], - drug_ids=result["drug_ids"], - predictions=result["predictions"], - ) - model = generate_model_names(test_mode=test_mode, model_name=model_name, pred_file=pred_file) + dataset = DrugResponseDataset.from_csv(pred_file) + + model = _generate_model_names(test_mode=test_mode, model_name=model_name, pred_file=pred_file) # overall evaluation - overall_eval = {model: evaluate(dataset, AVAILABLE_METRICS.keys())} + overall_eval = {model: evaluate(dataset, list(AVAILABLE_METRICS.keys()))} true_vs_pred = pd.DataFrame( { @@ -133,11 +139,11 @@ def evaluate_file(pred_file: pathlib.Path, test_mode: str, model_name: str): evaluation_results_per_drug = None evaluation_results_per_cl = None - norm_drug_eval_results = {} - norm_cl_eval_results = {} + norm_drug_eval_results: dict[str, dict[str, float]] = {} + norm_cl_eval_results: dict[str, dict[str, float]] = {} if "LPO" in model or "LCO" in model: - norm_drug_eval_results, evaluation_results_per_drug = evaluate_per_group( + norm_drug_eval_results, evaluation_results_per_drug = _evaluate_per_group( df=true_vs_pred, group_by="drug", norm_group_eval_results=norm_drug_eval_results, @@ -145,7 +151,7 @@ def evaluate_file(pred_file: pathlib.Path, test_mode: str, model_name: str): model=model, ) if "LPO" in model or "LDO" in model: - norm_cl_eval_results, evaluation_results_per_cl = evaluate_per_group( + norm_cl_eval_results, evaluation_results_per_cl = _evaluate_per_group( df=true_vs_pred, group_by="cell_line", norm_group_eval_results=norm_cl_eval_results, @@ -154,9 +160,9 @@ def evaluate_file(pred_file: pathlib.Path, test_mode: str, model_name: str): ) overall_eval = pd.DataFrame.from_dict(overall_eval, orient="index") if len(norm_drug_eval_results) > 0: - overall_eval = concat_results(norm_drug_eval_results, "drug", overall_eval) + overall_eval = _concat_results(norm_drug_eval_results, "drug", overall_eval) if len(norm_cl_eval_results) > 0: - overall_eval = concat_results(norm_cl_eval_results, "cell_line", overall_eval) + overall_eval = _concat_results(norm_cl_eval_results, "cell_line", overall_eval) return ( overall_eval, @@ -167,30 +173,37 @@ def evaluate_file(pred_file: pathlib.Path, test_mode: str, model_name: str): ) -def concat_results(norm_group_res, group_by, eval_res): +def _concat_results(norm_group_res: dict[str, dict[str, float]], group_by: str, eval_res: pd.DataFrame) -> pd.DataFrame: """ Concatenate the normalized group results to the evaluation results. - :param norm_group_res: - :param group_by: - :param eval_res: - :return: + + :param norm_group_res: dictionary with the normalized group results, key: model name, value: evaluation results + :param group_by: either cell line or drug + :param eval_res: overall dataframe + :returns: overall dataframe extended by the normalized group results """ - norm_group_res = pd.DataFrame.from_dict(norm_group_res, orient="index") + norm_group_df = pd.DataFrame.from_dict(norm_group_res, orient="index") # append 'group normalized ' to the column names - norm_group_res.columns = [f"{col}: {group_by} normalized" for col in norm_group_res.columns] - eval_res = pd.concat([eval_res, norm_group_res], axis=1) + norm_group_df.columns = [f"{col}: {group_by} normalized" for col in norm_group_df.columns] + eval_res = pd.concat([eval_res, norm_group_df], axis=1) return eval_res -def prep_results(eval_results, eval_results_per_drug, eval_results_per_cell_line, t_vs_p): +@pipeline_function +def prep_results( + eval_results: pd.DataFrame, + eval_results_per_drug: pd.DataFrame, + eval_results_per_cell_line: pd.DataFrame, + t_vs_p: pd.DataFrame, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ - Prepare the results by introducing new columns for algorithm, randomization, setting, split, - CV_split. - :param eval_results: - :param eval_results_per_drug: - :param eval_results_per_cell_line: - :param t_vs_p: - :return: + Prepare the results by introducing new columns for algorithm, randomization, setting, split, CV_split. + + :param eval_results: evaluation results + :param eval_results_per_drug: evaluation results per drug + :param eval_results_per_cell_line: evaluation results per cell line + :param t_vs_p: true vs. predicted values + :returns: the same dataframes with new columns """ # add variables # split the index by "_" into: algorithm, randomization, setting, split, CV_split @@ -224,13 +237,15 @@ def prep_results(eval_results, eval_results_per_drug, eval_results_per_cell_line ) -def generate_model_names(test_mode, model_name, pred_file): +def _generate_model_names(test_mode: str, model_name: str, pred_file: pathlib.Path) -> str: """ Generate the model names based on the prediction file. - :param test_mode: - :param model_name: - :param pred_file: - :return: + + :param test_mode: test mode, e.g., LPO + :param model_name: model name, e.g., SimpleNeuralNetwork + :param pred_file: file containing the predictions + :returns: unique name of run = {model_name}_{pred_setting}_{test_mode}_{split} + :raises ValueError: if the prediction setting is unknown """ file_parts = os.path.basename(pred_file).split("_") pred_rand_rob = file_parts[0] @@ -248,15 +263,22 @@ def generate_model_names(test_mode, model_name, pred_file): return f"{model_name}_{pred_setting}_{test_mode}_{split}" -def evaluate_per_group(df, group_by, norm_group_eval_results, eval_results_per_group, model): +def _evaluate_per_group( + df: pd.DataFrame, + group_by: str, + norm_group_eval_results: dict[str, dict[str, float]], + eval_results_per_group: Optional[pd.DataFrame], + model: str, +) -> tuple[dict[str, dict[str, float]], pd.DataFrame]: """ Evaluate the predictions per group. - :param df: - :param group_by: - :param norm_group_eval_results: - :param eval_results_per_group: - :param model: - :return: + + :param df: true vs. predicted values + :param group_by: cell line or drug + :param norm_group_eval_results: dictionary to store the normalized group evaluation results + :param eval_results_per_group: evaluation results per group + :param model: model name + :returns: dictionary with the normalized group evaluation results and the evaluation results per group """ # calculate the mean of y_true per drug print(f"Calculating {group_by}-wise evaluation measures …") @@ -266,36 +288,37 @@ def evaluate_per_group(df, group_by, norm_group_eval_results, eval_results_per_g norm_df["y_pred"] = norm_df["y_pred"] - norm_df[f"mean_y_true_per_{group_by}"] norm_group_eval_results[model] = evaluate( DrugResponseDataset( - response=norm_df["y_true"], - cell_line_ids=norm_df["cell_line"], - drug_ids=norm_df["drug"], - predictions=norm_df["y_pred"], + response=norm_df["y_true"].to_numpy(), + cell_line_ids=norm_df["cell_line"].to_numpy(), + drug_ids=norm_df["drug"].to_numpy(), + predictions=norm_df["y_pred"].to_numpy(), ), - AVAILABLE_METRICS.keys() - {"MSE", "RMSE", "MAE"}, + list(AVAILABLE_METRICS.keys() - {"MSE", "RMSE", "MAE"}), ) # evaluation per group eval_results_per_group = compute_evaluation(df, eval_results_per_group, group_by, model) return norm_group_eval_results, eval_results_per_group -def compute_evaluation(df, return_df, group_by, model): +def compute_evaluation(df: pd.DataFrame, return_df: pd.DataFrame | None, group_by: str, model: str) -> pd.DataFrame: """ Compute the evaluation metrics per group. - :param df: - :param return_df: - :param group_by: - :param model: - :return: + + :param df: true vs. predicted values with mean_y_true_per_{group_by} column + :param return_df: DataFrame to store the results + :param group_by: either cell line or drug + :param model: model name + :returns: dataframe with the evaluation results per group """ - result_per_group = df.groupby(group_by).apply( + result_per_group = df.groupby(group_by)[["y_true", "cell_line", "drug", "y_pred"]].apply( lambda x: evaluate( DrugResponseDataset( - response=x["y_true"], - cell_line_ids=x["cell_line"], - drug_ids=x["drug"], - predictions=x["y_pred"], + response=x["y_true"].to_numpy(), + cell_line_ids=x["cell_line"].to_numpy(), + drug_ids=x["drug"].to_numpy(), + predictions=x["y_pred"].to_numpy(), ), - AVAILABLE_METRICS.keys(), + list(AVAILABLE_METRICS.keys()), ) ) groups = result_per_group.index @@ -309,15 +332,22 @@ def compute_evaluation(df, return_df, group_by, model): return return_df -def write_results(path_out, eval_results, eval_results_per_drug, eval_results_per_cl, t_vs_p): +@pipeline_function +def write_results( + path_out: str, + eval_results: pd.DataFrame, + eval_results_per_drug: pd.DataFrame, + eval_results_per_cl: pd.DataFrame, + t_vs_p: pd.DataFrame, +) -> None: """ Write the results to csv files. - :param path_out: - :param eval_results: - :param eval_results_per_drug: - :param eval_results_per_cl: - :param t_vs_p: - :return: + + :param path_out: path to the output directory, e.g., results/my_run/ + :param eval_results: evaluation results + :param eval_results_per_drug: evaluation results per drug + :param eval_results_per_cl: evaluation results per cell line + :param t_vs_p: true vs. predicted values """ eval_results.to_csv(f"{path_out}evaluation_results.csv", index=True) if eval_results_per_drug is not None: @@ -327,13 +357,14 @@ def write_results(path_out, eval_results, eval_results_per_drug, eval_results_pe t_vs_p.to_csv(f"{path_out}true_vs_pred.csv", index=True) -def create_index_html(custom_id: str, test_modes: list[str], prefix_results: str): +@pipeline_function +def create_index_html(custom_id: str, test_modes: list[str], prefix_results: str) -> None: """ Create the index.html file. - :param custom_id: - :param test_modes: - :param prefix_results: - :return: + + :param custom_id: custom id for the results, e.g., my_run + :param test_modes: list of test modes, e.g., ["LPO", "LCO", "LDO"] + :param prefix_results: path to the results directory, e.g., results/my_run """ # copy images to the results directory file_to_copy = [ @@ -357,7 +388,7 @@ def create_index_html(custom_id: str, test_modes: list[str], prefix_results: str ) idx_html_path = os.path.join(prefix_results, "index.html") with open(idx_html_path, "w", encoding="utf-8") as f: - parse_layout(f=f, path_to_layout=layout_path) + _parse_layout(f=f, path_to_layout=layout_path) f.write('
\n') f.write('Logo\n') f.write(f"

Results for {custom_id}

\n") @@ -384,14 +415,15 @@ def create_index_html(custom_id: str, test_modes: list[str], prefix_results: str f.write("\n") -def create_html(run_id: str, lpo_lco_ldo: str, files: list, prefix_results: str): +@pipeline_function +def create_html(run_id: str, lpo_lco_ldo: str, files: list, prefix_results: str) -> None: """ - Create the html file for the given test mode. - :param run_id: - :param lpo_lco_ldo: - :param files: - :param prefix_results: - :return: + Create the html file for the given test mode, e.g., LPO.html. + + :param run_id: custom id for the results, e.g., my_run + :param lpo_lco_ldo: test mode, e.g., LPO + :param files: list of files in the results directory + :param prefix_results: path to the results directory, e.g., results/my_run """ page_layout = os.path.join( str(importlib_resources.files("drevalpy")), @@ -400,7 +432,7 @@ def create_html(run_id: str, lpo_lco_ldo: str, files: list, prefix_results: str) html_path = os.path.join(prefix_results, f"{lpo_lco_ldo}.html") with open(html_path, "w", encoding="utf-8") as f: - parse_layout(f=f, path_to_layout=page_layout) + _parse_layout(f=f, path_to_layout=page_layout) f.write(f"

Results for {run_id}: {lpo_lco_ldo}

\n") # Critical difference plot diff --git a/drevalpy/visualization/vioheat.py b/drevalpy/visualization/vioheat.py index 00b85d3..e10d7a3 100644 --- a/drevalpy/visualization/vioheat.py +++ b/drevalpy/visualization/vioheat.py @@ -1,4 +1,6 @@ -from typing import TextIO +"""Parent class for Violin and Heatmap plots of performance measures over CV runs.""" + +from io import TextIOWrapper import pandas as pd @@ -6,8 +8,16 @@ class VioHeat(OutPlot): + """Parent class for Violin and Heatmap plots of performance measures over CV runs.""" def __init__(self, df: pd.DataFrame, normalized_metrics=False, whole_name=False): + """ + Initialize the VioHeat class. + + :param df: evaluation results, either overall or per algorithm + :param normalized_metrics: whether the metrics are normalized + :param whole_name: whether the whole name should be displayed + """ self.df = df.sort_index() self.all_metrics = [ "R^2", @@ -37,15 +47,30 @@ def __init__(self, df: pd.DataFrame, normalized_metrics=False, whole_name=False) self.all_metrics = [metric for metric in self.all_metrics if "normalized" not in metric] def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: + """ + Draw and save the plot. + + :param out_prefix: e.g., results/my_run/heatmaps/ + :param out_suffix: e.g., algorithms_normalized + """ pass - def __draw__(self) -> None: + def _draw(self) -> None: pass @staticmethod - def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO: - plot = kwargs.get("plot") - files = kwargs.get("files") + def write_to_html(lpo_lco_ldo: str, f: TextIOWrapper, *args, **kwargs) -> TextIOWrapper: + """ + Write the Violin and Heatmap plots into the result HTML file. + + :param lpo_lco_ldo: setting, e.g., LPO + :param f: result HTML file + :param args: additional arguments + :param kwargs: additional keyword arguments, in this case, the plot type and the files + :returns: the result HTML file + """ + plot: str = kwargs.get("plot", "") + files: list[str] = kwargs.get("files", []) if plot == "Violin": nav_id = "violin" diff --git a/drevalpy/visualization/violin.py b/drevalpy/visualization/violin.py index 21267c6..46c73ca 100644 --- a/drevalpy/visualization/violin.py +++ b/drevalpy/visualization/violin.py @@ -1,11 +1,25 @@ +"""Plots a violin plot of the evaluation metrics.""" + import pandas as pd import plotly.graph_objects as go -from drevalpy.visualization.vioheat import VioHeat +from ..pipeline_function import pipeline_function +from .vioheat import VioHeat class Violin(VioHeat): + """Plots a violin plot of the evaluation metrics.""" + + @pipeline_function def __init__(self, df: pd.DataFrame, normalized_metrics=False, whole_name=False): + """ + Initialize the Violin class. + + :param df: either containing all predictions for all algorithms or all tests for one algorithm (including + robustness, randomization, … tests then) + :param normalized_metrics: whether the metrics are normalized + :param whole_name: whether the whole name should be displayed + """ super().__init__(df, normalized_metrics, whole_name) self.df["box"] = self.df["algorithm"] + "_" + self.df["rand_setting"] + "_" + self.df["LPO_LCO_LDO"] # remove columns with only NaN values @@ -13,13 +27,20 @@ def __init__(self, df: pd.DataFrame, normalized_metrics=False, whole_name=False) self.fig = go.Figure() self.occurring_metrics = [metric for metric in self.all_metrics if metric in self.df.columns] + @pipeline_function def draw_and_save(self, out_prefix: str, out_suffix: str) -> None: - self.__draw__() + """ + Draw the violin and save it to a file. + + :param out_prefix: e.g., results/my_run/violin_plots/ + :param out_suffix: e.g., algorithms_normalized + """ + self._draw() path_out = f"{out_prefix}violin_{out_suffix}.html" self.fig.write_html(path_out) - def __draw__(self) -> None: - self.__create_evaluation_violins__() + def _draw(self) -> None: + self._create_evaluation_violins() count_sum = ( self.count_r2 + self.count_pearson @@ -188,7 +209,7 @@ def __draw__(self) -> None: ) self.fig.update_layout(title_text="All Metrics", height=600, width=1100) - def __create_evaluation_violins__(self): + def _create_evaluation_violins(self): print("Drawing Violin plots ...") self.count_r2 = 0 self.count_pearson = 0 @@ -215,9 +236,9 @@ def __create_evaluation_violins__(self): self.count_mse += 1 * len(self.df["box"].unique()) elif "MAE" in metric: self.count_mae += 1 * len(self.df["box"].unique()) - self.__add_violin__(metric) + self._add_violin(metric) - def __add_violin__(self, metric): + def _add_violin(self, metric): for box in self.df["box"].unique(): tmp_df = self.df[self.df["box"] == box] if self.whole_name: diff --git a/noxfile.py b/noxfile.py index cd2def8..fc5cb95 100644 --- a/noxfile.py +++ b/noxfile.py @@ -100,7 +100,11 @@ def activate_virtualenv_in_precommit_hooks(session: Session) -> None: @session(name="pre-commit", python=python_versions) def precommit(session: Session) -> None: - """Lint using pre-commit.""" + """ + Lint using pre-commit. + + :param session: The Session object. + """ args = session.posargs or ["run", "--all-files"] session.install( "black", @@ -123,8 +127,12 @@ def precommit(session: Session) -> None: @session(python=python_versions) def safety(session: Session) -> None: - """Scan dependencies for insecure packages.""" - to_ignore = "--ignore=70612" + """ + Scan dependencies for insecure packages. + + :param session: The Session object. + """ + to_ignore = "--ignore=70612,65189" requirements = session.poetry.export_requirements() session.install("safety") session.run("safety", "check", "--full-report", f"--file={requirements}", to_ignore) @@ -132,16 +140,24 @@ def safety(session: Session) -> None: @session(python=python_versions) def mypy(session: Session) -> None: - """Type-check using mypy.""" + """ + Type-check using mypy. + + :param session: The Session object. + """ args = session.posargs or ["drevalpy", "tests", "docs/conf.py"] session.install(".") - session.install("mypy", "pytest", "types-requests", "types-attrs") + session.install("mypy", "pytest", "types-requests", "types-attrs", "types-PyYAML") session.run("mypy", *args) @session(python=python_versions) def tests(session: Session) -> None: - """Run the test suite.""" + """ + Run the test suite. + + :param session: The Session object. + """ session.install(".") session.install("coverage[toml]", "pytest", "pygments") try: @@ -153,7 +169,11 @@ def tests(session: Session) -> None: @session def coverage(session: Session) -> None: - """Produce the coverage report.""" + """ + Produce the coverage report. + + :param session: The Session object. + """ # Do not use session.posargs unless this is the only session. nsessions = len(session._runner.manifest) # type: ignore[attr-defined] has_args = session.posargs and nsessions == 1 @@ -169,7 +189,11 @@ def coverage(session: Session) -> None: @session(python=python_versions) def typeguard(session: Session) -> None: - """Runtime type checking using Typeguard.""" + """ + Runtime type checking using Typeguard. + + :param session: The Session object. + """ session.install(".") session.install("pytest", "typeguard", "pygments") session.run("pytest", f"--typeguard-packages={package}", *session.posargs) @@ -177,7 +201,11 @@ def typeguard(session: Session) -> None: @session(python=python_versions) def xdoctest(session: Session) -> None: - """Run examples with xdoctest.""" + """ + Run examples with xdoctest. + + :param session: The Session object. + """ args = session.posargs or ["all"] session.install(".") session.install("xdoctest[colors]") @@ -186,7 +214,11 @@ def xdoctest(session: Session) -> None: @session(name="docs-build", python=python_versions) def docs_build(session: Session) -> None: - """Build the documentation.""" + """ + Build the documentation. + + :param session: The Session object. + """ args = session.posargs or ["docs", "docs/_build"] session.install("-r", "./docs/requirements.txt") @@ -199,7 +231,11 @@ def docs_build(session: Session) -> None: @session(python=python_versions) def docs(session: Session) -> None: - """Build and serve the documentation with live reloading on file changes.""" + """ + Build and serve the documentation with live reloading on file changes. + + :param session: The Session object. + """ args = session.posargs or ["--open-browser", "docs", "docs/_build"] session.install(".") session.install( diff --git a/poetry.lock b/poetry.lock index f829a92..ee14f42 100644 --- a/poetry.lock +++ b/poetry.lock @@ -13,112 +13,98 @@ files = [ [[package]] name = "aiohttp" -version = "3.10.10" +version = "3.11.6" description = "Async http client/server framework (asyncio)" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "aiohttp-3.10.10-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:be7443669ae9c016b71f402e43208e13ddf00912f47f623ee5994e12fc7d4b3f"}, - {file = "aiohttp-3.10.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b06b7843929e41a94ea09eb1ce3927865387e3e23ebe108e0d0d09b08d25be9"}, - {file = "aiohttp-3.10.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:333cf6cf8e65f6a1e06e9eb3e643a0c515bb850d470902274239fea02033e9a8"}, - {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:274cfa632350225ce3fdeb318c23b4a10ec25c0e2c880eff951a3842cf358ac1"}, - {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9e5e4a85bdb56d224f412d9c98ae4cbd032cc4f3161818f692cd81766eee65a"}, - {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b606353da03edcc71130b52388d25f9a30a126e04caef1fd637e31683033abd"}, - {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab5a5a0c7a7991d90446a198689c0535be89bbd6b410a1f9a66688f0880ec026"}, - {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:578a4b875af3e0daaf1ac6fa983d93e0bbfec3ead753b6d6f33d467100cdc67b"}, - {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8105fd8a890df77b76dd3054cddf01a879fc13e8af576805d667e0fa0224c35d"}, - {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3bcd391d083f636c06a68715e69467963d1f9600f85ef556ea82e9ef25f043f7"}, - {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fbc6264158392bad9df19537e872d476f7c57adf718944cc1e4495cbabf38e2a"}, - {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e48d5021a84d341bcaf95c8460b152cfbad770d28e5fe14a768988c461b821bc"}, - {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2609e9ab08474702cc67b7702dbb8a80e392c54613ebe80db7e8dbdb79837c68"}, - {file = "aiohttp-3.10.10-cp310-cp310-win32.whl", hash = "sha256:84afcdea18eda514c25bc68b9af2a2b1adea7c08899175a51fe7c4fb6d551257"}, - {file = "aiohttp-3.10.10-cp310-cp310-win_amd64.whl", hash = "sha256:9c72109213eb9d3874f7ac8c0c5fa90e072d678e117d9061c06e30c85b4cf0e6"}, - {file = "aiohttp-3.10.10-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c30a0eafc89d28e7f959281b58198a9fa5e99405f716c0289b7892ca345fe45f"}, - {file = "aiohttp-3.10.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:258c5dd01afc10015866114e210fb7365f0d02d9d059c3c3415382ab633fcbcb"}, - {file = "aiohttp-3.10.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:15ecd889a709b0080f02721255b3f80bb261c2293d3c748151274dfea93ac871"}, - {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3935f82f6f4a3820270842e90456ebad3af15810cf65932bd24da4463bc0a4c"}, - {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:413251f6fcf552a33c981c4709a6bba37b12710982fec8e558ae944bfb2abd38"}, - {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1720b4f14c78a3089562b8875b53e36b51c97c51adc53325a69b79b4b48ebcb"}, - {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:679abe5d3858b33c2cf74faec299fda60ea9de62916e8b67e625d65bf069a3b7"}, - {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79019094f87c9fb44f8d769e41dbb664d6e8fcfd62f665ccce36762deaa0e911"}, - {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe2fb38c2ed905a2582948e2de560675e9dfbee94c6d5ccdb1301c6d0a5bf092"}, - {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a3f00003de6eba42d6e94fabb4125600d6e484846dbf90ea8e48a800430cc142"}, - {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:1bbb122c557a16fafc10354b9d99ebf2f2808a660d78202f10ba9d50786384b9"}, - {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:30ca7c3b94708a9d7ae76ff281b2f47d8eaf2579cd05971b5dc681db8caac6e1"}, - {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:df9270660711670e68803107d55c2b5949c2e0f2e4896da176e1ecfc068b974a"}, - {file = "aiohttp-3.10.10-cp311-cp311-win32.whl", hash = "sha256:aafc8ee9b742ce75044ae9a4d3e60e3d918d15a4c2e08a6c3c3e38fa59b92d94"}, - {file = "aiohttp-3.10.10-cp311-cp311-win_amd64.whl", hash = "sha256:362f641f9071e5f3ee6f8e7d37d5ed0d95aae656adf4ef578313ee585b585959"}, - {file = "aiohttp-3.10.10-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9294bbb581f92770e6ed5c19559e1e99255e4ca604a22c5c6397b2f9dd3ee42c"}, - {file = "aiohttp-3.10.10-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a8fa23fe62c436ccf23ff930149c047f060c7126eae3ccea005f0483f27b2e28"}, - {file = "aiohttp-3.10.10-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5c6a5b8c7926ba5d8545c7dd22961a107526562da31a7a32fa2456baf040939f"}, - {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:007ec22fbc573e5eb2fb7dec4198ef8f6bf2fe4ce20020798b2eb5d0abda6138"}, - {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9627cc1a10c8c409b5822a92d57a77f383b554463d1884008e051c32ab1b3742"}, - {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:50edbcad60d8f0e3eccc68da67f37268b5144ecc34d59f27a02f9611c1d4eec7"}, - {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a45d85cf20b5e0d0aa5a8dca27cce8eddef3292bc29d72dcad1641f4ed50aa16"}, - {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b00807e2605f16e1e198f33a53ce3c4523114059b0c09c337209ae55e3823a8"}, - {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f2d4324a98062be0525d16f768a03e0bbb3b9fe301ceee99611dc9a7953124e6"}, - {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:438cd072f75bb6612f2aca29f8bd7cdf6e35e8f160bc312e49fbecab77c99e3a"}, - {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:baa42524a82f75303f714108fea528ccacf0386af429b69fff141ffef1c534f9"}, - {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a7d8d14fe962153fc681f6366bdec33d4356f98a3e3567782aac1b6e0e40109a"}, - {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c1277cd707c465cd09572a774559a3cc7c7a28802eb3a2a9472588f062097205"}, - {file = "aiohttp-3.10.10-cp312-cp312-win32.whl", hash = "sha256:59bb3c54aa420521dc4ce3cc2c3fe2ad82adf7b09403fa1f48ae45c0cbde6628"}, - {file = "aiohttp-3.10.10-cp312-cp312-win_amd64.whl", hash = "sha256:0e1b370d8007c4ae31ee6db7f9a2fe801a42b146cec80a86766e7ad5c4a259cf"}, - {file = "aiohttp-3.10.10-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ad7593bb24b2ab09e65e8a1d385606f0f47c65b5a2ae6c551db67d6653e78c28"}, - {file = "aiohttp-3.10.10-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1eb89d3d29adaf533588f209768a9c02e44e4baf832b08118749c5fad191781d"}, - {file = "aiohttp-3.10.10-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3fe407bf93533a6fa82dece0e74dbcaaf5d684e5a51862887f9eaebe6372cd79"}, - {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50aed5155f819873d23520919e16703fc8925e509abbb1a1491b0087d1cd969e"}, - {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f05e9727ce409358baa615dbeb9b969db94324a79b5a5cea45d39bdb01d82e6"}, - {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dffb610a30d643983aeb185ce134f97f290f8935f0abccdd32c77bed9388b42"}, - {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa6658732517ddabe22c9036479eabce6036655ba87a0224c612e1ae6af2087e"}, - {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:741a46d58677d8c733175d7e5aa618d277cd9d880301a380fd296975a9cdd7bc"}, - {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e00e3505cd80440f6c98c6d69269dcc2a119f86ad0a9fd70bccc59504bebd68a"}, - {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ffe595f10566f8276b76dc3a11ae4bb7eba1aac8ddd75811736a15b0d5311414"}, - {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:bdfcf6443637c148c4e1a20c48c566aa694fa5e288d34b20fcdc58507882fed3"}, - {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:d183cf9c797a5291e8301790ed6d053480ed94070637bfaad914dd38b0981f67"}, - {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:77abf6665ae54000b98b3c742bc6ea1d1fb31c394bcabf8b5d2c1ac3ebfe7f3b"}, - {file = "aiohttp-3.10.10-cp313-cp313-win32.whl", hash = "sha256:4470c73c12cd9109db8277287d11f9dd98f77fc54155fc71a7738a83ffcc8ea8"}, - {file = "aiohttp-3.10.10-cp313-cp313-win_amd64.whl", hash = "sha256:486f7aabfa292719a2753c016cc3a8f8172965cabb3ea2e7f7436c7f5a22a151"}, - {file = "aiohttp-3.10.10-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:1b66ccafef7336a1e1f0e389901f60c1d920102315a56df85e49552308fc0486"}, - {file = "aiohttp-3.10.10-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:acd48d5b80ee80f9432a165c0ac8cbf9253eaddb6113269a5e18699b33958dbb"}, - {file = "aiohttp-3.10.10-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3455522392fb15ff549d92fbf4b73b559d5e43dc522588f7eb3e54c3f38beee7"}, - {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45c3b868724137f713a38376fef8120c166d1eadd50da1855c112fe97954aed8"}, - {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:da1dee8948d2137bb51fbb8a53cce6b1bcc86003c6b42565f008438b806cccd8"}, - {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c5ce2ce7c997e1971b7184ee37deb6ea9922ef5163c6ee5aa3c274b05f9e12fa"}, - {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28529e08fde6f12eba8677f5a8608500ed33c086f974de68cc65ab218713a59d"}, - {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7db54c7914cc99d901d93a34704833568d86c20925b2762f9fa779f9cd2e70f"}, - {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:03a42ac7895406220124c88911ebee31ba8b2d24c98507f4a8bf826b2937c7f2"}, - {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:7e338c0523d024fad378b376a79faff37fafb3c001872a618cde1d322400a572"}, - {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:038f514fe39e235e9fef6717fbf944057bfa24f9b3db9ee551a7ecf584b5b480"}, - {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:64f6c17757251e2b8d885d728b6433d9d970573586a78b78ba8929b0f41d045a"}, - {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:93429602396f3383a797a2a70e5f1de5df8e35535d7806c9f91df06f297e109b"}, - {file = "aiohttp-3.10.10-cp38-cp38-win32.whl", hash = "sha256:c823bc3971c44ab93e611ab1a46b1eafeae474c0c844aff4b7474287b75fe49c"}, - {file = "aiohttp-3.10.10-cp38-cp38-win_amd64.whl", hash = "sha256:54ca74df1be3c7ca1cf7f4c971c79c2daf48d9aa65dea1a662ae18926f5bc8ce"}, - {file = "aiohttp-3.10.10-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:01948b1d570f83ee7bbf5a60ea2375a89dfb09fd419170e7f5af029510033d24"}, - {file = "aiohttp-3.10.10-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9fc1500fd2a952c5c8e3b29aaf7e3cc6e27e9cfc0a8819b3bce48cc1b849e4cc"}, - {file = "aiohttp-3.10.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f614ab0c76397661b90b6851a030004dac502e48260ea10f2441abd2207fbcc7"}, - {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00819de9e45d42584bed046314c40ea7e9aea95411b38971082cad449392b08c"}, - {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05646ebe6b94cc93407b3bf34b9eb26c20722384d068eb7339de802154d61bc5"}, - {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:998f3bd3cfc95e9424a6acd7840cbdd39e45bc09ef87533c006f94ac47296090"}, - {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9010c31cd6fa59438da4e58a7f19e4753f7f264300cd152e7f90d4602449762"}, - {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ea7ffc6d6d6f8a11e6f40091a1040995cdff02cfc9ba4c2f30a516cb2633554"}, - {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ef9c33cc5cbca35808f6c74be11eb7f5f6b14d2311be84a15b594bd3e58b5527"}, - {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ce0cdc074d540265bfeb31336e678b4e37316849d13b308607efa527e981f5c2"}, - {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:597a079284b7ee65ee102bc3a6ea226a37d2b96d0418cc9047490f231dc09fe8"}, - {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:7789050d9e5d0c309c706953e5e8876e38662d57d45f936902e176d19f1c58ab"}, - {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e7f8b04d83483577fd9200461b057c9f14ced334dcb053090cea1da9c8321a91"}, - {file = "aiohttp-3.10.10-cp39-cp39-win32.whl", hash = "sha256:c02a30b904282777d872266b87b20ed8cc0d1501855e27f831320f471d54d983"}, - {file = "aiohttp-3.10.10-cp39-cp39-win_amd64.whl", hash = "sha256:edfe3341033a6b53a5c522c802deb2079eee5cbfbb0af032a55064bd65c73a23"}, - {file = "aiohttp-3.10.10.tar.gz", hash = "sha256:0631dd7c9f0822cc61c88586ca76d5b5ada26538097d0f1df510b082bad3411a"}, + {file = "aiohttp-3.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7510b3ca2275691875ddf072a5b6cd129278d11fe09301add7d292fc8d3432de"}, + {file = "aiohttp-3.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfab0d2c3380c588fc925168533edb21d3448ad76c3eadc360ff963019161724"}, + {file = "aiohttp-3.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf02dba0f342f3a8228f43fae256aafc21c4bc85bffcf537ce4582e2b1565188"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92daedf7221392e7a7984915ca1b0481a94c71457c2f82548414a41d65555e70"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2274a7876e03429e3218589a6d3611a194bdce08c3f1e19962e23370b47c0313"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8a2e1eae2d2f62f3660a1591e16e543b2498358593a73b193006fb89ee37abc6"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:978ec3fb0a42efcd98aae608f58c6cfcececaf0a50b4e86ee3ea0d0a574ab73b"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a51f87b27d9219ed4e202ed8d6f1bb96f829e5eeff18db0d52f592af6de6bdbf"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:04d1a02a669d26e833c8099992c17f557e3b2fdb7960a0c455d7b1cbcb05121d"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3679d5fcbc7f1ab518ab4993f12f80afb63933f6afb21b9b272793d398303b98"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a4b24e03d04893b5c8ec9cd5f2f11dc9c8695c4e2416d2ac2ce6c782e4e5ffa5"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:d9abdfd35ecff1c95f270b7606819a0e2de9e06fa86b15d9080de26594cf4c23"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8b5c3e7928a0ad80887a5eba1c1da1830512ddfe7394d805badda45c03db3109"}, + {file = "aiohttp-3.11.6-cp310-cp310-win32.whl", hash = "sha256:913dd9e9378f3c38aeb5c4fb2b8383d6490bc43f3b427ae79f2870651ae08f22"}, + {file = "aiohttp-3.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:4ac26d482c2000c3a59bf757a77adc972828c9d4177b4bd432a46ba682ca7271"}, + {file = "aiohttp-3.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:26ac4c960ea8debf557357a172b3ef201f2236a462aefa1bc17683a75483e518"}, + {file = "aiohttp-3.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8b1f13ebc99fb98c7c13057b748f05224ccc36d17dee18136c695ef23faaf4ff"}, + {file = "aiohttp-3.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4679f1a47516189fab1774f7e45a6c7cac916224c91f5f94676f18d0b64ab134"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74491fdb3d140ff561ea2128cb7af9ba0a360067ee91074af899c9614f88a18f"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f51e1a90412d387e62aa2d243998c5eddb71373b199d811e6ed862a9f34f9758"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:72ab89510511c3bb703d0bb5504787b11e0ed8be928ed2a7cf1cda9280628430"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6681c9e046d99646e8059266688374a063da85b2e4c0ebfa078cda414905d080"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a17f8a6d3ab72cbbd137e494d1a23fbd3ea973db39587941f32901bb3c5c350"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:867affc7612a314b95f74d93aac550ce0909bc6f0b6c658cc856890f4d326542"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:00d894ebd609d5a423acef885bd61e7f6a972153f99c5b3ea45fc01fe909196c"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:614c87be9d0d64477d1e4b663bdc5d1534fc0a7ebd23fb08347ab9fd5fe20fd7"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:533ed46cf772f28f3bffae81c0573d916a64dee590b5dfaa3f3d11491da05b95"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:589884cfbc09813afb1454816b45677e983442e146183143f988f7f5a040791a"}, + {file = "aiohttp-3.11.6-cp311-cp311-win32.whl", hash = "sha256:1da63633ba921669eec3d7e080459d4ceb663752b3dafb2f31f18edd248d2170"}, + {file = "aiohttp-3.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:d778ddda09622e7d83095cc8051698a0084c155a1474bfee9bac27d8613dbc31"}, + {file = "aiohttp-3.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:943a952df105a5305257984e7a1f5c2d0fd8564ff33647693c4d07eb2315446d"}, + {file = "aiohttp-3.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d24ec28b7658970a1f1d98608d67f88376c7e503d9d45ff2ba1949c09f2b358c"}, + {file = "aiohttp-3.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6720e809a660fdb9bec7c168c582e11cfedce339af0a5ca847a5d5b588dce826"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4252d30da0ada6e6841b325869c7ef5104b488e8dd57ec439892abbb8d7b3615"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f65f43ff01b238aa0b5c47962c83830a49577efe31bd37c1400c3d11d8a32835"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc5933f6c9b26404444d36babb650664f984b8e5fa0694540e7b7315d11a4ff"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bf546ba0c029dfffc718c4b67748687fd4f341b07b7c8f1719d6a3a46164798"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c351d05bbeae30c088009c0bb3b17dda04fd854f91cc6196c448349cc98f71c3"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:10499079b063576fad1597898de3f9c0a2ce617c19cc7cd6b62fdcff6b408bf7"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:442ee82eda47dd59798d6866ce020fb8d02ea31ac9ac82b3d719ed349e6a9d52"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:86fce9127bc317119b34786d9e9ae8af4508a103158828a535f56d201da6ab19"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:973d26a5537ce5d050302eb3cd876457451745b1da0624cbb483217970e12567"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:532b8f038a4e001137d3600cea5d3439d1881df41bdf44d0f9651264d562fdf0"}, + {file = "aiohttp-3.11.6-cp312-cp312-win32.whl", hash = "sha256:4863c59f748dbe147da82b389931f2a676aebc9d3419813ed5ca32d057c9cb32"}, + {file = "aiohttp-3.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:5d7f481f82c18ac1f7986e31ba6eea9be8b2e2c86f1ef035b6866179b6c5dd68"}, + {file = "aiohttp-3.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:40f502350496ba4c6820816d3164f8a0297b9aa4e95d910da31beb189866a9df"}, + {file = "aiohttp-3.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9072669b0bffb40f1f6977d0b5e8a296edc964f9cefca3a18e68649c214d0ce3"}, + {file = "aiohttp-3.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:518160ecf4e6ffd61715bc9173da0925fcce44ae6c7ca3d3f098fe42585370fb"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f69cc1b45115ac44795b63529aa5caa9674be057f11271f65474127b24fc1ce6"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c6be90a6beced41653bda34afc891617c6d9e8276eef9c183f029f851f0a3c3d"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00c22fe2486308770d22ef86242101d7b0f1e1093ce178f2358f860e5149a551"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2607ebb783e3aeefa017ec8f34b506a727e6b6ab2c4b037d65f0bc7151f4430a"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5f761d6819870c2a8537f75f3e2fc610b163150cefa01f9f623945840f601b2c"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e44d1bc6c88f5234115011842219ba27698a5f2deee245c963b180080572aaa2"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7e0cb6a1b1f499cb2aa0bab1c9f2169ad6913c735b7447e058e0c29c9e51c0b5"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a76b4d4ca34254dca066acff2120811e2a8183997c135fcafa558280f2cc53f3"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:69051c1e45fb18c0ae4d39a075532ff0b015982e7997f19eb5932eb4a3e05c17"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:aff2ed18274c0bfe0c1d772781c87d5ca97ae50f439729007cec9644ee9b15fe"}, + {file = "aiohttp-3.11.6-cp313-cp313-win32.whl", hash = "sha256:2fbea25f2d44df809a46414a8baafa5f179d9dda7e60717f07bded56300589b3"}, + {file = "aiohttp-3.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:f77bc29a465c0f9f6573d1abe656d385fa673e34efe615bd4acc50899280ee47"}, + {file = "aiohttp-3.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:de6123b298d17bca9e53581f50a275b36e10d98e8137eb743ce69ee766dbdfe9"}, + {file = "aiohttp-3.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a10200f705f4fff00e148b7f41e5d1d929c7cd4ac523c659171a0ea8284cd6fb"}, + {file = "aiohttp-3.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b7776ef6901b54dd557128d96c71e412eec0c39ebc07567e405ac98737995aad"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e5c2a55583cd91936baf73d223807bb93ace6eb1fe54424782690f2707162ab"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b032bd6cf7422583bf44f233f4a1489fee53c6d35920123a208adc54e2aba41e"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04fe2d99acbc5cf606f75d7347bf3a027c24c27bc052d470fb156f4cfcea5739"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84a79c366375c2250934d1238abe5d5ea7754c823a1c7df0c52bf0a2bfded6a9"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c33cbbe97dc94a34d1295a7bb68f82727bcbff2b284f73ae7e58ecc05903da97"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:19e4fb9ac727834b003338dcdd27dcfe0de4fb44082b01b34ed0ab67c3469fc9"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a97f6b2afbe1d27220c0c14ea978e09fb4868f462ef3d56d810d206bd2e057a2"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c3f7afeea03a9bc49be6053dfd30809cd442cc12627d6ca08babd1c1f9e04ccf"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:0d10967600ce5bb69ddcb3e18d84b278efb5199d8b24c3c71a4959c2f08acfd0"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:60f2f631b9fe7aa321fa0f0ff3f5d8b9f7f9b72afd4eecef61c33cf1cfea5d58"}, + {file = "aiohttp-3.11.6-cp39-cp39-win32.whl", hash = "sha256:4d2b75333deb5c5f61bac5a48bba3dbc142eebbd3947d98788b6ef9cc48628ae"}, + {file = "aiohttp-3.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:8908c235421972a2e02abcef87d16084aabfe825d14cc9a1debd609b3cfffbea"}, + {file = "aiohttp-3.11.6.tar.gz", hash = "sha256:fd9f55c1b51ae1c20a1afe7216a64a88d38afee063baa23c7fce03757023c999"}, ] [package.dependencies] aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" -async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} +async-timeout = {version = ">=4.0,<6.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" -yarl = ">=1.12.0,<2.0" +propcache = ">=0.2.0" +yarl = ">=1.17.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] @@ -186,13 +172,13 @@ test = ["coverage", "mypy", "pexpect", "ruff", "wheel"] [[package]] name = "async-timeout" -version = "4.0.3" +version = "5.0.1" description = "Timeout context manager for asyncio programs" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, - {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, + {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, + {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, ] [[package]] @@ -671,17 +657,17 @@ flake8 = ">=5.0.0" [[package]] name = "flake8-bugbear" -version = "24.8.19" +version = "24.10.31" description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." optional = false python-versions = ">=3.8.1" files = [ - {file = "flake8_bugbear-24.8.19-py3-none-any.whl", hash = "sha256:25bc3867f7338ee3b3e0916bf8b8a0b743f53a9a5175782ddc4325ed4f386b89"}, - {file = "flake8_bugbear-24.8.19.tar.gz", hash = "sha256:9b77627eceda28c51c27af94560a72b5b2c97c016651bdce45d8f56c180d2d32"}, + {file = "flake8_bugbear-24.10.31-py3-none-any.whl", hash = "sha256:cccf786ccf9b2e1052b1ecfa80fb8f80832d0880425bcbd4cd45d3c8128c2683"}, + {file = "flake8_bugbear-24.10.31.tar.gz", hash = "sha256:435b531c72b27f8eff8d990419697956b9fd25c6463c5ba98b3991591de439db"}, ] [package.dependencies] -attrs = ">=19.2.0" +attrs = ">=22.2.0" flake8 = ">=6.0.0" [package.extras] @@ -734,59 +720,61 @@ files = [ [[package]] name = "fonttools" -version = "4.54.1" +version = "4.55.0" description = "Tools to manipulate font files" optional = false python-versions = ">=3.8" files = [ - {file = "fonttools-4.54.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7ed7ee041ff7b34cc62f07545e55e1468808691dddfd315d51dd82a6b37ddef2"}, - {file = "fonttools-4.54.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41bb0b250c8132b2fcac148e2e9198e62ff06f3cc472065dff839327945c5882"}, - {file = "fonttools-4.54.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7965af9b67dd546e52afcf2e38641b5be956d68c425bef2158e95af11d229f10"}, - {file = "fonttools-4.54.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:278913a168f90d53378c20c23b80f4e599dca62fbffae4cc620c8eed476b723e"}, - {file = "fonttools-4.54.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0e88e3018ac809b9662615072dcd6b84dca4c2d991c6d66e1970a112503bba7e"}, - {file = "fonttools-4.54.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4aa4817f0031206e637d1e685251ac61be64d1adef111060df84fdcbc6ab6c44"}, - {file = "fonttools-4.54.1-cp310-cp310-win32.whl", hash = "sha256:7e3b7d44e18c085fd8c16dcc6f1ad6c61b71ff463636fcb13df7b1b818bd0c02"}, - {file = "fonttools-4.54.1-cp310-cp310-win_amd64.whl", hash = "sha256:dd9cc95b8d6e27d01e1e1f1fae8559ef3c02c76317da650a19047f249acd519d"}, - {file = "fonttools-4.54.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5419771b64248484299fa77689d4f3aeed643ea6630b2ea750eeab219588ba20"}, - {file = "fonttools-4.54.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:301540e89cf4ce89d462eb23a89464fef50915255ece765d10eee8b2bf9d75b2"}, - {file = "fonttools-4.54.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76ae5091547e74e7efecc3cbf8e75200bc92daaeb88e5433c5e3e95ea8ce5aa7"}, - {file = "fonttools-4.54.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82834962b3d7c5ca98cb56001c33cf20eb110ecf442725dc5fdf36d16ed1ab07"}, - {file = "fonttools-4.54.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d26732ae002cc3d2ecab04897bb02ae3f11f06dd7575d1df46acd2f7c012a8d8"}, - {file = "fonttools-4.54.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:58974b4987b2a71ee08ade1e7f47f410c367cdfc5a94fabd599c88165f56213a"}, - {file = "fonttools-4.54.1-cp311-cp311-win32.whl", hash = "sha256:ab774fa225238986218a463f3fe151e04d8c25d7de09df7f0f5fce27b1243dbc"}, - {file = "fonttools-4.54.1-cp311-cp311-win_amd64.whl", hash = "sha256:07e005dc454eee1cc60105d6a29593459a06321c21897f769a281ff2d08939f6"}, - {file = "fonttools-4.54.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:54471032f7cb5fca694b5f1a0aaeba4af6e10ae989df408e0216f7fd6cdc405d"}, - {file = "fonttools-4.54.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fa92cb248e573daab8d032919623cc309c005086d743afb014c836636166f08"}, - {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a911591200114969befa7f2cb74ac148bce5a91df5645443371aba6d222e263"}, - {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93d458c8a6a354dc8b48fc78d66d2a8a90b941f7fec30e94c7ad9982b1fa6bab"}, - {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5eb2474a7c5be8a5331146758debb2669bf5635c021aee00fd7c353558fc659d"}, - {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c9c563351ddc230725c4bdf7d9e1e92cbe6ae8553942bd1fb2b2ff0884e8b714"}, - {file = "fonttools-4.54.1-cp312-cp312-win32.whl", hash = "sha256:fdb062893fd6d47b527d39346e0c5578b7957dcea6d6a3b6794569370013d9ac"}, - {file = "fonttools-4.54.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4564cf40cebcb53f3dc825e85910bf54835e8a8b6880d59e5159f0f325e637e"}, - {file = "fonttools-4.54.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6e37561751b017cf5c40fce0d90fd9e8274716de327ec4ffb0df957160be3bff"}, - {file = "fonttools-4.54.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:357cacb988a18aace66e5e55fe1247f2ee706e01debc4b1a20d77400354cddeb"}, - {file = "fonttools-4.54.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e953cc0bddc2beaf3a3c3b5dd9ab7554677da72dfaf46951e193c9653e515a"}, - {file = "fonttools-4.54.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:58d29b9a294573d8319f16f2f79e42428ba9b6480442fa1836e4eb89c4d9d61c"}, - {file = "fonttools-4.54.1-cp313-cp313-win32.whl", hash = "sha256:9ef1b167e22709b46bf8168368b7b5d3efeaaa746c6d39661c1b4405b6352e58"}, - {file = "fonttools-4.54.1-cp313-cp313-win_amd64.whl", hash = "sha256:262705b1663f18c04250bd1242b0515d3bbae177bee7752be67c979b7d47f43d"}, - {file = "fonttools-4.54.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ed2f80ca07025551636c555dec2b755dd005e2ea8fbeb99fc5cdff319b70b23b"}, - {file = "fonttools-4.54.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9dc080e5a1c3b2656caff2ac2633d009b3a9ff7b5e93d0452f40cd76d3da3b3c"}, - {file = "fonttools-4.54.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d152d1be65652fc65e695e5619e0aa0982295a95a9b29b52b85775243c06556"}, - {file = "fonttools-4.54.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8583e563df41fdecef31b793b4dd3af8a9caa03397be648945ad32717a92885b"}, - {file = "fonttools-4.54.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0d1d353ef198c422515a3e974a1e8d5b304cd54a4c2eebcae708e37cd9eeffb1"}, - {file = "fonttools-4.54.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fda582236fee135d4daeca056c8c88ec5f6f6d88a004a79b84a02547c8f57386"}, - {file = "fonttools-4.54.1-cp38-cp38-win32.whl", hash = "sha256:e7d82b9e56716ed32574ee106cabca80992e6bbdcf25a88d97d21f73a0aae664"}, - {file = "fonttools-4.54.1-cp38-cp38-win_amd64.whl", hash = "sha256:ada215fd079e23e060157aab12eba0d66704316547f334eee9ff26f8c0d7b8ab"}, - {file = "fonttools-4.54.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f5b8a096e649768c2f4233f947cf9737f8dbf8728b90e2771e2497c6e3d21d13"}, - {file = "fonttools-4.54.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4e10d2e0a12e18f4e2dd031e1bf7c3d7017be5c8dbe524d07706179f355c5dac"}, - {file = "fonttools-4.54.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31c32d7d4b0958600eac75eaf524b7b7cb68d3a8c196635252b7a2c30d80e986"}, - {file = "fonttools-4.54.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c39287f5c8f4a0c5a55daf9eaf9ccd223ea59eed3f6d467133cc727d7b943a55"}, - {file = "fonttools-4.54.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a7a310c6e0471602fe3bf8efaf193d396ea561486aeaa7adc1f132e02d30c4b9"}, - {file = "fonttools-4.54.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d3b659d1029946f4ff9b6183984578041b520ce0f8fb7078bb37ec7445806b33"}, - {file = "fonttools-4.54.1-cp39-cp39-win32.whl", hash = "sha256:e96bc94c8cda58f577277d4a71f51c8e2129b8b36fd05adece6320dd3d57de8a"}, - {file = "fonttools-4.54.1-cp39-cp39-win_amd64.whl", hash = "sha256:e8a4b261c1ef91e7188a30571be6ad98d1c6d9fa2427244c545e2fa0a2494dd7"}, - {file = "fonttools-4.54.1-py3-none-any.whl", hash = "sha256:37cddd62d83dc4f72f7c3f3c2bcf2697e89a30efb152079896544a93907733bd"}, - {file = "fonttools-4.54.1.tar.gz", hash = "sha256:957f669d4922f92c171ba01bef7f29410668db09f6c02111e22b2bce446f3285"}, + {file = "fonttools-4.55.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:51c029d4c0608a21a3d3d169dfc3fb776fde38f00b35ca11fdab63ba10a16f61"}, + {file = "fonttools-4.55.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bca35b4e411362feab28e576ea10f11268b1aeed883b9f22ed05675b1e06ac69"}, + {file = "fonttools-4.55.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ce4ba6981e10f7e0ccff6348e9775ce25ffadbee70c9fd1a3737e3e9f5fa74f"}, + {file = "fonttools-4.55.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31d00f9852a6051dac23294a4cf2df80ced85d1d173a61ba90a3d8f5abc63c60"}, + {file = "fonttools-4.55.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e198e494ca6e11f254bac37a680473a311a88cd40e58f9cc4dc4911dfb686ec6"}, + {file = "fonttools-4.55.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7208856f61770895e79732e1dcbe49d77bd5783adf73ae35f87fcc267df9db81"}, + {file = "fonttools-4.55.0-cp310-cp310-win32.whl", hash = "sha256:e7e6a352ff9e46e8ef8a3b1fe2c4478f8a553e1b5a479f2e899f9dc5f2055880"}, + {file = "fonttools-4.55.0-cp310-cp310-win_amd64.whl", hash = "sha256:636caaeefe586d7c84b5ee0734c1a5ab2dae619dc21c5cf336f304ddb8f6001b"}, + {file = "fonttools-4.55.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fa34aa175c91477485c44ddfbb51827d470011e558dfd5c7309eb31bef19ec51"}, + {file = "fonttools-4.55.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:37dbb3fdc2ef7302d3199fb12468481cbebaee849e4b04bc55b77c24e3c49189"}, + {file = "fonttools-4.55.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5263d8e7ef3c0ae87fbce7f3ec2f546dc898d44a337e95695af2cd5ea21a967"}, + {file = "fonttools-4.55.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f307f6b5bf9e86891213b293e538d292cd1677e06d9faaa4bf9c086ad5f132f6"}, + {file = "fonttools-4.55.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f0a4b52238e7b54f998d6a56b46a2c56b59c74d4f8a6747fb9d4042190f37cd3"}, + {file = "fonttools-4.55.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3e569711464f777a5d4ef522e781dc33f8095ab5efd7548958b36079a9f2f88c"}, + {file = "fonttools-4.55.0-cp311-cp311-win32.whl", hash = "sha256:2b3ab90ec0f7b76c983950ac601b58949f47aca14c3f21eed858b38d7ec42b05"}, + {file = "fonttools-4.55.0-cp311-cp311-win_amd64.whl", hash = "sha256:aa046f6a63bb2ad521004b2769095d4c9480c02c1efa7d7796b37826508980b6"}, + {file = "fonttools-4.55.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:838d2d8870f84fc785528a692e724f2379d5abd3fc9dad4d32f91cf99b41e4a7"}, + {file = "fonttools-4.55.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f46b863d74bab7bb0d395f3b68d3f52a03444964e67ce5c43ce43a75efce9246"}, + {file = "fonttools-4.55.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33b52a9cfe4e658e21b1f669f7309b4067910321757fec53802ca8f6eae96a5a"}, + {file = "fonttools-4.55.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:732a9a63d6ea4a81b1b25a1f2e5e143761b40c2e1b79bb2b68e4893f45139a40"}, + {file = "fonttools-4.55.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7dd91ac3fcb4c491bb4763b820bcab6c41c784111c24172616f02f4bc227c17d"}, + {file = "fonttools-4.55.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1f0e115281a32ff532118aa851ef497a1b7cda617f4621c1cdf81ace3e36fb0c"}, + {file = "fonttools-4.55.0-cp312-cp312-win32.whl", hash = "sha256:6c99b5205844f48a05cb58d4a8110a44d3038c67ed1d79eb733c4953c628b0f6"}, + {file = "fonttools-4.55.0-cp312-cp312-win_amd64.whl", hash = "sha256:f8c8c76037d05652510ae45be1cd8fb5dd2fd9afec92a25374ac82255993d57c"}, + {file = "fonttools-4.55.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8118dc571921dc9e4b288d9cb423ceaf886d195a2e5329cc427df82bba872cd9"}, + {file = "fonttools-4.55.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:01124f2ca6c29fad4132d930da69158d3f49b2350e4a779e1efbe0e82bd63f6c"}, + {file = "fonttools-4.55.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81ffd58d2691f11f7c8438796e9f21c374828805d33e83ff4b76e4635633674c"}, + {file = "fonttools-4.55.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5435e5f1eb893c35c2bc2b9cd3c9596b0fcb0a59e7a14121562986dd4c47b8dd"}, + {file = "fonttools-4.55.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d12081729280c39d001edd0f4f06d696014c26e6e9a0a55488fabc37c28945e4"}, + {file = "fonttools-4.55.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a7ad1f1b98ab6cb927ab924a38a8649f1ffd7525c75fe5b594f5dab17af70e18"}, + {file = "fonttools-4.55.0-cp313-cp313-win32.whl", hash = "sha256:abe62987c37630dca69a104266277216de1023cf570c1643bb3a19a9509e7a1b"}, + {file = "fonttools-4.55.0-cp313-cp313-win_amd64.whl", hash = "sha256:2863555ba90b573e4201feaf87a7e71ca3b97c05aa4d63548a4b69ea16c9e998"}, + {file = "fonttools-4.55.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:00f7cf55ad58a57ba421b6a40945b85ac7cc73094fb4949c41171d3619a3a47e"}, + {file = "fonttools-4.55.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f27526042efd6f67bfb0cc2f1610fa20364396f8b1fc5edb9f45bb815fb090b2"}, + {file = "fonttools-4.55.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e67974326af6a8879dc2a4ec63ab2910a1c1a9680ccd63e4a690950fceddbe"}, + {file = "fonttools-4.55.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61dc0a13451143c5e987dec5254d9d428f3c2789a549a7cf4f815b63b310c1cc"}, + {file = "fonttools-4.55.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:b2e526b325a903868c62155a6a7e24df53f6ce4c5c3160214d8fe1be2c41b478"}, + {file = "fonttools-4.55.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:b7ef9068a1297714e6fefe5932c33b058aa1d45a2b8be32a4c6dee602ae22b5c"}, + {file = "fonttools-4.55.0-cp38-cp38-win32.whl", hash = "sha256:55718e8071be35dff098976bc249fc243b58efa263768c611be17fe55975d40a"}, + {file = "fonttools-4.55.0-cp38-cp38-win_amd64.whl", hash = "sha256:553bd4f8cc327f310c20158e345e8174c8eed49937fb047a8bda51daf2c353c8"}, + {file = "fonttools-4.55.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f901cef813f7c318b77d1c5c14cf7403bae5cb977cede023e22ba4316f0a8f6"}, + {file = "fonttools-4.55.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8c9679fc0dd7e8a5351d321d8d29a498255e69387590a86b596a45659a39eb0d"}, + {file = "fonttools-4.55.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd2820a8b632f3307ebb0bf57948511c2208e34a4939cf978333bc0a3f11f838"}, + {file = "fonttools-4.55.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23bbbb49bec613a32ed1b43df0f2b172313cee690c2509f1af8fdedcf0a17438"}, + {file = "fonttools-4.55.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a656652e1f5d55b9728937a7e7d509b73d23109cddd4e89ee4f49bde03b736c6"}, + {file = "fonttools-4.55.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f50a1f455902208486fbca47ce33054208a4e437b38da49d6721ce2fef732fcf"}, + {file = "fonttools-4.55.0-cp39-cp39-win32.whl", hash = "sha256:161d1ac54c73d82a3cded44202d0218ab007fde8cf194a23d3dd83f7177a2f03"}, + {file = "fonttools-4.55.0-cp39-cp39-win_amd64.whl", hash = "sha256:ca7fd6987c68414fece41c96836e945e1f320cda56fc96ffdc16e54a44ec57a2"}, + {file = "fonttools-4.55.0-py3-none-any.whl", hash = "sha256:12db5888cd4dd3fcc9f0ee60c6edd3c7e1fd44b7dd0f31381ea03df68f8a153f"}, + {file = "fonttools-4.55.0.tar.gz", hash = "sha256:7636acc6ab733572d5e7eec922b254ead611f1cdad17be3f0be7418e8bfaca71"}, ] [package.extras] @@ -959,13 +947,13 @@ files = [ [[package]] name = "identify" -version = "2.6.1" +version = "2.6.2" description = "File identification library for Python" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "identify-2.6.1-py2.py3-none-any.whl", hash = "sha256:53863bcac7caf8d2ed85bd20312ea5dcfc22226800f6d6881f232d861db5a8f0"}, - {file = "identify-2.6.1.tar.gz", hash = "sha256:91478c5fb7c3aac5ff7bf9b4344f803843dc586832d5f110d672b19aa1984c98"}, + {file = "identify-2.6.2-py2.py3-none-any.whl", hash = "sha256:c097384259f49e372f4ea00a19719d95ae27dd5ff0fd77ad630aa891306b82f3"}, + {file = "identify-2.6.2.tar.gz", hash = "sha256:fab5c716c24d7a789775228823797296a2994b075fb6080ac83a102772a98cbd"}, ] [package.extras] @@ -1254,13 +1242,13 @@ files = [ [[package]] name = "lightning-utilities" -version = "0.11.8" +version = "0.11.9" description = "Lightning toolbox for across the our ecosystem." optional = false python-versions = ">=3.8" files = [ - {file = "lightning_utilities-0.11.8-py3-none-any.whl", hash = "sha256:a57edb34a44258f0c61eed8b8b88926766e9052f5e60bbe69e4871a2b2bfd970"}, - {file = "lightning_utilities-0.11.8.tar.gz", hash = "sha256:8dfbdc6c52f9847efc948dc462ab8bebb4f4e9a43bd69c82c1b1da484dac20e6"}, + {file = "lightning_utilities-0.11.9-py3-none-any.whl", hash = "sha256:ac6d4e9e28faf3ff4be997876750fee10dc604753dbc429bf3848a95c5d7e0d2"}, + {file = "lightning_utilities-0.11.9.tar.gz", hash = "sha256:f5052b81344cc2684aa9afd74b7ce8819a8f49a858184ec04548a5a109dfd053"}, ] [package.dependencies] @@ -1772,50 +1760,46 @@ files = [ [[package]] name = "nvidia-cublas-cu12" -version = "12.4.5.8" +version = "12.1.3.1" description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"}, - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"}, - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, ] [[package]] name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" +version = "12.1.105" description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"}, - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"}, - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, ] [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" +version = "12.1.105" description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"}, - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"}, - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, ] [[package]] name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" +version = "12.1.105" description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"}, - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"}, - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, ] [[package]] @@ -1834,41 +1818,35 @@ nvidia-cublas-cu12 = "*" [[package]] name = "nvidia-cufft-cu12" -version = "11.2.1.3" +version = "11.0.2.54" description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"}, - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"}, - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, ] -[package.dependencies] -nvidia-nvjitlink-cu12 = "*" - [[package]] name = "nvidia-curand-cu12" -version = "10.3.5.147" +version = "10.3.2.106" description = "CURAND native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"}, - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"}, - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, ] [[package]] name = "nvidia-cusolver-cu12" -version = "11.6.1.9" +version = "11.4.5.107" description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"}, - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"}, - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, ] [package.dependencies] @@ -1878,14 +1856,13 @@ nvidia-nvjitlink-cu12 = "*" [[package]] name = "nvidia-cusparse-cu12" -version = "12.3.1.170" +version = "12.1.0.106" description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"}, - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"}, - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, ] [package.dependencies] @@ -1893,47 +1870,47 @@ nvidia-nvjitlink-cu12 = "*" [[package]] name = "nvidia-nccl-cu12" -version = "2.21.5" +version = "2.20.5" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.4.127" +version = "12.6.85" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a"}, + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41"}, + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c"}, ] [[package]] name = "nvidia-nvtx-cu12" -version = "12.4.127" +version = "12.1.105" description = "NVIDIA Tools Extension" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"}, - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"}, - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] [[package]] name = "packaging" -version = "24.1" +version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, - {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, ] [[package]] @@ -2050,18 +2027,17 @@ files = [ [[package]] name = "patsy" -version = "0.5.6" +version = "1.0.1" description = "A Python package for describing statistical models and for building design matrices." optional = false -python-versions = "*" +python-versions = ">=3.6" files = [ - {file = "patsy-0.5.6-py2.py3-none-any.whl", hash = "sha256:19056886fd8fa71863fa32f0eb090267f21fb74be00f19f5c70b2e9d76c883c6"}, - {file = "patsy-0.5.6.tar.gz", hash = "sha256:95c6d47a7222535f84bff7f63d7303f2e297747a598db89cf5c67f0c0c7d2cdb"}, + {file = "patsy-1.0.1-py2.py3-none-any.whl", hash = "sha256:751fb38f9e97e62312e921a1954b81e1bb2bcda4f5eeabaf94db251ee791509c"}, + {file = "patsy-1.0.1.tar.gz", hash = "sha256:e786a9391eec818c054e359b737bbce692f051aee4c661f4141cc88fb459c0c4"}, ] [package.dependencies] numpy = ">=1.4" -six = "*" [package.extras] test = ["pytest", "pytest-cov", "scipy"] @@ -2077,6 +2053,20 @@ files = [ {file = "pbr-6.1.0.tar.gz", hash = "sha256:788183e382e3d1d7707db08978239965e8b9e4e5ed42669bf4758186734d5f24"}, ] +[[package]] +name = "pep8-naming" +version = "0.14.1" +description = "Check PEP-8 naming conventions, plugin for flake8" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pep8-naming-0.14.1.tar.gz", hash = "sha256:1ef228ae80875557eb6c1549deafed4dabbf3261cfcafa12f773fe0db9be8a36"}, + {file = "pep8_naming-0.14.1-py3-none-any.whl", hash = "sha256:63f514fc777d715f935faf185dedd679ab99526a7f2f503abb61587877f7b1c5"}, +] + +[package.dependencies] +flake8 = ">=5.0.0" + [[package]] name = "pillow" version = "11.0.0" @@ -2401,6 +2391,87 @@ files = [ {file = "protobuf-5.28.3.tar.gz", hash = "sha256:64badbc49180a5e401f373f9ce7ab1d18b63f7dd4a9cdc43c92b9f0b481cef7b"}, ] +[[package]] +name = "psutil" +version = "6.1.0" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:000d1d1ebd634b4efb383f4034437384e44a6d455260aaee2eca1e9c1b55f047"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:5cd2bcdc75b452ba2e10f0e8ecc0b57b827dd5d7aaffbc6821b2a9a242823a76"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:045f00a43c737f960d273a83973b2511430d61f283a44c96bf13a6e829ba8fdc"}, + {file = "psutil-6.1.0-cp27-none-win32.whl", hash = "sha256:9118f27452b70bb1d9ab3198c1f626c2499384935aaf55388211ad982611407e"}, + {file = "psutil-6.1.0-cp27-none-win_amd64.whl", hash = "sha256:a8506f6119cff7015678e2bce904a4da21025cc70ad283a53b099e7620061d85"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6e2dcd475ce8b80522e51d923d10c7871e45f20918e027ab682f94f1c6351688"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0895b8414afafc526712c498bd9de2b063deaac4021a3b3c34566283464aff8e"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dcbfce5d89f1d1f2546a2090f4fcf87c7f669d1d90aacb7d7582addece9fb38"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498c6979f9c6637ebc3a73b3f87f9eb1ec24e1ce53a7c5173b8508981614a90b"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d905186d647b16755a800e7263d43df08b790d709d575105d419f8b6ef65423a"}, + {file = "psutil-6.1.0-cp36-cp36m-win32.whl", hash = "sha256:6d3fbbc8d23fcdcb500d2c9f94e07b1342df8ed71b948a2649b5cb060a7c94ca"}, + {file = "psutil-6.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1209036fbd0421afde505a4879dee3b2fd7b1e14fee81c0069807adcbbcca747"}, + {file = "psutil-6.1.0-cp37-abi3-win32.whl", hash = "sha256:1ad45a1f5d0b608253b11508f80940985d1d0c8f6111b5cb637533a0e6ddc13e"}, + {file = "psutil-6.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:a8fb3752b491d246034fa4d279ff076501588ce8cbcdbb62c32fd7a377d996be"}, + {file = "psutil-6.1.0.tar.gz", hash = "sha256:353815f59a7f64cdaca1c0307ee13558a0512f6db064e92fe833784f08539c7a"}, +] + +[package.extras] +dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"] +test = ["pytest", "pytest-xdist", "setuptools"] + +[[package]] +name = "pyarrow" +version = "17.0.0" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, +] + +[package.dependencies] +numpy = ">=1.16.6" + +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + [[package]] name = "pyarrow" version = "18.0.0" @@ -2677,31 +2748,31 @@ files = [ [[package]] name = "ray" -version = "2.38.0" +version = "2.39.0" description = "Ray provides a simple, universal API for building distributed applications." optional = false python-versions = ">=3.9" files = [ - {file = "ray-2.38.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:fe01fce188ddea96ca5c7dfa4a783d2e5d80662318a640fae58d89e6eaf2cd7f"}, - {file = "ray-2.38.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fa0833cc54ca0c48aebc98b813fa1e990a20c8ee1da857073e11eb72696d316"}, - {file = "ray-2.38.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:3cdd71617f935a0d741864e94061093d14fad659e67271c9a779108878294ac3"}, - {file = "ray-2.38.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3a7d6f7159bce4117bfe8f9c3d0b65ff27257fe2dd8d737dad0f3869666440da"}, - {file = "ray-2.38.0-cp310-cp310-win_amd64.whl", hash = "sha256:b56c78ebdd7535ab6e8566e66c1f1c65a694432875dd683b1310e3d7b9af79f3"}, - {file = "ray-2.38.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:cce1a39fa91fe08b15d2d62d084052968a155c8528415f248346567aa589580c"}, - {file = "ray-2.38.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:454f576b3dbef2231693e3081ba5bf093add610c72ebf3c17788943f6653fe68"}, - {file = "ray-2.38.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:282a326d2848d411c3ce305e57e2de8357e24cb9becbec7e507e8800572c487e"}, - {file = "ray-2.38.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:ece802cf3a1c102b53f63b8bc90d947971c4b387deaf233c224ed8ef34a1f3cb"}, - {file = "ray-2.38.0-cp311-cp311-win_amd64.whl", hash = "sha256:64f7cd908177dd50089469cf331afbeb22e61e26d0a4be210ad20dccddbf6efb"}, - {file = "ray-2.38.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:10174ac63406b95a0a795a89396aeb8966286f15558087127719b13c367b40e3"}, - {file = "ray-2.38.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ea4148e929c17543378ba8909398fc81ce09d8e2257fc21afa62fc88ba4babc2"}, - {file = "ray-2.38.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:d4efaf1cfc727d60d78cc7112ff8eaa67634a5327e2a84f8dcaab5d167fe7fec"}, - {file = "ray-2.38.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:07507d2f9961e8d5390c0eb606df249216ef5afb1ff8185581f3e92041d66293"}, - {file = "ray-2.38.0-cp312-cp312-win_amd64.whl", hash = "sha256:6fdef893cbe617ac9d079e65702e9f1b3f455835f05b6f8b46467cf2184a52dc"}, - {file = "ray-2.38.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:0910eb721943f9825d10ae16d9cd3c7de70f4dde985207e18fddf59c0126770f"}, - {file = "ray-2.38.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3d0bd0d7a116ab79864ca8bf3222758ad85cc9f9421a51136ca33429e8e87ed9"}, - {file = "ray-2.38.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:cdfd910da985bc3c985945b7bbbef5f891473eddd06af9208b8af0d020e3a9a7"}, - {file = "ray-2.38.0-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:e18ac9e23da17393b4447ef2924e11ef95bb8a5d5b561ca8c74c05f2a594a6fe"}, - {file = "ray-2.38.0-cp39-cp39-win_amd64.whl", hash = "sha256:1f0d014f215b25f92041d4a2acfbc4e44abb2a92f43971228f493ba7874ede00"}, + {file = "ray-2.39.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:13d62cead910f433817ca5b41eda75d9c24e81a6b727e0d4e9c5817da86eca5b"}, + {file = "ray-2.39.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74219fade4acaf722d34a2630008220a2a5b2ba856e874cd5a8c24ab2f2b2412"}, + {file = "ray-2.39.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:54ed235b4542ad6d0e317988dc4feaf46af99902f3dfd2097600e0294751bf88"}, + {file = "ray-2.39.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:6298fb981cd0fa8607f1917deb27925ab8add48c60ba5bd0f6cf40d4cc5dace4"}, + {file = "ray-2.39.0-cp310-cp310-win_amd64.whl", hash = "sha256:c9d1a26fa3c4d32555c483fab57f54c4ba017f7552732fe9841396aaa24ee6ea"}, + {file = "ray-2.39.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5547f2e6cf3b5d5aaea8aabea2d223a65c9566db198349c0aac668f454710f1a"}, + {file = "ray-2.39.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7f8a83c2b7719386b3f8d6e3120aae49d9aa4cf49050acaee059b45df92eb281"}, + {file = "ray-2.39.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:413488eb2f8bfced8ecc269b120321f33106cbe412a69c3e23ce20c6d5b6f702"}, + {file = "ray-2.39.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:21aee127ae1a9cf6193001ab41d2551bcc81331ba3b7196d000f16d10f15c705"}, + {file = "ray-2.39.0-cp311-cp311-win_amd64.whl", hash = "sha256:fdcb7ad51883d194f7b49f23533d29b3c96d78034f829b6cde1e24b6783dff9d"}, + {file = "ray-2.39.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:77fbcf0002cfbb673b2832e273ee8a834358a2a2bff77e2ff5c97924fcd2b389"}, + {file = "ray-2.39.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a10cfca3a2f05d285ba1ab3cdd3ce43ec2934b05eb91516a9766bcfc4c070425"}, + {file = "ray-2.39.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:f8d01550f718a65e0be48da578fa2a3f2e1be85a5453d4b98c3576e1cfaab01b"}, + {file = "ray-2.39.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:016930e6ba74b91b40117a64b24f7bfff48a6a780f23d2b064a7a3f43bc4e1a2"}, + {file = "ray-2.39.0-cp312-cp312-win_amd64.whl", hash = "sha256:4893cc7fd8b3c48c68c3d90bc5fe2023ee2732f91e9664ee79e8272b18ddb170"}, + {file = "ray-2.39.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:f8291c8b82146b34d5e3989ca9a521a15258aa90b874b0db2fa18592c2e31155"}, + {file = "ray-2.39.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:078a309450be28e4563eda473d726c04eb85826f13c9c846b71fbd01e28367ed"}, + {file = "ray-2.39.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e4917adfaa831dfde2745311d50b4cd22d2d8b7b61219e77331b56724d5755d4"}, + {file = "ray-2.39.0-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4ed775b2630495ce2a6b086d45b94402a33a23ea3f86c344eeb621d617693b41"}, + {file = "ray-2.39.0-cp39-cp39-win_amd64.whl", hash = "sha256:7b1a4db0a23a3aa5ad49076a04b66e88b7b28263b038d70619301db1c23c2dbf"}, ] [package.dependencies] @@ -2715,26 +2786,29 @@ msgpack = ">=1.0.0,<2.0.0" packaging = "*" pandas = {version = "*", optional = true, markers = "extra == \"tune\""} protobuf = ">=3.15.3,<3.19.5 || >3.19.5" -pyarrow = {version = ">=6.0.1", optional = true, markers = "extra == \"tune\""} +pyarrow = [ + {version = ">=6.0.1,<18", optional = true, markers = "sys_platform == \"darwin\" and platform_machine == \"x86_64\" and extra == \"tune\""}, + {version = ">=6.0.1", optional = true, markers = "sys_platform != \"darwin\" and extra == \"tune\" or platform_machine != \"x86_64\" and extra == \"tune\""}, +] pyyaml = "*" requests = "*" tensorboardX = {version = ">=1.9", optional = true, markers = "extra == \"tune\""} [package.extras] adag = ["cupy-cuda12x"] -air = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "fastapi", "fsspec", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "memray", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] -all = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "cupy-cuda12x", "dm-tree", "fastapi", "fsspec", "grpcio (!=1.56.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "gymnasium (==0.28.1)", "lz4", "memray", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyOpenSSL", "pyarrow (>=6.0.1)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] -all-cpp = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "cupy-cuda12x", "dm-tree", "fastapi", "fsspec", "grpcio (!=1.56.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "gymnasium (==0.28.1)", "lz4", "memray", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyOpenSSL", "pyarrow (>=6.0.1)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "pyyaml", "ray-cpp (==2.38.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +air = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "fastapi", "fsspec", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "memray", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (<18)", "pyarrow (>=6.0.1)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +all = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "cupy-cuda12x", "dm-tree", "fastapi", "fsspec", "grpcio (!=1.56.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "gymnasium (==1.0.0)", "lz4", "memray", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyOpenSSL", "pyarrow (<18)", "pyarrow (>=6.0.1)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +all-cpp = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "cupy-cuda12x", "dm-tree", "fastapi", "fsspec", "grpcio (!=1.56.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "gymnasium (==1.0.0)", "lz4", "memray", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyOpenSSL", "pyarrow (<18)", "pyarrow (>=6.0.1)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "pyyaml", "ray-cpp (==2.39.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] client = ["grpcio (!=1.56.0)"] -cpp = ["ray-cpp (==2.38.0)"] -data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (>=6.0.1)"] +cpp = ["ray-cpp (==2.39.0)"] +data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (<18)", "pyarrow (>=6.0.1)"] default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "memray", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "virtualenv (>=20.0.24,!=20.21.1)"] observability = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"] -rllib = ["dm-tree", "fsspec", "gymnasium (==0.28.1)", "lz4", "pandas", "pyarrow (>=6.0.1)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "tensorboardX (>=1.9)", "typer"] +rllib = ["dm-tree", "fsspec", "gymnasium (==1.0.0)", "lz4", "pandas", "pyarrow (<18)", "pyarrow (>=6.0.1)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "tensorboardX (>=1.9)", "typer"] serve = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "fastapi", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "memray", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "fastapi", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "memray", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyOpenSSL", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] -train = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] -tune = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] +train = ["fsspec", "pandas", "pyarrow (<18)", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] +tune = ["fsspec", "pandas", "pyarrow (<18)", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] [[package]] name = "referencing" @@ -2787,13 +2861,13 @@ docutils = ">=0.11,<1.0" [[package]] name = "rich" -version = "13.9.3" +version = "13.9.4" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.8.0" files = [ - {file = "rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283"}, - {file = "rich-13.9.3.tar.gz", hash = "sha256:bc1e01b899537598cf02579d2b9f4a415104d3fc439313a7a2c165d76557a08e"}, + {file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"}, + {file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"}, ] [package.dependencies] @@ -2806,114 +2880,101 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "rpds-py" -version = "0.20.0" +version = "0.21.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, - {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, - {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, - {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, - {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, - {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, - {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, - {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, - {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, - {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, - {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, - {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, - {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, - {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, - {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, - {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, - {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, - {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, - {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, - {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, - {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, - {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, - {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, - {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, - {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, - {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, - {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, - {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, - {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, - {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, - {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, - {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, - {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, - {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, - {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, - {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, - {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, - {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, - {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, - {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, - {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, - {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, - {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, - {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, - {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, - {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, - {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, - {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, - {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, - {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, - {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, + {file = "rpds_py-0.21.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a017f813f24b9df929674d0332a374d40d7f0162b326562daae8066b502d0590"}, + {file = "rpds_py-0.21.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:20cc1ed0bcc86d8e1a7e968cce15be45178fd16e2ff656a243145e0b439bd250"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad116dda078d0bc4886cb7840e19811562acdc7a8e296ea6ec37e70326c1b41c"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:808f1ac7cf3b44f81c9475475ceb221f982ef548e44e024ad5f9e7060649540e"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de552f4a1916e520f2703ec474d2b4d3f86d41f353e7680b597512ffe7eac5d0"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:efec946f331349dfc4ae9d0e034c263ddde19414fe5128580f512619abed05f1"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b80b4690bbff51a034bfde9c9f6bf9357f0a8c61f548942b80f7b66356508bf5"}, + {file = "rpds_py-0.21.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:085ed25baac88953d4283e5b5bd094b155075bb40d07c29c4f073e10623f9f2e"}, + {file = "rpds_py-0.21.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:daa8efac2a1273eed2354397a51216ae1e198ecbce9036fba4e7610b308b6153"}, + {file = "rpds_py-0.21.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:95a5bad1ac8a5c77b4e658671642e4af3707f095d2b78a1fdd08af0dfb647624"}, + {file = "rpds_py-0.21.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3e53861b29a13d5b70116ea4230b5f0f3547b2c222c5daa090eb7c9c82d7f664"}, + {file = "rpds_py-0.21.0-cp310-none-win32.whl", hash = "sha256:ea3a6ac4d74820c98fcc9da4a57847ad2cc36475a8bd9683f32ab6d47a2bd682"}, + {file = "rpds_py-0.21.0-cp310-none-win_amd64.whl", hash = "sha256:b8f107395f2f1d151181880b69a2869c69e87ec079c49c0016ab96860b6acbe5"}, + {file = "rpds_py-0.21.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5555db3e618a77034954b9dc547eae94166391a98eb867905ec8fcbce1308d95"}, + {file = "rpds_py-0.21.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:97ef67d9bbc3e15584c2f3c74bcf064af36336c10d2e21a2131e123ce0f924c9"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ab2c2a26d2f69cdf833174f4d9d86118edc781ad9a8fa13970b527bf8236027"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4e8921a259f54bfbc755c5bbd60c82bb2339ae0324163f32868f63f0ebb873d9"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a7ff941004d74d55a47f916afc38494bd1cfd4b53c482b77c03147c91ac0ac3"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5145282a7cd2ac16ea0dc46b82167754d5e103a05614b724457cffe614f25bd8"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de609a6f1b682f70bb7163da745ee815d8f230d97276db049ab447767466a09d"}, + {file = "rpds_py-0.21.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:40c91c6e34cf016fa8e6b59d75e3dbe354830777fcfd74c58b279dceb7975b75"}, + {file = "rpds_py-0.21.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d2132377f9deef0c4db89e65e8bb28644ff75a18df5293e132a8d67748397b9f"}, + {file = "rpds_py-0.21.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0a9e0759e7be10109645a9fddaaad0619d58c9bf30a3f248a2ea57a7c417173a"}, + {file = "rpds_py-0.21.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9e20da3957bdf7824afdd4b6eeb29510e83e026473e04952dca565170cd1ecc8"}, + {file = "rpds_py-0.21.0-cp311-none-win32.whl", hash = "sha256:f71009b0d5e94c0e86533c0b27ed7cacc1239cb51c178fd239c3cfefefb0400a"}, + {file = "rpds_py-0.21.0-cp311-none-win_amd64.whl", hash = "sha256:e168afe6bf6ab7ab46c8c375606298784ecbe3ba31c0980b7dcbb9631dcba97e"}, + {file = "rpds_py-0.21.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:30b912c965b2aa76ba5168fd610087bad7fcde47f0a8367ee8f1876086ee6d1d"}, + {file = "rpds_py-0.21.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ca9989d5d9b1b300bc18e1801c67b9f6d2c66b8fd9621b36072ed1df2c977f72"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f54e7106f0001244a5f4cf810ba8d3f9c542e2730821b16e969d6887b664266"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fed5dfefdf384d6fe975cc026886aece4f292feaf69d0eeb716cfd3c5a4dd8be"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:590ef88db231c9c1eece44dcfefd7515d8bf0d986d64d0caf06a81998a9e8cab"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f983e4c2f603c95dde63df633eec42955508eefd8d0f0e6d236d31a044c882d7"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b229ce052ddf1a01c67d68166c19cb004fb3612424921b81c46e7ea7ccf7c3bf"}, + {file = "rpds_py-0.21.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ebf64e281a06c904a7636781d2e973d1f0926a5b8b480ac658dc0f556e7779f4"}, + {file = "rpds_py-0.21.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:998a8080c4495e4f72132f3d66ff91f5997d799e86cec6ee05342f8f3cda7dca"}, + {file = "rpds_py-0.21.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:98486337f7b4f3c324ab402e83453e25bb844f44418c066623db88e4c56b7c7b"}, + {file = "rpds_py-0.21.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a78d8b634c9df7f8d175451cfeac3810a702ccb85f98ec95797fa98b942cea11"}, + {file = "rpds_py-0.21.0-cp312-none-win32.whl", hash = "sha256:a58ce66847711c4aa2ecfcfaff04cb0327f907fead8945ffc47d9407f41ff952"}, + {file = "rpds_py-0.21.0-cp312-none-win_amd64.whl", hash = "sha256:e860f065cc4ea6f256d6f411aba4b1251255366e48e972f8a347cf88077b24fd"}, + {file = "rpds_py-0.21.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ee4eafd77cc98d355a0d02f263efc0d3ae3ce4a7c24740010a8b4012bbb24937"}, + {file = "rpds_py-0.21.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:688c93b77e468d72579351a84b95f976bd7b3e84aa6686be6497045ba84be560"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c38dbf31c57032667dd5a2f0568ccde66e868e8f78d5a0d27dcc56d70f3fcd3b"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2d6129137f43f7fa02d41542ffff4871d4aefa724a5fe38e2c31a4e0fd343fb0"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:520ed8b99b0bf86a176271f6fe23024323862ac674b1ce5b02a72bfeff3fff44"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aaeb25ccfb9b9014a10eaf70904ebf3f79faaa8e60e99e19eef9f478651b9b74"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af04ac89c738e0f0f1b913918024c3eab6e3ace989518ea838807177d38a2e94"}, + {file = "rpds_py-0.21.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b9b76e2afd585803c53c5b29e992ecd183f68285b62fe2668383a18e74abe7a3"}, + {file = "rpds_py-0.21.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5afb5efde74c54724e1a01118c6e5c15e54e642c42a1ba588ab1f03544ac8c7a"}, + {file = "rpds_py-0.21.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:52c041802a6efa625ea18027a0723676a778869481d16803481ef6cc02ea8cb3"}, + {file = "rpds_py-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ee1e4fc267b437bb89990b2f2abf6c25765b89b72dd4a11e21934df449e0c976"}, + {file = "rpds_py-0.21.0-cp313-none-win32.whl", hash = "sha256:0c025820b78817db6a76413fff6866790786c38f95ea3f3d3c93dbb73b632202"}, + {file = "rpds_py-0.21.0-cp313-none-win_amd64.whl", hash = "sha256:320c808df533695326610a1b6a0a6e98f033e49de55d7dc36a13c8a30cfa756e"}, + {file = "rpds_py-0.21.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:2c51d99c30091f72a3c5d126fad26236c3f75716b8b5e5cf8effb18889ced928"}, + {file = "rpds_py-0.21.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cbd7504a10b0955ea287114f003b7ad62330c9e65ba012c6223dba646f6ffd05"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6dcc4949be728ede49e6244eabd04064336012b37f5c2200e8ec8eb2988b209c"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f414da5c51bf350e4b7960644617c130140423882305f7574b6cf65a3081cecb"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9afe42102b40007f588666bc7de82451e10c6788f6f70984629db193849dced1"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b929c2bb6e29ab31f12a1117c39f7e6d6450419ab7464a4ea9b0b417174f044"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8404b3717da03cbf773a1d275d01fec84ea007754ed380f63dfc24fb76ce4592"}, + {file = "rpds_py-0.21.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e12bb09678f38b7597b8346983d2323a6482dcd59e423d9448108c1be37cac9d"}, + {file = "rpds_py-0.21.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:58a0e345be4b18e6b8501d3b0aa540dad90caeed814c515e5206bb2ec26736fd"}, + {file = "rpds_py-0.21.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:c3761f62fcfccf0864cc4665b6e7c3f0c626f0380b41b8bd1ce322103fa3ef87"}, + {file = "rpds_py-0.21.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c2b2f71c6ad6c2e4fc9ed9401080badd1469fa9889657ec3abea42a3d6b2e1ed"}, + {file = "rpds_py-0.21.0-cp39-none-win32.whl", hash = "sha256:b21747f79f360e790525e6f6438c7569ddbfb1b3197b9e65043f25c3c9b489d8"}, + {file = "rpds_py-0.21.0-cp39-none-win_amd64.whl", hash = "sha256:0626238a43152918f9e72ede9a3b6ccc9e299adc8ade0d67c5e142d564c9a83d"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6b4ef7725386dc0762857097f6b7266a6cdd62bfd209664da6712cb26acef035"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:6bc0e697d4d79ab1aacbf20ee5f0df80359ecf55db33ff41481cf3e24f206919"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da52d62a96e61c1c444f3998c434e8b263c384f6d68aca8274d2e08d1906325c"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:98e4fe5db40db87ce1c65031463a760ec7906ab230ad2249b4572c2fc3ef1f9f"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:30bdc973f10d28e0337f71d202ff29345320f8bc49a31c90e6c257e1ccef4333"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:faa5e8496c530f9c71f2b4e1c49758b06e5f4055e17144906245c99fa6d45356"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32eb88c30b6a4f0605508023b7141d043a79b14acb3b969aa0b4f99b25bc7d4a"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a89a8ce9e4e75aeb7fa5d8ad0f3fecdee813802592f4f46a15754dcb2fd6b061"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:241e6c125568493f553c3d0fdbb38c74babf54b45cef86439d4cd97ff8feb34d"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:3b766a9f57663396e4f34f5140b3595b233a7b146e94777b97a8413a1da1be18"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:af4a644bf890f56e41e74be7d34e9511e4954894d544ec6b8efe1e21a1a8da6c"}, + {file = "rpds_py-0.21.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3e30a69a706e8ea20444b98a49f386c17b26f860aa9245329bab0851ed100677"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:031819f906bb146561af051c7cef4ba2003d28cff07efacef59da973ff7969ba"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b876f2bc27ab5954e2fd88890c071bd0ed18b9c50f6ec3de3c50a5ece612f7a6"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc5695c321e518d9f03b7ea6abb5ea3af4567766f9852ad1560f501b17588c7b"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b4de1da871b5c0fd5537b26a6fc6814c3cc05cabe0c941db6e9044ffbb12f04a"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:878f6fea96621fda5303a2867887686d7a198d9e0f8a40be100a63f5d60c88c9"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8eeec67590e94189f434c6d11c426892e396ae59e4801d17a93ac96b8c02a6c"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff2eba7f6c0cb523d7e9cff0903f2fe1feff8f0b2ceb6bd71c0e20a4dcee271"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a429b99337062877d7875e4ff1a51fe788424d522bd64a8c0a20ef3021fdb6ed"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:d167e4dbbdac48bd58893c7e446684ad5d425b407f9336e04ab52e8b9194e2ed"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:4eb2de8a147ffe0626bfdc275fc6563aa7bf4b6db59cf0d44f0ccd6ca625a24e"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:e78868e98f34f34a88e23ee9ccaeeec460e4eaf6db16d51d7a9b883e5e785a5e"}, + {file = "rpds_py-0.21.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:4991ca61656e3160cdaca4851151fd3f4a92e9eba5c7a530ab030d6aee96ec89"}, + {file = "rpds_py-0.21.0.tar.gz", hash = "sha256:ed6378c9d66d0de903763e7706383d60c33829581f0adff47b6535f1802fa6db"}, ] [[package]] @@ -3099,23 +3160,23 @@ stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"] [[package]] name = "setuptools" -version = "75.3.0" +version = "75.6.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "setuptools-75.3.0-py3-none-any.whl", hash = "sha256:f2504966861356aa38616760c0f66568e535562374995367b4e69c7143cf6bcd"}, - {file = "setuptools-75.3.0.tar.gz", hash = "sha256:fba5dd4d766e97be1b1681d98712680ae8f2f26d7881245f2ce9e40714f1a686"}, + {file = "setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d"}, + {file = "setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6"}, ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"] +core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] [[package]] name = "six" @@ -3245,13 +3306,13 @@ sphinx = ">=4.0" [[package]] name = "sphinx-rtd-theme" -version = "3.0.1" +version = "3.0.2" description = "Read the Docs theme for Sphinx" optional = false python-versions = ">=3.8" files = [ - {file = "sphinx_rtd_theme-3.0.1-py2.py3-none-any.whl", hash = "sha256:921c0ece75e90633ee876bd7b148cfaad136b481907ad154ac3669b6fc957916"}, - {file = "sphinx_rtd_theme-3.0.1.tar.gz", hash = "sha256:a4c5745d1b06dfcb80b7704fe532eb765b44065a8fad9851e4258c8804140703"}, + {file = "sphinx_rtd_theme-3.0.2-py2.py3-none-any.whl", hash = "sha256:422ccc750c3a3a311de4ae327e82affdaf59eb695ba4936538552f3b00f4ee13"}, + {file = "sphinx_rtd_theme-3.0.2.tar.gz", hash = "sha256:b7457bc25dda723b20b086a670b9953c859eab60a2a03ee8eb2bb23e176e5f85"}, ] [package.dependencies] @@ -3372,13 +3433,13 @@ test = ["pytest"] [[package]] name = "starlette" -version = "0.41.2" +version = "0.41.3" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.41.2-py3-none-any.whl", hash = "sha256:fbc189474b4731cf30fcef52f18a8d070e3f3b46c6a04c97579e85e6ffca942d"}, - {file = "starlette-0.41.2.tar.gz", hash = "sha256:9834fd799d1a87fd346deb76158668cfa0b0d56f85caefe8268e2d97c3468b62"}, + {file = "starlette-0.41.3-py3-none-any.whl", hash = "sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7"}, + {file = "starlette-0.41.3.tar.gz", hash = "sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835"}, ] [package.dependencies] @@ -3441,13 +3502,13 @@ docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "n [[package]] name = "stevedore" -version = "5.3.0" +version = "5.4.0" description = "Manage dynamic plugins for Python applications" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "stevedore-5.3.0-py3-none-any.whl", hash = "sha256:1efd34ca08f474dad08d9b19e934a22c68bb6fe416926479ba29e5013bcc8f78"}, - {file = "stevedore-5.3.0.tar.gz", hash = "sha256:9a64265f4060312828151c204efbe9b7a9852a0d9228756344dbc7e4023e375a"}, + {file = "stevedore-5.4.0-py3-none-any.whl", hash = "sha256:b0be3c4748b3ea7b854b265dcb4caa891015e442416422be16f8b31756107857"}, + {file = "stevedore-5.4.0.tar.gz", hash = "sha256:79e92235ecb828fe952b6b8b0c6c87863248631922c8e8e0fa5b17b232c4514d"}, ] [package.dependencies] @@ -3455,13 +3516,13 @@ pbr = ">=2.0.0" [[package]] name = "sympy" -version = "1.13.1" +version = "1.13.3" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" files = [ - {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"}, - {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, + {file = "sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73"}, + {file = "sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9"}, ] [package.dependencies] @@ -3539,13 +3600,13 @@ files = [ [[package]] name = "tomli" -version = "2.0.2" +version = "2.1.0" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" files = [ - {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"}, - {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, + {file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"}, + {file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"}, ] [[package]] @@ -3561,28 +3622,31 @@ files = [ [[package]] name = "torch" -version = "2.5.1" +version = "2.4.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:71328e1bbe39d213b8721678f9dcac30dfc452a46d586f1d514a6aa0a99d4744"}, - {file = "torch-2.5.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:34bfa1a852e5714cbfa17f27c49d8ce35e1b7af5608c4bc6e81392c352dbc601"}, - {file = "torch-2.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:32a037bd98a241df6c93e4c789b683335da76a2ac142c0973675b715102dc5fa"}, - {file = "torch-2.5.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:23d062bf70776a3d04dbe74db950db2a5245e1ba4f27208a87f0d743b0d06e86"}, - {file = "torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:de5b7d6740c4b636ef4db92be922f0edc425b65ed78c5076c43c42d362a45457"}, - {file = "torch-2.5.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:340ce0432cad0d37f5a31be666896e16788f1adf8ad7be481196b503dad675b9"}, - {file = "torch-2.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:603c52d2fe06433c18b747d25f5c333f9c1d58615620578c326d66f258686f9a"}, - {file = "torch-2.5.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:31f8c39660962f9ae4eeec995e3049b5492eb7360dd4f07377658ef4d728fa4c"}, - {file = "torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ed231a4b3a5952177fafb661213d690a72caaad97d5824dd4fc17ab9e15cec03"}, - {file = "torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:3f4b7f10a247e0dcd7ea97dc2d3bfbfc90302ed36d7f3952b0008d0df264e697"}, - {file = "torch-2.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:73e58e78f7d220917c5dbfad1a40e09df9929d3b95d25e57d9f8558f84c9a11c"}, - {file = "torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1"}, - {file = "torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7"}, - {file = "torch-2.5.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:1f3b7fb3cf7ab97fae52161423f81be8c6b8afac8d9760823fd623994581e1a3"}, - {file = "torch-2.5.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:7974e3dce28b5a21fb554b73e1bc9072c25dde873fa00d54280861e7a009d7dc"}, - {file = "torch-2.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:46c817d3ea33696ad3b9df5e774dba2257e9a4cd3c4a3afbf92f6bb13ac5ce2d"}, - {file = "torch-2.5.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:8046768b7f6d35b85d101b4b38cba8aa2f3cd51952bc4c06a49580f2ce682291"}, + {file = "torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:4ed94583e244af51d6a8d28701ca5a9e02d1219e782f5a01dd401f90af17d8ac"}, + {file = "torch-2.4.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c4ca297b7bd58b506bfd6e78ffd14eb97c0e7797dcd7965df62f50bb575d8954"}, + {file = "torch-2.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:2497cbc7b3c951d69b276ca51fe01c2865db67040ac67f5fc20b03e41d16ea4a"}, + {file = "torch-2.4.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:685418ab93730efbee71528821ff54005596970dd497bf03c89204fb7e3f71de"}, + {file = "torch-2.4.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e743adadd8c8152bb8373543964551a7cb7cc20ba898dc8f9c0cdbe47c283de0"}, + {file = "torch-2.4.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:7334325c0292cbd5c2eac085f449bf57d3690932eac37027e193ba775703c9e6"}, + {file = "torch-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:97730014da4c57ffacb3c09298c6ce05400606e890bd7a05008d13dd086e46b1"}, + {file = "torch-2.4.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f169b4ea6dc93b3a33319611fcc47dc1406e4dd539844dcbd2dec4c1b96e166d"}, + {file = "torch-2.4.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:997084a0f9784d2a89095a6dc67c7925e21bf25dea0b3d069b41195016ccfcbb"}, + {file = "torch-2.4.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bc3988e8b36d1e8b998d143255d9408d8c75da4ab6dd0dcfd23b623dfb0f0f57"}, + {file = "torch-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3374128bbf7e62cdaed6c237bfd39809fbcfaa576bee91e904706840c3f2195c"}, + {file = "torch-2.4.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:91aaf00bfe1ffa44dc5b52809d9a95129fca10212eca3ac26420eb11727c6288"}, + {file = "torch-2.4.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cc30457ea5489c62747d3306438af00c606b509d78822a88f804202ba63111ed"}, + {file = "torch-2.4.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a046491aaf96d1215e65e1fa85911ef2ded6d49ea34c8df4d0638879f2402eef"}, + {file = "torch-2.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:688eec9240f3ce775f22e1e1a5ab9894f3d5fe60f3f586deb7dbd23a46a83916"}, + {file = "torch-2.4.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:3af4de2a618fb065e78404c4ba27a818a7b7957eaeff28c6c66ce7fb504b68b8"}, + {file = "torch-2.4.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:618808d3f610d5f180e47a697d4ec90b810953bb1e020f424b2ac7fb0884b545"}, + {file = "torch-2.4.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ed765d232d23566052ba83632ec73a4fccde00b4c94ad45d63b471b09d63b7a7"}, + {file = "torch-2.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:a2feb98ac470109472fb10dfef38622a7ee08482a16c357863ebc7bc7db7c8f7"}, + {file = "torch-2.4.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:8940fc8b97a4c61fdb5d46a368f21f4a3a562a17879e932eb51a5ec62310cb31"}, ] [package.dependencies] @@ -3590,64 +3654,91 @@ filelock = "*" fsspec = "*" jinja2 = "*" networkx = "*" -nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -setuptools = {version = "*", markers = "python_version >= \"3.12\""} -sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""} -triton = {version = "3.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +sympy = "*" +triton = {version = "3.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""} typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.12.0)"] +optree = ["optree (>=0.11.0)"] + +[[package]] +name = "torch-geometric" +version = "2.6.1" +description = "Graph Neural Network Library for PyTorch" +optional = false +python-versions = ">=3.8" +files = [ + {file = "torch_geometric-2.6.1-py3-none-any.whl", hash = "sha256:8faeb353f9655f7dbec44c5e0b44c721773bdfb279994da96b9b8b12fd30f427"}, + {file = "torch_geometric-2.6.1.tar.gz", hash = "sha256:1f18f9d0fc4d2239d526221e4f22606a4a3895b5d965a9856d27610a3df662c6"}, +] + +[package.dependencies] +aiohttp = "*" +fsspec = "*" +jinja2 = "*" +numpy = "*" +psutil = ">=5.8.0" +pyparsing = "*" +requests = "*" +tqdm = "*" + +[package.extras] +benchmark = ["matplotlib", "networkx", "pandas", "protobuf (<4.21)", "wandb"] +dev = ["ipython", "matplotlib-inline", "pre-commit", "torch_geometric[test]"] +full = ["ase", "captum (<0.7.0)", "graphviz", "h5py", "matplotlib", "networkx", "numba (<0.60.0)", "opt_einsum", "pandas", "pgmpy", "pynndescent", "pytorch-memlab", "rdflib", "rdkit", "scikit-image", "scikit-learn", "scipy", "statsmodels", "sympy", "tabulate", "torch_geometric[graphgym,modelhub]", "torchmetrics", "trimesh"] +graphgym = ["protobuf (<4.21)", "pytorch-lightning (<2.3.0)", "yacs"] +modelhub = ["huggingface_hub"] +test = ["onnx", "onnxruntime", "pytest", "pytest-cov"] [[package]] name = "torchmetrics" -version = "1.5.1" +version = "1.6.0" description = "PyTorch native Metrics" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "torchmetrics-1.5.1-py3-none-any.whl", hash = "sha256:1f297aa40958b3d276dddd7494f07e6417cad3efa1366b458b902d6d2ab76def"}, - {file = "torchmetrics-1.5.1.tar.gz", hash = "sha256:9701632cf811bc460abf07bd7b971b79c1ae9c8231e03d495b53a0975e43fe07"}, + {file = "torchmetrics-1.6.0-py3-none-any.whl", hash = "sha256:a508cdd87766cedaaf55a419812bf9f493aff8fffc02cc19df5a8e2e7ccb942a"}, + {file = "torchmetrics-1.6.0.tar.gz", hash = "sha256:aebba248708fb90def20cccba6f55bddd134a58de43fb22b0c5ca0f3a89fa984"}, ] [package.dependencies] lightning-utilities = ">=0.8.0" -numpy = ">1.20.0,<2.0" +numpy = ">1.20.0" packaging = ">17.1" -torch = ">=1.10.0" +torch = ">=2.0.0" [package.extras] -all = ["SciencePlots (>=2.0.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.9.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.11.2)", "nltk (>3.8.1)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (<4.67.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] -audio = ["gammatone (>=1.0.0)", "librosa (>=0.9.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=0.10.0)"] -detection = ["pycocotools (>2.0.0)", "torchvision (>=0.8)"] -dev = ["SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (==0.7.6)", "dython (>=0.7.8,<0.8.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.27)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.9.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.2)", "mypy (==1.11.2)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.2.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.5.0)", "torch-complex (<0.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (<4.67.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] -image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.8)"] +all = ["SciencePlots (>=2.0.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.10.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.13.0)", "nltk (>3.8.1)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.5.1)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +audio = ["gammatone (>=1.0.0)", "librosa (>=0.10.0)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=2.0.1)"] +detection = ["pycocotools (>2.0.0)", "torchvision (>=0.15.1)"] +dev = ["PyTDC (==0.4.1)", "SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (==0.7.6)", "dython (>=0.7.8,<0.8.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.27)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.10.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.2)", "monai (==1.4.0)", "mypy (==1.13.0)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.0)", "numpy (<2.2.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "permetrics (==2.0.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.5.1)", "torch-complex (<0.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.15.1)"] multimodal = ["piq (<=0.8.0)", "transformers (>=4.42.3)"] -text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.67.0)", "transformers (>4.4.0)"] -typing = ["mypy (==1.11.2)", "torch (==2.5.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.68.0)", "transformers (>4.4.0)"] +typing = ["mypy (==1.13.0)", "torch (==2.5.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.6.0)"] [[package]] name = "tqdm" -version = "4.66.6" +version = "4.67.0" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"}, - {file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"}, + {file = "tqdm-4.67.0-py3-none-any.whl", hash = "sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be"}, + {file = "tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a"}, ] [package.dependencies] @@ -3655,22 +3746,23 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [package.extras] dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +discord = ["requests"] notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] [[package]] name = "triton" -version = "3.1.0" +version = "3.0.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8"}, - {file = "triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c"}, - {file = "triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc"}, - {file = "triton-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dadaca7fc24de34e180271b5cf864c16755702e9f63a16f62df714a8099126a"}, - {file = "triton-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aafa9a20cd0d9fee523cd4504aa7131807a864cd77dcf6efe7e981f18b8c6c11"}, + {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, + {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, + {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, + {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, + {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, ] [package.dependencies] @@ -3681,6 +3773,31 @@ build = ["cmake (>=3.20)", "lit"] tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20240917" +description = "Typing stubs for PyYAML" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-PyYAML-6.0.12.20240917.tar.gz", hash = "sha256:d1405a86f9576682234ef83bcb4e6fff7c9305c8b1fbad5e0bcd4f7dbdc9c587"}, + {file = "types_PyYAML-6.0.12.20240917-py3-none-any.whl", hash = "sha256:392b267f1c0fe6022952462bf5d6523f31e37f6cea49b14cee7ad634b6301570"}, +] + +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, + {file = "types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.12.2" @@ -3722,13 +3839,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.32.0" +version = "0.32.1" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.32.0-py3-none-any.whl", hash = "sha256:60b8f3a5ac027dcd31448f411ced12b5ef452c646f76f02f8cc3f25d8d26fd82"}, - {file = "uvicorn-0.32.0.tar.gz", hash = "sha256:f78b36b143c16f54ccdb8190d0a26b5f1901fe5a3c777e1ab29f26391af8551e"}, + {file = "uvicorn-0.32.1-py3-none-any.whl", hash = "sha256:82ad92fd58da0d12af7482ecdb5f2470a04c9c9a53ced65b9bbb4a205377602e"}, + {file = "uvicorn-0.32.1.tar.gz", hash = "sha256:ee9519c246a72b1c084cea8d3b44ed6026e78a4a309cbedae9c37e4cb9fbb175"}, ] [package.dependencies] @@ -3737,7 +3854,7 @@ h11 = ">=0.8" typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} [package.extras] -standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] +standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] [[package]] name = "virtualenv" @@ -3856,97 +3973,80 @@ anyio = ">=3.0.0" [[package]] name = "websockets" -version = "13.1" +version = "14.1" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"}, - {file = "websockets-13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7"}, - {file = "websockets-13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6"}, - {file = "websockets-13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b"}, - {file = "websockets-13.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa"}, - {file = "websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700"}, - {file = "websockets-13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c"}, - {file = "websockets-13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0"}, - {file = "websockets-13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f"}, - {file = "websockets-13.1-cp310-cp310-win32.whl", hash = "sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe"}, - {file = "websockets-13.1-cp310-cp310-win_amd64.whl", hash = "sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a"}, - {file = "websockets-13.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19"}, - {file = "websockets-13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5"}, - {file = "websockets-13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd"}, - {file = "websockets-13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02"}, - {file = "websockets-13.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7"}, - {file = "websockets-13.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096"}, - {file = "websockets-13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084"}, - {file = "websockets-13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3"}, - {file = "websockets-13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9"}, - {file = "websockets-13.1-cp311-cp311-win32.whl", hash = "sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f"}, - {file = "websockets-13.1-cp311-cp311-win_amd64.whl", hash = "sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557"}, - {file = "websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc"}, - {file = "websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49"}, - {file = "websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd"}, - {file = "websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0"}, - {file = "websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6"}, - {file = "websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9"}, - {file = "websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68"}, - {file = "websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14"}, - {file = "websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf"}, - {file = "websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c"}, - {file = "websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3"}, - {file = "websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6"}, - {file = "websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708"}, - {file = "websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418"}, - {file = "websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a"}, - {file = "websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f"}, - {file = "websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5"}, - {file = "websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135"}, - {file = "websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2"}, - {file = "websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6"}, - {file = "websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d"}, - {file = "websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2"}, - {file = "websockets-13.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d"}, - {file = "websockets-13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23"}, - {file = "websockets-13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c"}, - {file = "websockets-13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea"}, - {file = "websockets-13.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7"}, - {file = "websockets-13.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54"}, - {file = "websockets-13.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db"}, - {file = "websockets-13.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295"}, - {file = "websockets-13.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96"}, - {file = "websockets-13.1-cp38-cp38-win32.whl", hash = "sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf"}, - {file = "websockets-13.1-cp38-cp38-win_amd64.whl", hash = "sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6"}, - {file = "websockets-13.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d"}, - {file = "websockets-13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7"}, - {file = "websockets-13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a"}, - {file = "websockets-13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa"}, - {file = "websockets-13.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa"}, - {file = "websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79"}, - {file = "websockets-13.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17"}, - {file = "websockets-13.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6"}, - {file = "websockets-13.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5"}, - {file = "websockets-13.1-cp39-cp39-win32.whl", hash = "sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c"}, - {file = "websockets-13.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d"}, - {file = "websockets-13.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238"}, - {file = "websockets-13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5"}, - {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9"}, - {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6"}, - {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a"}, - {file = "websockets-13.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23"}, - {file = "websockets-13.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b"}, - {file = "websockets-13.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51"}, - {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7"}, - {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d"}, - {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027"}, - {file = "websockets-13.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978"}, - {file = "websockets-13.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e"}, - {file = "websockets-13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09"}, - {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842"}, - {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb"}, - {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20"}, - {file = "websockets-13.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678"}, - {file = "websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f"}, - {file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"}, + {file = "websockets-14.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a0adf84bc2e7c86e8a202537b4fd50e6f7f0e4a6b6bf64d7ccb96c4cd3330b29"}, + {file = "websockets-14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90b5d9dfbb6d07a84ed3e696012610b6da074d97453bd01e0e30744b472c8179"}, + {file = "websockets-14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2177ee3901075167f01c5e335a6685e71b162a54a89a56001f1c3e9e3d2ad250"}, + {file = "websockets-14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f14a96a0034a27f9d47fd9788913924c89612225878f8078bb9d55f859272b0"}, + {file = "websockets-14.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f874ba705deea77bcf64a9da42c1f5fc2466d8f14daf410bc7d4ceae0a9fcb0"}, + {file = "websockets-14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9607b9a442392e690a57909c362811184ea429585a71061cd5d3c2b98065c199"}, + {file = "websockets-14.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:bea45f19b7ca000380fbd4e02552be86343080120d074b87f25593ce1700ad58"}, + {file = "websockets-14.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:219c8187b3ceeadbf2afcf0f25a4918d02da7b944d703b97d12fb01510869078"}, + {file = "websockets-14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ad2ab2547761d79926effe63de21479dfaf29834c50f98c4bf5b5480b5838434"}, + {file = "websockets-14.1-cp310-cp310-win32.whl", hash = "sha256:1288369a6a84e81b90da5dbed48610cd7e5d60af62df9851ed1d1d23a9069f10"}, + {file = "websockets-14.1-cp310-cp310-win_amd64.whl", hash = "sha256:e0744623852f1497d825a49a99bfbec9bea4f3f946df6eb9d8a2f0c37a2fec2e"}, + {file = "websockets-14.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:449d77d636f8d9c17952628cc7e3b8faf6e92a17ec581ec0c0256300717e1512"}, + {file = "websockets-14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a35f704be14768cea9790d921c2c1cc4fc52700410b1c10948511039be824aac"}, + {file = "websockets-14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b1f3628a0510bd58968c0f60447e7a692933589b791a6b572fcef374053ca280"}, + {file = "websockets-14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c3deac3748ec73ef24fc7be0b68220d14d47d6647d2f85b2771cb35ea847aa1"}, + {file = "websockets-14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7048eb4415d46368ef29d32133134c513f507fff7d953c18c91104738a68c3b3"}, + {file = "websockets-14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6cf0ad281c979306a6a34242b371e90e891bce504509fb6bb5246bbbf31e7b6"}, + {file = "websockets-14.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cc1fc87428c1d18b643479caa7b15db7d544652e5bf610513d4a3478dbe823d0"}, + {file = "websockets-14.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f95ba34d71e2fa0c5d225bde3b3bdb152e957150100e75c86bc7f3964c450d89"}, + {file = "websockets-14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9481a6de29105d73cf4515f2bef8eb71e17ac184c19d0b9918a3701c6c9c4f23"}, + {file = "websockets-14.1-cp311-cp311-win32.whl", hash = "sha256:368a05465f49c5949e27afd6fbe0a77ce53082185bbb2ac096a3a8afaf4de52e"}, + {file = "websockets-14.1-cp311-cp311-win_amd64.whl", hash = "sha256:6d24fc337fc055c9e83414c94e1ee0dee902a486d19d2a7f0929e49d7d604b09"}, + {file = "websockets-14.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ed907449fe5e021933e46a3e65d651f641975a768d0649fee59f10c2985529ed"}, + {file = "websockets-14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:87e31011b5c14a33b29f17eb48932e63e1dcd3fa31d72209848652310d3d1f0d"}, + {file = "websockets-14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bc6ccf7d54c02ae47a48ddf9414c54d48af9c01076a2e1023e3b486b6e72c707"}, + {file = "websockets-14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9777564c0a72a1d457f0848977a1cbe15cfa75fa2f67ce267441e465717dcf1a"}, + {file = "websockets-14.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a655bde548ca98f55b43711b0ceefd2a88a71af6350b0c168aa77562104f3f45"}, + {file = "websockets-14.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3dfff83ca578cada2d19e665e9c8368e1598d4e787422a460ec70e531dbdd58"}, + {file = "websockets-14.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6a6c9bcf7cdc0fd41cc7b7944447982e8acfd9f0d560ea6d6845428ed0562058"}, + {file = "websockets-14.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4b6caec8576e760f2c7dd878ba817653144d5f369200b6ddf9771d64385b84d4"}, + {file = "websockets-14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eb6d38971c800ff02e4a6afd791bbe3b923a9a57ca9aeab7314c21c84bf9ff05"}, + {file = "websockets-14.1-cp312-cp312-win32.whl", hash = "sha256:1d045cbe1358d76b24d5e20e7b1878efe578d9897a25c24e6006eef788c0fdf0"}, + {file = "websockets-14.1-cp312-cp312-win_amd64.whl", hash = "sha256:90f4c7a069c733d95c308380aae314f2cb45bd8a904fb03eb36d1a4983a4993f"}, + {file = "websockets-14.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:3630b670d5057cd9e08b9c4dab6493670e8e762a24c2c94ef312783870736ab9"}, + {file = "websockets-14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:36ebd71db3b89e1f7b1a5deaa341a654852c3518ea7a8ddfdf69cc66acc2db1b"}, + {file = "websockets-14.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5b918d288958dc3fa1c5a0b9aa3256cb2b2b84c54407f4813c45d52267600cd3"}, + {file = "websockets-14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00fe5da3f037041da1ee0cf8e308374e236883f9842c7c465aa65098b1c9af59"}, + {file = "websockets-14.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8149a0f5a72ca36720981418eeffeb5c2729ea55fa179091c81a0910a114a5d2"}, + {file = "websockets-14.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77569d19a13015e840b81550922056acabc25e3f52782625bc6843cfa034e1da"}, + {file = "websockets-14.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cf5201a04550136ef870aa60ad3d29d2a59e452a7f96b94193bee6d73b8ad9a9"}, + {file = "websockets-14.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:88cf9163ef674b5be5736a584c999e98daf3aabac6e536e43286eb74c126b9c7"}, + {file = "websockets-14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:836bef7ae338a072e9d1863502026f01b14027250a4545672673057997d5c05a"}, + {file = "websockets-14.1-cp313-cp313-win32.whl", hash = "sha256:0d4290d559d68288da9f444089fd82490c8d2744309113fc26e2da6e48b65da6"}, + {file = "websockets-14.1-cp313-cp313-win_amd64.whl", hash = "sha256:8621a07991add373c3c5c2cf89e1d277e49dc82ed72c75e3afc74bd0acc446f0"}, + {file = "websockets-14.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:01bb2d4f0a6d04538d3c5dfd27c0643269656c28045a53439cbf1c004f90897a"}, + {file = "websockets-14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:414ffe86f4d6f434a8c3b7913655a1a5383b617f9bf38720e7c0799fac3ab1c6"}, + {file = "websockets-14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8fda642151d5affdee8a430bd85496f2e2517be3a2b9d2484d633d5712b15c56"}, + {file = "websockets-14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd7c11968bc3860d5c78577f0dbc535257ccec41750675d58d8dc66aa47fe52c"}, + {file = "websockets-14.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a032855dc7db987dff813583d04f4950d14326665d7e714d584560b140ae6b8b"}, + {file = "websockets-14.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7e7ea2f782408c32d86b87a0d2c1fd8871b0399dd762364c731d86c86069a78"}, + {file = "websockets-14.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:39450e6215f7d9f6f7bc2a6da21d79374729f5d052333da4d5825af8a97e6735"}, + {file = "websockets-14.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ceada5be22fa5a5a4cdeec74e761c2ee7db287208f54c718f2df4b7e200b8d4a"}, + {file = "websockets-14.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3fc753451d471cff90b8f467a1fc0ae64031cf2d81b7b34e1811b7e2691bc4bc"}, + {file = "websockets-14.1-cp39-cp39-win32.whl", hash = "sha256:14839f54786987ccd9d03ed7f334baec0f02272e7ec4f6e9d427ff584aeea8b4"}, + {file = "websockets-14.1-cp39-cp39-win_amd64.whl", hash = "sha256:d9fd19ecc3a4d5ae82ddbfb30962cf6d874ff943e56e0c81f5169be2fda62979"}, + {file = "websockets-14.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e5dc25a9dbd1a7f61eca4b7cb04e74ae4b963d658f9e4f9aad9cd00b688692c8"}, + {file = "websockets-14.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:04a97aca96ca2acedf0d1f332c861c5a4486fdcba7bcef35873820f940c4231e"}, + {file = "websockets-14.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df174ece723b228d3e8734a6f2a6febbd413ddec39b3dc592f5a4aa0aff28098"}, + {file = "websockets-14.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:034feb9f4286476f273b9a245fb15f02c34d9586a5bc936aff108c3ba1b21beb"}, + {file = "websockets-14.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:660c308dabd2b380807ab64b62985eaccf923a78ebc572bd485375b9ca2b7dc7"}, + {file = "websockets-14.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5a42d3ecbb2db5080fc578314439b1d79eef71d323dc661aa616fb492436af5d"}, + {file = "websockets-14.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ddaa4a390af911da6f680be8be4ff5aaf31c4c834c1a9147bc21cbcbca2d4370"}, + {file = "websockets-14.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a4c805c6034206143fbabd2d259ec5e757f8b29d0a2f0bf3d2fe5d1f60147a4a"}, + {file = "websockets-14.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:205f672a6c2c671a86d33f6d47c9b35781a998728d2c7c2a3e1cf3333fcb62b7"}, + {file = "websockets-14.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5ef440054124728cc49b01c33469de06755e5a7a4e83ef61934ad95fc327fbb0"}, + {file = "websockets-14.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7591d6f440af7f73c4bd9404f3772bfee064e639d2b6cc8c94076e71b2471c1"}, + {file = "websockets-14.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:25225cc79cfebc95ba1d24cd3ab86aaa35bcd315d12fa4358939bd55e9bd74a5"}, + {file = "websockets-14.1-py3-none-any.whl", hash = "sha256:4d4fc827a20abe6d544a119896f6b78ee13fe81cbfef416f3f2ddf09a03f0e2e"}, + {file = "websockets-14.1.tar.gz", hash = "sha256:398b10c77d471c0aab20a845e7a60076b6390bfdaac7a6d2edb0d2c59d75e8d8"}, ] [[package]] @@ -3975,93 +4075,93 @@ viz = ["matplotlib", "nc-time-axis", "seaborn"] [[package]] name = "yarl" -version = "1.17.0" +version = "1.17.2" description = "Yet another URL library" optional = false python-versions = ">=3.9" files = [ - {file = "yarl-1.17.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2d8715edfe12eee6f27f32a3655f38d6c7410deb482158c0b7d4b7fad5d07628"}, - {file = "yarl-1.17.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1803bf2a7a782e02db746d8bd18f2384801bc1d108723840b25e065b116ad726"}, - {file = "yarl-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e66589110e20c2951221a938fa200c7aa134a8bdf4e4dc97e6b21539ff026d4"}, - {file = "yarl-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7069d411cfccf868e812497e0ec4acb7c7bf8d684e93caa6c872f1e6f5d1664d"}, - {file = "yarl-1.17.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cbf70ba16118db3e4b0da69dcde9d4d4095d383c32a15530564c283fa38a7c52"}, - {file = "yarl-1.17.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0bc53cc349675b32ead83339a8de79eaf13b88f2669c09d4962322bb0f064cbc"}, - {file = "yarl-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6aa18a402d1c80193ce97c8729871f17fd3e822037fbd7d9b719864018df746"}, - {file = "yarl-1.17.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d89c5bc701861cfab357aa0cd039bc905fe919997b8c312b4b0c358619c38d4d"}, - {file = "yarl-1.17.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b728bdf38ca58f2da1d583e4af4ba7d4cd1a58b31a363a3137a8159395e7ecc7"}, - {file = "yarl-1.17.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:5542e57dc15d5473da5a39fbde14684b0cc4301412ee53cbab677925e8497c11"}, - {file = "yarl-1.17.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e564b57e5009fb150cb513804d7e9e9912fee2e48835638f4f47977f88b4a39c"}, - {file = "yarl-1.17.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:eb3c4cff524b4c1c1dba3a6da905edb1dfd2baf6f55f18a58914bbb2d26b59e1"}, - {file = "yarl-1.17.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:05e13f389038842da930d439fbed63bdce3f7644902714cb68cf527c971af804"}, - {file = "yarl-1.17.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:153c38ee2b4abba136385af4467459c62d50f2a3f4bde38c7b99d43a20c143ef"}, - {file = "yarl-1.17.0-cp310-cp310-win32.whl", hash = "sha256:4065b4259d1ae6f70fd9708ffd61e1c9c27516f5b4fae273c41028afcbe3a094"}, - {file = "yarl-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:abf366391a02a8335c5c26163b5fe6f514cc1d79e74d8bf3ffab13572282368e"}, - {file = "yarl-1.17.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:19a4fe0279626c6295c5b0c8c2bb7228319d2e985883621a6e87b344062d8135"}, - {file = "yarl-1.17.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cadd0113f4db3c6b56868d6a19ca6286f5ccfa7bc08c27982cf92e5ed31b489a"}, - {file = "yarl-1.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:60d6693eef43215b1ccfb1df3f6eae8db30a9ff1e7989fb6b2a6f0b468930ee8"}, - {file = "yarl-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bb8bf3843e1fa8cf3fe77813c512818e57368afab7ebe9ef02446fe1a10b492"}, - {file = "yarl-1.17.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d2a5b35fd1d8d90443e061d0c8669ac7600eec5c14c4a51f619e9e105b136715"}, - {file = "yarl-1.17.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c5bf17b32f392df20ab5c3a69d37b26d10efaa018b4f4e5643c7520d8eee7ac7"}, - {file = "yarl-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48f51b529b958cd06e78158ff297a8bf57b4021243c179ee03695b5dbf9cb6e1"}, - {file = "yarl-1.17.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5fcaa06bf788e19f913d315d9c99a69e196a40277dc2c23741a1d08c93f4d430"}, - {file = "yarl-1.17.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:32f3ee19ff0f18a7a522d44e869e1ebc8218ad3ae4ebb7020445f59b4bbe5897"}, - {file = "yarl-1.17.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:a4fb69a81ae2ec2b609574ae35420cf5647d227e4d0475c16aa861dd24e840b0"}, - {file = "yarl-1.17.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7bacc8b77670322132a1b2522c50a1f62991e2f95591977455fd9a398b4e678d"}, - {file = "yarl-1.17.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:437bf6eb47a2d20baaf7f6739895cb049e56896a5ffdea61a4b25da781966e8b"}, - {file = "yarl-1.17.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:30534a03c87484092080e3b6e789140bd277e40f453358900ad1f0f2e61fc8ec"}, - {file = "yarl-1.17.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b30df4ff98703649915144be6f0df3b16fd4870ac38a09c56d5d9e54ff2d5f96"}, - {file = "yarl-1.17.0-cp311-cp311-win32.whl", hash = "sha256:263b487246858e874ab53e148e2a9a0de8465341b607678106829a81d81418c6"}, - {file = "yarl-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:07055a9e8b647a362e7d4810fe99d8f98421575e7d2eede32e008c89a65a17bd"}, - {file = "yarl-1.17.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:84095ab25ba69a8fa3fb4936e14df631b8a71193fe18bd38be7ecbe34d0f5512"}, - {file = "yarl-1.17.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02608fb3f6df87039212fc746017455ccc2a5fc96555ee247c45d1e9f21f1d7b"}, - {file = "yarl-1.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:13468d291fe8c12162b7cf2cdb406fe85881c53c9e03053ecb8c5d3523822cd9"}, - {file = "yarl-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8da3f8f368fb7e2f052fded06d5672260c50b5472c956a5f1bd7bf474ae504ab"}, - {file = "yarl-1.17.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec0507ab6523980bed050137007c76883d941b519aca0e26d4c1ec1f297dd646"}, - {file = "yarl-1.17.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08fc76df7fd8360e9ff30e6ccc3ee85b8dbd6ed5d3a295e6ec62bcae7601b932"}, - {file = "yarl-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d522f390686acb6bab2b917dd9ca06740c5080cd2eaa5aef8827b97e967319d"}, - {file = "yarl-1.17.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:147c527a80bb45b3dcd6e63401af8ac574125d8d120e6afe9901049286ff64ef"}, - {file = "yarl-1.17.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:24cf43bcd17a0a1f72284e47774f9c60e0bf0d2484d5851f4ddf24ded49f33c6"}, - {file = "yarl-1.17.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c28a44b9e0fba49c3857360e7ad1473fc18bc7f6659ca08ed4f4f2b9a52c75fa"}, - {file = "yarl-1.17.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:350cacb2d589bc07d230eb995d88fcc646caad50a71ed2d86df533a465a4e6e1"}, - {file = "yarl-1.17.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:fd1ab1373274dea1c6448aee420d7b38af163b5c4732057cd7ee9f5454efc8b1"}, - {file = "yarl-1.17.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4934e0f96dadc567edc76d9c08181633c89c908ab5a3b8f698560124167d9488"}, - {file = "yarl-1.17.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8d0a278170d75c88e435a1ce76557af6758bfebc338435b2eba959df2552163e"}, - {file = "yarl-1.17.0-cp312-cp312-win32.whl", hash = "sha256:61584f33196575a08785bb56db6b453682c88f009cd9c6f338a10f6737ce419f"}, - {file = "yarl-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:9987a439ad33a7712bd5bbd073f09ad10d38640425fa498ecc99d8aa064f8fc4"}, - {file = "yarl-1.17.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8deda7b8eb15a52db94c2014acdc7bdd14cb59ec4b82ac65d2ad16dc234a109e"}, - {file = "yarl-1.17.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:56294218b348dcbd3d7fce0ffd79dd0b6c356cb2a813a1181af730b7c40de9e7"}, - {file = "yarl-1.17.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1fab91292f51c884b290ebec0b309a64a5318860ccda0c4940e740425a67b6b7"}, - {file = "yarl-1.17.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cf93fa61ff4d9c7d40482ce1a2c9916ca435e34a1b8451e17f295781ccc034f"}, - {file = "yarl-1.17.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:261be774a0d71908c8830c33bacc89eef15c198433a8cc73767c10eeeb35a7d0"}, - {file = "yarl-1.17.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:deec9693b67f6af856a733b8a3e465553ef09e5e8ead792f52c25b699b8f9e6e"}, - {file = "yarl-1.17.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c804b07622ba50a765ca7fb8145512836ab65956de01307541def869e4a456c9"}, - {file = "yarl-1.17.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d013a7c9574e98c14831a8f22d27277688ec3b2741d0188ac01a910b009987a"}, - {file = "yarl-1.17.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e2cfcba719bd494c7413dcf0caafb51772dec168c7c946e094f710d6aa70494e"}, - {file = "yarl-1.17.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:c068aba9fc5b94dfae8ea1cedcbf3041cd4c64644021362ffb750f79837e881f"}, - {file = "yarl-1.17.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3616df510ffac0df3c9fa851a40b76087c6c89cbcea2de33a835fc80f9faac24"}, - {file = "yarl-1.17.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:755d6176b442fba9928a4df787591a6a3d62d4969f05c406cad83d296c5d4e05"}, - {file = "yarl-1.17.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:c18f6e708d1cf9ff5b1af026e697ac73bea9cb70ee26a2b045b112548579bed2"}, - {file = "yarl-1.17.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5b937c216b6dee8b858c6afea958de03c5ff28406257d22b55c24962a2baf6fd"}, - {file = "yarl-1.17.0-cp313-cp313-win32.whl", hash = "sha256:d0131b14cb545c1a7bd98f4565a3e9bdf25a1bd65c83fc156ee5d8a8499ec4a3"}, - {file = "yarl-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:01c96efa4313c01329e88b7e9e9e1b2fc671580270ddefdd41129fa8d0db7696"}, - {file = "yarl-1.17.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0d44f67e193f0a7acdf552ecb4d1956a3a276c68e7952471add9f93093d1c30d"}, - {file = "yarl-1.17.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:16ea0aa5f890cdcb7ae700dffa0397ed6c280840f637cd07bffcbe4b8d68b985"}, - {file = "yarl-1.17.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cf5469dc7dcfa65edf5cc3a6add9f84c5529c6b556729b098e81a09a92e60e51"}, - {file = "yarl-1.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e662bf2f6e90b73cf2095f844e2bc1fda39826472a2aa1959258c3f2a8500a2f"}, - {file = "yarl-1.17.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8260e88f1446904ba20b558fa8ce5d0ab9102747238e82343e46d056d7304d7e"}, - {file = "yarl-1.17.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dc16477a4a2c71e64c5d3d15d7ae3d3a6bb1e8b955288a9f73c60d2a391282f"}, - {file = "yarl-1.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46027e326cecd55e5950184ec9d86c803f4f6fe4ba6af9944a0e537d643cdbe0"}, - {file = "yarl-1.17.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc95e46c92a2b6f22e70afe07e34dbc03a4acd07d820204a6938798b16f4014f"}, - {file = "yarl-1.17.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:16ca76c7ac9515320cd09d6cc083d8d13d1803f6ebe212b06ea2505fd66ecff8"}, - {file = "yarl-1.17.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:eb1a5b97388f2613f9305d78a3473cdf8d80c7034e554d8199d96dcf80c62ac4"}, - {file = "yarl-1.17.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:41fd5498975418cdc34944060b8fbeec0d48b2741068077222564bea68daf5a6"}, - {file = "yarl-1.17.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:146ca582ed04a5664ad04b0e0603934281eaab5c0115a5a46cce0b3c061a56a1"}, - {file = "yarl-1.17.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:6abb8c06107dbec97481b2392dafc41aac091a5d162edf6ed7d624fe7da0587a"}, - {file = "yarl-1.17.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4d14be4613dd4f96c25feb4bd8c0d8ce0f529ab0ae555a17df5789e69d8ec0c5"}, - {file = "yarl-1.17.0-cp39-cp39-win32.whl", hash = "sha256:174d6a6cad1068f7850702aad0c7b1bca03bcac199ca6026f84531335dfc2646"}, - {file = "yarl-1.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:6af417ca2c7349b101d3fd557ad96b4cd439fdb6ab0d288e3f64a068eea394d0"}, - {file = "yarl-1.17.0-py3-none-any.whl", hash = "sha256:62dd42bb0e49423f4dd58836a04fcf09c80237836796025211bbe913f1524993"}, - {file = "yarl-1.17.0.tar.gz", hash = "sha256:d3f13583f378930377e02002b4085a3d025b00402d5a80911726d43a67911cd9"}, + {file = "yarl-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:93771146ef048b34201bfa382c2bf74c524980870bb278e6df515efaf93699ff"}, + {file = "yarl-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8281db240a1616af2f9c5f71d355057e73a1409c4648c8949901396dc0a3c151"}, + {file = "yarl-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:170ed4971bf9058582b01a8338605f4d8c849bd88834061e60e83b52d0c76870"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc61b005f6521fcc00ca0d1243559a5850b9dd1e1fe07b891410ee8fe192d0c0"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:871e1b47eec7b6df76b23c642a81db5dd6536cbef26b7e80e7c56c2fd371382e"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3a58a2f2ca7aaf22b265388d40232f453f67a6def7355a840b98c2d547bd037f"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:736bb076f7299c5c55dfef3eb9e96071a795cb08052822c2bb349b06f4cb2e0a"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8fd51299e21da709eabcd5b2dd60e39090804431292daacbee8d3dabe39a6bc0"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:358dc7ddf25e79e1cc8ee16d970c23faee84d532b873519c5036dbb858965795"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:50d866f7b1a3f16f98603e095f24c0eeba25eb508c85a2c5939c8b3870ba2df8"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:8b9c4643e7d843a0dca9cd9d610a0876e90a1b2cbc4c5ba7930a0d90baf6903f"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d63123bfd0dce5f91101e77c8a5427c3872501acece8c90df457b486bc1acd47"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:4e76381be3d8ff96a4e6c77815653063e87555981329cf8f85e5be5abf449021"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:734144cd2bd633a1516948e477ff6c835041c0536cef1d5b9a823ae29899665b"}, + {file = "yarl-1.17.2-cp310-cp310-win32.whl", hash = "sha256:26bfb6226e0c157af5da16d2d62258f1ac578d2899130a50433ffee4a5dfa673"}, + {file = "yarl-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:76499469dcc24759399accd85ec27f237d52dec300daaca46a5352fcbebb1071"}, + {file = "yarl-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:792155279dc093839e43f85ff7b9b6493a8eaa0af1f94f1f9c6e8f4de8c63500"}, + {file = "yarl-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:38bc4ed5cae853409cb193c87c86cd0bc8d3a70fd2268a9807217b9176093ac6"}, + {file = "yarl-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4a8c83f6fcdc327783bdc737e8e45b2e909b7bd108c4da1892d3bc59c04a6d84"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c6d5fed96f0646bfdf698b0a1cebf32b8aae6892d1bec0c5d2d6e2df44e1e2d"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:782ca9c58f5c491c7afa55518542b2b005caedaf4685ec814fadfcee51f02493"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff6af03cac0d1a4c3c19e5dcc4c05252411bf44ccaa2485e20d0a7c77892ab6e"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a3f47930fbbed0f6377639503848134c4aa25426b08778d641491131351c2c8"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1fa68a3c921365c5745b4bd3af6221ae1f0ea1bf04b69e94eda60e57958907f"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:187df91395c11e9f9dc69b38d12406df85aa5865f1766a47907b1cc9855b6303"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:93d1c8cc5bf5df401015c5e2a3ce75a5254a9839e5039c881365d2a9dcfc6dc2"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:11d86c6145ac5c706c53d484784cf504d7d10fa407cb73b9d20f09ff986059ef"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:c42774d1d1508ec48c3ed29e7b110e33f5e74a20957ea16197dbcce8be6b52ba"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0c8e589379ef0407b10bed16cc26e7392ef8f86961a706ade0a22309a45414d7"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1056cadd5e850a1c026f28e0704ab0a94daaa8f887ece8dfed30f88befb87bb0"}, + {file = "yarl-1.17.2-cp311-cp311-win32.whl", hash = "sha256:be4c7b1c49d9917c6e95258d3d07f43cfba2c69a6929816e77daf322aaba6628"}, + {file = "yarl-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:ac8eda86cc75859093e9ce390d423aba968f50cf0e481e6c7d7d63f90bae5c9c"}, + {file = "yarl-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:dd90238d3a77a0e07d4d6ffdebc0c21a9787c5953a508a2231b5f191455f31e9"}, + {file = "yarl-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c74f0b0472ac40b04e6d28532f55cac8090e34c3e81f118d12843e6df14d0909"}, + {file = "yarl-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4d486ddcaca8c68455aa01cf53d28d413fb41a35afc9f6594a730c9779545876"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25b7e93f5414b9a983e1a6c1820142c13e1782cc9ed354c25e933aebe97fcf2"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3a0baff7827a632204060f48dca9e63fbd6a5a0b8790c1a2adfb25dc2c9c0d50"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:460024cacfc3246cc4d9f47a7fc860e4fcea7d1dc651e1256510d8c3c9c7cde0"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5870d620b23b956f72bafed6a0ba9a62edb5f2ef78a8849b7615bd9433384171"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2941756754a10e799e5b87e2319bbec481ed0957421fba0e7b9fb1c11e40509f"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9611b83810a74a46be88847e0ea616794c406dbcb4e25405e52bff8f4bee2d0a"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:cd7e35818d2328b679a13268d9ea505c85cd773572ebb7a0da7ccbca77b6a52e"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:6b981316fcd940f085f646b822c2ff2b8b813cbd61281acad229ea3cbaabeb6b"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:688058e89f512fb7541cb85c2f149c292d3fa22f981d5a5453b40c5da49eb9e8"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:56afb44a12b0864d17b597210d63a5b88915d680f6484d8d202ed68ade38673d"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:17931dfbb84ae18b287279c1f92b76a3abcd9a49cd69b92e946035cff06bcd20"}, + {file = "yarl-1.17.2-cp312-cp312-win32.whl", hash = "sha256:ff8d95e06546c3a8c188f68040e9d0360feb67ba8498baf018918f669f7bc39b"}, + {file = "yarl-1.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:4c840cc11163d3c01a9d8aad227683c48cd3e5be5a785921bcc2a8b4b758c4f3"}, + {file = "yarl-1.17.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:3294f787a437cb5d81846de3a6697f0c35ecff37a932d73b1fe62490bef69211"}, + {file = "yarl-1.17.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f1e7fedb09c059efee2533119666ca7e1a2610072076926fa028c2ba5dfeb78c"}, + {file = "yarl-1.17.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:da9d3061e61e5ae3f753654813bc1cd1c70e02fb72cf871bd6daf78443e9e2b1"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91c012dceadc695ccf69301bfdccd1fc4472ad714fe2dd3c5ab4d2046afddf29"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f11fd61d72d93ac23718d393d2a64469af40be2116b24da0a4ca6922df26807e"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46c465ad06971abcf46dd532f77560181387b4eea59084434bdff97524444032"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef6eee1a61638d29cd7c85f7fd3ac7b22b4c0fabc8fd00a712b727a3e73b0685"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4434b739a8a101a837caeaa0137e0e38cb4ea561f39cb8960f3b1e7f4967a3fc"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:752485cbbb50c1e20908450ff4f94217acba9358ebdce0d8106510859d6eb19a"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:17791acaa0c0f89323c57da7b9a79f2174e26d5debbc8c02d84ebd80c2b7bff8"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5c6ea72fe619fee5e6b5d4040a451d45d8175f560b11b3d3e044cd24b2720526"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:db5ac3871ed76340210fe028f535392f097fb31b875354bcb69162bba2632ef4"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:7a1606ba68e311576bcb1672b2a1543417e7e0aa4c85e9e718ba6466952476c0"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9bc27dd5cfdbe3dc7f381b05e6260ca6da41931a6e582267d5ca540270afeeb2"}, + {file = "yarl-1.17.2-cp313-cp313-win32.whl", hash = "sha256:52492b87d5877ec405542f43cd3da80bdcb2d0c2fbc73236526e5f2c28e6db28"}, + {file = "yarl-1.17.2-cp313-cp313-win_amd64.whl", hash = "sha256:8e1bf59e035534ba4077f5361d8d5d9194149f9ed4f823d1ee29ef3e8964ace3"}, + {file = "yarl-1.17.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c556fbc6820b6e2cda1ca675c5fa5589cf188f8da6b33e9fc05b002e603e44fa"}, + {file = "yarl-1.17.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f2f44a4247461965fed18b2573f3a9eb5e2c3cad225201ee858726cde610daca"}, + {file = "yarl-1.17.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3a3ede8c248f36b60227eb777eac1dbc2f1022dc4d741b177c4379ca8e75571a"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2654caaf5584449d49c94a6b382b3cb4a246c090e72453493ea168b931206a4d"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0d41c684f286ce41fa05ab6af70f32d6da1b6f0457459a56cf9e393c1c0b2217"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2270d590997445a0dc29afa92e5534bfea76ba3aea026289e811bf9ed4b65a7f"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18662443c6c3707e2fc7fad184b4dc32dd428710bbe72e1bce7fe1988d4aa654"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75ac158560dec3ed72f6d604c81090ec44529cfb8169b05ae6fcb3e986b325d9"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1fee66b32e79264f428dc8da18396ad59cc48eef3c9c13844adec890cd339db5"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:585ce7cd97be8f538345de47b279b879e091c8b86d9dbc6d98a96a7ad78876a3"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:c019abc2eca67dfa4d8fb72ba924871d764ec3c92b86d5b53b405ad3d6aa56b0"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c6e659b9a24d145e271c2faf3fa6dd1fcb3e5d3f4e17273d9e0350b6ab0fe6e2"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:d17832ba39374134c10e82d137e372b5f7478c4cceeb19d02ae3e3d1daed8721"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:bc3003710e335e3f842ae3fd78efa55f11a863a89a72e9a07da214db3bf7e1f8"}, + {file = "yarl-1.17.2-cp39-cp39-win32.whl", hash = "sha256:f5ffc6b7ace5b22d9e73b2a4c7305740a339fbd55301d52735f73e21d9eb3130"}, + {file = "yarl-1.17.2-cp39-cp39-win_amd64.whl", hash = "sha256:48e424347a45568413deec6f6ee2d720de2cc0385019bedf44cd93e8638aa0ed"}, + {file = "yarl-1.17.2-py3-none-any.whl", hash = "sha256:dd7abf4f717e33b7487121faf23560b3a50924f80e4bef62b22dab441ded8f3b"}, + {file = "yarl-1.17.2.tar.gz", hash = "sha256:753eaaa0c7195244c84b5cc159dc8204b7fd99f716f11198f999f2332a86b178"}, ] [package.dependencies] @@ -4071,13 +4171,13 @@ propcache = ">=0.2.0" [[package]] name = "zipp" -version = "3.20.2" +version = "3.21.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, - {file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, + {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, + {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, ] [package.extras] @@ -4091,4 +4191,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.13" -content-hash = "bb788711115ac19d4d00540acf0a5ffc621bd2b8cce5bb80546703a87f6eb403" +content-hash = "9b76fe3f4aaf979f00ea371363124781c106ce3da04241a5c4309cbfbc0a07e7" diff --git a/pyproject.toml b/pyproject.toml index 97045bc..c9a21da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "drevalpy" -version = "1.0.9" +version = "1.0.10" description = "Drug response evaluation of cancer cell line drug response models in a fair setting" authors = ["DrEvalPy development team"] license = "GPL-3.0" @@ -8,13 +8,15 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.9,<=3.13" -numpy = ">=1.20,<2.1" +numpy = ">=1.20,<2.0" scipy = "*" -scikit-learn = ">=1.4" +scikit-learn = ">=1.4,<1.6" pandas = "*" networkx = "*" pyyaml = "*" pytorch-lightning = "*" +torch = ">=2.1,<=2.4" +torch-geometric = "*" flaky = "*" requests = "*" pingouin = "*" @@ -24,11 +26,11 @@ matplotlib = "*" importlib-resources = "*" -[tool.poetry.group.dev.dependencies] +[tool.poetry.group.development.dependencies] sphinx-autodoc-typehints = "<3.0" sphinx = ">=4.0.2" sphinx-autobuild = ">=2021.3.14" -sphinx-rtd-theme = ">=1.0.0" +sphinx-rtd-theme = ">=1.0.0,<3.0.3" sphinx-click = ">=3.0.0" pytest = "*" nox = "*" @@ -44,6 +46,9 @@ darglint = "*" pre-commit = "*" pre-commit-hooks = "*" pyupgrade = "*" +pep8-naming = "*" +types-requests = "*" +types-PyYAML = "*" [tool.black] line-length = 120 diff --git a/requirements.txt b/requirements.txt index d0d3454..e7afaf4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ aiohappyeyeballs==2.4.3 ; python_version >= "3.9" and python_full_version <= "3.13.0" -aiohttp==3.10.10 ; python_version >= "3.9" and python_full_version <= "3.13.0" +aiohttp==3.11.4 ; python_version >= "3.9" and python_full_version <= "3.13.0" aiosignal==1.3.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" -async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11" +async-timeout==5.0.1 ; python_version >= "3.9" and python_version < "3.11" attrs==24.2.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" certifi==2024.8.30 ; python_version >= "3.9" and python_full_version <= "3.13.0" charset-normalizer==3.4.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" @@ -11,7 +11,7 @@ contourpy==1.3.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" cycler==0.12.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" filelock==3.16.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" flaky==3.8.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" -fonttools==4.54.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" +fonttools==4.55.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" frozenlist==1.5.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" fsspec==2024.10.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" fsspec[http]==2024.10.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" @@ -30,55 +30,57 @@ msgpack==1.1.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" multidict==6.1.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" networkx==3.2.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" numpy==1.26.4 ; python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-cublas-cu12==12.4.5.8 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-cuda-cupti-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-cuda-runtime-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-cufft-cu12==11.2.1.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-curand-cu12==10.3.5.147 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-cusolver-cu12==11.6.1.9 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-cusparse-cu12==12.3.1.170 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-nccl-cu12==2.21.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-nvjitlink-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -nvidia-nvtx-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" -packaging==24.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-nccl-cu12==2.20.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-nvjitlink-cu12==12.6.77 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" +packaging==24.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" pandas-flavor==0.6.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" pandas==2.2.3 ; python_version >= "3.9" and python_full_version <= "3.13.0" -patsy==0.5.6 ; python_version >= "3.9" and python_full_version <= "3.13.0" +patsy==1.0.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" pillow==11.0.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" pingouin==0.5.5 ; python_version >= "3.9" and python_full_version <= "3.13.0" plotly==5.24.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" propcache==0.2.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" protobuf==5.28.3 ; python_version >= "3.9" and python_full_version <= "3.13.0" -pyarrow==18.0.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +psutil==6.1.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +pyarrow==18.0.0 ; sys_platform != "darwin" and python_version >= "3.9" and python_full_version <= "3.13.0" or platform_machine != "x86_64" and python_version >= "3.9" and python_full_version <= "3.13.0" pyparsing==3.2.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_full_version <= "3.13.0" pytorch-lightning==2.4.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" pytz==2024.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" pyyaml==6.0.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" -ray[tune]==2.38.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +ray[tune]==2.39.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" referencing==0.35.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" requests==2.32.3 ; python_version >= "3.9" and python_full_version <= "3.13.0" -rpds-py==0.20.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +rpds-py==0.21.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" scikit-learn==1.5.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" scipy==1.13.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" seaborn==0.13.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" -setuptools==75.3.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +setuptools==75.5.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" six==1.16.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" statsmodels==0.14.4 ; python_version >= "3.9" and python_full_version <= "3.13.0" -sympy==1.13.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" +sympy==1.13.3 ; python_version >= "3.9" and python_full_version <= "3.13.0" tabulate==0.9.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" tenacity==9.0.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" tensorboardx==2.6.2.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" threadpoolctl==3.5.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" -torch==2.5.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" -torchmetrics==1.5.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" -tqdm==4.66.6 ; python_version >= "3.9" and python_full_version <= "3.13.0" -triton==3.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9" +torch-geometric==2.6.1 ; python_version >= "3.9" and python_full_version <= "3.13.0" +torch==2.4.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +torchmetrics==1.6.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +tqdm==4.67.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" +triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9" typing-extensions==4.12.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" tzdata==2024.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" urllib3==2.2.3 ; python_version >= "3.9" and python_full_version <= "3.13.0" xarray==2024.7.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" -yarl==1.17.0 ; python_version >= "3.9" and python_full_version <= "3.13.0" -zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.10" +yarl==1.17.2 ; python_version >= "3.9" and python_full_version <= "3.13.0" +zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.10" diff --git a/run_suite.py b/run_suite.py index a9e2d1c..8677f3f 100644 --- a/run_suite.py +++ b/run_suite.py @@ -1,6 +1,4 @@ -""" -Main script to run the drug response evaluation pipeline. -""" +"""Main script to run the drug response evaluation pipeline.""" from drevalpy.utils import get_parser, main diff --git a/setup.cfg b/setup.cfg index 14eee71..93ebefd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,11 +9,10 @@ balanced_wrapping=true line_length=120 profile = "black" [flake8] -select = B,B9,C,D,DAR,E,F,N,RST,S,W -ignore = DAR,D100,D103,D212,D max-line-length = 120 max-complexity = 10 docstring-convention = google per-file-ignores = tests/*:S101,S301,S403 + drevalpy/datasets/toy.py:S403,S301 docstring_style = sphinx diff --git a/setup.py b/setup.py index 2846232..a69b787 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,5 @@ +"""Setup file for the drevalpy package.""" + from setuptools import find_packages, setup setup( @@ -45,7 +47,7 @@ package_data={ "": [ "models/baselines/hyperparameters.yaml", - "models/simple_neural_network/hyperparameters.yaml", + "models/SimpleNeuralNetwork/hyperparameters.yaml", "visualization/style_utils/favicon.png", "visualization/style_utils/index_layout.html", "visualization/style_utils/LCO.png", diff --git a/tests/conftest.py b/tests/conftest.py index d58eb79..47774d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,11 @@ +"""Pytest configuration file for the tests directory.""" + import os import pytest @pytest.hookimpl(tryfirst=True) -def pytest_configure(config): +def pytest_configure() -> None: + """Change to the tests directory.""" os.chdir(os.path.dirname(os.path.abspath(__file__))) diff --git a/tests/individual_models/__init__.py b/tests/individual_models/__init__.py index e69de29..58c79e8 100644 --- a/tests/individual_models/__init__.py +++ b/tests/individual_models/__init__.py @@ -0,0 +1 @@ +"""Tests for individual models.""" diff --git a/tests/individual_models/conftest.py b/tests/individual_models/conftest.py index 0896a40..ea1025e 100644 --- a/tests/individual_models/conftest.py +++ b/tests/individual_models/conftest.py @@ -1,3 +1,5 @@ +"""Sample_dataset fixture for testing individual models.""" + import pytest from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset @@ -12,13 +14,18 @@ @pytest.fixture(scope="session") def sample_dataset() -> tuple[DrugResponseDataset, FeatureDataset, FeatureDataset]: + """ + Sample dataset for testing individual models. + + :returns: drug_response, cell_line_input, drug_input + """ path_data = "../data" drug_response = load_toy(path_data) cell_line_input = get_multiomics_feature_dataset(data_path=path_data, dataset_name="Toy_Data", gene_list=None) cell_line_ids = load_cl_ids_from_csv(path=path_data, dataset_name="Toy_Data") - cell_line_input._add_features(cell_line_ids) + cell_line_input.add_features(cell_line_ids) # Load the drug features drug_ids = load_drug_ids_from_csv(data_path=path_data, dataset_name="Toy_Data") drug_input = load_drug_fingerprint_features(data_path=path_data, dataset_name="Toy_Data") - drug_input._add_features(drug_ids) + drug_input.add_features(drug_ids) return drug_response, cell_line_input, drug_input diff --git a/tests/individual_models/test_baselines.py b/tests/individual_models/test_baselines.py index 83d7319..0d38331 100644 --- a/tests/individual_models/test_baselines.py +++ b/tests/individual_models/test_baselines.py @@ -1,7 +1,12 @@ +"""Tests for the baselines in the models module.""" + +from typing import cast + import numpy as np import pytest from sklearn.linear_model import ElasticNet, Ridge +from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset from drevalpy.evaluation import evaluate, pearson from drevalpy.models import ( MODEL_FACTORY, @@ -10,9 +15,8 @@ NaivePredictor, SingleDrugRandomForest, ) - -from .conftest import sample_dataset -from .utils import call_save_and_load +from drevalpy.models.baselines.sklearn_models import SklearnModel +from drevalpy.models.drp_model import DRPModel @pytest.mark.parametrize( @@ -29,12 +33,22 @@ ], ) @pytest.mark.parametrize("test_mode", ["LPO", "LCO", "LDO"]) -def test_baselines(sample_dataset, model_name, test_mode): +def test_baselines( + sample_dataset: tuple[DrugResponseDataset, FeatureDataset, FeatureDataset], model_name: str, test_mode: str +) -> None: + """ + Test the baselines. + + :param sample_dataset: from conftest.py + :param model_name: name of the model + :param test_mode: either LPO, LCO, or LDO + """ drug_response, cell_line_input, drug_input = sample_dataset drug_response.split_dataset( n_cv_splits=5, mode=test_mode, ) + assert drug_response.cv_splits is not None split = drug_response.cv_splits[0] train_dataset = split["train"] val_dataset = split["validation"] @@ -50,9 +64,9 @@ def test_baselines(sample_dataset, model_name, test_mode): print(f"Reduced val dataset from {len_pred_before} to {len(val_dataset)}") if model_name == "NaivePredictor": - call_naive_predictor(train_dataset, val_dataset, test_mode) + _call_naive_predictor(train_dataset, val_dataset, cell_line_input, test_mode) elif model_name == "NaiveDrugMeanPredictor": - call_naive_group_predictor( + _call_naive_group_predictor( "drug", train_dataset, val_dataset, @@ -61,7 +75,7 @@ def test_baselines(sample_dataset, model_name, test_mode): test_mode, ) elif model_name == "NaiveCellLineMeanPredictor": - call_naive_group_predictor( + _call_naive_group_predictor( "cell_line", train_dataset, val_dataset, @@ -70,56 +84,86 @@ def test_baselines(sample_dataset, model_name, test_mode): test_mode, ) else: - call_other_baselines( + _call_other_baselines( model_name, train_dataset, val_dataset, cell_line_input, drug_input, - test_mode, ) @pytest.mark.parametrize("model_name", ["SingleDrugRandomForest"]) @pytest.mark.parametrize("test_mode", ["LPO", "LCO"]) -def test_single_drug_baselines(sample_dataset, model_name, test_mode): +def test_single_drug_baselines( + sample_dataset: tuple[DrugResponseDataset, FeatureDataset, FeatureDataset], model_name: str, test_mode: str +) -> None: + """ + Test the SingleDrugRandomForest model, can also test other baseline single drug models. + + :param sample_dataset: from conftest.py + :param model_name: model name + :param test_mode: either LPO or LCO + """ drug_response, cell_line_input, drug_input = sample_dataset drug_response.split_dataset( n_cv_splits=5, mode=test_mode, ) + assert drug_response.cv_splits is not None split = drug_response.cv_splits[0] train_dataset = split["train"] val_dataset = split["validation"] + all_unique_drugs = np.unique(train_dataset.drug_ids) + # randomly sample a drug to speed up testing + np.random.seed(42) + np.random.shuffle(all_unique_drugs) + random_drug = all_unique_drugs[:1] + all_predictions = np.zeros_like(val_dataset.drug_ids, dtype=float) - for drug in np.unique(train_dataset.drug_ids): - model = SingleDrugRandomForest() - hpam_combi = model.get_hyperparameter_set()[0] - model.build_model(hpam_combi) - output_mask = train_dataset.drug_ids == drug - drug_train = train_dataset.copy() - drug_train.mask(output_mask) - model.train(output=drug_train, cell_line_input=cell_line_input) - - val_mask = val_dataset.drug_ids == drug - all_predictions[val_mask] = model.predict( - drug_ids=drug, - cell_line_ids=val_dataset.cell_line_ids[val_mask], - cell_line_input=cell_line_input, - ) - pcc_drug = pearson(val_dataset.response[val_mask], all_predictions[val_mask]) - print(f"{test_mode}: Performance of {model_name} for drug {drug}: PCC = {pcc_drug}") - val_dataset.predictions = all_predictions - metrics = evaluate(val_dataset, metric=["Pearson"]) - print(f"{test_mode}: Collapsed performance of {model_name}: PCC = {metrics['Pearson']}") - assert metrics["Pearson"] > 0.0 + model = SingleDrugRandomForest() + hpam_combi = model.get_hyperparameter_set()[0] + hpam_combi["n_estimators"] = 2 # reduce test time + hpam_combi["max_depth"] = 2 # reduce test time + model.build_model(hpam_combi) + output_mask = train_dataset.drug_ids == random_drug + drug_train = train_dataset.copy() + drug_train.mask(output_mask) + model.train(output=drug_train, cell_line_input=cell_line_input) + + val_mask = val_dataset.drug_ids == random_drug + all_predictions[val_mask] = model.predict( + drug_ids=random_drug, + cell_line_ids=val_dataset.cell_line_ids[val_mask], + cell_line_input=cell_line_input, + ) + pcc_drug = pearson(val_dataset.response[val_mask], all_predictions[val_mask]) + print(f"{test_mode}: Performance of {model_name} for drug {random_drug}: PCC = {pcc_drug}") + + assert pcc_drug > 0.0 -def call_naive_predictor(train_dataset, val_dataset, test_mode): + +def _call_naive_predictor( + train_dataset: DrugResponseDataset, + val_dataset: DrugResponseDataset, + cell_line_input: FeatureDataset, + test_mode: str, +) -> None: + """ + Call the NaivePredictor model. + + :param train_dataset: training dataset + :param val_dataset: validation dataset + :param cell_line_input: features cell lines + :param test_mode: either LPO, LCO, or LDO + """ naive = NaivePredictor() - naive.train(output=train_dataset) - val_dataset.predictions = naive.predict(cell_line_ids=val_dataset.cell_line_ids) + naive.train(output=train_dataset, cell_line_input=cell_line_input, drug_input=None) + val_dataset._predictions = naive.predict( + cell_line_ids=val_dataset.cell_line_ids, drug_ids=val_dataset.drug_ids, cell_line_input=cell_line_input + ) assert val_dataset.predictions is not None train_mean = train_dataset.response.mean() assert train_mean == naive.dataset_mean @@ -127,18 +171,39 @@ def call_naive_predictor(train_dataset, val_dataset, test_mode): metrics = evaluate(val_dataset, metric=["Pearson"]) assert metrics["Pearson"] == 0.0 print(f"{test_mode}: Performance of NaivePredictor: PCC = {metrics['Pearson']}") - call_save_and_load(naive) -def assert_group_mean(train_dataset, val_dataset, group_ids, naive_means): +def _assert_group_mean( + train_dataset: DrugResponseDataset, + val_dataset: DrugResponseDataset, + group_ids: dict[str, np.ndarray], + naive_means: dict[int, float], +) -> None: + """ + Assert the group mean. + + :param train_dataset: training dataset + :param val_dataset: validation dataset + :param group_ids: group ids + :param naive_means: means + """ common_ids = np.intersect1d(group_ids["train"], group_ids["val"]) random_id = np.random.choice(common_ids) group_mean = train_dataset.response[group_ids["train"] == random_id].mean() assert group_mean == naive_means[random_id] + assert val_dataset.predictions is not None assert np.all(val_dataset.predictions[group_ids["val"] == random_id] == group_mean) -def call_naive_group_predictor(group, train_dataset, val_dataset, cell_line_input, drug_input, test_mode): +def _call_naive_group_predictor( + group: str, + train_dataset: DrugResponseDataset, + val_dataset: DrugResponseDataset, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset, + test_mode: str, +) -> None: + naive: NaiveDrugMeanPredictor | NaiveCellLineMeanPredictor if group == "drug": naive = NaiveDrugMeanPredictor() else: @@ -148,14 +213,17 @@ def call_naive_group_predictor(group, train_dataset, val_dataset, cell_line_inpu cell_line_input=cell_line_input, drug_input=drug_input, ) - val_dataset.predictions = naive.predict(cell_line_ids=val_dataset.cell_line_ids, drug_ids=val_dataset.drug_ids) + val_dataset._predictions = naive.predict( + cell_line_ids=val_dataset.cell_line_ids, drug_ids=val_dataset.drug_ids, cell_line_input=cell_line_input + ) assert val_dataset.predictions is not None train_mean = train_dataset.response.mean() assert train_mean == naive.dataset_mean if (group == "drug" and test_mode == "LDO") or (group == "cell_line" and test_mode == "LCO"): assert np.all(val_dataset.predictions == train_mean) elif group == "drug": - assert_group_mean( + assert isinstance(naive, NaiveDrugMeanPredictor) + _assert_group_mean( train_dataset, val_dataset, group_ids={ @@ -165,7 +233,8 @@ def call_naive_group_predictor(group, train_dataset, val_dataset, cell_line_inpu naive_means=naive.drug_means, ) else: # group == "cell_line" - assert_group_mean( + assert isinstance(naive, NaiveCellLineMeanPredictor) + _assert_group_mean( train_dataset, val_dataset, group_ids={ @@ -175,32 +244,53 @@ def call_naive_group_predictor(group, train_dataset, val_dataset, cell_line_inpu naive_means=naive.cell_line_means, ) metrics = evaluate(val_dataset, metric=["Pearson"]) - print(f"{test_mode}: Performance of {naive.model_name}: PCC = {metrics['Pearson']}") + print(f"{test_mode}: Performance of {naive.get_model_name()}: PCC = {metrics['Pearson']}") if (group == "drug" and test_mode == "LDO") or (group == "cell_line" and test_mode == "LCO"): assert metrics["Pearson"] == 0.0 - call_save_and_load(naive) -def call_other_baselines(model, train_dataset, val_dataset, cell_line_input, drug_input, test_mode): - model_class = MODEL_FACTORY[model] +def _call_other_baselines( + model: str, + train_dataset: DrugResponseDataset, + val_dataset: DrugResponseDataset, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset, +) -> None: + """ + Call the other baselines. + + :param model: model name + :param train_dataset: training + :param val_dataset: validation + :param cell_line_input: features cell lines + :param drug_input: features drugs + """ + model_class = cast(type[DRPModel], MODEL_FACTORY[model]) hpams = model_class.get_hyperparameter_set() - if len(hpams) > 3: - hpams = hpams[:3] + if len(hpams) > 2: + hpams = hpams[:2] model_instance = model_class() + assert isinstance(model_instance, SklearnModel) for hpam_combi in hpams: + if model == "RandomForest" or model == "GradientBoosting": + hpam_combi["n_estimators"] = 2 + hpam_combi["max_depth"] = 2 + if model == "GradientBoosting": + hpam_combi["subsample"] = 0.1 model_instance.build_model(hpam_combi) if model == "ElasticNet": if hpam_combi["l1_ratio"] == 0.0: assert issubclass(type(model_instance.model), Ridge) else: assert issubclass(type(model_instance.model), ElasticNet) - + # smaller dataset for faster testing + train_dataset.remove_rows(indices=np.array([list(range(len(train_dataset) - 1000))])) model_instance.train( output=train_dataset, cell_line_input=cell_line_input, drug_input=drug_input, ) - val_dataset.predictions = model_instance.predict( + val_dataset._predictions = model_instance.predict( drug_ids=val_dataset.drug_ids, cell_line_ids=val_dataset.cell_line_ids, drug_input=drug_input, @@ -208,6 +298,4 @@ def call_other_baselines(model, train_dataset, val_dataset, cell_line_input, dru ) assert val_dataset.predictions is not None metrics = evaluate(val_dataset, metric=["Pearson"]) - print(f"{test_mode}: Performance of {model}, hpams: {hpam_combi}: PCC = {metrics['Pearson']}") - assert metrics["Pearson"] > -0.1 - call_save_and_load(model_instance) + assert metrics["Pearson"] >= -1 diff --git a/tests/individual_models/test_literature_models.py b/tests/individual_models/test_literature_models.py new file mode 100644 index 0000000..926b98a --- /dev/null +++ b/tests/individual_models/test_literature_models.py @@ -0,0 +1,145 @@ +"""Test the MOLIR and SuperFELTR models.""" + +from typing import cast + +import numpy as np +import pytest + +from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset +from drevalpy.evaluation import evaluate, pearson +from drevalpy.models import MODEL_FACTORY +from drevalpy.models.drp_model import DRPModel + + +@pytest.mark.parametrize("test_mode", ["LCO"]) +@pytest.mark.parametrize("model_name", ["SuperFELTR", "MOLIR"]) +def test_molir_superfeltr( + sample_dataset: tuple[DrugResponseDataset, FeatureDataset, FeatureDataset], model_name: str, test_mode: str +) -> None: + """ + Test the MOLIR, SuperFELTR. + + :param sample_dataset: from conftest.py + :param model_name: model name + :param test_mode: LCO + """ + drug_response, cell_line_input, drug_input = sample_dataset + drug_response.split_dataset( + n_cv_splits=5, + mode=test_mode, + ) + assert drug_response.cv_splits is not None + split = drug_response.cv_splits[0] + train_dataset = split["train"] + all_unique_drugs = np.unique(train_dataset.drug_ids) + # randomly sample drugs to speed up testing + np.random.seed(42) + np.random.shuffle(all_unique_drugs) + random_drug = all_unique_drugs[:1] + val_es_dataset = split["validation_es"] + es_dataset = split["early_stopping"] + + cell_lines_to_keep = cell_line_input.identifiers + drugs_to_keep = drug_input.identifiers + + len_train_before = len(train_dataset) + len_pred_before = len(val_es_dataset) + len_es_before = len(es_dataset) + train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + val_es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + print(f"Reduced training dataset from {len_train_before} to {len(train_dataset)}") + print(f"Reduced val_es dataset from {len_pred_before} to {len(val_es_dataset)}") + print(f"Reduced es dataset from {len_es_before} to {len(es_dataset)}") + + all_predictions = np.zeros_like(val_es_dataset.drug_ids, dtype=float) + model_class = cast(type[DRPModel], MODEL_FACTORY[model_name]) + model = model_class() + hpam_combi = model.get_hyperparameter_set()[0] + hpam_combi["epochs"] = 1 + model.build_model(hpam_combi) + + output_mask = train_dataset.drug_ids == random_drug + drug_train = train_dataset.copy() + drug_train.mask(output_mask) + es_mask = es_dataset.drug_ids == random_drug + es_dataset_drug = es_dataset.copy() + es_dataset_drug.mask(es_mask) + # smaller dataset for faster testing + drug_train.remove_rows(indices=np.array([list(range(len(drug_train) - 100))])) + model.train( + output=drug_train, + cell_line_input=cell_line_input, + drug_input=None, + output_earlystopping=es_dataset_drug, + ) + + val_mask = val_es_dataset.drug_ids == random_drug + all_predictions[val_mask] = model.predict( + drug_ids=random_drug, + cell_line_ids=val_es_dataset.cell_line_ids[val_mask], + cell_line_input=cell_line_input, + ) + pcc_drug = pearson(val_es_dataset.response[val_mask], all_predictions[val_mask]) + assert pcc_drug >= -1 + + # subset the dataset to only the drugs that were used + val_es_mask = np.isin(val_es_dataset.drug_ids, random_drug) + val_es_dataset._cell_line_ids = val_es_dataset.cell_line_ids[val_es_mask] + val_es_dataset._drug_ids = val_es_dataset.drug_ids[val_es_mask] + val_es_dataset._response = val_es_dataset.response[val_es_mask] + val_es_dataset._predictions = all_predictions[val_es_mask] + metrics = evaluate(val_es_dataset, metric=["Pearson"]) + print(f"{test_mode}: Collapsed performance of {model_name}: PCC = {metrics['Pearson']}") + assert metrics["Pearson"] >= -1.0 + + +@pytest.mark.parametrize("test_mode", ["LCO"]) +@pytest.mark.parametrize("model_name", ["DIPK"]) +def test_dipk( + sample_dataset: tuple[DrugResponseDataset, FeatureDataset, FeatureDataset], model_name: str, test_mode: str +) -> None: + """Test the DIPK model. + + :param sample_dataset: from conftest.py + :param model_name: model name + :param test_mode: LCO + """ + drug_response, cell_line_input, drug_input = sample_dataset + drug_response.split_dataset( + n_cv_splits=5, + mode=test_mode, + ) + assert drug_response.cv_splits is not None + split = drug_response.cv_splits[0] + train_dataset = split["train"] + val_es_dataset = split["validation_es"] + model = MODEL_FACTORY[model_name]() + hpam_combi = model.get_hyperparameter_set()[0] + hpam_combi["epochs"] = 1 + hpam_combi["epochs_autoencoder"] = 1 + model.build_model(hpam_combi) + drug_input = model.load_drug_features(data_path="../data", dataset_name="Toy_Data") # type: ignore + cell_line_input = model.load_cell_line_features(data_path="../data", dataset_name="Toy_Data") + + cell_lines_to_keep = cell_line_input.identifiers + drugs_to_keep = drug_input.identifiers + + train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + val_es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + + model.train( + output=train_dataset, + cell_line_input=cell_line_input, + drug_input=drug_input, + output_earlystopping=val_es_dataset, + ) + out = model.predict( + cell_line_ids=val_es_dataset.cell_line_ids, + drug_ids=val_es_dataset.drug_ids, + cell_line_input=cell_line_input, + drug_input=drug_input, + ) + val_es_dataset._predictions = out + metrics = evaluate(val_es_dataset, metric=["Pearson"]) + assert metrics["Pearson"] >= -1.0 diff --git a/tests/individual_models/test_molir_superfeltr.py b/tests/individual_models/test_molir_superfeltr.py deleted file mode 100644 index a8422b1..0000000 --- a/tests/individual_models/test_molir_superfeltr.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np -import pytest - -from drevalpy.evaluation import evaluate, pearson -from drevalpy.models import MODEL_FACTORY - -from .conftest import sample_dataset - - -@pytest.mark.parametrize("test_mode", ["LCO"]) -@pytest.mark.parametrize("model_name", ["MOLIR", "SuperFELTR"]) -def test_molir_superfeltr(sample_dataset, model_name, test_mode): - drug_response, cell_line_input, drug_input = sample_dataset - drug_response.split_dataset( - n_cv_splits=5, - mode=test_mode, - ) - split = drug_response.cv_splits[0] - train_dataset = split["train"] - all_unique_drugs = np.unique(train_dataset.drug_ids) - # randomly sample 3 - np.random.seed(42) - np.random.shuffle(all_unique_drugs) - all_unique_drugs = all_unique_drugs[:3] - val_es_dataset = split["validation_es"] - es_dataset = split["early_stopping"] - - cell_lines_to_keep = cell_line_input.identifiers - drugs_to_keep = drug_input.identifiers - - len_train_before = len(train_dataset) - len_pred_before = len(val_es_dataset) - len_es_before = len(es_dataset) - train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) - val_es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) - es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) - print(f"Reduced training dataset from {len_train_before} to {len(train_dataset)}") - print(f"Reduced val_es dataset from {len_pred_before} to {len(val_es_dataset)}") - print(f"Reduced es dataset from {len_es_before} to {len(es_dataset)}") - - all_predictions = np.zeros_like(val_es_dataset.drug_ids, dtype=float) - for drug in all_unique_drugs: - model = MODEL_FACTORY[model_name]() - hpam_combi = model.get_hyperparameter_set()[0] - model.build_model(hpam_combi) - - output_mask = train_dataset.drug_ids == drug - drug_train = train_dataset.copy() - drug_train.mask(output_mask) - es_mask = es_dataset.drug_ids == drug - es_dataset_drug = es_dataset.copy() - es_dataset_drug.mask(es_mask) - - model.train( - output=drug_train, - cell_line_input=cell_line_input, - drug_input=None, - output_earlystopping=es_dataset_drug, - ) - - val_mask = val_es_dataset.drug_ids == drug - all_predictions[val_mask] = model.predict( - drug_ids=drug, - cell_line_ids=val_es_dataset.cell_line_ids[val_mask], - cell_line_input=cell_line_input, - ) - pcc_drug = pearson(val_es_dataset.response[val_mask], all_predictions[val_mask]) - print(f"{test_mode}: Performance of {model_name} for drug {drug}: PCC = {pcc_drug}") - # subset the dataset to only the drugs that were used - val_es_mask = np.isin(val_es_dataset.drug_ids, all_unique_drugs) - val_es_dataset.cell_line_ids = val_es_dataset.cell_line_ids[val_es_mask] - val_es_dataset.drug_ids = val_es_dataset.drug_ids[val_es_mask] - val_es_dataset.response = val_es_dataset.response[val_es_mask] - val_es_dataset.predictions = all_predictions[val_es_mask] - metrics = evaluate(val_es_dataset, metric=["Pearson"]) - print(f"{test_mode}: Collapsed performance of {model_name}: PCC = {metrics['Pearson']}") - assert metrics["Pearson"] > 0.0 diff --git a/tests/individual_models/test_simple_neural_network.py b/tests/individual_models/test_simple_neural_network.py index 411137c..0e031b3 100644 --- a/tests/individual_models/test_simple_neural_network.py +++ b/tests/individual_models/test_simple_neural_network.py @@ -1,41 +1,54 @@ +"""Test the SimpleNeuralNetwork model.""" + +from typing import cast + +import numpy as np import pytest +from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset from drevalpy.evaluation import evaluate from drevalpy.models import MODEL_FACTORY - -from .conftest import sample_dataset -from .utils import call_save_and_load +from drevalpy.models.drp_model import DRPModel @pytest.mark.parametrize("test_mode", ["LPO"]) @pytest.mark.parametrize("model_name", ["SRMF", "SimpleNeuralNetwork", "MultiOmicsNeuralNetwork"]) -def test_simple_neural_network(sample_dataset, model_name, test_mode): +def test_simple_neural_network( + sample_dataset: tuple[DrugResponseDataset, FeatureDataset, FeatureDataset], model_name: str, test_mode: str +) -> None: + """ + Test the SimpleNeuralNetwork model. + + :param sample_dataset: from conftest.py + :param model_name: either SRMF, SimpleNeuralNetwork, or MultiOmicsNeuralNetwork + :param test_mode: LPO + """ drug_response, cell_line_input, drug_input = sample_dataset drug_response.split_dataset( n_cv_splits=5, mode=test_mode, ) + assert drug_response.cv_splits is not None split = drug_response.cv_splits[0] train_dataset = split["train"] + # smaller dataset for faster testing + train_dataset.remove_rows(indices=np.array([list(range(len(train_dataset) - 1000))])) + val_es_dataset = split["validation_es"] es_dataset = split["early_stopping"] cell_lines_to_keep = cell_line_input.identifiers drugs_to_keep = drug_input.identifiers - len_train_before = len(train_dataset) - len_pred_before = len(val_es_dataset) - len_es_before = len(es_dataset) train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) val_es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) - print(f"Reduced training dataset from {len_train_before} to {len(train_dataset)}") - print(f"Reduced val_es dataset from {len_pred_before} to {len(val_es_dataset)}") - print(f"Reduced es dataset from {len_es_before} to {len(es_dataset)}") - model = MODEL_FACTORY[model_name]() + model_class = cast(type[DRPModel], MODEL_FACTORY[model_name]) + model = model_class() hpams = model.get_hyperparameter_set() hpam_combi = hpams[0] + hpam_combi["units_per_layer"] = [2, 2] model.build_model(hyperparameters=hpam_combi) model.train( output=train_dataset, @@ -44,7 +57,7 @@ def test_simple_neural_network(sample_dataset, model_name, test_mode): output_earlystopping=es_dataset, ) - val_es_dataset.predictions = model.predict( + val_es_dataset._predictions = model.predict( drug_ids=val_es_dataset.drug_ids, cell_line_ids=val_es_dataset.cell_line_ids, drug_input=drug_input, @@ -52,7 +65,4 @@ def test_simple_neural_network(sample_dataset, model_name, test_mode): ) metrics = evaluate(val_es_dataset, metric=["Pearson"]) - print(f"{test_mode}: Performance of {model}, hpams: {hpam_combi}: PCC = {metrics['Pearson']}") - assert metrics["Pearson"] > 0.0 - - call_save_and_load(model) + assert metrics["Pearson"] >= -1 diff --git a/tests/individual_models/utils.py b/tests/individual_models/utils.py deleted file mode 100644 index df0fb80..0000000 --- a/tests/individual_models/utils.py +++ /dev/null @@ -1,11 +0,0 @@ -import tempfile - -import pytest - - -def call_save_and_load(model): - tmp = tempfile.NamedTemporaryFile() - with pytest.raises(NotImplementedError): - model.save(path=tmp.name) - with pytest.raises(NotImplementedError): - model.load(path=tmp.name) diff --git a/tests/test_available_data.py b/tests/test_available_data.py index b8e904a..febef25 100644 --- a/tests/test_available_data.py +++ b/tests/test_available_data.py @@ -1,11 +1,12 @@ -import tempfile +"""Tests for the available datasets.""" -import pytest +import tempfile from drevalpy.datasets import AVAILABLE_DATASETS -def test_factory(): +def test_factory() -> None: + """Test the dataset factory.""" assert "GDSC1" in AVAILABLE_DATASETS assert "GDSC2" in AVAILABLE_DATASETS assert "CCLE" in AVAILABLE_DATASETS @@ -13,24 +14,22 @@ def test_factory(): assert len(AVAILABLE_DATASETS) == 4 -def test_gdsc1(): +def test_gdsc1() -> None: + """Test the GDSC1 dataset.""" tempdir = tempfile.TemporaryDirectory() gdsc1 = AVAILABLE_DATASETS["GDSC1"](path_data=tempdir.name) assert len(gdsc1) == 292849 def test_gdsc2(): + """Test the GDSC2 dataset.""" tempdir = tempfile.TemporaryDirectory() gdsc2 = AVAILABLE_DATASETS["GDSC2"](path_data=tempdir.name) assert len(gdsc2) == 131108 def test_ccle(): + """Test the CCLE dataset.""" tempdir = tempfile.TemporaryDirectory() ccle = AVAILABLE_DATASETS["CCLE"](path_data=tempdir.name) assert len(ccle) == 8478 - - -# Run the tests -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 21f8947..96eb38c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,3 +1,5 @@ +"""Tests for the DrugResponseDataset and the FeatureDataset class.""" + import os import tempfile @@ -12,8 +14,8 @@ # Tests for the DrugResponseDataset class -# Test if the dataset loads correctly from CSV files -def test_response_dataset_load(): +def test_response_dataset_load() -> None: + """Test if the dataset loads correctly from CSV files.""" # Create a temporary CSV file with mock data data = { "cell_line_id": np.array([1, 2, 3]), @@ -28,8 +30,7 @@ def test_response_dataset_load(): dataset.save("dataset.csv") del dataset # Load the dataset - dataset = DrugResponseDataset() - dataset.load("dataset.csv") + dataset = DrugResponseDataset.from_csv("dataset.csv") os.remove("dataset.csv") @@ -39,7 +40,8 @@ def test_response_dataset_load(): assert np.allclose(dataset.response, data["response"]) -def test_response_dataset_add_rows(): +def test_response_dataset_add_rows() -> None: + """Test if the add_rows method works correctly.""" dataset1 = DrugResponseDataset( response=np.array([1, 2, 3]), cell_line_ids=np.array([101, 102, 103]), @@ -57,7 +59,8 @@ def test_response_dataset_add_rows(): assert np.array_equal(dataset1.drug_ids, np.array(["A", "B", "C", "D", "E", "F"])) -def test_remove_nan_responses(): +def test_remove_nan_responses() -> None: + """Test if the remove_nan_responses method works correctly.""" dataset = DrugResponseDataset( response=np.array([1, 2, 3, np.nan, 5, 6]), cell_line_ids=np.array([101, 102, 103, 104, 105, 106]), @@ -70,6 +73,7 @@ def test_remove_nan_responses(): def test_response_dataset_shuffle(): + """Test if the shuffle method works correctly.""" # Create a dataset with known values dataset = DrugResponseDataset( response=np.array([1, 2, 3, 4, 5, 6]), @@ -92,6 +96,7 @@ def test_response_dataset_shuffle(): def test_response_data_remove_drugs_and_cell_lines(): + """Test if the remove_drugs and remove_cell_lines methods work correctly.""" # Create a dataset with known values dataset = DrugResponseDataset( response=np.array([1, 2, 3, 4, 5]), @@ -100,8 +105,8 @@ def test_response_data_remove_drugs_and_cell_lines(): ) # Remove specific drugs and cell lines - dataset.remove_drugs(["A", "C"]) - dataset.remove_cell_lines([101, 103]) + dataset._remove_drugs(["A", "C"]) + dataset._remove_cell_lines([101, 103]) # Check if the removed drugs and cell lines are not present in the dataset assert "A" not in dataset.drug_ids @@ -116,18 +121,20 @@ def test_response_data_remove_drugs_and_cell_lines(): def test_remove_rows(): + """Test if the remove_rows method works correctly.""" dataset = DrugResponseDataset( response=np.array([1, 2, 3, 4, 5]), cell_line_ids=np.array([101, 102, 103, 104, 105]), drug_ids=np.array(["A", "B", "C", "D", "E"]), ) - dataset.remove_rows([0, 2, 4]) + dataset.remove_rows(np.array([0, 2, 4])) assert np.array_equal(dataset.response, np.array([2, 4])) assert np.array_equal(dataset.cell_line_ids, np.array([102, 104])) assert np.array_equal(dataset.drug_ids, np.array(["B", "D"])) def test_response_dataset_reduce_to(): + """Test if the reduce_to method works correctly.""" # Create a dataset with known values dataset = DrugResponseDataset( response=np.array([1, 2, 3, 4, 5]), @@ -136,7 +143,7 @@ def test_response_dataset_reduce_to(): ) # Reduce the dataset to a subset of cell line IDs and drug IDs - dataset.reduce_to(cell_line_ids=[102, 104], drug_ids=["B", "D"]) + dataset.reduce_to(cell_line_ids=np.array([102, 104]), drug_ids=np.array(["B", "D"])) # Check if only the rows corresponding to the specified cell line IDs and drug IDs remain assert all(cell_line_id in [102, 104] for cell_line_id in dataset.cell_line_ids) @@ -150,7 +157,13 @@ def test_response_dataset_reduce_to(): @pytest.mark.parametrize("mode", ["LPO", "LCO", "LDO"]) @pytest.mark.parametrize("split_validation", [True, False]) -def test_split_response_dataset(mode, split_validation): +def test_split_response_dataset(mode: str, split_validation: bool) -> None: + """ + Test if the split_dataset method works correctly. + + :param mode: setting, either LPO, LCO, or LDO + :param split_validation: whether to split the dataset into validation and early stopping sets + """ # Create a dataset with known values dataset = DrugResponseDataset( response=np.random.random(100), @@ -240,7 +253,13 @@ def test_split_response_dataset(mode, split_validation): @pytest.mark.parametrize("resp_transform", ["standard", "minmax", "robust"]) -def test_transform(resp_transform): +def test_transform(resp_transform: str): + """ + Test if the fit_transform and inverse_transform methods work correctly. + + :param resp_transform: response transformation method + :raises ValueError: if an invalid response transformation method is provided + """ from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler dataset = DrugResponseDataset( @@ -256,6 +275,8 @@ def test_transform(resp_transform): scaler = MinMaxScaler() elif resp_transform == "robust": scaler = RobustScaler() + else: + raise ValueError("Invalid response transformation method.") vals = scaler.fit_transform(np.array([1, 2, 3, 4, 5]).reshape(-1, 1)) assert np.allclose(dataset.response, vals.flatten()) @@ -267,7 +288,12 @@ def test_transform(resp_transform): @pytest.fixture -def sample_dataset(): +def sample_dataset() -> FeatureDataset: + """ + Create a sample FeatureDataset for testing. + + :returns: a sample FeatureDataset + """ features = { "drug1": { "fingerprints": np.random.rand(5), @@ -303,7 +329,13 @@ def sample_dataset(): return FeatureDataset(features=features, meta_info=meta_info) -def random_power_law_graph(size=20): +def random_power_law_graph(size: int = 20) -> nx.Graph: + """ + Create a random graph with power law degree distribution. + + :param size: size of the graph + :returns: a random graph with power law degree distribution + """ # make a graph with degrees distributed as a power law graph = nx.Graph() degrees = np.round(nx.utils.powerlaw_sequence(size, 2.5)) @@ -319,7 +351,12 @@ def random_power_law_graph(size=20): @pytest.fixture -def graph_dataset(): +def graph_dataset() -> FeatureDataset: + """ + Create a sample FeatureDataset with molecular graphs for testing. + + :returns: a sample FeatureDataset with molecular graphs + """ features = { "drug1": { "molecular_graph": random_power_law_graph(), @@ -343,19 +380,34 @@ def graph_dataset(): return FeatureDataset(features=features, meta_info=meta_info) -def test_feature_dataset_get_ids(sample_dataset): - assert np.all(sample_dataset.get_ids() == ["drug1", "drug2", "drug3", "drug4", "drug5"]) +def test_feature_dataset_get_ids(sample_dataset: FeatureDataset) -> None: + """ + Test if the get_ids method works correctly. + + :param sample_dataset: sample FeatureDataset + """ + assert np.all(sample_dataset.identifiers == ["drug1", "drug2", "drug3", "drug4", "drug5"]) -def test_feature_dataset_get_view_names(sample_dataset): - assert sample_dataset.get_view_names() == [ +def test_feature_dataset_get_view_names(sample_dataset: FeatureDataset) -> None: + """ + Test if the get_view_names method works correctly. + + :param sample_dataset: sample FeatureDataset + """ + assert sample_dataset.view_names == [ "fingerprints", "chemical_features", ] -def test_feature_dataset_get_feature_matrix(sample_dataset): - feature_matrix = sample_dataset.get_feature_matrix("fingerprints", ["drug1", "drug2"]) +def test_feature_dataset_get_feature_matrix(sample_dataset: FeatureDataset) -> None: + """ + Test if the get_feature_matrix method works correctly. + + :param sample_dataset: sample FeatureDataset + """ + feature_matrix = sample_dataset.get_feature_matrix("fingerprints", np.array(["drug1", "drug2"])) assert feature_matrix.shape == (2, 5) assert np.allclose( feature_matrix, @@ -369,7 +421,12 @@ def test_feature_dataset_get_feature_matrix(sample_dataset): assert isinstance(feature_matrix, np.ndarray) -def test_feature_dataset_copy(sample_dataset): +def test_feature_dataset_copy(sample_dataset: FeatureDataset) -> None: + """ + Test if the copy method works correctly. + + :param sample_dataset: sample FeatureDataset + """ copied_dataset = sample_dataset.copy() assert copied_dataset.features["drug1"]["fingerprints"] is not sample_dataset.features["drug1"]["fingerprints"] assert np.allclose( @@ -385,7 +442,12 @@ def test_feature_dataset_copy(sample_dataset): @flaky(max_runs=25) # permutation randomization might map to the same feature vector for some tries -def test_permutation_randomization(sample_dataset): +def test_permutation_randomization(sample_dataset: FeatureDataset) -> None: + """ + Test if the permutation randomization works correctly. + + :param sample_dataset: sample FeatureDataset + """ views_to_randomize, randomization_type = "fingerprints", "permutation" start_sample_dataset = sample_dataset.copy() sample_dataset.randomize_features(views_to_randomize, randomization_type) @@ -397,7 +459,12 @@ def test_permutation_randomization(sample_dataset): @flaky(max_runs=25) # permutation randomization might map to the same feature vector for some tries -def test_permutation_randomization_graph(graph_dataset): +def test_permutation_randomization_graph(graph_dataset: FeatureDataset) -> None: + """ + Test if the permutation randomization works correctly for molecular graphs. + + :param graph_dataset: sample FeatureDataset with molecular graphs + """ views_to_randomize, randomization_type = "molecular_graph", "permutation" start_graph_dataset = graph_dataset.copy() graph_dataset.randomize_features(views_to_randomize, randomization_type) @@ -409,7 +476,12 @@ def test_permutation_randomization_graph(graph_dataset): ) -def test_invariant_randomization_array(sample_dataset): +def test_invariant_randomization_array(sample_dataset: FeatureDataset) -> None: + """ + Test if the invariant randomization works correctly. + + :param sample_dataset: sample FeatureDataset + """ views_to_randomize, randomization_type = "chemical_features", "invariant" start_sample_dataset = sample_dataset.copy() sample_dataset.randomize_features(views_to_randomize, randomization_type) @@ -421,7 +493,12 @@ def test_invariant_randomization_array(sample_dataset): @flaky(max_runs=5) # expected degree randomization might produce the same graph -def test_invariant_randomization_graph(graph_dataset): +def test_invariant_randomization_graph(graph_dataset: FeatureDataset) -> None: + """ + Test if the invariant randomization works correctly for molecular graphs. + + :param graph_dataset: sample FeatureDataset with molecular graphs + """ views_to_randomize, randomization_type = "molecular_graph", "invariant" start_graph_dataset = graph_dataset.copy() graph_dataset.randomize_features(views_to_randomize, randomization_type) @@ -432,21 +509,28 @@ def test_invariant_randomization_graph(graph_dataset): ) -def test_feature_dataset_save_and_load(sample_dataset): +def test_feature_dataset_save_and_load(sample_dataset: FeatureDataset) -> None: + """ + Test if the save and load methods work correctly. + + :param sample_dataset: sample FeatureDataset + """ tmp = tempfile.NamedTemporaryFile() with pytest.raises(NotImplementedError): sample_dataset.save(path=tmp.name) with pytest.raises(NotImplementedError): - sample_dataset.load(path=tmp.name) - + FeatureDataset.from_csv(tmp.name) -def test_add_features(sample_dataset, graph_dataset): - sample_dataset._add_features(graph_dataset) - assert "molecular_graph" in sample_dataset.meta_info - assert "molecular_graph" in sample_dataset.get_view_names() +def test_add_features(sample_dataset: FeatureDataset, graph_dataset: FeatureDataset) -> None: + """ + Test if the add_features method works correctly. -# Run the tests -if __name__ == "__main__": - pytest.main([__file__]) + :param sample_dataset: sample FeatureDataset + :param graph_dataset: sample FeatureDataset with molecular graphs + """ + sample_dataset.add_features(graph_dataset) + assert sample_dataset.meta_info is not None + assert "molecular_graph" in sample_dataset.meta_info + assert "molecular_graph" in sample_dataset.view_names diff --git a/tests/test_drp_model.py b/tests/test_drp_model.py index 7231442..97a0274 100644 --- a/tests/test_drp_model.py +++ b/tests/test_drp_model.py @@ -1,5 +1,8 @@ +"""Tests for the DRPModel.""" + import os import tempfile +from typing import Optional import numpy as np import pandas as pd @@ -17,7 +20,8 @@ ) -def test_factory(): +def test_factory() -> None: + """Test the model factory.""" assert "NaivePredictor" in MODEL_FACTORY assert "NaiveDrugMeanPredictor" in MODEL_FACTORY assert "NaiveCellLineMeanPredictor" in MODEL_FACTORY @@ -32,23 +36,33 @@ def test_factory(): assert "GradientBoosting" in MODEL_FACTORY assert "MOLIR" in MODEL_FACTORY assert "SuperFELTR" in MODEL_FACTORY - assert len(MODEL_FACTORY) == 14 + assert "DIPK" in MODEL_FACTORY + assert len(MODEL_FACTORY) == 15 -def test_load_cl_ids_from_csv(): +def test_load_cl_ids_from_csv() -> None: + """Test the loading of cell line identifiers from a CSV file.""" temp = tempfile.TemporaryDirectory() os.mkdir(os.path.join(temp.name, "GDSC1_small")) temp_file = os.path.join(temp.name, "GDSC1_small", "cell_line_names.csv") with open(temp_file, "w") as f: - f.write("cellosaurus_id,CELL_LINE_NAME\nCVCL_X481,201T\nCVCL_1045,22Rv1\n" - "CVCL_1046,23132/87\nCVCL_1798,42-MG-BA\n") + f.write( + "cellosaurus_id,CELL_LINE_NAME\nCVCL_X481,201T\nCVCL_1045,22Rv1\n" + "CVCL_1046,23132/87\nCVCL_1798,42-MG-BA\n" + ) cl_ids_gdsc1 = load_cl_ids_from_csv(temp.name, "GDSC1_small") assert len(cl_ids_gdsc1.features) == 4 assert cl_ids_gdsc1.identifiers[0] == "201T" -def write_gene_list(temp_dir, gene_list): +def _write_gene_list(temp_dir: tempfile.TemporaryDirectory, gene_list: Optional[str] = None) -> None: + """ + Write a gene list to a temporary directory. + + :param temp_dir: temporary directory + :param gene_list: either None, landmark_genes, drug_target_genes_all_drugs, or gene_list_paccmann_network_prop + """ os.mkdir(os.path.join(temp_dir.name, "GDSC1_small", "gene_lists")) temp_file = os.path.join(temp_dir.name, "GDSC1_small", "gene_lists", f"{gene_list}.csv") if gene_list == "landmark_genes": @@ -78,7 +92,12 @@ def write_gene_list(temp_dir, gene_list): "gene_list_paccmann_network_prop", ], ) -def test_load_and_reduce_gene_features(gene_list): +def test_load_and_reduce_gene_features(gene_list: Optional[str]) -> None: + """ + Test the loading and reduction of gene features. + + :param gene_list: either None, landmark_genes, drug_target_genes_all_drugs, or gene_list_paccmann_network_prop + """ temp = tempfile.TemporaryDirectory() os.mkdir(os.path.join(temp.name, "GDSC1_small")) temp_file = os.path.join(temp.name, "GDSC1_small", "gene_expression.csv") @@ -97,7 +116,7 @@ def test_load_and_reduce_gene_features(gene_list): "3.54519297942073,3.9337949618623704,2.8629939819029904\n" ) if gene_list is not None: - write_gene_list(temp, gene_list) + _write_gene_list(temp, gene_list) if gene_list == "gene_list_paccmann_network_prop": with pytest.raises(ValueError) as valerr: @@ -106,17 +125,20 @@ def test_load_and_reduce_gene_features(gene_list): gene_features_gdsc1 = load_and_reduce_gene_features("gene_expression", gene_list, temp.name, "GDSC1_small") if gene_list is None: assert len(gene_features_gdsc1.features) == 5 + assert gene_features_gdsc1.meta_info is not None assert len(gene_features_gdsc1.meta_info["gene_expression"]) == 7 gene_names = ["TSPAN6", "TNMD", "BRCA1", "SCYL3", "HDAC1", "INSIG1", "FOXO3"] assert np.all(gene_features_gdsc1.meta_info["gene_expression"] == gene_names) elif gene_list == "landmark_genes": assert len(gene_features_gdsc1.features) == 5 + assert gene_features_gdsc1.meta_info is not None assert len(gene_features_gdsc1.meta_info["gene_expression"]) == 4 colnames = gene_features_gdsc1.meta_info["gene_expression"] colnames.sort() assert np.all(colnames == ["BRCA1", "FOXO3", "INSIG1", "SCYL3"]) elif gene_list == "drug_target_genes_all_drugs": assert len(gene_features_gdsc1.features) == 5 + assert gene_features_gdsc1.meta_info is not None assert len(gene_features_gdsc1.meta_info["gene_expression"]) == 3 colnames = gene_features_gdsc1.meta_info["gene_expression"] colnames.sort() @@ -125,7 +147,8 @@ def test_load_and_reduce_gene_features(gene_list): assert "The following genes are missing from the dataset GDSC1_small" in str(valerr.value) -def test_iterate_features(): +def test_iterate_features() -> None: + """Test the iteration over features.""" df = pd.DataFrame({"GeneA": [1, 2, 3, 2], "GeneB": [4, 5, 6, 2], "GeneC": [7, 8, 9, 2]}) df.index = ["CellLine1", "CellLine2", "CellLine3", "CellLine1"] with pytest.warns(UserWarning): @@ -134,7 +157,8 @@ def test_iterate_features(): assert np.all(features["CellLine1"]["gene_expression"] == [1, 4, 7]) -def test_load_drug_ids_from_csv(): +def test_load_drug_ids_from_csv() -> None: + """Test the loading of drug identifiers from a CSV file.""" temp = tempfile.TemporaryDirectory() os.mkdir(os.path.join(temp.name, "GDSC1_small")) temp_file = os.path.join(temp.name, "GDSC1_small", "drug_names.csv") @@ -145,7 +169,8 @@ def test_load_drug_ids_from_csv(): assert drug_ids_gdsc1.identifiers[0] == "(5Z)-7-Oxozeaenol" -def test_load_drugs_from_fingerprints(): +def test_load_drugs_from_fingerprints() -> None: + """Test the loading of drugs from fingerprints.""" temp = tempfile.TemporaryDirectory() os.mkdir(os.path.join(temp.name, "GDSC1_small")) os.mkdir(os.path.join(temp.name, "GDSC1_small", "drug_fingerprints")) @@ -185,7 +210,12 @@ def test_load_drugs_from_fingerprints(): "gene_list_paccmann_network_prop", ], ) -def test_get_multiomics_feature_dataset(gene_list): +def test_get_multiomics_feature_dataset(gene_list: Optional[str]) -> None: + """ + Test the loading of multiomics features. + + :param gene_list: list of genes to keep + """ temp = tempfile.TemporaryDirectory() os.mkdir(os.path.join(temp.name, "GDSC1_small")) # gene expression @@ -240,7 +270,7 @@ def test_get_multiomics_feature_dataset(gene_list): "CVCL_1045,22Rv1,1.0,1.0,-1.0,1.0,1.0,1.0,1.0\n" ) if gene_list is not None: - write_gene_list(temp, gene_list) + _write_gene_list(temp, gene_list) if gene_list == "gene_list_paccmann_network_prop": with pytest.raises(ValueError) as valerr: dataset = get_multiomics_feature_dataset( @@ -255,48 +285,48 @@ def test_get_multiomics_feature_dataset(gene_list): gene_list=gene_list, ) assert len(dataset.features) == 2 - common_cls = dataset.get_ids() + common_cls = dataset.identifiers common_cls.sort() assert np.all(common_cls == ["22Rv1", "CAL-120"]) + assert dataset.meta_info is not None assert len(dataset.meta_info) == 4 if gene_list is None: + assert dataset.meta_info is not None assert np.all( dataset.meta_info["gene_expression"] == ["TSPAN6", "TNMD", "BRCA1", "SCYL3", "HDAC1", "INSIG1", "FOXO3"] ) for key in dataset.meta_info: assert len(dataset.meta_info[key]) == 7 - elif gene_list == "landmark_genes": - feature_names = [] - for key in dataset.meta_info: - if key == "methylation": - assert len(dataset.meta_info[key]) == 7 - else: - assert len(dataset.meta_info[key]) == 4 - if len(feature_names) == 0: - feature_names = dataset.meta_info[key] + else: + feature_names: list[str] = [] + if gene_list == "landmark_genes": + assert dataset.meta_info is not None + for key in dataset.meta_info: + if key == "methylation": + assert len(dataset.meta_info[key]) == 7 else: - assert np.all(dataset.meta_info[key] == feature_names) - elif gene_list == "drug_target_genes_all_drugs": - feature_names = [] - for key in dataset.meta_info: - if key == "methylation": - assert len(dataset.meta_info[key]) == 7 - else: - assert len(dataset.meta_info[key]) == 3 - if len(feature_names) == 0: - feature_names = dataset.meta_info[key] + assert len(dataset.meta_info[key]) == 4 + if len(feature_names) == 0: + feature_names = dataset.meta_info[key] + else: + assert np.all(dataset.meta_info[key] == feature_names) + elif gene_list == "drug_target_genes_all_drugs": + assert dataset.meta_info is not None + for key in dataset.meta_info: + if key == "methylation": + assert len(dataset.meta_info[key]) == 7 else: - assert np.all(dataset.meta_info[key] == feature_names) - elif gene_list == "gene_list_paccmann_network_prop": - assert "The following genes are missing from the dataset GDSC1_small" in str(valerr.value) + assert len(dataset.meta_info[key]) == 3 + if len(feature_names) == 0: + feature_names = dataset.meta_info[key] + else: + assert np.all(dataset.meta_info[key] == feature_names) + elif gene_list == "gene_list_paccmann_network_prop": + assert "The following genes are missing from the dataset GDSC1_small" in str(valerr.value) -def test_unique(): +def test_unique() -> None: + """Test the unique function.""" array = np.array([1, 9, 3, 2, 1, 4, 5, 6, 7, 8, 9, 2, 1, 2, 3, 4, 5, 6, 7, 8, 9]) unique_array = unique(array) assert np.all(unique_array == np.array([1, 9, 3, 2, 4, 5, 6, 7, 8])) - - -# Run the tests -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 452dfcf..2fcbba7 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -1,3 +1,5 @@ +"""Tests for evaluation.py.""" + import numpy as np import pandas as pd import pytest @@ -8,7 +10,8 @@ from drevalpy.evaluation import evaluate, kendall, partial_correlation, pearson, spearman -def test_evaluate(): +def test_evaluate() -> None: + """Test the evaluate function.""" # Create mock dataset predictions = np.array([1, 2, 3, 4, 5]) response = np.array([1.1, 2.2, 3.3, 4.4, 5.5]) @@ -37,7 +40,12 @@ def test_evaluate(): # Mock dataset generation function @pytest.fixture -def generate_mock_data_drug_mean(): +def generate_mock_data_drug_mean() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Generate mock data with a mean response per drug. + + :returns: response, cell_line_ids, drug_ids + """ response_list = [] drug_ids = [] cell_line_ids = [] @@ -52,34 +60,59 @@ def generate_mock_data_drug_mean(): @pytest.fixture -def generate_mock_data_constant_prediction(): +def generate_mock_data_constant_prediction() -> tuple[np.ndarray, np.ndarray]: + """ + Generate mock data with constant prediction. + + :returns: y_pred, response + """ response = np.arange(2e6) y_pred = np.ones_like(response, dtype=float) return y_pred, response @pytest.fixture -def generate_mock_anticorrelated_data(): +def generate_mock_anticorrelated_data() -> tuple[np.ndarray, np.ndarray]: + """ + Generate mock data with anticorrelated prediction. + + :returns: y_pred, response + """ response = np.arange(2e6, 0, -1) y_pred = response[::-1] return y_pred, response @pytest.fixture -def generate_mock_uncorrelated_data(): +def generate_mock_uncorrelated_data() -> tuple[np.ndarray, np.ndarray]: + """ + Generate mock data with uncorrelated prediction. + + :returns: y_pred, response + """ response = np.arange(2e6) y_pred = np.random.permutation(response) return y_pred, response @pytest.fixture -def generate_mock_correlated_data(): +def generate_mock_correlated_data() -> tuple[np.ndarray, np.ndarray]: + """ + Generate mock data with correlated prediction. + + :returns: y_pred, response + """ response = np.arange(2e6) y_pred = response return y_pred, response -def test_partial_correlation(generate_mock_data_drug_mean): +def test_partial_correlation(generate_mock_data_drug_mean: tuple[np.ndarray, np.ndarray, np.ndarray]) -> None: + """ + Test the partial correlation function. + + :param generate_mock_data_drug_mean: mock data generator + """ response, cell_line_ids, drug_ids = generate_mock_data_drug_mean df = pd.DataFrame( @@ -102,84 +135,130 @@ def test_partial_correlation(generate_mock_data_drug_mean): assert np.isclose(pc, 0.0, atol=0.1) -def test_pearson_correlated(generate_mock_correlated_data): +def test_pearson_correlated(generate_mock_correlated_data: tuple[np.ndarray, np.ndarray]) -> None: + """ + Test the pearson correlation function. + + :param generate_mock_correlated_data: mock data generator + """ y_pred, response = generate_mock_correlated_data pc = pearson(y_pred, response) - assert bool(np.isclose(pc, 1.0, atol=1e-3)) + assert np.isclose(pc, 1.0, atol=1e-3) -def test_pearson_anticorrelated(generate_mock_anticorrelated_data): +def test_pearson_anticorrelated(generate_mock_anticorrelated_data: tuple[np.ndarray, np.ndarray]) -> None: + """ + Test the pearson correlation function. + + :param generate_mock_anticorrelated_data: mock data generator + """ y_pred, response = generate_mock_anticorrelated_data pc = pearson(y_pred, response) - assert bool(np.isclose(pc, -1.0, atol=1e-1)) + assert np.isclose(pc, -1.0, atol=1e-1) @flaky(max_runs=3) -def test_pearson_uncorrelated(generate_mock_uncorrelated_data): +def test_pearson_uncorrelated(generate_mock_uncorrelated_data: tuple[np.ndarray, np.ndarray]) -> None: + """ + Test the pearson correlation function. + + :param generate_mock_uncorrelated_data: mock data generator + """ y_pred, response = generate_mock_uncorrelated_data pc = pearson(y_pred, response) - assert bool(np.isclose(pc, 0.0, atol=1e-3)) + assert np.isclose(pc, 0.0, atol=1e-3) -def test_spearman_correlated(generate_mock_correlated_data): +def test_spearman_correlated(generate_mock_correlated_data: tuple[np.ndarray, np.ndarray]) -> None: + """ + Test the spearman correlation function. + + :param generate_mock_correlated_data: mock data generator + """ y_pred, response = generate_mock_correlated_data sp = spearman(y_pred, response) - assert bool(np.isclose(sp, 1.0, atol=1e-3)) + assert np.isclose(sp, 1.0, atol=1e-3) + +def test_spearman_anticorrelated(generate_mock_anticorrelated_data: tuple[np.ndarray, np.ndarray]) -> None: + """ + Test the spearman correlation function. -def test_spearman_anticorrelated(generate_mock_anticorrelated_data): + :param generate_mock_anticorrelated_data: mock data generator + """ y_pred, response = generate_mock_anticorrelated_data sp = spearman(y_pred, response) - assert bool(np.isclose(sp, -1.0, atol=1e-1)) + assert np.isclose(sp, -1.0, atol=1e-1) @flaky(max_runs=3) -def test_spearman_uncorrelated(generate_mock_uncorrelated_data): +def test_spearman_uncorrelated(generate_mock_uncorrelated_data: tuple[np.ndarray, np.ndarray]) -> None: + """ + Test the spearman correlation function. + + :param generate_mock_uncorrelated_data: mock data generator + """ y_pred, response = generate_mock_uncorrelated_data sp = spearman(y_pred, response) print(sp) - assert bool(np.isclose(sp, 0.0, atol=1e-3)) + assert np.isclose(sp, 0.0, atol=1e-3) -def test_kendall_correlated(generate_mock_correlated_data): +def test_kendall_correlated(generate_mock_correlated_data: tuple[np.ndarray, np.ndarray]) -> None: + """ + Test the kendall correlation function. + + :param generate_mock_correlated_data: mock data generator + """ y_pred, response = generate_mock_correlated_data kd = kendall(y_pred, response) - assert bool(np.isclose(kd, 1.0, atol=1e-3)) + assert np.isclose(kd, 1.0, atol=1e-3) + +def test_kendall_anticorrelated(generate_mock_anticorrelated_data: tuple[np.ndarray, np.ndarray]) -> None: + """ + Test the kendall correlation function. -def test_kendall_anticorrelated(generate_mock_anticorrelated_data): + :param generate_mock_anticorrelated_data: mock data generator + """ y_pred, response = generate_mock_anticorrelated_data kd = kendall(y_pred, response) - assert bool(np.isclose(kd, -1.0, atol=1e-1)) + assert np.isclose(kd, -1.0, atol=1e-1) @flaky(max_runs=3) def test_kendall_uncorrelated(generate_mock_uncorrelated_data): + """ + Test the kendall correlation function. + + :param generate_mock_uncorrelated_data: mock data generator + """ y_pred, response = generate_mock_uncorrelated_data kd = kendall(y_pred, response) - assert bool(np.isclose(kd, 0.0, atol=1e-3)) + assert np.isclose(kd, 0.0, atol=1e-3) def test_correlations_constant_prediction( - generate_mock_data_constant_prediction, -): + generate_mock_data_constant_prediction: tuple[np.ndarray, np.ndarray] +) -> None: + """ + Test the correlation functions with constant prediction. + + :param generate_mock_data_constant_prediction: mock data generator + """ y_pred, response = generate_mock_data_constant_prediction pc = pearson(y_pred, response) sp = spearman(y_pred, response) kd = kendall(y_pred, response) - assert bool(np.isclose(pc, 0.0, atol=1e-3)) - assert bool(np.isclose(sp, 0.0, atol=1e-3)) - assert bool(np.isclose(kd, 0.0, atol=1e-3)) - - -if __name__ == "__main__": - pytest.main([__file__]) + assert np.isclose(pc, 0.0, atol=1e-3) + assert np.isclose(sp, 0.0, atol=1e-3) + assert np.isclose(kd, 0.0, atol=1e-3) diff --git a/tests/test_run_suite.py b/tests/test_run_suite.py index e7da441..f309ced 100644 --- a/tests/test_run_suite.py +++ b/tests/test_run_suite.py @@ -7,6 +7,7 @@ import pytest from drevalpy.utils import main +from drevalpy.visualization.utils import parse_results, prep_results @pytest.mark.parametrize( @@ -36,14 +37,14 @@ def test_run_suite(args): """ Tests run_suite.py, i.e., all functionality of the main package. - :param args: TODO + :param args: arguments for the main function """ temp_dir = tempfile.TemporaryDirectory() args["path_out"] = temp_dir.name args = Namespace(**args) main(args) assert os.listdir(temp_dir.name) == ["test_run"] - """ + ( evaluation_results, evaluation_results_per_drug, @@ -62,7 +63,6 @@ def test_run_suite(args): evaluation_results_per_cell_line, true_vs_pred, ) - # TODO: needs fixing assert len(evaluation_results.columns) == 22 assert len(evaluation_results_per_drug.columns) == 15 assert len(evaluation_results_per_cell_line.columns) == 15 @@ -84,4 +84,3 @@ def test_run_suite(args): assert all(test_mode in evaluation_results.LPO_LCO_LDO.unique() for test_mode in args.test_mode) assert evaluation_results.CV_split.astype(int).max() == (args.n_cv_splits - 1) assert evaluation_results.Pearson.astype(float).max() > 0.5 - """