Skip to content

Commit

Permalink
Merge branch 'release/0.3.9'
Browse files Browse the repository at this point in the history
  • Loading branch information
ControlNet committed Feb 18, 2024
2 parents 06becfb + 2ee59e4 commit b74131a
Show file tree
Hide file tree
Showing 18 changed files with 223 additions and 15 deletions.
3 changes: 2 additions & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ numpy >= 1.20.1
einops >= 0.3.0
matplotlib >= 3.4.2
pyyaml == 6.0
toml >= 0.10.0
toml >= 0.10.0
safetensors >= 0.3.1
11 changes: 8 additions & 3 deletions src/tensorneko/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import os.path

from tensorneko_util import io
from . import backend
from . import callback
from . import dataset
from . import debug
from . import evaluation
from . import io
from . import layer
from . import module
from . import notebook
from . import optim
from . import preprocess
from . import util
from . import visualization
from .io import read, write
from .neko_model import NekoModel
from .neko_module import NekoModule
from .neko_trainer import NekoTrainer

__version__ = io.read.text(os.path.join(util.get_tensorneko_path(), "version.txt"))

__all__ = [
"callback",
"dataset",
Expand All @@ -33,7 +36,9 @@
"debug",
"NekoModel",
"NekoTrainer",
"NekoModule"
"NekoModule",
"read",
"write",
]

__version__ = io.read.text(os.path.join(util.get_tensorneko_path(), "version.txt"))

4 changes: 3 additions & 1 deletion src/tensorneko/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from tensorneko_util.backend import parallel, run_blocking, VisualLib, AudioLib
from tensorneko_util.backend import parallel, run_blocking, VisualLib, AudioLib, import_tqdm_auto, import_tqdm

__all__ = [
"parallel",
"run_blocking",
"VisualLib",
"AudioLib",
"import_tqdm_auto",
"import_tqdm",
]
3 changes: 2 additions & 1 deletion src/tensorneko/evaluation/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from torchvision.transforms.functional import resize

from tensorneko_util.backend import VisualLib
from tensorneko_util.backend._tqdm import import_tqdm_auto
from tensorneko_util.backend.tqdm import import_tqdm_auto

try:
from typing import Literal
TypeOption = Literal["video", "image"]
except ImportError:
Literal = None
TypeOption = str


Expand Down
2 changes: 2 additions & 0 deletions src/tensorneko/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Type

from tensorneko_util.io.reader import Reader as BaseReader
from .weight import WeightReader

try:
from .mesh import MeshReader
Expand All @@ -15,6 +16,7 @@ class Reader(BaseReader):

def __init__(self):
super().__init__()
self.weight = WeightReader
self._mesh = None

@property
Expand Down
2 changes: 2 additions & 0 deletions src/tensorneko/io/weight/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .weight_reader import WeightReader
from .weight_writer import WeightWriter
81 changes: 81 additions & 0 deletions src/tensorneko/io/weight/weight_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import OrderedDict

import torch

from ...util import Device


class WeightReader:
"""WeightReader for read model weights (checkpoints, state_dict, etc)."""

@classmethod
def of_pt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
"""
Reads PyTorch model weights from a `.pt` or `.pth` file.
Args:
path (``str``): The path of the `.pt` or `.pth` file.
map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
Returns:
:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
"""
return torch.load(path, map_location=map_location)

@classmethod
def of_ckpt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
"""
Reads PyTorch model weights from a `.ckpt` file.
Args:
path (``str``): The path of the `.ckpt` file.
map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
Returns:
:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
"""
return torch.load(path, map_location=map_location)["state_dict"]

@classmethod
def of_safetensors(cls, path: str, map_location: str = "cpu") -> OrderedDict[str, torch.Tensor]:
"""
Reads model weights from a `.safetensors` file.
Args:
path (``str``): The path of the `.safetensors` file.
map_location (``str``): The location to load the model weights. Default: "cpu"
Returns:
:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
"""
import safetensors
from collections import OrderedDict
tensors = OrderedDict()
with safetensors.safe_open(path, framework="pt", device=map_location) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors

@classmethod
def of(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]:
"""
Reads model weights from a file.
Args:
path (``str``): The path of the file.
map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu"
Returns:
:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights.
"""

if path.endswith(".pt") or path.endswith(".pth"):
return cls.of_pt(path, map_location)
elif path.endswith(".ckpt"):
return cls.of_ckpt(path, map_location)
elif path.endswith(".safetensors"):
if isinstance(map_location, torch.device):
map_location = str(map_location)
return cls.of_safetensors(path, map_location)
else:
raise ValueError("Unknown file type. Supported types: .pt, .pth, .ckpt, .safetensors")
48 changes: 48 additions & 0 deletions src/tensorneko/io/weight/weight_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Dict

import torch


class WeightWriter:
"""WeightWriter for write model weights (checkpoints, state_dict, etc)."""

@classmethod
def to_pt(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
"""
Writes PyTorch model weights to a `.pt` or `.pth` file.
Args:
path (``str``): The path of the `.pt` or `.pth` file.
weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
"""
torch.save(weights, path)

@classmethod
def to_safetensors(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
"""
Writes model weights to a `.safetensors` file.
Args:
path (``str``): The path of the `.safetensors` file.
weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
"""
import safetensors.torch
safetensors.torch.save_file(weights, path)

@classmethod
def to(cls, path: str, weights: Dict[str, torch.Tensor]) -> None:
"""
Writes model weights to a file.
Args:
path (``str``): The path of the file.
weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights.
"""
file_type = path.split(".")[-1]

if file_type == "pt":
cls.to_pt(path, weights)
elif file_type == "safetensors":
cls.to_safetensors(path, weights)
else:
raise ValueError("Unknown file type. Supported types: .pt, .safetensors")
2 changes: 2 additions & 0 deletions src/tensorneko/io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Type

from tensorneko_util.io.writer import Writer as BaseWriter
from .weight import WeightWriter

try:
from .mesh import MeshWriter
Expand All @@ -16,6 +17,7 @@ class Writer(BaseWriter):

def __init__(self):
super().__init__()
self.weight = WeightWriter
self._mesh = None

@property
Expand Down
2 changes: 1 addition & 1 deletion src/tensorneko/layer/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class GaussianNoise(NekoModule):
from https://discuss.pytorch.org/t/writing-a-simple-gaussian-noise-layer-in-pytorch/4694/3
"""

def __init__(self, sigma=0.1, device: Union[Device, str] = "cuda"):
def __init__(self, sigma=0.1, device: Device = "cuda"):
super().__init__()
self.sigma = sigma
self.noise = torch.tensor(0.).to(device)
Expand Down
2 changes: 1 addition & 1 deletion src/tensorneko/util/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""The module builder type of ``() -> torch.nn.Module | (int) -> torch.nn.Module``"""


Device = device
Device = Union[str, device]
"""Device type of :class:`torch.device`"""


Expand Down
3 changes: 3 additions & 0 deletions src/tensorneko_util/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from .visual_lib import VisualLib
from .audio_lib import AudioLib
from .blocking import run_blocking
from .tqdm import import_tqdm_auto, import_tqdm
from . import parallel

__all__ = [
"VisualLib",
"AudioLib",
"run_blocking",
"import_tqdm_auto",
"import_tqdm",
"parallel",
]
File renamed without changes.
67 changes: 64 additions & 3 deletions src/tensorneko_util/io/json/json_reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Union, Optional
from typing import Union, Optional, List

from ...util.type import T

Expand Down Expand Up @@ -52,6 +52,67 @@ def of(cls, path: str, clazz: Optional[T] = None, encoding: str = "UTF-8") -> Un

return obj


@classmethod
def of_jsonl(cls, path: str, clazz: Optional[T] = None, encoding: str = "UTF-8") -> List[Union[T, dict, list]]:
"""
Read jsonl files to ``list`` or ``dict``.
Args:
path (``str``): Jsonl file path.
clazz: (``T``, optional): The object of the jsonl read for. The type should be decorated by
:func:`json_data`. This should be ``T`` or ``List[T]`` or ``List[List[T]]``
encoding (``str``, optional): The encoding for python ``open`` function. Default: "UTF-8"
Returns:
``List[dict]`` | ``List[list]`` | ``List[T]``: The object of given jsonl.
"""
if clazz is not None:
if "typing.List[" in str(clazz):
with open(path, "r", encoding=encoding) as file:
obj: list = [json.loads(line) for line in file]

inner_type = clazz.__args__[0]
if "typing.List[" in str(inner_type):
inner_inner_type = inner_type.__args__[0]
try:
is_json_data = inner_type.is_json_data
except AttributeError:
is_json_data = False
if is_json_data:
obj = [[inner_inner_type(e) for e in each] for each in obj]

else:
try:
is_json_data = inner_type.is_json_data
except AttributeError:
is_json_data = False

if is_json_data:
obj = list(map(inner_type, obj))
else:
with open(path, "r", encoding=encoding) as file:
obj = clazz([json.loads(line) for line in file])
else:
with open(path, "r", encoding=encoding) as file:
obj = [json.loads(line) for line in file]

return obj

def __new__(cls, path: str, clazz: T = None, encoding: str = "UTF-8") -> Union[T, dict, list]:
"""Alias of :meth:`~tensorneko_util.io.json.json_reader.JsonReader.of`."""
return cls.of(path, clazz, encoding)
"""
Read json or jsonl file smartly.
Args:
path (``str``): Json or jsonl file path.
clazz: (``T``, optional): The object of the json read for. The type should be decorated by
:func:`json_data`. This should be ``T`` or ``List[T]`` or ``List[List[T]]``
encoding (``str``, optional): The encoding for python ``open`` function. Default: "UTF-8"
Returns:
``dict`` | ``list`` | ``object``: The object of given json or jsonl.
"""
if path.endswith(".jsonl"):
return cls.of_jsonl(path, clazz, encoding)
else:
return cls.of(path, clazz, encoding)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from numpy import ndarray

from ...backend._tqdm import import_tqdm_auto
from ...backend.tqdm import import_tqdm_auto


class AbstractFaceDetector(ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/tensorneko_util/util/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional
from urllib.request import urlretrieve

from ..backend._tqdm import import_tqdm_auto
from ..backend.tqdm import import_tqdm_auto

try:
auto = import_tqdm_auto()
Expand Down
2 changes: 1 addition & 1 deletion src/tensorneko_util/util/fp/array/abstract_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..monad.monad import Monad
from ...type import T, R
from ....backend._tqdm import import_tqdm_auto
from ....backend.tqdm import import_tqdm_auto
from ....backend.parallel import ParallelType


Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.8
0.3.9

0 comments on commit b74131a

Please sign in to comment.