Skip to content

Commit 88f4f76

Browse files
committed
Add pyright + fix caught typos
1 parent 0c1718c commit 88f4f76

File tree

11 files changed

+62
-17
lines changed

11 files changed

+62
-17
lines changed

.github/workflows/pyright.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: pyright
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
pyright:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.12"]
15+
16+
steps:
17+
- uses: actions/checkout@v2
18+
- name: Set up Python ${{ matrix.python-version }}
19+
uses: actions/setup-python@v1
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
- name: Install dependencies
23+
run: |
24+
pip install uv
25+
uv pip install --system -e .
26+
uv pip install --system jax
27+
uv pip install --system git+https://github.com/brentyi/jaxls.git
28+
uv pip install --system pyright
29+
- name: Run pyright
30+
run: |
31+
pyright .

3_aria_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def main(args: Args) -> None:
188188
hamer_detections,
189189
aria_detections,
190190
points_data=points_data,
191-
splat_path=splat_path,
191+
splat_path=traj_paths.splat_path,
192192
floor_z=floor_z,
193193
)
194194
while True:

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ dependencies = [
2828
"tensorboardX",
2929
"loguru",
3030
"projectaria-tools[all]",
31-
"opencv-python"
31+
"opencv-python",
32+
"gdown",
3233
]
3334

3435
[tool.setuptools.package-data]

src/egoallo/data/amass.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from pathlib import Path
2-
from typing import Any, Literal, cast
2+
from typing import Any, Literal, assert_never, cast
33

44
import h5py
55
import numpy as np
66
import torch
77
import torch.utils
88
import torch.utils.data
9-
from typing_extensions import assert_never
109

1110
from .dataclass import EgoTrainingData
1211

src/egoallo/data/dataclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def joints_wrt_world(self) -> Tensor:
5858

5959
@staticmethod
6060
def load_from_npz(
61-
body_model: fncsmpl.SmplModel,
61+
body_model: fncsmpl.SmplhModel,
6262
path: Path,
6363
include_hands: bool,
6464
) -> EgoTrainingData:

src/egoallo/network.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass
44
from functools import cache, cached_property
5-
from typing import Literal
5+
from typing import Literal, assert_never
66

77
import numpy as np
88
import torch
@@ -11,7 +11,6 @@
1111
from loguru import logger
1212
from rotary_embedding_torch import RotaryEmbedding
1313
from torch import Tensor, nn
14-
from typing_extensions import assert_never
1514

1615
from .fncsmpl import SmplhModel, SmplhShapedAndPosed
1716
from .tensor_dataclass import TensorDataclass

src/egoallo/tensor_dataclass.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import dataclasses
2-
from typing import Any, Callable
2+
from typing import Any, Callable, Self, dataclass_transform
33

44
import torch
5-
from typing_extensions import Self, dataclass_transform
65

76

87
@dataclass_transform()

src/egoallo/training_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@
88
import time
99
import traceback as tb
1010
from pathlib import Path
11-
from typing import Any, Dict, Generator, Iterable, Protocol, Sized, overload
11+
from typing import (
12+
Any,
13+
Dict,
14+
Generator,
15+
Iterable,
16+
Protocol,
17+
Sized,
18+
get_type_hints,
19+
overload,
20+
)
1221

1322
import torch
14-
from typing_extensions import get_type_hints
1523

1624

1725
def flattened_hparam_dict_from_dataclass(

src/egoallo/transforms/_base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
import abc
2-
from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload
2+
from typing import (
3+
ClassVar,
4+
Generic,
5+
Self,
6+
Tuple,
7+
Type,
8+
TypeVar,
9+
Union,
10+
final,
11+
overload,
12+
override,
13+
)
314

415
import numpy as onp
516
import torch
617
from torch import Tensor
7-
from typing_extensions import Self, final, override
818

919
GroupType = TypeVar("GroupType", bound="MatrixLieGroup")
1020
SEGroupType = TypeVar("SEGroupType", bound="SEBase")

src/egoallo/transforms/_se3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import cast
4+
from typing import Union, cast, override
55

66
import numpy as np
77
import torch
88
from torch import Tensor
9-
from typing_extensions import Union, override
109

1110
from . import _base
1211
from ._so3 import SO3

0 commit comments

Comments
 (0)