diff --git a/.github/workflows/unit_tests_inference_experimental.yml b/.github/workflows/unit_tests_inference_experimental.yml index aa3c4bd4b5..990e4c6b84 100644 --- a/.github/workflows/unit_tests_inference_experimental.yml +++ b/.github/workflows/unit_tests_inference_experimental.yml @@ -44,4 +44,4 @@ jobs: working-directory: inference_models run: | source .venv/bin/activate - python -m pytest -n auto tests/unit_tests + python -m pytest -n auto -m "not gpu_only" tests/unit_tests diff --git a/inference/core/version.py b/inference/core/version.py index ebee388c8a..995cc4939b 100644 --- a/inference/core/version.py +++ b/inference/core/version.py @@ -1,4 +1,4 @@ -__version__ = "1.2.10" +__version__ = "1.2.11" if __name__ == "__main__": diff --git a/inference_models/docs/changelog.md b/inference_models/docs/changelog.md index 03631ba684..66b568eb80 100644 --- a/inference_models/docs/changelog.md +++ b/inference_models/docs/changelog.md @@ -1,5 +1,11 @@ # Changelog +## `0.29.0` + +### Added + +- Custom Triton Kernel for RFDETR pre-processing (Codeflash). Numerical parity with `PIL → F.resize → F.to_tensor → F.normalize` chain. + ## `0.28.1` ### Fixed diff --git a/inference_models/inference_models/configuration.py b/inference_models/inference_models/configuration.py index 3b7fb89bce..3bf9179712 100644 --- a/inference_models/inference_models/configuration.py +++ b/inference_models/inference_models/configuration.py @@ -458,3 +458,8 @@ "ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND" ) DEFAULT_ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND = False + +DEFAULT_USE_TRITON_FOR_PREPROCESSING = False + +USE_TRITON_FOR_PREPROCESSING = get_boolean_from_env( + variable_name="USE_TRITON_FOR_PREPROCESSING", default=DEFAULT_USE_TRITON_FOR_PREPROCESSING) diff --git a/inference_models/inference_models/models/rfdetr/pre_processing.py b/inference_models/inference_models/models/rfdetr/pre_processing.py index 3e2ceaad09..9c6a02cb3f 100644 --- a/inference_models/inference_models/models/rfdetr/pre_processing.py +++ b/inference_models/inference_models/models/rfdetr/pre_processing.py @@ -9,8 +9,14 @@ torch.Tensor inputs (advanced caller, float CHW [0, 1]): tensor F.resize → F.normalize + +Triton path: for the common case (single-stage resize, no contrast) +we invoke a single PIL-exact Triton kernel that writes the normalized +fp32 tensor directly to `target_device`. """ +from functools import lru_cache +import warnings from typing import List, Optional, Tuple, Union import numpy as np @@ -18,6 +24,7 @@ import torchvision.transforms.functional as TF from PIL import Image +from inference_models.configuration import USE_TRITON_FOR_PREPROCESSING from inference_models import PreProcessingOverrides from inference_models.entities import ColorFormat, ImageDimensions from inference_models.errors import ModelRuntimeError @@ -40,6 +47,43 @@ pre_process_numpy_image, ) +try: + from inference_models.models.rfdetr.triton_preprocess import ( + TRITON_AVAILABLE as _TRITON_AVAILABLE, + build_resample_tables, + triton_preprocess_rfdetr_stretch, ResampleTables, +) +except ImportError: + _TRITON_AVAILABLE = False + build_resample_tables = None + triton_preprocess_rfdetr_stretch = None + +if USE_TRITON_FOR_PREPROCESSING and not _TRITON_AVAILABLE: + warnings.warn( + "USE_TRITON_FOR_PREPROCESSING is enabled, but Triton is not available; " + "RF-DETR Triton preprocessing will be disabled.", + RuntimeWarning, + stacklevel=2, + ) + +@lru_cache(maxsize=50) +def _get_resample_tables_cached( + device_str: str, src_h: int, src_w: int, th: int, tw: int +) -> ResampleTables: + return build_resample_tables( + src_h=src_h, + src_w=src_w, + target_h=th, + target_w=tw, + device=torch.device(device_str), + ) + + +def get_resample_tables( + device: torch.device, src_h: int, src_w: int, th: int, tw: int +) -> ResampleTables: + return _get_resample_tables_cached(str(device), src_h, src_w, th, tw) + def pre_process_network_input( images: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], @@ -92,6 +136,23 @@ def pre_process_network_input( else: image_list = [images] + if triton_path_eligible( + image_list=image_list, + image_pre_processing=image_pre_processing, + network_input=network_input, + target_device=target_device, + pre_processing_overrides=pre_processing_overrides, + ): + return triton_path_preprocess( + image_list=image_list, + image_pre_processing=image_pre_processing, + network_input=network_input, + target_size=target_size, + target_device=target_device, + input_color_mode=input_color_mode, + pre_processing_overrides=pre_processing_overrides, + ) + tensors: List[torch.Tensor] = [] metadata: List[PreProcessingMetadata] = [] for img in image_list: @@ -129,6 +190,181 @@ def pre_process_network_input( return batch, metadata +def triton_path_eligible( + image_list: List[Union[np.ndarray, torch.Tensor]], + image_pre_processing: ImagePreProcessing, + network_input: NetworkInputDefinition, + target_device: torch.device, + pre_processing_overrides: Optional[PreProcessingOverrides], +) -> bool: + """True when every item in `image_list` can go through the fused Triton + kernel. Predicate is intentionally conservative — any miss → PIL path.""" + if not USE_TRITON_FOR_PREPROCESSING: + return False + if not _TRITON_AVAILABLE: + return False + # Kernel runs on CUDA. + if target_device.type != "cuda": + return False + # grayscale / contrast aren't implemented in the kernel; static_crop is + # (as a load-time offset + effective-dims substitution). + ipp = image_pre_processing + if (ipp.contrast is not None and ipp.contrast.enabled) or ( + ipp.grayscale is not None and ipp.grayscale.enabled + ): + return False + # Two-stage dataset-version resize isn't in the kernel. + ni = network_input + if _needs_two_step_resize(ni): + return False + if ni.input_channels != 3: + return False + if ni.scaling_factor not in (None, 255): + return False + if ni.normalization is None: + return False + if ni.resize_mode not in ( + ResizeMode.STRETCH_TO, + ResizeMode.LETTERBOX, + ResizeMode.CENTER_CROP, + ResizeMode.LETTERBOX_REFLECT_EDGES, + ): + return False + if not image_list: + return False + # `pre_process_network_input` already unbinds 4D inputs (batch tensors + # and 4D numpy arrays) into a list of 3D items before calling us, so + # the per-item check below only needs to handle 3D. + for img in image_list: + if isinstance(img, np.ndarray): + if img.dtype != np.uint8 or img.ndim != 3: + return False + if img.shape[2] != 3 and not looks_like_chw(img.shape): + return False + elif isinstance(img, torch.Tensor): + # Only uint8 3-channel images (HWC or CHW). Float tensors keep + # the existing tensor branch (F.resize bilinear *without* PIL's + # antialias — a caller-accepted divergence we don't silently + # change). + if img.dtype != torch.uint8 or img.ndim != 3: + return False + if img.shape[-1] != 3 and not looks_like_chw(img.shape): + return False + else: + return False + return True + + +def looks_like_chw(shape) -> bool: + """Matches _tensor_to_hwc_uint8's CHW heuristic: first dim is 1/3/4 and + last dim is not. Catches torchvision.io.read_image's CHW output.""" + return ( + len(shape) == 3 + and shape[0] in (1, 3, 4) + and shape[-1] not in (1, 3, 4) + ) + + +def as_hwc_uint8_cuda( + img: Union[np.ndarray, torch.Tensor], device: torch.device +) -> torch.Tensor: + """Return a contiguous (H, W, 3) uint8 CUDA tensor, copying if needed. + Accepts HWC or CHW 3D inputs (CHW is torchvision.io.read_image's layout).""" + if isinstance(img, torch.Tensor): + if looks_like_chw(img.shape): + img = img.permute(1, 2, 0) + if img.device != device: + img = img.to(device=device, non_blocking=True) + return img.contiguous() + if looks_like_chw(img.shape): + img = np.transpose(img, (1, 2, 0)) + t = torch.from_numpy(np.ascontiguousarray(img)) + return t.to(device=device, non_blocking=True) + + +def triton_path_preprocess( + image_list: List[Union[np.ndarray, torch.Tensor]], + image_pre_processing: ImagePreProcessing, + network_input: NetworkInputDefinition, + target_size: ImageDimensions, + target_device: torch.device, + input_color_mode: Optional[ColorMode], + pre_processing_overrides: Optional[PreProcessingOverrides], +) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]: + """Per-image Triton launch + stack. Assumes `triton_path_eligible` passed.""" + means, stds = network_input.normalization + means_t = (float(means[0]), float(means[1]), float(means[2])) + stds_t = (float(stds[0]), float(stds[1]), float(stds[2])) + th, tw = target_size.height, target_size.width + + # Match _pre_process_numpy's swap semantics: the PIL path does + swap_rb = input_color_mode != network_input.color_mode + + # Resolve whether static_crop is active for this call. Computed once + # because the config + override are call-level. + static_crop_overridden = ( + pre_processing_overrides is not None + and pre_processing_overrides.disable_static_crop is True + ) + crop_cfg = image_pre_processing.static_crop + crop_active = ( + crop_cfg is not None and crop_cfg.enabled and not static_crop_overridden + ) + + outs: List[torch.Tensor] = [] + metas: List[PreProcessingMetadata] = [] + for img in image_list: + src_gpu = as_hwc_uint8_cuda(img, target_device) + sh, sw = int(src_gpu.shape[0]), int(src_gpu.shape[1]) + + if crop_active: + # Matches apply_static_crop_to_numpy_image: percentage-based. + x0 = int(crop_cfg.x_min / 100 * sw) + y0 = int(crop_cfg.y_min / 100 * sh) + x1 = int(crop_cfg.x_max / 100 * sw) + y1 = int(crop_cfg.y_max / 100 * sh) + crop_w = x1 - x0 + crop_h = y1 - y0 + else: + x0 = y0 = 0 + crop_w, crop_h = sw, sh + + tables = get_resample_tables(target_device, crop_h, crop_w, th, tw) + out = triton_preprocess_rfdetr_stretch( + src=src_gpu, + tables=tables, + target_h=th, + target_w=tw, + means=means_t, + stds=stds_t, + swap_rb=swap_rb, + crop_offset_y=y0, + crop_offset_x=x0, + crop_h=crop_h, + crop_w=crop_w, + ) + outs.append(out[0]) # drop leading batch dim to match stack below + metas.append( + PreProcessingMetadata( + pad_left=0, + pad_top=0, + pad_right=0, + pad_bottom=0, + original_size=ImageDimensions(width=sw, height=sh), + size_after_pre_processing=ImageDimensions(width=crop_w, height=crop_h), + inference_size=ImageDimensions(width=tw, height=th), + scale_width=tw / crop_w, + scale_height=th / crop_h, + static_crop_offset=StaticCropOffset( + offset_x=x0, offset_y=y0, crop_width=crop_w, crop_height=crop_h + ), + ) + ) + + batch = torch.stack(outs).contiguous() + return batch, metas + + def _pre_process_numpy( image: np.ndarray, image_pre_processing: ImagePreProcessing, diff --git a/inference_models/inference_models/models/rfdetr/triton_preprocess.py b/inference_models/inference_models/models/rfdetr/triton_preprocess.py new file mode 100644 index 0000000000..04d8d156d3 --- /dev/null +++ b/inference_models/inference_models/models/rfdetr/triton_preprocess.py @@ -0,0 +1,420 @@ +"""Fused Triton preprocessing kernel for RF-DETR. + +Byte-exact port of PIL's separable bilinear-antialias resize (the algorithm +torchvision's `TF.resize(pil, ..., antialias=True)` uses on PIL inputs), with +the subsequent `/255` + ImageNet normalize fused into the same pass. + +PIL's scheme (src/libImaging/Resample.c): + + PRECISION_BITS = 22 + scale = in_size / out_size + filterscale = max(1.0, scale) + support = 1.0 * filterscale # triangle radius = 1 + ksize = ceil(support) * 2 + 1 + center(o) = (o + 0.5) * scale + xmin(o) = int(center - support + 0.5) clipped to [0, in] + xmax(o) = int(center + support + 0.5) clipped to [0, in] + w_f(o, k) = triangle((k + xmin - center + 0.5) / filterscale) + w_f normalised to sum to 1 per output pixel + w_i(o, k) = round(w_f(o, k) * (1 << PRECISION_BITS)) int32 + out(o) = clamp((Σ w_i(o, k) * src_u8) + (1 << (PRECISION_BITS-1)) >> PRECISION_BITS, 0, 255) + +Single fused kernel: the horizontal uint8 intermediate lives in registers +rather than a DRAM scratch buffer. For each output tile we loop over +KSIZE_Y source rows; for each contributing source row we recompute the +horizontal convolution (int32 fixed-point, uint8 quantize) on the fly, +multiply by the vertical weight, and accumulate. Final: uint8 quantize, +BGR↔RGB swap, /255, ImageNet normalize, fp32 CHW store. +""" + +from __future__ import annotations + +import math +from typing import Optional, Tuple + +import numpy as np +import torch + +from inference_models.errors import ( + MissingDependencyError, + ModelInputError, + ModelRuntimeError, +) + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: # pragma: no cover + triton = None + tl = None + TRITON_AVAILABLE = False + + +PRECISION_BITS = 22 + + +def _bilinear_antialias_weights_1d_int( + in_size: int, out_size: int +) -> Tuple[np.ndarray, np.ndarray, int]: + """PIL's precompute_coeffs, int32 fixed-point form.""" + scale = in_size / out_size + filterscale = max(1.0, scale) + support = filterscale + ksize = int(math.ceil(support)) * 2 + 1 + + starts = np.zeros(out_size, dtype=np.int32) + weights_fp = np.zeros((out_size, ksize), dtype=np.float64) + inv_fs = 1.0 / filterscale + + for o in range(out_size): + center = (o + 0.5) * scale + xmin = int(center - support + 0.5) + if xmin < 0: + xmin = 0 + xmax = int(center + support + 0.5) + if xmax > in_size: + xmax = in_size + actual = xmax - xmin + starts[o] = xmin + total = 0.0 + for k in range(actual): + t = (k + xmin - center + 0.5) * inv_fs + t_abs = -t if t < 0.0 else t + w = 1.0 - t_abs if t_abs < 1.0 else 0.0 + weights_fp[o, k] = w + total += w + if total != 0.0: + weights_fp[o, :actual] /= total + + weights_int = np.rint(weights_fp * (1 << PRECISION_BITS)).astype(np.int32) + return starts, weights_int, ksize + + +if TRITON_AVAILABLE: + + _HALF = 1 << (PRECISION_BITS - 1) + + @triton.jit + def fused_resize_normalize_kernel( + src_ptr, + dst_ptr, + ymin_ptr, + xmin_ptr, + wy_ptr, + wx_ptr, + src_h, + src_w, + src_stride_h, + src_stride_w, + crop_offset_y, + crop_offset_x, + dst_stride_c, + dst_stride_h, + target_h, + target_w, + inv_std_255_r, + inv_std_255_g, + inv_std_255_b, + offset_r, + offset_g, + offset_b, + CH_R: tl.constexpr, + CH_G: tl.constexpr, + CH_B: tl.constexpr, + KSIZE_Y: tl.constexpr, + KSIZE_X: tl.constexpr, + PRECISION_BITS_C: tl.constexpr, + HALF_C: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, + ): + """One kernel per (tile_y, tile_x) over target image. + + In : src uint8 HWC (src_h, src_w, 3), source color order. The + resample tables are built against `(crop_h, crop_w)` — the + logical source size after a possible static crop — which the + caller passes as `src_h`/`src_w`. `crop_offset_{y,x}` is the + load-time offset into the raw HWC buffer. + Out: dst fp32 CHW (1, 3, target_h, target_w), network color order, + (pixel/255 - mean)/std. + """ + pid_y = tl.program_id(0) + pid_x = tl.program_id(1) + + offs_y = pid_y * BLOCK_H + tl.arange(0, BLOCK_H) + offs_x = pid_x * BLOCK_W + tl.arange(0, BLOCK_W) + mask_y = offs_y < target_h + mask_x = offs_x < target_w + mask_out = mask_y[:, None] & mask_x[None, :] + + ymin = tl.load(ymin_ptr + offs_y, mask=mask_y, other=0) + xmin = tl.load(xmin_ptr + offs_x, mask=mask_x, other=0) + + # Vertical pass accumulators (int32 fixed-point) for 3 channels. + vacc_0 = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + vacc_1 = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + vacc_2 = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + + for ky in tl.static_range(KSIZE_Y): + # Source row (after static crop) contributing to each output row. + sy = ymin + ky + sy_c = tl.maximum(tl.minimum(sy, src_h - 1), 0) + crop_offset_y + wy = tl.load(wy_ptr + offs_y * KSIZE_Y + ky, mask=mask_y, other=0) + + # Horizontal pass for (output_rows_in_tile, output_cols_in_tile): + # for each source column in the kernel, gather src[sy_c, sx_c, :] + # and accumulate with wx[output_col, kx]. + hacc_0 = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + hacc_1 = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + hacc_2 = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + + for kx in tl.static_range(KSIZE_X): + sx = xmin + kx + sx_c = tl.maximum(tl.minimum(sx, src_w - 1), 0) + crop_offset_x + wx = tl.load(wx_ptr + offs_x * KSIZE_X + kx, mask=mask_x, other=0) + base = sy_c[:, None] * src_stride_h + sx_c[None, :] * src_stride_w + p0 = tl.load(src_ptr + base + 0, mask=mask_out, other=0).to(tl.int32) + p1 = tl.load(src_ptr + base + 1, mask=mask_out, other=0).to(tl.int32) + p2 = tl.load(src_ptr + base + 2, mask=mask_out, other=0).to(tl.int32) + wx_2d = wx[None, :] + hacc_0 += p0 * wx_2d + hacc_1 += p1 * wx_2d + hacc_2 += p2 * wx_2d + + # Horizontal uint8 quantization (byte-exact to PIL's intermediate). + hacc_0 = (hacc_0 + HALF_C) >> PRECISION_BITS_C + hacc_1 = (hacc_1 + HALF_C) >> PRECISION_BITS_C + hacc_2 = (hacc_2 + HALF_C) >> PRECISION_BITS_C + hacc_0 = tl.minimum(tl.maximum(hacc_0, 0), 255) + hacc_1 = tl.minimum(tl.maximum(hacc_1, 0), 255) + hacc_2 = tl.minimum(tl.maximum(hacc_2, 0), 255) + + wy_2d = wy[:, None] + vacc_0 += hacc_0 * wy_2d + vacc_1 += hacc_1 * wy_2d + vacc_2 += hacc_2 * wy_2d + + # Vertical uint8 quantization. + q_0 = (vacc_0 + HALF_C) >> PRECISION_BITS_C + q_1 = (vacc_1 + HALF_C) >> PRECISION_BITS_C + q_2 = (vacc_2 + HALF_C) >> PRECISION_BITS_C + q_0 = tl.minimum(tl.maximum(q_0, 0), 255) + q_1 = tl.minimum(tl.maximum(q_1, 0), 255) + q_2 = tl.minimum(tl.maximum(q_2, 0), 255) + + # Source-to-output channel remap (triton requires constexpr branches). + if CH_R == 0: + q_r = q_0 + elif CH_R == 1: + q_r = q_1 + else: + q_r = q_2 + if CH_G == 0: + q_g = q_0 + elif CH_G == 1: + q_g = q_1 + else: + q_g = q_2 + if CH_B == 0: + q_b = q_0 + elif CH_B == 1: + q_b = q_1 + else: + q_b = q_2 + + # (pixel/255 - mean)/std == pixel * (1/(255*std)) + (-mean/std) + out_r = q_r.to(tl.float32) * inv_std_255_r + offset_r + out_g = q_g.to(tl.float32) * inv_std_255_g + offset_g + out_b = q_b.to(tl.float32) * inv_std_255_b + offset_b + + out_row = offs_y[:, None] * dst_stride_h + offs_x[None, :] + tl.store(dst_ptr + 0 * dst_stride_c + out_row, out_r, mask=mask_out) + tl.store(dst_ptr + 1 * dst_stride_c + out_row, out_g, mask=mask_out) + tl.store(dst_ptr + 2 * dst_stride_c + out_row, out_b, mask=mask_out) + + +class ResampleTables: + """Cache of per-axis PIL-int32 weight tables for one (src, dst) pair.""" + + __slots__ = ( + "ymin_gpu", + "xmin_gpu", + "wy_gpu", + "wx_gpu", + "ksize_y", + "ksize_x", + ) + + def __init__( + self, + ymin_gpu: torch.Tensor, + xmin_gpu: torch.Tensor, + wy_gpu: torch.Tensor, + wx_gpu: torch.Tensor, + ksize_y: int, + ksize_x: int, + ) -> None: + self.ymin_gpu = ymin_gpu + self.xmin_gpu = xmin_gpu + self.wy_gpu = wy_gpu + self.wx_gpu = wx_gpu + self.ksize_y = ksize_y + self.ksize_x = ksize_x + + +def build_resample_tables( + src_h: int, + src_w: int, + target_h: int, + target_w: int, + device: torch.device, +) -> ResampleTables: + ymin, wy, ksize_y = _bilinear_antialias_weights_1d_int(src_h, target_h) + xmin, wx, ksize_x = _bilinear_antialias_weights_1d_int(src_w, target_w) + return ResampleTables( + ymin_gpu=torch.from_numpy(ymin).to(device=device, non_blocking=True), + xmin_gpu=torch.from_numpy(xmin).to(device=device, non_blocking=True), + wy_gpu=torch.from_numpy(wy.ravel()).to(device=device, non_blocking=True), + wx_gpu=torch.from_numpy(wx.ravel()).to(device=device, non_blocking=True), + ksize_y=ksize_y, + ksize_x=ksize_x, + ) + + +def triton_preprocess_rfdetr_stretch( + src: torch.Tensor, + tables: ResampleTables, + target_h: int, + target_w: int, + means: Tuple[float, float, float] = (0.485, 0.456, 0.406), + stds: Tuple[float, float, float] = (0.229, 0.224, 0.225), + swap_rb: bool = True, + crop_offset_y: int = 0, + crop_offset_x: int = 0, + crop_h: Optional[int] = None, + crop_w: Optional[int] = None, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Fused PIL-exact resize + color swap + normalize. + + Args: + src: uint8 CUDA tensor, shape (H, W, 3), HWC layout. + tables: precomputed int32 resample tables sized against the *cropped* + source `(crop_h, crop_w)` → `(target_h, target_w)`. + target_h, target_w: output spatial dims. + means, stds: normalization in output channel order (R, G, B for + network_input.color_mode == 'rgb'). + swap_rb: if True, source channel 0 → output B (BGR input, RGB network). + crop_offset_y/_x: load-time offset into `src` for a static crop. 0 + means no crop. + crop_h/_w: effective source dims after crop. Defaults to src dims + when no crop is configured. + out: optional preallocated fp32 (1, 3, H, W) CUDA tensor. + + Returns: + fp32 (1, 3, target_h, target_w) on the same device as `src`. + """ + if not TRITON_AVAILABLE: + raise MissingDependencyError( + message="triton is not installed", + help_url="https://inference-models.roboflow.com/errors/runtime-environment/#missingdependencyerror", + ) + if not src.is_cuda: + raise ModelInputError( + message=f"expected CUDA src tensor, got device={src.device}", + help_url="https://inference-models.roboflow.com/errors/input-validation/#modelinputerror", + ) + if src.dtype != torch.uint8: + raise ModelInputError( + message=f"expected uint8 src, got {src.dtype}", + help_url="https://inference-models.roboflow.com/errors/input-validation/#modelinputerror", + ) + if src.ndim != 3 or src.shape[2] != 3: + raise ModelInputError( + message=f"expected HWC 3-channel, got shape={tuple(src.shape)}", + help_url="https://inference-models.roboflow.com/errors/input-validation/#modelinputerror", + ) + + src = src.contiguous() + raw_src_h, raw_src_w = int(src.shape[0]), int(src.shape[1]) + src_h = crop_h if crop_h is not None else raw_src_h + src_w = crop_w if crop_w is not None else raw_src_w + src_stride_h = int(src.stride(0)) + src_stride_w = int(src.stride(1)) + + if out is None: + out = torch.empty( + (1, 3, target_h, target_w), dtype=torch.float32, device=src.device + ) + else: + if tuple(out.shape) != (1, 3, target_h, target_w): + raise ModelRuntimeError( + message=( + f"out has shape {tuple(out.shape)}, expected " + f"(1, 3, {target_h}, {target_w})" + ), + help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", + ) + if out.dtype != torch.float32 or not out.is_cuda: + raise ModelRuntimeError( + message="out must be fp32 CUDA tensor", + help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", + ) + + dst_stride_c = target_h * target_w + dst_stride_h = target_w + + if swap_rb: + ch_r, ch_g, ch_b = 2, 1, 0 + else: + ch_r, ch_g, ch_b = 0, 1, 2 + + inv_std_255_r = 1.0 / (255.0 * stds[0]) + inv_std_255_g = 1.0 / (255.0 * stds[1]) + inv_std_255_b = 1.0 / (255.0 * stds[2]) + offset_r = -means[0] / stds[0] + offset_g = -means[1] / stds[1] + offset_b = -means[2] / stds[2] + + BLOCK_H = 16 + BLOCK_W = 16 + grid = ( + (target_h + BLOCK_H - 1) // BLOCK_H, + (target_w + BLOCK_W - 1) // BLOCK_W, + ) + fused_resize_normalize_kernel[grid]( + src, + out, + tables.ymin_gpu, + tables.xmin_gpu, + tables.wy_gpu, + tables.wx_gpu, + src_h, + src_w, + src_stride_h, + src_stride_w, + int(crop_offset_y), + int(crop_offset_x), + dst_stride_c, + dst_stride_h, + target_h, + target_w, + float(inv_std_255_r), + float(inv_std_255_g), + float(inv_std_255_b), + float(offset_r), + float(offset_g), + float(offset_b), + CH_R=ch_r, + CH_G=ch_g, + CH_B=ch_b, + KSIZE_Y=tables.ksize_y, + KSIZE_X=tables.ksize_x, + PRECISION_BITS_C=PRECISION_BITS, + HALF_C=_HALF, + BLOCK_H=BLOCK_H, + BLOCK_W=BLOCK_W, + ) + return out diff --git a/inference_models/pyproject.toml b/inference_models/pyproject.toml index eee7e77496..a50e4447f8 100644 --- a/inference_models/pyproject.toml +++ b/inference_models/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "inference-models" -version = "0.28.2" +version = "0.28.4" description = "The new inference engine for Computer Vision models" readme = "README.md" requires-python = ">=3.10,<3.13" @@ -35,10 +35,8 @@ dependencies = [ "pybase64~=1.0.0", "rf-segment-anything==1.0", "rf-sam-2==1.0.3", - "sam3==0.1.3; sys_platform != 'darwin'", "pycocotools>=2.0.10,<2.1.0", # transitive-dependencies pinning - "triton<4.0.0; sys_platform != 'darwin'", "urllib3>=2.7.0,<3.0.0", "pillow>=12.2.0,<13.0.0", "GitPython>=3.1.50,<4.0.0", @@ -53,35 +51,41 @@ torch-cpu = [ "torchvision" ] torch-cu118 = [ + "sam3==0.1.4; sys_platform != 'darwin'", "torch>=2.0.0,<3.0.0", "torchvision", "pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'", "Mako>=1.3.11", ] torch-cu124 = [ + "sam3==0.1.4; sys_platform != 'darwin'", "torch>=2.0.0,<3.0.0", "torchvision", "pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'", "Mako>=1.3.11", ] torch-cu126 = [ + "sam3==0.1.4; sys_platform != 'darwin'", "torch>=2.0.0,<3.0.0", "torchvision", "pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'", ] torch-cu128 = [ + "sam3==0.1.4; sys_platform != 'darwin'", "torch>=2.0.0,<3.0.0", "torchvision", "pycuda>=2025.0.0,<2026.0.0; platform_system != 'darwin' and python_version >= '3.10'", "Mako>=1.3.11", ] torch-cu130 = [ + "sam3==0.1.4; sys_platform != 'darwin'", "torch>=2.0.0,<3.0.0", "torchvision", "pycuda>=2025.0.0,<=2026.1.0; platform_system != 'darwin' and python_version >= '3.10'", "Mako>=1.3.11", ] torch-jp6-cu126 = [ + "sam3==0.1.4; sys_platform != 'darwin'", "numpy<2.0.0", "torch>=2.0.0,<3.0.0", "torchvision", diff --git a/inference_models/tests/integration_tests/models/test_sam3_predictions.py b/inference_models/tests/integration_tests/models/test_sam3_predictions.py index fc67352748..d31bb71623 100644 --- a/inference_models/tests/integration_tests/models/test_sam3_predictions.py +++ b/inference_models/tests/integration_tests/models/test_sam3_predictions.py @@ -10,11 +10,12 @@ Sam3ImageEmbeddingsInMemoryCache, Sam3LowResolutionMasksInMemoryCache, ) -from inference_models.models.sam3.sam3_torch import SAM3Torch @pytest.fixture(scope="module") -def sam3_model(sam3_package: str) -> SAM3Torch: +def sam3_model(sam3_package: str): + from inference_models.models.sam3.sam3_torch import SAM3Torch + model = SAM3Torch.from_pretrained( sam3_package, device=DEFAULT_DEVICE, @@ -45,7 +46,7 @@ def _free_gpu_after_test(): @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_embeddings_numpy( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -63,7 +64,7 @@ def test_sam3_embeddings_numpy( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_embeddings_torch( - sam3_model: SAM3Torch, truck_image_torch: torch.Tensor + sam3_model, truck_image_torch: torch.Tensor ) -> None: # given model = sam3_model @@ -80,7 +81,7 @@ def test_sam3_embeddings_torch( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_embeddings_batch_numpy( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -98,7 +99,7 @@ def test_sam3_embeddings_batch_numpy( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_embeddings_caching( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -116,7 +117,7 @@ def test_sam3_embeddings_caching( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_without_prompting_numpy( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -135,7 +136,7 @@ def test_sam3_segment_images_without_prompting_numpy( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_without_prompting_batch_numpy( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -153,7 +154,7 @@ def test_sam3_segment_images_without_prompting_batch_numpy( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_with_point_prompting( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -185,7 +186,7 @@ def test_sam3_segment_images_with_point_prompting( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_with_multiple_points( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -209,7 +210,7 @@ def test_sam3_segment_images_with_multiple_points( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_with_embeddings( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -234,7 +235,7 @@ def test_sam3_segment_images_with_embeddings( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_with_box_prompting( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -257,7 +258,7 @@ def test_sam3_segment_images_with_box_prompting( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_with_box_prompting_and_embeddings( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -281,7 +282,7 @@ def test_sam3_segment_images_with_box_prompting_and_embeddings( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_with_combined_prompting( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -306,7 +307,7 @@ def test_sam3_segment_images_with_combined_prompting( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_with_mask_prompting( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -336,7 +337,7 @@ def test_sam3_segment_images_with_mask_prompting( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_raises_on_missing_input( - sam3_model: SAM3Torch, + sam3_model, ) -> None: # given model = sam3_model @@ -350,7 +351,7 @@ def test_sam3_segment_images_raises_on_missing_input( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_with_misaligned_batch_sizes( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -370,7 +371,7 @@ def test_sam3_segment_images_with_misaligned_batch_sizes( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_with_text_single_prompt( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -396,7 +397,7 @@ def test_sam3_segment_with_text_single_prompt( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_with_text_multiple_prompts( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -425,7 +426,7 @@ def test_sam3_segment_with_text_multiple_prompts( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_with_text_visual_prompt( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -455,7 +456,7 @@ def test_sam3_segment_with_text_visual_prompt( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_with_text_batch_images( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -479,7 +480,7 @@ def test_sam3_segment_with_text_batch_images( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_multi_mask_output( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -514,7 +515,7 @@ def test_sam3_segment_images_multi_mask_output( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_images_return_logits( - sam3_model: SAM3Torch, + sam3_model, truck_image_numpy: np.ndarray, ) -> None: # given @@ -550,7 +551,7 @@ def test_sam3_segment_images_return_logits( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_caching_disabled( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: # given model = sam3_model @@ -570,7 +571,7 @@ def test_sam3_caching_disabled( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_embed_then_segment_with_client_hash( - sam3_model: SAM3Torch, truck_image_numpy: np.ndarray + sam3_model, truck_image_numpy: np.ndarray ) -> None: """SAM2-style embed-then-segment flow: caller provides image_id at embed time and reuses it at segment time without re-sending image bytes.""" @@ -603,7 +604,7 @@ def test_sam3_embed_then_segment_with_client_hash( @pytest.mark.torch_models @pytest.mark.gpu_only def test_sam3_segment_with_unknown_hash_raises( - sam3_model: SAM3Torch, + sam3_model, ) -> None: """segment_with_visual_prompts with a hash never seen by the cache must raise.""" # given diff --git a/inference_models/tests/unit_tests/models/rfdetr/test_pre_processing.py b/inference_models/tests/unit_tests/models/rfdetr/test_pre_processing.py index 2bc433814c..d55c8644e7 100644 --- a/inference_models/tests/unit_tests/models/rfdetr/test_pre_processing.py +++ b/inference_models/tests/unit_tests/models/rfdetr/test_pre_processing.py @@ -5,6 +5,7 @@ import torchvision.transforms.functional as TF from PIL import Image +import inference_models.models.rfdetr.pre_processing as rfdetr_pre_processing from inference_models.entities import ImageDimensions from inference_models.models.common.roboflow.model_packages import ( ColorMode, @@ -360,3 +361,226 @@ def test_batched_input_produces_per_image_metadata() -> None: torch.testing.assert_close( batch_tensor[1], expected_b.squeeze(0), atol=1e-6, rtol=0 ) + + +def enable_triton_fast_path(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + rfdetr_pre_processing, "USE_TRITON_FOR_PREPROCESSING", True + ) + monkeypatch.setattr(rfdetr_pre_processing, "_TRITON_AVAILABLE", True) + + +def skip_if_triton_gpu_path_unavailable() -> None: + pytest.importorskip("triton") + if not torch.cuda.is_available(): # pragma: no cover - host-dependent + pytest.skip("CUDA not available") + + +@pytest.mark.gpu_only +def test_triton_fast_path_matches_reference_pipeline_for_rgb_numpy( + monkeypatch: pytest.MonkeyPatch, +) -> None: + skip_if_triton_gpu_path_unavailable() + enable_triton_fast_path(monkeypatch=monkeypatch) + + image_pre_processing = ImagePreProcessing() + network_input = _build_network_input( + training_h=64, + training_w=64, + resize_mode=ResizeMode.STRETCH_TO, + dataset_version_dims=None, + ) + rng = np.random.default_rng(seed=123) + image = rng.integers(0, 256, size=(192, 168, 3), dtype=np.uint8) + + actual_tensor, actual_meta = pre_process_network_input( + images=image, + image_pre_processing=image_pre_processing, + network_input=network_input, + target_device=torch.device("cuda"), + input_color_format="rgb", + ) + + expected = _reference_pipeline(Image.fromarray(image), target_h=64, target_w=64) + torch.testing.assert_close(actual_tensor.cpu(), expected, atol=1e-5, rtol=0) + assert actual_meta[0].original_size == ImageDimensions(width=168, height=192) + assert actual_meta[0].inference_size == ImageDimensions(width=64, height=64) + + +@pytest.mark.gpu_only +def test_triton_fast_path_matches_reference_pipeline_for_bgr_numpy( + monkeypatch: pytest.MonkeyPatch, +) -> None: + skip_if_triton_gpu_path_unavailable() + enable_triton_fast_path(monkeypatch=monkeypatch) + + image_pre_processing = ImagePreProcessing() + network_input = _build_network_input( + training_h=64, + training_w=64, + resize_mode=ResizeMode.STRETCH_TO, + dataset_version_dims=None, + ) + rng = np.random.default_rng(seed=456) + rgb_image = rng.integers(0, 256, size=(96, 144, 3), dtype=np.uint8) + bgr_image = rgb_image[:, :, ::-1].copy() + + actual_tensor = pre_process_network_input( + images=bgr_image, + image_pre_processing=image_pre_processing, + network_input=network_input, + target_device=torch.device("cuda"), + input_color_format="bgr", + )[0] + + expected = _reference_pipeline(Image.fromarray(rgb_image), target_h=64, target_w=64) + torch.testing.assert_close(actual_tensor.cpu(), expected, atol=1e-5, rtol=0) + + +@pytest.mark.gpu_only +def test_triton_fast_path_matches_reference_pipeline_for_platform_style_stretch( + monkeypatch: pytest.MonkeyPatch, +) -> None: + skip_if_triton_gpu_path_unavailable() + enable_triton_fast_path(monkeypatch=monkeypatch) + + image_pre_processing = ImagePreProcessing() + network_input = _build_network_input( + training_h=64, + training_w=64, + resize_mode=ResizeMode.STRETCH_TO, + dataset_version_dims=TrainingInputSize(height=48, width=80), + ) + rng = np.random.default_rng(seed=789) + image = rng.integers(0, 256, size=(128, 200, 3), dtype=np.uint8) + + actual_tensor = pre_process_network_input( + images=image, + image_pre_processing=image_pre_processing, + network_input=network_input, + target_device=torch.device("cuda"), + input_color_format="rgb", + )[0] + + expected = _reference_pipeline(Image.fromarray(image), target_h=64, target_w=64) + torch.testing.assert_close(actual_tensor.cpu(), expected, atol=1e-5, rtol=0) + + +def test_pre_process_network_input_dispatches_to_triton_fast_path_when_eligible( + monkeypatch: pytest.MonkeyPatch, +) -> None: + enable_triton_fast_path(monkeypatch=monkeypatch) + + image_pre_processing = ImagePreProcessing() + network_input = _build_network_input( + training_h=64, + training_w=64, + resize_mode=ResizeMode.STRETCH_TO, + dataset_version_dims=TrainingInputSize(height=48, width=80), + ) + image = np.zeros((32, 48, 3), dtype=np.uint8) + sentinel_tensor = torch.empty((1, 3, 64, 64), dtype=torch.float32) + sentinel_meta = ["sentinel-meta"] + captured_kwargs = {} + + def fake_triton_path_preprocess(**kwargs): + captured_kwargs.update(kwargs) + return sentinel_tensor, sentinel_meta + + monkeypatch.setattr( + rfdetr_pre_processing, + "triton_path_preprocess", + fake_triton_path_preprocess, + ) + + actual_tensor, actual_meta = pre_process_network_input( + images=image, + image_pre_processing=image_pre_processing, + network_input=network_input, + target_device=torch.device("cuda"), + input_color_format="rgb", + ) + + assert actual_tensor is sentinel_tensor + assert actual_meta is sentinel_meta + assert captured_kwargs["target_device"] == torch.device("cuda") + assert captured_kwargs["target_size"] == ImageDimensions(width=64, height=64) + assert captured_kwargs["input_color_mode"] == ColorMode.RGB + assert len(captured_kwargs["image_list"]) == 1 + + +@pytest.mark.parametrize( + "image_pre_processing, network_input, expected", + [ + ( + ImagePreProcessing(), + _build_network_input( + training_h=64, + training_w=64, + resize_mode=ResizeMode.STRETCH_TO, + dataset_version_dims=None, + ), + True, + ), + ( + ImagePreProcessing(), + _build_network_input( + training_h=64, + training_w=64, + resize_mode=ResizeMode.STRETCH_TO, + dataset_version_dims=TrainingInputSize(height=48, width=80), + ), + True, + ), + ( + ImagePreProcessing(), + _build_network_input( + training_h=64, + training_w=64, + resize_mode=ResizeMode.LETTERBOX, + dataset_version_dims=TrainingInputSize(height=48, width=80), + ), + False, + ), + ( + ImagePreProcessing.model_validate( + { + "contrast": { + "enabled": True, + "type": "Contrast Stretching", + } + } + ), + _build_network_input( + training_h=64, + training_w=64, + resize_mode=ResizeMode.STRETCH_TO, + dataset_version_dims=None, + ), + False, + ), + ], + ids=[ + "stretch_without_dataset_dims", + "stretch_with_dataset_dims", + "two_step_letterbox", + "contrast_enabled", + ], +) +def test_triton_path_eligible_respects_configuration( + image_pre_processing: ImagePreProcessing, + network_input: NetworkInputDefinition, + expected: bool, + monkeypatch: pytest.MonkeyPatch, +) -> None: + enable_triton_fast_path(monkeypatch=monkeypatch) + + actual = rfdetr_pre_processing.triton_path_eligible( + image_list=[np.zeros((32, 48, 3), dtype=np.uint8)], + image_pre_processing=image_pre_processing, + network_input=network_input, + target_device=torch.device("cuda"), + pre_processing_overrides=None, + ) + + assert actual is expected diff --git a/inference_models/tests/unit_tests/models/test_sam3_client_hashes.py b/inference_models/tests/unit_tests/models/test_sam3_client_hashes.py index aacc42ef3a..64c4b53ce2 100644 --- a/inference_models/tests/unit_tests/models/test_sam3_client_hashes.py +++ b/inference_models/tests/unit_tests/models/test_sam3_client_hashes.py @@ -9,12 +9,15 @@ Sam3ImageEmbeddingsCacheNullObject, Sam3LowResolutionMasksCacheNullObject, ) -from inference_models.models.sam3.sam3_torch import SAM3Torch -pytestmark = pytest.mark.torch_models +pytestmark = [pytest.mark.torch_models, pytest.mark.gpu_only] -def _make_model(allow_client_hashes: bool) -> SAM3Torch: +def _make_model(allow_client_hashes: bool): + # sam3 is GPU-only and absent from CPU/vino builds, so import the model class + # inside the helper (these tests are gpu_only and run where sam3 is installed). + from inference_models.models.sam3.sam3_torch import SAM3Torch + return SAM3Torch( model=MagicMock(), transform=MagicMock(), diff --git a/inference_models/uv.lock b/inference_models/uv.lock index a2ecb59fcc..2f376750d6 100644 --- a/inference_models/uv.lock +++ b/inference_models/uv.lock @@ -526,18 +526,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, ] -[[package]] -name = "decord" -version = "0.6.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, - { url = "https://files.pythonhosted.org/packages/6c/be/e15b5b866da452e62635a7b27513f31cb581fa2ea9cc9b768b535d62a955/decord-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02665d7c4f1193a330205a791bc128f7e108eb6ae5b67144437a02f700943bad", size = 24733380, upload-time = "2021-06-14T21:30:57.766Z" }, -] - [[package]] name = "defusedxml" version = "0.7.1" @@ -925,7 +913,7 @@ wheels = [ [[package]] name = "inference-models" -version = "0.28.2" +version = "0.28.4" source = { virtual = "." } dependencies = [ { name = "accelerate" }, @@ -955,7 +943,6 @@ dependencies = [ { name = "rf-sam-2" }, { name = "rf-segment-anything" }, { name = "rich" }, - { name = "sam3", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, { name = "scikit-image" }, { name = "segmentation-models-pytorch" }, { name = "sentencepiece" }, @@ -979,9 +966,6 @@ dependencies = [ { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.jetson-ai-lab.io/jp6/cu126/+simple" }, marker = "(extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra != 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra != 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu126' and extra != 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra != 'extra-16-inference-models-torch-jp6-cu126')" }, { name = "transformers" }, - { name = "triton", version = "3.2.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "triton", version = "3.3.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and sys_platform != 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 's390x' and sys_platform == 'linux') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128')" }, - { name = "triton", version = "3.4.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, { name = "urllib3" }, ] @@ -1043,23 +1027,27 @@ torch-cpu = [ torch-cu118 = [ { name = "mako" }, { name = "pycuda", marker = "platform_system != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "sam3", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, { name = "torch", version = "2.7.1+cu118", source = { registry = "https://download.pytorch.org/whl/cu118" } }, { name = "torchvision", version = "0.22.1+cu118", source = { registry = "https://download.pytorch.org/whl/cu118" } }, ] torch-cu124 = [ { name = "mako" }, { name = "pycuda", marker = "platform_system != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "sam3", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" } }, { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" } }, ] torch-cu126 = [ { name = "pycuda", marker = "platform_system != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "sam3", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" } }, { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.org/simple" } }, ] torch-cu128 = [ { name = "mako" }, { name = "pycuda", marker = "platform_system != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "sam3", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, { name = "torch", version = "2.7.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" } }, { name = "torchvision", version = "0.22.1", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, { name = "torchvision", version = "0.22.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, @@ -1067,6 +1055,7 @@ torch-cu128 = [ torch-cu130 = [ { name = "mako" }, { name = "pycuda", marker = "platform_system != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "sam3", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" } }, { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.org/simple" } }, ] @@ -1074,6 +1063,7 @@ torch-jp6-cu126 = [ { name = "mako" }, { name = "numpy" }, { name = "pycuda" }, + { name = "sam3", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.jetson-ai-lab.io/jp6/cu126/+simple" } }, { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.jetson-ai-lab.io/jp6/cu126/+simple" } }, ] @@ -1163,7 +1153,12 @@ requires-dist = [ { name = "rf-sam-2", specifier = "==1.0.3" }, { name = "rf-segment-anything", specifier = "==1.0" }, { name = "rich", specifier = ">=14.1.0,<15.0.0" }, - { name = "sam3", marker = "sys_platform != 'darwin'", specifier = "==0.1.3" }, + { name = "sam3", marker = "sys_platform != 'darwin' and extra == 'torch-cu118'", specifier = "==0.1.4" }, + { name = "sam3", marker = "sys_platform != 'darwin' and extra == 'torch-cu124'", specifier = "==0.1.4" }, + { name = "sam3", marker = "sys_platform != 'darwin' and extra == 'torch-cu126'", specifier = "==0.1.4" }, + { name = "sam3", marker = "sys_platform != 'darwin' and extra == 'torch-cu128'", specifier = "==0.1.4" }, + { name = "sam3", marker = "sys_platform != 'darwin' and extra == 'torch-cu130'", specifier = "==0.1.4" }, + { name = "sam3", marker = "sys_platform != 'darwin' and extra == 'torch-jp6-cu126'", specifier = "==0.1.4" }, { name = "scikit-image", specifier = ">=0.24.0,<0.26.0" }, { name = "segmentation-models-pytorch", specifier = ">=0.5.0,<1.0.0" }, { name = "sentencepiece", specifier = ">=0.2.1,<0.3.0" }, @@ -1190,7 +1185,6 @@ requires-dist = [ { name = "torchvision", marker = "extra == 'torch-jp6-cu126'", index = "https://pypi.jetson-ai-lab.io/jp6/cu126/+simple", conflict = { package = "inference-models", extra = "torch-jp6-cu126" } }, { name = "tornado", marker = "extra == 'docs'", specifier = ">=6.5.5" }, { name = "transformers", specifier = "~=5.5" }, - { name = "triton", marker = "sys_platform != 'darwin'", specifier = "<4.0.0" }, { name = "urllib3", specifier = ">=2.7.0,<3.0.0" }, ] provides-extras = ["torch-cpu", "torch-cu118", "torch-cu124", "torch-cu126", "torch-cu128", "torch-cu130", "torch-jp6-cu126", "onnx-cpu", "onnx-cu118", "onnx-cu12", "onnx-jp6-cu126", "trt10", "test", "docs"] @@ -4249,24 +4243,23 @@ wheels = [ [[package]] name = "sam3" -version = "0.1.3" +version = "0.1.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "decord", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "einops", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "ftfy", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "huggingface-hub", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "iopath", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "numpy", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "pycocotools", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "regex", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "timm", version = "1.0.27", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "tqdm", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, - { name = "typing-extensions", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, + { name = "einops", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "ftfy", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "huggingface-hub", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "iopath", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "numpy", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "pycocotools", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "regex", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "timm", version = "1.0.27", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "tqdm", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, + { name = "typing-extensions", marker = "sys_platform != 'darwin' or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0f/83/0a7808fdf9ceda86f1f301c8f08af3fd3bc70e6a390adb74fbbea8d7211b/sam3-0.1.3.tar.gz", hash = "sha256:dce27ab32a11bcd25a8fc25fae31a6d5ca6a3222465127cc7a46ec3335a62996", size = 6183016, upload-time = "2026-03-09T18:39:34.092Z" } +sdist = { url = "https://files.pythonhosted.org/packages/69/59/2ad52462e60a70955aa667afbef642151c4bd9dfdfee4a46731a20ca7f58/sam3-0.1.4.tar.gz", hash = "sha256:dfbc06206efc3ddb5805fd9a717317ed9fff167f45424f90313a325a63d7cd1e", size = 6183330, upload-time = "2026-05-26T11:48:18.004Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/a9/e1b32092b11dadfd0776e604932d0ba938a562fb8c61c1306d0f7802f9ba/sam3-0.1.3-py3-none-any.whl", hash = "sha256:1790c200c6ecf1ec3dad14a76b151f8886432cb6ad405d5c26451f144421d46c", size = 506762, upload-time = "2026-03-09T18:39:32.535Z" }, + { url = "https://files.pythonhosted.org/packages/f3/73/16864549e382f23b9c0bc738b8bca5270770bbc88552e9fdaf9a95cee867/sam3-0.1.4-py3-none-any.whl", hash = "sha256:dc2089fccbcd183896df8cf9b3bf53da9f6a23f17e49f5b643ade993322d2949", size = 506751, upload-time = "2026-05-26T11:48:13.442Z" }, ] [[package]] @@ -5549,7 +5542,7 @@ resolution-markers = [ "python_full_version < '3.11' and platform_machine == 's390x' and sys_platform != 'darwin'", ] dependencies = [ - { name = "setuptools", marker = "(platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and sys_platform != 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 's390x' and sys_platform == 'linux') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128')" }, + { name = "setuptools", marker = "(platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 's390x' and sys_platform != 'darwin' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/8d/a9/549e51e9b1b2c9b854fd761a1d23df0ba2fbc60bd0c13b489ffa518cfcb7/triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e", size = 155600257, upload-time = "2025-05-29T23:39:36.085Z" }, @@ -5567,7 +5560,7 @@ resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux') or (python_full_version < '3.11' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux')", ] dependencies = [ - { name = "setuptools", marker = "(platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, + { name = "setuptools", marker = "(platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra != 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra != 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra != 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra != 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 's390x' and sys_platform != 'linux' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 'aarch64' and sys_platform == 'linux' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra != 'extra-16-inference-models-torch-cpu' and extra != 'extra-16-inference-models-torch-cu118' and extra != 'extra-16-inference-models-torch-cu124' and extra != 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine == 's390x' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (platform_machine != 'aarch64' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'darwin' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cpu' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-cu12') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu118' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-onnx-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-jp6-cu126' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'linux' and extra != 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu118') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cpu' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu124') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu118' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu124' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu128') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu126' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-cu130') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu128' and extra == 'extra-16-inference-models-torch-jp6-cu126') or (sys_platform == 'darwin' and extra != 'extra-16-inference-models-onnx-cpu' and extra != 'extra-16-inference-models-onnx-cu118' and extra != 'extra-16-inference-models-onnx-cu12' and extra == 'extra-16-inference-models-torch-cu130' and extra == 'extra-16-inference-models-torch-jp6-cu126')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, diff --git a/requirements/requirements.cpu.txt b/requirements/requirements.cpu.txt index 8ea73fae2e..810011e42e 100644 --- a/requirements/requirements.cpu.txt +++ b/requirements/requirements.cpu.txt @@ -1,3 +1,3 @@ onnxruntime>=1.15.1,<1.22.0 nvidia-ml-py<13.0.0 -inference-models[torch-cpu,onnx-cpu]~=0.28.1 # keep in sync between requirements.gpu.txt, requirements.cpu.txt, requirements.vino.txt +inference-models[torch-cpu,onnx-cpu]~=0.28.4 # keep in sync between requirements.gpu.txt, requirements.cpu.txt, requirements.vino.txt diff --git a/requirements/requirements.gpu.txt b/requirements/requirements.gpu.txt index 6f53990793..6e67a9b4ad 100644 --- a/requirements/requirements.gpu.txt +++ b/requirements/requirements.gpu.txt @@ -1,2 +1,2 @@ onnxruntime-gpu>=1.15.1,<1.22.0 -inference-models[torch-cu124,onnx-cu12]~=0.28.1 # keep in sync between requirements.jetson requirements.gpu.txt, requirements.cpu.txt, requirements.vino.txt +inference-models[torch-cu124,onnx-cu12]~=0.28.4 # keep in sync between requirements.jetson requirements.gpu.txt, requirements.cpu.txt, requirements.vino.txt diff --git a/requirements/requirements.jetson.txt b/requirements/requirements.jetson.txt index 4d9308e509..7ec79eaf98 100644 --- a/requirements/requirements.jetson.txt +++ b/requirements/requirements.jetson.txt @@ -1,4 +1,4 @@ pypdfium2>=4.11.0,<5.0.0 jupyterlab>=4.3.0,<5.0.0 PyYAML~=6.0.0 -inference-models~=0.28.1 # keep in sync between requirements.jetson requirements.gpu.txt, requirements.cpu.txt, requirements.vino.txt +inference-models~=0.28.4 # keep in sync between requirements.jetson requirements.gpu.txt, requirements.cpu.txt, requirements.vino.txt diff --git a/requirements/requirements.sam3.txt b/requirements/requirements.sam3.txt index b5630fdf36..5a69aa0a1c 100644 --- a/requirements/requirements.sam3.txt +++ b/requirements/requirements.sam3.txt @@ -1 +1 @@ -sam3==0.1.3 \ No newline at end of file +sam3==0.1.4 \ No newline at end of file diff --git a/requirements/requirements.vino.txt b/requirements/requirements.vino.txt index 872f409906..0474bcbf20 100644 --- a/requirements/requirements.vino.txt +++ b/requirements/requirements.vino.txt @@ -1,2 +1,2 @@ onnxruntime-openvino>=1.15.0,<1.22.0 -inference-models[torch-cpu]~=0.28.1 # keep in sync between requirements.jetson requirements.gpu.txt, requirements.cpu.txt, requirements.vino.txt +inference-models[torch-cpu]~=0.28.4 # keep in sync between requirements.jetson requirements.gpu.txt, requirements.cpu.txt, requirements.vino.txt diff --git a/tests/inference/models_predictions_tests/test_rfdetr_seg_preprocessing_parity.py b/tests/inference/models_predictions_tests/test_rfdetr_seg_preprocessing_parity.py new file mode 100644 index 0000000000..be15d9142a --- /dev/null +++ b/tests/inference/models_predictions_tests/test_rfdetr_seg_preprocessing_parity.py @@ -0,0 +1,336 @@ +"""End-to-end RF-DETR segmentation parity for Triton preprocessing. + +This test uses a small set of images +already present in the repository and toggles the Triton path in-process. +""" + +import json +import os +from pathlib import Path + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") +cv2 = pytest.importorskip("cv2") +pytest.importorskip("triton") +if not torch.cuda.is_available(): # pragma: no cover - host-dependent + pytest.skip("CUDA not available", allow_module_level=True) + +REPO_ROOT = Path(__file__).resolve().parents[3] +ASSETS_DIR = REPO_ROOT / "tests" / "inference" / "models_predictions_tests" / "assets" +DEFAULT_MODEL_ID = "rfdetr-seg-nano" +DEFAULT_CONFIDENCE = 0.4 +BACKENDS = ("torch", "onnx", "trt") +IMAGE_PATHS = [ + ASSETS_DIR / "beer.jpg", + ASSETS_DIR / "person_image.jpg", + ASSETS_DIR / "truck.jpg", + ASSETS_DIR / "melee.jpg", +] +MATCH_IOU_THRESHOLD = 0.99 +MIN_BOX_IOU = 0.99 +MIN_MEAN_BOX_IOU = 0.99 +MAX_SCORE_DELTA = 0.003 +MIN_MASK_IOU = 0.99 +MIN_MEAN_MASK_IOU = 0.99 + + +def ensure_backend_dependencies_available(backend: str) -> None: + if backend == "onnx": + onnxruntime = pytest.importorskip("onnxruntime") + available_providers = set(onnxruntime.get_available_providers()) + if "CUDAExecutionProvider" not in available_providers: + pytest.skip("onnxruntime CUDAExecutionProvider not available") + elif backend == "trt": + pytest.importorskip("tensorrt") + pytest.importorskip("pycuda.driver") + + +def configure_backend_environment( + backend: str, monkeypatch: pytest.MonkeyPatch +) -> None: + if backend == "onnx" and "ONNXRUNTIME_EXECUTION_PROVIDERS" not in os.environ: + monkeypatch.setenv( + "ONNXRUNTIME_EXECUTION_PROVIDERS", + "[CUDAExecutionProvider,CPUExecutionProvider]", + ) + elif backend == "trt" and "ONNXRUNTIME_EXECUTION_PROVIDERS" not in os.environ: + monkeypatch.setenv( + "ONNXRUNTIME_EXECUTION_PROVIDERS", + "[TensorrtExecutionProvider,CUDAExecutionProvider,CPUExecutionProvider]", + ) + + +def load_parity_model(backend: str): + from inference_models import AutoModel + + return AutoModel.from_pretrained( + DEFAULT_MODEL_ID, + backend=backend, + device="cuda", + trt_engine_host_code_allowed=True, + ) + + +def collect_parity_records(model, backend: str, use_triton: bool) -> dict: + import inference_models.models.rfdetr.pre_processing as pre_processing + + original_triton_kernel = pre_processing.triton_preprocess_rfdetr_stretch + if use_triton: + assert ( + original_triton_kernel is not None + ), "RF-DETR Triton preprocessing kernel is unavailable." + + triton_calls = {"count": 0} + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + pre_processing, + "USE_TRITON_FOR_PREPROCESSING", + use_triton, + ) + if use_triton: + + def counting(*args, **kwargs): + triton_calls["count"] += 1 + return original_triton_kernel(*args, **kwargs) + + monkeypatch.setattr( + pre_processing, + "triton_preprocess_rfdetr_stretch", + counting, + ) + + records = [] + for image_path in IMAGE_PATHS: + image = cv2.imread(str(image_path), cv2.IMREAD_COLOR) + if image is None: + raise RuntimeError(f"Could not load test image: {image_path}") + pre_processed, metadata = model.pre_process(image) + raw_predictions = model.forward(pre_processed) + detections = model.post_process( + raw_predictions, + metadata, + confidence=DEFAULT_CONFIDENCE, + )[0] + detections_count = int(detections.class_id.numel()) + records.append( + { + "path": str(image_path), + "xyxy": ( + detections.xyxy.detach().cpu().numpy() + if detections_count + else np.zeros((0, 4), dtype=np.float32) + ), + "conf": ( + detections.confidence.detach().cpu().numpy() + if detections_count + else np.zeros((0,), dtype=np.float32) + ), + "cls": ( + detections.class_id.detach().cpu().numpy() + if detections_count + else np.zeros((0,), dtype=np.int32) + ), + "mask": ( + detections.mask.detach().to(torch.bool).cpu().numpy() + if detections_count and detections.mask is not None + else None + ), + } + ) + + return { + "backend": backend, + "records": records, + "triton_calls": triton_calls["count"], + "use_triton_for_preprocessing": use_triton, + } + + +def iou_box(a: np.ndarray, b: np.ndarray) -> float: + x0 = max(float(a[0]), float(b[0])) + y0 = max(float(a[1]), float(b[1])) + x1 = min(float(a[2]), float(b[2])) + y1 = min(float(a[3]), float(b[3])) + inter_w = max(0.0, x1 - x0) + inter_h = max(0.0, y1 - y0) + inter = inter_w * inter_h + area_a = max(0.0, float(a[2]) - float(a[0])) * max( + 0.0, float(a[3]) - float(a[1]) + ) + area_b = max(0.0, float(b[2]) - float(b[0])) * max( + 0.0, float(b[3]) - float(b[1]) + ) + union = area_a + area_b - inter + return inter / union if union > 0 else 0.0 + + +def summarize_parity(triton_run: dict, reference_run: dict) -> dict: + triton_records = triton_run["records"] + reference_records = reference_run["records"] + if len(triton_records) != len(reference_records): + raise AssertionError( + "Collected different numbers of images between parity runs: " + f"{len(triton_records)} != {len(reference_records)}" + ) + + total_triton_detections = 0 + total_reference_detections = 0 + matched_detections = 0 + count_mismatch_images = 0 + class_disagreements = 0 + mask_presence_mismatches = 0 + box_ious = [] + score_deltas = [] + mask_ious = [] + + for triton_record, reference_record in zip(triton_records, reference_records): + if triton_record["path"] != reference_record["path"]: + raise AssertionError( + "Image order mismatch between parity runs: " + f"{triton_record['path']} != {reference_record['path']}" + ) + + triton_boxes = triton_record["xyxy"] + reference_boxes = reference_record["xyxy"] + triton_scores = triton_record["conf"] + reference_scores = reference_record["conf"] + triton_classes = triton_record["cls"] + reference_classes = reference_record["cls"] + triton_masks = triton_record["mask"] + reference_masks = reference_record["mask"] + + triton_count = len(triton_boxes) + reference_count = len(reference_boxes) + total_triton_detections += triton_count + total_reference_detections += reference_count + if triton_count != reference_count: + count_mismatch_images += 1 + if triton_count == 0 and reference_count == 0: + continue + + used_triton_indices = set() + for reference_index in range(reference_count): + best_triton_index = -1 + best_iou = MATCH_IOU_THRESHOLD + for triton_index in range(triton_count): + if triton_index in used_triton_indices: + continue + iou = iou_box( + triton_boxes[triton_index], + reference_boxes[reference_index], + ) + if iou > best_iou: + best_iou = iou + best_triton_index = triton_index + if best_triton_index < 0: + continue + + used_triton_indices.add(best_triton_index) + matched_detections += 1 + box_ious.append(best_iou) + score_deltas.append( + abs( + float(triton_scores[best_triton_index]) + - float(reference_scores[reference_index]) + ) + ) + if int(triton_classes[best_triton_index]) != int( + reference_classes[reference_index] + ): + class_disagreements += 1 + + if (triton_masks is None) != (reference_masks is None): + mask_presence_mismatches += 1 + elif triton_masks is not None and reference_masks is not None: + triton_mask = triton_masks[best_triton_index] + reference_mask = reference_masks[reference_index] + intersection = np.logical_and(triton_mask, reference_mask).sum() + union = np.logical_or(triton_mask, reference_mask).sum() + mask_ious.append(float(intersection) / float(union) if union else 0.0) + + unmatched_reference_detections = total_reference_detections - matched_detections + return { + "backend": reference_run["backend"], + "images": len(reference_records), + "triton_calls_enabled": int(triton_run["triton_calls"]), + "triton_calls_disabled": int(reference_run["triton_calls"]), + "total_triton_detections": int(total_triton_detections), + "total_reference_detections": int(total_reference_detections), + "matched_detections": int(matched_detections), + "unmatched_reference_detections": int(unmatched_reference_detections), + "count_mismatch_images": int(count_mismatch_images), + "class_disagreements": int(class_disagreements), + "mask_presence_mismatches": int(mask_presence_mismatches), + "mean_box_iou": float(np.mean(box_ious)) if box_ious else None, + "min_box_iou": float(np.min(box_ious)) if box_ious else None, + "mean_abs_score_delta": ( + float(np.mean(score_deltas)) if score_deltas else None + ), + "max_abs_score_delta": float(np.max(score_deltas)) if score_deltas else None, + "mean_mask_iou": float(np.mean(mask_ious)) if mask_ious else None, + "min_mask_iou": float(np.min(mask_ious)) if mask_ious else None, + } + + +def format_summary(summary: dict) -> str: + return json.dumps(summary, indent=2, sort_keys=True) + + +@pytest.mark.timeout(1200) +@pytest.mark.slow +@pytest.mark.parametrize("backend", BACKENDS, ids=BACKENDS) +def test_rfdetr_seg_nano_triton_preprocessing_matches_reference_path( + backend: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + missing_images = [ + str(image_path) for image_path in IMAGE_PATHS if not image_path.is_file() + ] + assert not missing_images, f"Missing parity images: {missing_images}" + + ensure_backend_dependencies_available(backend=backend) + configure_backend_environment(backend=backend, monkeypatch=monkeypatch) + model = load_parity_model(backend=backend) + + try: + enabled_run = collect_parity_records( + model=model, + backend=backend, + use_triton=True, + ) + disabled_run = collect_parity_records( + model=model, + backend=backend, + use_triton=False, + ) + finally: + del model + torch.cuda.empty_cache() + + summary = summarize_parity( + triton_run=enabled_run, + reference_run=disabled_run, + ) + summary_text = format_summary(summary) + + assert enabled_run["backend"] == disabled_run["backend"] == backend + assert enabled_run["triton_calls"] == len(IMAGE_PATHS), summary_text + assert disabled_run["triton_calls"] == 0, summary_text + assert summary["total_reference_detections"] > 0, summary_text + assert summary["count_mismatch_images"] == 0, summary_text + assert summary["class_disagreements"] == 0, summary_text + assert summary["mask_presence_mismatches"] == 0, summary_text + assert summary["unmatched_reference_detections"] == 0, summary_text + assert summary["mean_box_iou"] is not None, summary_text + assert summary["min_box_iou"] is not None, summary_text + assert summary["mean_abs_score_delta"] is not None, summary_text + assert summary["max_abs_score_delta"] is not None, summary_text + assert summary["mean_mask_iou"] is not None, summary_text + assert summary["min_mask_iou"] is not None, summary_text + assert summary["mean_box_iou"] >= MIN_MEAN_BOX_IOU, summary_text + assert summary["min_box_iou"] >= MIN_BOX_IOU, summary_text + assert summary["max_abs_score_delta"] <= MAX_SCORE_DELTA, summary_text + assert summary["mean_mask_iou"] >= MIN_MEAN_MASK_IOU, summary_text + assert summary["min_mask_iou"] >= MIN_MASK_IOU, summary_text