Skip to content

Commit e7f754e

Browse files
committed
use tempfile to avoid name conflict
1 parent e184579 commit e7f754e

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

ml_logger/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.8.32
1+
0.8.33

ml_logger/ml_logger/ml_logger.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,14 +1605,16 @@ def download_file(self, *keys, path=None, to, relative=False):
16051605
f.write(buf.getbuffer())
16061606

16071607
def load_torch(self, *keys, path=None, map_location=None, **kwargs):
1608-
import torch
1608+
import torch, tempfile
16091609
path = pJoin(*keys, path)
16101610
if path.lower().startswith('s3://'):
1611-
fn_or_buff = os.path.basename(path)
1612-
self.download_s3(path[5:], to=fn_or_buff)
1611+
postfix = os.path.basename(path)
1612+
with tempfile.NamedTemporaryFile(suffix=f'.{postfix}') as ntp:
1613+
self.download_s3(path[5:], to=ntp.name)
1614+
return torch.load(ntp, map_location=map_location, **kwargs)
16131615
else:
16141616
fn_or_buff = self.load_file(path)
1615-
return torch.load(fn_or_buff, map_location=map_location, **kwargs)
1617+
return torch.load(fn_or_buff, map_location=map_location, **kwargs)
16161618

16171619
torch_load = load_torch
16181620

0 commit comments

Comments
 (0)