Skip to content

Commit

Permalink
Fix some Pylance errors (#259)
Browse files Browse the repository at this point in the history
* Ignore IDE files

Signed-off-by: Jan Bielak <[email protected]>

* Fix typing errors

Signed-off-by: Jan Bielak <[email protected]>

* Ignore devcontainer files

Signed-off-by: Jan Bielak <[email protected]>

* Avoid import from private module

Signed-off-by: Jan Bielak <[email protected]>

* Apply @timmoon10 's suggestions

Signed-off-by: Jan Bielak <[email protected]>

---------

Signed-off-by: Jan Bielak <[email protected]>
  • Loading branch information
janekb04 authored Jun 2, 2023
1 parent 80825fd commit 144e488
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ build/
__pycache__
.ycm_extra_conf.py
.vimrc
.vs
.vscode
.cache
.hypothesis
.devcontainer.json
tests/cpp/build/
docs/_build
.ipynb_checkpoints
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ def get_batch_per_block(k_seqlen: int) -> int:
pow2 = 1 << (k_seqlen - 1).bit_length()
warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = threads_per_block / warp_size
warps_per_block = threads_per_block // warp_size
batches_per_block = warps_per_block * batches_per_warp
return batches_per_block

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/jax/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def generate_fp8_max_array(num_of_meta):
return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)

@staticmethod
def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int]:
def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int, int, int]:
"""
Obtain the index about FP8 metas by the given GEMM index.
"""
Expand Down Expand Up @@ -453,7 +453,7 @@ def get_delayed_scaling():
"""
amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
else "most_recent"
return DelayedScaling(margin=FP8Helper.MARGIN,
return DelayedScaling(margin=int(FP8Helper.MARGIN),
interval=FP8Helper.UPDATE_FP8META_INTERVAL,
fp8_format=FP8Helper.FP8_FORMAT,
amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Enums for e2e transformer"""
import torch
import torch.distributed
import transformer_engine_extensions as tex


Expand All @@ -29,4 +30,4 @@

GemmParallelModes = ("row", "column", None)

dist_group_type = torch._C._distributed_c10d.ProcessGroup
dist_group_type = torch.distributed.ProcessGroup
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Tuple, Dict, Any, List
from typing import Generator, Union, Optional, Tuple, Dict, Any, List
from functools import partial
from contextlib import contextmanager

Expand Down Expand Up @@ -86,7 +86,7 @@ def _prepare_backward(
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> None:
) -> Generator[None, None, None]:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
Expand Down Expand Up @@ -542,7 +542,7 @@ def prepare_forward(
inp: torch.Tensor,
is_first_microbatch: Union[bool, None],
num_gemms: int = 1,
) -> None:
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,6 @@ def get_batch_per_block(key_seq_len: int) -> int:
pow2 = 1 << (key_seq_len - 1).bit_length()
warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = THREADS_PER_BLOCK / warp_size
warps_per_block = THREADS_PER_BLOCK // warp_size
batches_per_block = warps_per_block * batches_per_warp
return batches_per_block
4 changes: 2 additions & 2 deletions transformer_engine/tensorflow/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""FP8 utilies for TransformerEngine"""
from contextlib import contextmanager
from typing import Optional, Dict, Any
from typing import Generator, Optional, Dict, Any

import tensorflow as tf
import transformer_engine_tensorflow as tex
Expand Down Expand Up @@ -69,7 +69,7 @@ def get_default_fp8_recipe():
def fp8_autocast(
enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
) -> Generator[None, None, None]:
"""
Context manager for FP8 usage.
Expand Down

0 comments on commit 144e488

Please sign in to comment.