Skip to content

Commit db2b552

Browse files
author
Alexandre de Siqueira
authored
Refactoring hash-dependent functions (#63)
* Fixing PermissionError on NamedTempFile * Revert "Fixing PermissionError on NamedTempFile" This reverts commit d6eb146. * Improving how to deal with hashes
1 parent fc56abd commit db2b552

7 files changed

+36
-98
lines changed

butterfly/connection.py

Lines changed: 36 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from pathlib import Path
22
from pooch import retrieve
3+
from urllib import request
34

5+
import hashlib
46
import socket
57

68

@@ -12,15 +14,9 @@
1214
}
1315

1416
URL_HASH = {
15-
'id_gender' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/id_gender/.SHA256SUM_ONLINE-id_gender',
16-
'id_position' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/id_position/.SHA256SUM_ONLINE-id_position',
17-
'segmentation' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/segmentation/.SHA256SUM_ONLINE-segmentation'
18-
}
19-
20-
LOCAL_HASH = {
21-
'id_gender' : Path('./models/SHA256SUM-id_gender'),
22-
'id_position' : Path('./models/SHA256SUM-id_position'),
23-
'segmentation' : Path('./models/SHA256SUM-segmentation')
17+
'id_gender' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/id_gender/SHA256SUM-id_gender',
18+
'id_position' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/id_position/SHA256SUM-id_position',
19+
'segmentation' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/segmentation/SHA256SUM-segmentation'
2420
}
2521

2622

@@ -39,39 +35,8 @@ def _get_model_info(weights):
3935
URL of the file for the latest model.
4036
url_hash : str
4137
URL of the hash file for the latest model.
42-
local_hash : pathlib.Path
43-
Path of the local hash file.
44-
"""
45-
return (URL_MODEL.get(weights.stem), URL_HASH.get(weights.stem),
46-
LOCAL_HASH.get(weights.stem))
47-
48-
49-
def _check_hashes(weights):
50-
"""Helping function. Downloads hashes for `weights` if they are not
51-
present.
52-
53-
Parameters
54-
----------
55-
weights : str or pathlib.Path
56-
Path of the file containing weights.
57-
58-
Returns
59-
-------
60-
None
6138
"""
62-
_, url_hash, local_hash = _get_model_info(weights)
63-
64-
if not local_hash.is_file():
65-
download_hash_from_url(url_hash=url_hash, filename=local_hash)
66-
67-
# creating filename to save url_hash.
68-
filename = local_hash.parent/Path(url_hash).name
69-
70-
if not filename.is_file():
71-
download_hash_from_url(url_hash=url_hash, filename=filename)
72-
73-
74-
return None
39+
return (URL_MODEL.get(weights.stem), URL_HASH.get(weights.stem))
7540

7641

7742
def download_weights(weights):
@@ -86,9 +51,7 @@ def download_weights(weights):
8651
-------
8752
None
8853
"""
89-
# check if hashes are in disk, then get info from the model.
90-
_check_hashes(weights)
91-
_, url_hash, local_hash = _get_model_info(weights)
54+
_, url_hash = _get_model_info(weights)
9255

9356
# check if weights is in its folder. If not, download the file.
9457
if not weights.is_file():
@@ -97,34 +60,15 @@ def download_weights(weights):
9760
# file exists: check if we have the last version; download if not.
9861
else:
9962
if has_internet():
100-
local_hash_val = read_hash_local(filename=local_hash)
101-
url_hash_val = read_hash_from_url(path=local_hash.parent,
102-
url_hash=url_hash)
63+
local_hash_val = read_hash_local(weights)
64+
url_hash_val = read_hash_from_url(url_hash)
10365
if local_hash_val != url_hash_val:
10466
print('New training data available. Downloading...')
10567
fetch_data(weights)
10668

10769
return None
10870

10971

110-
def download_hash_from_url(url_hash, filename):
111-
"""Downloads hash from `url_hash`.
112-
113-
Parameters
114-
----------
115-
url_hash : str
116-
URL of the SHA256 hash.
117-
filename : str
118-
Filename to save the SHA256 hash locally.
119-
120-
Returns
121-
-------
122-
None
123-
"""
124-
retrieve(url=url_hash, known_hash=None, fname=filename, path='.')
125-
return None
126-
127-
12872
def fetch_data(weights):
12973
"""Downloads and checks the hash of `weights`, according to its filename.
13074
@@ -137,15 +81,11 @@ def fetch_data(weights):
13781
-------
13882
None
13983
"""
140-
url_model, url_hash, local_hash = _get_model_info(weights)
84+
url_model, url_hash = _get_model_info(weights)
14185

142-
# creating filename to save url_hash.
143-
filename = local_hash.parent/Path(url_hash).name
144-
145-
download_hash_from_url(url_hash=url_hash, filename=filename)
146-
local_hash_val = read_hash_local(local_hash)
86+
url_hash_val = read_hash_from_url(url_hash)
14787
retrieve(url=url_model,
148-
known_hash=f'sha256:{local_hash_val}',
88+
known_hash=f'sha256:{url_hash_val}',
14989
fname=weights,
15090
path='.')
15191

@@ -166,40 +106,44 @@ def has_internet():
166106
return socket.gethostbyname(socket.gethostname()) != '127.0.0.1'
167107

168108

169-
def read_hash_local(filename):
170-
"""Reads local SHA256 hash file.
109+
def read_hash_local(weights):
110+
"""Reads local SHA256 hash from weights.
171111
172112
Parameters
173113
----------
174-
filename : pathlib.Path
175-
Path of the hash file.
114+
weights : str or pathlib.Path
115+
Path of the file containing weights.
176116
177117
Returns
178118
-------
179-
local_hash : str
180-
SHA256 hash.
119+
local_hash : str or None
120+
SHA256 hash of weights file.
181121
182122
Notes
183123
-----
184124
Returns None if file is not found.
185125
"""
126+
BUFFER_SIZE = 65536
127+
sha256 = hashlib.sha256()
128+
186129
try:
187-
with open(filename, 'r') as file_hash:
188-
hashes = [line for line in file_hash]
189-
# expecting only one hash, and not interested in the filename:
190-
local_hash, _ = hashes[0].split()
130+
with open(weights, 'rb') as file_weights:
131+
while True:
132+
data = file_weights.read(BUFFER_SIZE)
133+
if not data:
134+
break
135+
sha256.update(data)
136+
local_hash = sha256.hexdigest()
191137
except FileNotFoundError:
192138
local_hash = None
193139
return local_hash
194140

195141

196-
def read_hash_from_url(path, url_hash):
197-
"""Downloads and returns the SHA256 hash online for the file in `url_hash`.
142+
def read_hash_from_url(url_hash):
143+
"""Returns the SHA256 hash online for the file in `url_hash`.
198144
199145
Parameters
200146
----------
201-
path : str
202-
Where to look for the hash file.
203147
url_hash : str
204148
URL of the hash file for the latest model.
205149
@@ -208,14 +152,14 @@ def read_hash_from_url(path, url_hash):
208152
online_hash : str
209153
SHA256 hash for the file in `url_hash`.
210154
"""
211-
filename = Path(url_hash).name
212-
latest_hash = Path(f'{path}/{filename}')
155+
user_agent = 'Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.9.0.7) Gecko/2009021910 Firefox/3.0.7'
156+
headers = {'User-Agent':user_agent,}
213157

214-
download_hash_from_url(url_hash=url_hash, filename=filename)
215-
with open(latest_hash, 'r') as file_hash:
216-
hashes = [line for line in file_hash]
158+
aux_req = request.Request(url_hash, None, headers)
159+
response = request.urlopen(aux_req)
160+
hashes = response.read()
217161

218162
# expecting only one hash, and not interested in the filename:
219-
online_hash, _ = hashes[0].split()
163+
online_hash, _ = hashes.decode('ascii').split()
220164

221165
return online_hash

models/.SHA256SUM_ONLINE-id_gender

Lines changed: 0 additions & 1 deletion
This file was deleted.

models/.SHA256SUM_ONLINE-id_position

Lines changed: 0 additions & 1 deletion
This file was deleted.

models/.SHA256SUM_ONLINE-segmentation

Lines changed: 0 additions & 1 deletion
This file was deleted.

models/SHA256SUM-id_gender

Lines changed: 0 additions & 1 deletion
This file was deleted.

models/SHA256SUM-id_position

Lines changed: 0 additions & 1 deletion
This file was deleted.

models/SHA256SUM-segmentation

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)