Skip to content

Commit

Permalink
Catalog item flags & OCI spot instances (#118)
Browse files Browse the repository at this point in the history
- Introduce the concept of catalog item flags so
  that items that break compatibility are only
  returned when their flags are queried
  explicitly.
- Add OCI spot instances as an example of catalog
  items with flags.
- Publish new catalogs under the /v2 prefix.
- Convert /v2 catalogs to /v1 catalogs for legacy
  users. /v1 catalogs will not have any items with
  flags, as they can potentially break
  compatibility.
  • Loading branch information
jvstme authored Feb 21, 2025
1 parent a1340d6 commit 279d660
Show file tree
Hide file tree
Showing 16 changed files with 267 additions and 38 deletions.
47 changes: 30 additions & 17 deletions .github/workflows/catalogs.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name: Collect and publish catalogs
run-name: Collect and publish catalogs${{ inputs.staging && ' (staging)' || '' }}
on:
workflow_dispatch:
inputs:
channel:
description: 'Channel to publish catalogs to'
required: true
default: stgn
staging:
description: Staging
type: boolean
default: true
schedule:
- cron: '5 * * * *' # Run every hour at HH:05

Expand Down Expand Up @@ -255,30 +256,42 @@ jobs:
needs: [ test-catalog ]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install gpuhunt
run: pip install .
- name: Install AWS CLI
run: pip install awscli
- uses: actions/download-artifact@v4
with:
pattern: catalogs-*
merge-multiple: true
path: v2/
- name: Build legacy v1 catalogs
run: |
mkdir v1
for catalog_path in $(find v2/*.csv); do
file=$(basename "$catalog_path")
python -m gpuhunt.scripts.catalog_v1 --input "v2/$file" --output "v1/$file"
done
- name: Write version
run: echo "$(date +%Y%m%d)-${{ github.run_number }}" > version
run: echo "$(date +%Y%m%d)-${{ github.run_number }}" | tee v2/version | tee v1/version
- name: Package catalogs
run: zip catalog.zip *.csv version
- name: Set channel
run: |
if [[ ${{ github.event_name == 'workflow_dispatch' }} == true ]]; then
CHANNEL=${{ inputs.channel }}
else
CHANNEL=${{ vars.CHANNEL }}
fi
echo "CHANNEL=$CHANNEL" >> $GITHUB_ENV
zip -j v2/catalog.zip v2/*.csv v2/version
zip -j v1/catalog.zip v1/*.csv v1/version
- name: Upload to S3
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
BUCKET: s3://dstack-gpu-pricing${{ github.event_name == 'workflow_dispatch' && inputs.staging && '/stgn' || '' }}
run: |
VERSION=$(cat version)
aws s3 cp catalog.zip "s3://dstack-gpu-pricing/$CHANNEL/$VERSION/catalog.zip" --acl public-read
cat version | aws s3 cp - "s3://dstack-gpu-pricing/$CHANNEL/version" --acl public-read
aws s3 cp "s3://dstack-gpu-pricing/$CHANNEL/$VERSION/catalog.zip" "s3://dstack-gpu-pricing/$CHANNEL/latest/catalog.zip" --acl public-read
VERSION=$(cat v2/version)
aws s3 cp v2/catalog.zip "$BUCKET/v2/$VERSION/catalog.zip" --acl public-read
aws s3 cp v1/catalog.zip "$BUCKET/v1/$VERSION/catalog.zip" --acl public-read
echo $VERSION | aws s3 cp - "$BUCKET/v2/version" --acl public-read
echo $VERSION | aws s3 cp - "$BUCKET/v1/version" --acl public-read
aws s3 cp "$BUCKET/v2/$VERSION/catalog.zip" "$BUCKET/v2/latest/catalog.zip" --acl public-read
aws s3 cp "$BUCKET/v1/$VERSION/catalog.zip" "$BUCKET/v1/latest/catalog.zip" --acl public-read
8 changes: 2 additions & 6 deletions src/gpuhunt/__main__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import logging
import os
import sys

import gpuhunt._internal.storage as storage
from gpuhunt._internal.utils import configure_logging


def main():
Expand All @@ -27,11 +27,7 @@ def main():
parser.add_argument("--output", required=True)
parser.add_argument("--no-filter", action="store_true")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format="%(asctime)s %(levelname)s %(message)s",
)
configure_logging()

if args.provider == "aws":
from gpuhunt.providers.aws import AWSProvider
Expand Down
8 changes: 6 additions & 2 deletions src/gpuhunt/_internal/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import urllib.request
import zipfile
from collections.abc import Container
from concurrent.futures import ThreadPoolExecutor, wait
from pathlib import Path
from typing import Optional, Union
Expand All @@ -17,8 +18,8 @@
from gpuhunt.providers import AbstractProvider

logger = logging.getLogger(__name__)
version_url = "https://dstack-gpu-pricing.s3.eu-west-1.amazonaws.com/v1/version"
catalog_url = "https://dstack-gpu-pricing.s3.eu-west-1.amazonaws.com/v1/{version}/catalog.zip"
version_url = "https://dstack-gpu-pricing.s3.eu-west-1.amazonaws.com/v2/version"
catalog_url = "https://dstack-gpu-pricing.s3.eu-west-1.amazonaws.com/v2/{version}/catalog.zip"
OFFLINE_PROVIDERS = ["aws", "azure", "datacrunch", "gcp", "lambdalabs", "oci", "runpod"]
ONLINE_PROVIDERS = ["cudo", "tensordock", "vastai", "vultr"]
RELOAD_INTERVAL = 15 * 60 # 15 minutes
Expand Down Expand Up @@ -60,6 +61,7 @@ def query(
min_compute_capability: Optional[Union[str, tuple[int, int]]] = None,
max_compute_capability: Optional[Union[str, tuple[int, int]]] = None,
spot: Optional[bool] = None,
allowed_flags: Optional[Container[str]] = None,
) -> list[CatalogItem]:
"""
Query the catalog for matching offers
Expand All @@ -84,6 +86,7 @@ def query(
min_compute_capability: minimum compute capability of the GPU
max_compute_capability: maximum compute capability of the GPU
spot: if `False`, only ondemand offers will be returned. If `True`, only spot offers will be returned
allowed_flags: only offers with all flags allowed will be returned. `None` allows all flags
Returns:
list of matching offers
Expand Down Expand Up @@ -114,6 +117,7 @@ def query(
min_compute_capability=parse_compute_capability(min_compute_capability),
max_compute_capability=parse_compute_capability(max_compute_capability),
spot=spot,
allowed_flags=allowed_flags,
)

if query_filter.provider is not None:
Expand Down
3 changes: 3 additions & 0 deletions src/gpuhunt/_internal/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def matches(i: CatalogItem, q: QueryFilter) -> bool:
if i.disk_size is not None:
if not is_between(i.disk_size, q.min_disk_size, q.max_disk_size):
return False
if q.allowed_flags is not None:
if any(flag not in q.allowed_flags for flag in i.flags):
return False
return True


Expand Down
24 changes: 22 additions & 2 deletions src/gpuhunt/_internal/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
from dataclasses import asdict, dataclass, fields
from collections.abc import Container
from dataclasses import asdict, dataclass, field, fields
from typing import (
ClassVar,
Optional,
Expand Down Expand Up @@ -42,6 +43,11 @@ def cast(cls, value: Union["AcceleratorVendor", str]) -> "AcceleratorVendor":

@dataclass
class RawCatalogItem:
"""
An item stored in the catalog.
See `CatalogItem` for field descriptions.
"""

instance_name: Optional[str]
location: Optional[str]
price: Optional[float]
Expand All @@ -53,6 +59,7 @@ class RawCatalogItem:
spot: Optional[bool]
disk_size: Optional[float]
gpu_vendor: Optional[str] = None
flags: list[str] = field(default_factory=list)

def __post_init__(self) -> None:
# This heuristic will be required indefinitely since we support historical catalogs.
Expand Down Expand Up @@ -87,15 +94,20 @@ def from_dict(v: dict) -> "RawCatalogItem":
gpu_memory=empty_as_none(v.get("gpu_memory"), loader=float),
spot=empty_as_none(v.get("spot"), loader=bool_loader),
disk_size=empty_as_none(v.get("disk_size"), loader=float),
flags=v.get("flags", "").split(),
)

def dict(self) -> dict[str, Union[str, int, float, bool, None]]:
return asdict(self)
return {
**asdict(self),
"flags": " ".join(self.flags),
}


@dataclass
class CatalogItem:
"""
An item returned by `Catalog.query`.
Attributes:
instance_name: name of the instance
location: region or zone
Expand All @@ -108,6 +120,11 @@ class CatalogItem:
spot: whether the instance is a spot instance
provider: name of the provider
disk_size: size of disk in GB
flags: list of flags. If a catalog item breaks existing dstack versions,
add a flag to hide the item from those versions. Newer dstack versions
will have to request this flag explicitly to get the catalog item.
If you are adding a new provider, leave the flags empty.
Flag names should be in kebab-case.
"""

instance_name: str
Expand All @@ -122,6 +139,7 @@ class CatalogItem:
disk_size: Optional[float]
provider: str
gpu_vendor: Optional[AcceleratorVendor] = None
flags: list[str] = field(default_factory=list)

def __post_init__(self) -> None:
gpu_vendor = self.gpu_vendor
Expand Down Expand Up @@ -167,6 +185,7 @@ class QueryFilter:
min_compute_capability: minimum compute capability of the GPU
max_compute_capability: maximum compute capability of the GPU
spot: if `False`, only ondemand offers will be returned. If `True`, only spot offers will be returned
allowed_flags: only offers with all flags allowed will be returned. `None` allows all flags
"""

provider: Optional[list[str]] = None # strings can have mixed case
Expand All @@ -189,6 +208,7 @@ class QueryFilter:
min_compute_capability: Optional[tuple[int, int]] = None
max_compute_capability: Optional[tuple[int, int]] = None
spot: Optional[bool] = None
allowed_flags: Optional[Container[str]] = None

def __repr__(self) -> str:
"""
Expand Down
28 changes: 20 additions & 8 deletions src/gpuhunt/_internal/storage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import csv
import dataclasses
from collections.abc import Iterable
from typing import TypeVar

from gpuhunt._internal.models import RawCatalogItem

CATALOG_V1_FIELDS = [
"instance_name",
"location",
"price",
"cpu",
"memory",
"gpu_count",
"gpu_name",
"gpu_memory",
"spot",
"disk_size",
"gpu_vendor",
]
T = TypeVar("T", bound=RawCatalogItem)


Expand All @@ -16,11 +28,11 @@ def dump(items: list[T], path: str, *, cls: type[T] = RawCatalogItem):
writer.writerow(item.dict())


def load(path: str, *, cls: type[T] = RawCatalogItem) -> list[T]:
items = []
with open(path, newline="") as f:
reader: Iterable[dict[str, str]] = csv.DictReader(f)
def convert_catalog_v2_to_v1(path_v2: str, path_v1: str) -> None:
with open(path_v2) as f_v2, open(path_v1, "w") as f_v1:
reader = csv.DictReader(f_v2)
writer = csv.DictWriter(f_v1, fieldnames=CATALOG_V1_FIELDS, extrasaction="ignore")
writer.writeheader()
for row in reader:
item = cls.from_dict(row)
items.append(item)
return items
if not row.get("flags"):
writer.writerow(row)
10 changes: 10 additions & 0 deletions src/gpuhunt/_internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import logging
import sys
from typing import Callable, Optional, Union


def configure_logging() -> None:
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format="%(asctime)s %(levelname)s %(message)s",
)


def empty_as_none(value: Optional[str], loader: Optional[Callable] = None):
if value is None or value == "":
return None
Expand Down
22 changes: 19 additions & 3 deletions src/gpuhunt/providers/oci.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import re
from collections.abc import Iterable
Expand Down Expand Up @@ -64,7 +65,7 @@ def get(
)
continue

catalog_item = RawCatalogItem(
on_demand_item = RawCatalogItem(
instance_name=shape.name,
location=None,
price=resources.total_price(),
Expand All @@ -77,17 +78,31 @@ def get(
spot=False,
disk_size=None,
)
result.extend(self._duplicate_item_in_regions(catalog_item, regions))
item_variations = [on_demand_item]
if shape.allow_preemptible:
item_variations.append(self._make_spot_item(on_demand_item))
for item in item_variations:
result.extend(self._duplicate_item_in_regions(item, regions))

return sorted(result, key=lambda i: i.price)

@staticmethod
def _make_spot_item(item: RawCatalogItem) -> RawCatalogItem:
item = copy.deepcopy(item)
item.spot = True
# > Preemptible capacity costs 50% less than on-demand capacity
# https://docs.oracle.com/en-us/iaas/Content/Compute/Concepts/preemptible.htm#howitworks__billing
item.price *= 0.5
item.flags.append("oci-spot")
return item

@staticmethod
def _duplicate_item_in_regions(
item: RawCatalogItem, regions: Iterable[Region]
) -> list[RawCatalogItem]:
result = []
for region in regions:
regional_item = RawCatalogItem(**item.dict())
regional_item = copy.deepcopy(item)
regional_item.location = region.name
result.append(regional_item)
return result
Expand All @@ -110,6 +125,7 @@ class CostEstimatorShape(BaseModel):
name: str
hidden: bool
status: str
allow_preemptible: bool
bundle_memory_qty: int
gpu_qty: Optional[int]
gpu_memory_qty: Optional[int]
Expand Down
Empty file added src/gpuhunt/scripts/__init__.py
Empty file.
Empty file.
31 changes: 31 additions & 0 deletions src/gpuhunt/scripts/catalog_v1/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import argparse
import logging
from collections.abc import Sequence
from pathlib import Path
from textwrap import dedent
from typing import Optional

from gpuhunt._internal import storage
from gpuhunt._internal.utils import configure_logging


def main(args: Optional[Sequence[str]] = None):
configure_logging()
parser = argparse.ArgumentParser(
description=dedent(
"""
Convert a v2 catalog to a v1 catalog. Legacy v1 catalogs are used by older
gpuhunt versions that do not respect the `flags` field. Any catalog items
with flags are filtered out when converting to v1.
"""
)
)
parser.add_argument("--input", type=Path, required=True, help="The v2 catalog file to read")
parser.add_argument("--output", type=Path, required=True, help="The v1 catalog file to write")
args = parser.parse_args(args)
storage.convert_catalog_v2_to_v1(path_v2=args.input, path_v1=args.output)
logging.info("Converted %s -> %s", args.input, args.output)


if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions src/integrity_tests/test_oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ def test_on_demand_present(data_rows: list[dict]):
assert "False" in map(itemgetter("spot"), data_rows)


def test_spot_present(data_rows: list[dict]):
assert "True" in map(itemgetter("spot"), data_rows)


def test_spots_contain_flag(data_rows: list[dict]):
for row in data_rows:
assert (row["spot"] == "True") == ("oci-spot" in row["flags"]), row


@pytest.mark.parametrize("prefix", ["VM.Standard", "BM.Standard", "VM.GPU", "BM.GPU"])
def test_family_present(prefix: str, data_rows: list[dict]):
assert any(name.startswith(prefix) for name in map(itemgetter("instance_name"), data_rows))
Expand Down
Loading

0 comments on commit 279d660

Please sign in to comment.