-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
223 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
from tensorneko_util.backend import parallel, run_blocking, VisualLib, AudioLib | ||
from tensorneko_util.backend import parallel, run_blocking, VisualLib, AudioLib, import_tqdm_auto, import_tqdm | ||
|
||
__all__ = [ | ||
"parallel", | ||
"run_blocking", | ||
"VisualLib", | ||
"AudioLib", | ||
"import_tqdm_auto", | ||
"import_tqdm", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .weight_reader import WeightReader | ||
from .weight_writer import WeightWriter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from typing import OrderedDict | ||
|
||
import torch | ||
|
||
from ...util import Device | ||
|
||
|
||
class WeightReader: | ||
"""WeightReader for read model weights (checkpoints, state_dict, etc).""" | ||
|
||
@classmethod | ||
def of_pt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]: | ||
""" | ||
Reads PyTorch model weights from a `.pt` or `.pth` file. | ||
Args: | ||
path (``str``): The path of the `.pt` or `.pth` file. | ||
map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu" | ||
Returns: | ||
:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights. | ||
""" | ||
return torch.load(path, map_location=map_location) | ||
|
||
@classmethod | ||
def of_ckpt(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]: | ||
""" | ||
Reads PyTorch model weights from a `.ckpt` file. | ||
Args: | ||
path (``str``): The path of the `.ckpt` file. | ||
map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu" | ||
Returns: | ||
:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights. | ||
""" | ||
return torch.load(path, map_location=map_location)["state_dict"] | ||
|
||
@classmethod | ||
def of_safetensors(cls, path: str, map_location: str = "cpu") -> OrderedDict[str, torch.Tensor]: | ||
""" | ||
Reads model weights from a `.safetensors` file. | ||
Args: | ||
path (``str``): The path of the `.safetensors` file. | ||
map_location (``str``): The location to load the model weights. Default: "cpu" | ||
Returns: | ||
:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights. | ||
""" | ||
import safetensors | ||
from collections import OrderedDict | ||
tensors = OrderedDict() | ||
with safetensors.safe_open(path, framework="pt", device=map_location) as f: | ||
for key in f.keys(): | ||
tensors[key] = f.get_tensor(key) | ||
return tensors | ||
|
||
@classmethod | ||
def of(cls, path: str, map_location: Device = "cpu") -> OrderedDict[str, torch.Tensor]: | ||
""" | ||
Reads model weights from a file. | ||
Args: | ||
path (``str``): The path of the file. | ||
map_location (:class:`torch.device` | ``str``): The location to load the model weights. Default: "cpu" | ||
Returns: | ||
:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]: The model weights. | ||
""" | ||
|
||
if path.endswith(".pt") or path.endswith(".pth"): | ||
return cls.of_pt(path, map_location) | ||
elif path.endswith(".ckpt"): | ||
return cls.of_ckpt(path, map_location) | ||
elif path.endswith(".safetensors"): | ||
if isinstance(map_location, torch.device): | ||
map_location = str(map_location) | ||
return cls.of_safetensors(path, map_location) | ||
else: | ||
raise ValueError("Unknown file type. Supported types: .pt, .pth, .ckpt, .safetensors") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import Dict | ||
|
||
import torch | ||
|
||
|
||
class WeightWriter: | ||
"""WeightWriter for write model weights (checkpoints, state_dict, etc).""" | ||
|
||
@classmethod | ||
def to_pt(cls, path: str, weights: Dict[str, torch.Tensor]) -> None: | ||
""" | ||
Writes PyTorch model weights to a `.pt` or `.pth` file. | ||
Args: | ||
path (``str``): The path of the `.pt` or `.pth` file. | ||
weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights. | ||
""" | ||
torch.save(weights, path) | ||
|
||
@classmethod | ||
def to_safetensors(cls, path: str, weights: Dict[str, torch.Tensor]) -> None: | ||
""" | ||
Writes model weights to a `.safetensors` file. | ||
Args: | ||
path (``str``): The path of the `.safetensors` file. | ||
weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights. | ||
""" | ||
import safetensors.torch | ||
safetensors.torch.save_file(weights, path) | ||
|
||
@classmethod | ||
def to(cls, path: str, weights: Dict[str, torch.Tensor]) -> None: | ||
""" | ||
Writes model weights to a file. | ||
Args: | ||
path (``str``): The path of the file. | ||
weights (:class:`collections.OrderedDict`[``str``, :class:`torch.Tensor`]): The model weights. | ||
""" | ||
file_type = path.split(".")[-1] | ||
|
||
if file_type == "pt": | ||
cls.to_pt(path, weights) | ||
elif file_type == "safetensors": | ||
cls.to_safetensors(path, weights) | ||
else: | ||
raise ValueError("Unknown file type. Supported types: .pt, .safetensors") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,14 @@ | ||
from .visual_lib import VisualLib | ||
from .audio_lib import AudioLib | ||
from .blocking import run_blocking | ||
from .tqdm import import_tqdm_auto, import_tqdm | ||
from . import parallel | ||
|
||
__all__ = [ | ||
"VisualLib", | ||
"AudioLib", | ||
"run_blocking", | ||
"import_tqdm_auto", | ||
"import_tqdm", | ||
"parallel", | ||
] |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.3.8 | ||
0.3.9 |