Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations in Fishtest #2140

Open
vdbergh opened this issue Nov 9, 2024 · 6 comments
Open

Type annotations in Fishtest #2140

vdbergh opened this issue Nov 9, 2024 · 6 comments

Comments

@vdbergh
Copy link
Contributor

vdbergh commented Nov 9, 2024

To stay up to date I adapted vtjson to work well with type annotations. This would be the new runs schema:

import copy
import math
from datetime import datetime, timezone
from typing import Annotated, Literal, NotRequired, TypedDict

from bson.objectid import ObjectId

from vtjson import (
    at_most_one_of,
    div,
    fields,
    ge,
    glob,
    gt,
    ifthen,
    intersect,
    ip_address,
    keys,
    lax,
    one_of,
    quote,
    regex,
    skip_first,
    url,
)

username = Annotated[str, regex(r"[!-~][ -~]{0,30}[!-~]", name="username"), skip_first]
net_name = Annotated[str, regex("nn-[a-f0-9]{12}.nnue", name="net_name"), skip_first]
tc = Annotated[
    str, regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc"), skip_first
]
str_int = Annotated[str, regex(r"[1-9]\d*", name="str_int"), skip_first]
sha = Annotated[str, regex(r"[a-f0-9]{40}", name="sha"), skip_first]
country_code = Annotated[str, regex(r"[A-Z][A-Z]", name="country_code"), skip_first]
run_id = Annotated[str, ObjectId.is_valid]
uuid = Annotated[
    str,
    regex(r"[0-9a-zA-Z]{2,}(-[a-f0-9]{4}){3}-[a-f0-9]{12}", name="uuid"),
    skip_first,
]
epd_file = Annotated[str, glob("*.epd", name="epd_file"), skip_first]
pgn_file = Annotated[str, glob("*.pgn", name="pgn_file"), skip_first]
even = Annotated[int, div(2, name="even"), skip_first]
datetime_utc = Annotated[datetime, fields({"tzinfo": timezone.utc})]

uint = Annotated[int, ge(0)]
suint = Annotated[int, gt(0)]
ufloat = Annotated[float, ge(0)]
sufloat = Annotated[float, gt(0)]


class results_type(TypedDict):
    wins: uint
    losses: uint
    draws: uint
    crashes: uint
    time_losses: uint
    pentanomial: Annotated[list[int], [uint, uint, uint, uint, uint], skip_first]


def valid_results(R: results_type) -> bool:
    l, d, w = R["losses"], R["draws"], R["wins"]
    Rp = R["pentanomial"]
    return (
        l + d + w == 2 * sum(Rp)
        and w - l == 2 * Rp[4] + Rp[3] - Rp[1] - 2 * Rp[0]
        and Rp[3] + 2 * Rp[2] + Rp[1] >= d >= Rp[3] + Rp[1]
    )


results_schema = Annotated[
    results_type,
    valid_results,
]


class worker_info_schema(TypedDict):
    uname: str
    architecture: Annotated[list[str], [str, str], skip_first]
    concurrency: suint
    max_memory: uint
    min_threads: suint
    username: str
    version: uint
    python_version: Annotated[list[int], [uint, uint, uint], skip_first]
    gcc_version: Annotated[list[int], [uint, uint, uint], skip_first]
    compiler: Literal["clang++", "g++"]
    unique_key: uuid
    modified: bool
    ARCH: str
    nps: ufloat
    near_github_api_limit: bool
    remote_addr: Annotated[str, ip_address]
    country_code: country_code | Literal["?"]


class overshoot_type(TypedDict):
    last_update: uint
    skipped_updates: uint
    ref0: float
    m0: float
    sq0: ufloat
    ref1: float
    m1: float
    sq1: ufloat


class sprt_type(TypedDict):
    alpha: Annotated[float, 0.05, skip_first]
    beta: Annotated[float, 0.05, skip_first]
    elo0: float
    elo1: float
    elo_model: Literal["normalized"]
    state: Literal["", "accepted", "rejected"]
    llr: float
    batch_size: suint
    lower_bound: Annotated[float, -math.log(19), skip_first]
    upper_bound: Annotated[float, math.log(19), skip_first]
    lost_samples: NotRequired[uint]
    illegal_update: NotRequired[uint]
    overshoot: NotRequired[overshoot_type]


sprt_schema = Annotated[
    sprt_type,
    one_of("overshoot", "lost_samples"),
]


class param_schema(TypedDict):
    name: str
    start: float
    min: float
    max: float
    c_end: sufloat
    r_end: ufloat
    c: sufloat
    a_end: ufloat
    a: ufloat
    theta: float


class param_history_schema(TypedDict):
    theta: float
    R: ufloat
    c: ufloat


class spsa_schema(TypedDict):
    A: ufloat
    alpha: ufloat
    gamma: ufloat
    raw_params: str
    iter: uint
    num_iter: uint
    params: list[param_schema]
    param_history: NotRequired[list[list[param_history_schema]]]


class args_type(TypedDict):
    base_tag: str
    new_tag: str
    base_nets: list[net_name]
    new_nets: list[net_name]
    num_games: Annotated[uint, even]
    tc: tc
    new_tc: tc
    book: epd_file | pgn_file
    book_depth: str_int
    threads: suint
    resolved_base: sha
    resolved_new: sha
    master_sha: sha
    official_master_sha: sha
    msg_base: str
    msg_new: str
    base_options: str
    new_options: str
    info: str
    base_signature: str_int
    new_signature: str_int
    username: username
    tests_repo: Annotated[str, url, skip_first]
    auto_purge: bool
    throughput: ufloat
    itp: ufloat
    priority: float
    adjudication: bool
    sprt: NotRequired[sprt_schema]
    spsa: NotRequired[spsa_schema]


args_schema = Annotated[
    args_type,
    at_most_one_of("sprt", "spsa"),
]


class task_type(TypedDict):
    num_games: Annotated[uint, even]
    active: bool
    last_updated: datetime_utc
    start: uint
    residual: float
    residual_color: NotRequired[str]
    bad: NotRequired[Literal[True]]
    stats: results_schema
    worker_info: worker_info_schema


zero_results: results_type = {
    "wins": 0,
    "draws": 0,
    "losses": 0,
    "crashes": 0,
    "time_losses": 0,
    "pentanomial": 5 * [0],
}

if_bad_then_zero_stats_and_not_active = ifthen(
    keys("bad"), lax({"active": False, "stats": quote(zero_results)})
)

task_schema = Annotated[
    task_type,
    if_bad_then_zero_stats_and_not_active,
]


class bad_task_schema(TypedDict):
    num_games: Annotated[uint, even]
    active: Literal[False]
    last_updated: datetime_utc
    start: uint
    residual: float
    residual_color: str
    bad: Literal[True]
    task_id: uint
    stats: results_schema
    worker_info: worker_info_schema


class results_info_schema(TypedDict):
    style: str
    info: list[str]


class runs_type(TypedDict):
    _id: NotRequired[ObjectId]
    version: uint
    start_time: datetime_utc
    last_updated: datetime_utc
    tc_base: ufloat
    base_same_as_master: bool
    rescheduled_from: NotRequired[run_id]
    approved: bool
    approver: username | Literal[""]
    finished: bool
    deleted: bool
    failed: bool
    is_green: bool
    is_yellow: bool
    workers: uint
    cores: uint
    results: results_schema
    results_info: NotRequired[results_info_schema]
    args: args_schema
    tasks: list[task_schema]
    bad_tasks: NotRequired[list[bad_task_schema]]


def final_results_must_match(run: runs_type) -> bool:
    rr = copy.deepcopy(zero_results)
    for t in run["tasks"]:
        r = t["stats"]
        # mypy does not support variable keys for
        # TypedDict
        rr["wins"] += r["wins"]
        rr["losses"] += r["losses"]
        rr["draws"] += r["draws"]
        rr["crashes"] += r["crashes"]
        rr["time_losses"] += r["time_losses"]
        for i, p in enumerate(r["pentanomial"]):
            rr["pentanomial"][i] += p
    if rr != run["results"]:
        raise Exception(
            f"The final results {run['results']} do not match the computed results {rr}"
        )
    else:
        return True


def cores_must_match(run: runs_type) -> bool:
    cores = 0
    for t in run["tasks"]:
        if t["active"]:
            cores += t["worker_info"]["concurrency"]
    if cores != run["cores"]:
        raise Exception(
            f"Cores mismatch. Cores from tasks: {cores}. Cores from "
            f"run: {run['cores']}"
        )

    return True


def workers_must_match(run: runs_type) -> bool:
    workers = 0
    for t in run["tasks"]:
        if t["active"]:
            workers += 1
    if workers != run["workers"]:
        raise Exception(
            f"Workers mismatch. Workers from tasks: {workers}. Workers from "
            f"run: {run['workers']}"
        )

    return True


valid_aggregated_data = intersect(
    final_results_must_match,
    cores_must_match,
    workers_must_match,
)

runs_schema = Annotated[
    runs_type,
    lax(ifthen({"approved": True}, {"approver": username}, {"approver": ""})),
    lax(ifthen({"is_green": True}, {"is_yellow": False})),
    lax(ifthen({"is_yellow": True}, {"is_green": False})),
    lax(ifthen({"failed": True}, {"finished": True})),
    lax(ifthen({"deleted": True}, {"finished": True})),
    lax(ifthen({"finished": True}, {"workers": 0, "cores": 0})),
    lax(ifthen({"finished": True}, {"tasks": [{"active": False}, ...]})),
    valid_aggregated_data,
]
@vdbergh
Copy link
Contributor Author

vdbergh commented Dec 4, 2024

I have now created comprehensive documentation for vtjson. See https://www.cantate.be/vtjson/ (canonical reference) or https://vtjson.readthedocs.io (if you don't mind some ads).

@ppigazzini
Copy link
Collaborator

Curiosities:

  • the example code uses snake case instead than camel case for the class naming
  • the classes are used only for the validation or are used in the code replacing the dictionaries

@vdbergh
Copy link
Contributor Author

vdbergh commented Jan 6, 2025

The classes are used both for type checking and for validation. TypedDict is a standard python typing type for typed dictionaries.

class Foo(TypedDict):
   baz: int
   baz2: str

is functionally equivalent to

{"baz": int, "baz2": str} 

as far as vtjson is concerned. But the class Foo can also be used by static type checkers such as mypy for compile time validation

The point is that if we get some untyped json boo_untyped from an api then we can write

boo_typed = safe_cast(Foo, boo_untyped)

This will accomplish two things:

  • boo_untyped will be checked at run time that it really correspond to the schema Foo.
  • mypy will assign the type Foo to boo_typed and use it for further static type checking.

@vdbergh
Copy link
Contributor Author

vdbergh commented Jan 6, 2025

the example code uses snake case instead than camel case for the class naming

I used lower case for the classes since they are really vtjson schemas in disguise and currently the schemas are written in lower case... I don't feel strongly about this.

@ppigazzini
Copy link
Collaborator

ppigazzini commented Jan 7, 2025

Very interesting! I can use Annotated in my fastapi projects to validate the data as well (too lazy to read any advanced use of Pydantic until now).

from typing import Annotated

from pydantic import (
    AnyUrl,
    BaseModel,
    EmailStr,
    Field,
    IPvAnyAddress,
    ValidationError,
    conint,
    conlist,
    constr,
)


class AddressModel(BaseModel):
    street: Annotated[constr(pattern=r"^\d+\s[A-Za-z\s\.]+$"), Field()]
    city: Annotated[constr(pattern=r"^[A-Za-z\s]+$"), Field()]
    zipcode: Annotated[constr(pattern=r"^\d{5}$"), Field()]


class UserModel(BaseModel):
    name: Annotated[constr(min_length=2), Field()]
    age: Annotated[conint(strict=True, gt=0, lt=120), Field()]
    address: AddressModel
    email: EmailStr
    urls: list[AnyUrl] | None = None
    ips: Annotated[conlist(IPvAnyAddress, min_length=2, max_length=2), Field()]


class NestedDictModel(BaseModel):
    user: UserModel
    other_key: str


def validate_data(data: dict) -> dict:
    try:
        validated_data = NestedDictModel(**data)
        return validated_data.model_dump()
    except ValidationError as e:
        print(e.json())
        raise


# Example usage
good_data = {
    "user": {
        "name": "John Doe",
        "age": 40,
        "address": {"street": "123 Main St.", "city": "Anytown", "zipcode": "12345"},
        "email": "[email protected]",
        "urls": ["https://example.com", "https://example.org"],
        "ips": ["192.168.1.1", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"],
    },
    "other_key": "other_value",
}

try:
    validated_data = validate_data(good_data)
    print("Validated data:", validated_data)
except ValidationError as e:
    print("Validation error:", e)

bad_data = {
    "user": {
        "name": "John",
        "age": "ggggg",
        "address": {"street": "Main St.", "city": 123, "zipcode": "123456"},
        "email": "john.doe@invalid",
        "urls": ["invalid-url"],
        "ips": ["999.999"],
    },
}

try:
    validated_data = validate_data(bad_data)
    print("Validated data:", validated_data)
except ValidationError as e:
    print("Validation error:", e)

@vdbergh
Copy link
Contributor Author

vdbergh commented Jan 7, 2025

So pydantic seems to be somewhat similar to vtjson... :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants