Skip to content

Commit 3a33a6d

Browse files
committed
Fixed bug where not all mlflow system tags were being exported
1 parent 162bb5e commit 3a33a6d

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

mlflow_export_import/run/export_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def export_run(self, run_id, output_dir):
5151
fs = _filesystem.get_filesystem(output_dir)
5252
run = self.mlflow_client.get_run(run_id)
5353
fs.mkdirs(output_dir)
54-
tags = utils.create_tags_for_metadata(self.mlflow_client, run, self.export_source_tags)
54+
tags = utils.create_source_tags(self.mlflow_client, run, self.export_source_tags)
5555
dct = {
5656
"export_info": {
5757
"mlflow_version": mlflow.__version__,

mlflow_export_import/utils.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,51 +12,50 @@
1212
TAG_PREFIX_MLFLOW = "mlflow."
1313
TAG_PARENT_ID = "mlflow.parentRunId"
1414

15+
1516
# Databricks tags that cannot be set
1617
_databricks_skip_tags = set([
1718
"mlflow.user",
1819
"mlflow.log-model.history",
1920
"mlflow.rootRunId"
2021
])
2122

23+
2224
def create_mlflow_tags_for_databricks_import(tags):
2325
if importing_into_databricks():
2426
tags = { k:v for k,v in tags.items() if not k in _databricks_skip_tags }
2527
return tags
2628

27-
def _create_metadata_tags(src_client, tags, run):
28-
tags[TAG_PREFIX_EXPORT_IMPORT_METADATA+".mlflow_version"] = mlflow.__version__
29+
30+
def _create_source_tags(src_client, tags, run):
31+
tags[f"{TAG_PREFIX_EXPORT_IMPORT_METADATA}.mlflow_version"] = mlflow.__version__
2932
uri = mlflow.tracking.get_tracking_uri()
30-
tags[TAG_PREFIX_EXPORT_IMPORT_METADATA+".tracking_uri"] = uri
33+
tags[f"{TAG_PREFIX_EXPORT_IMPORT_METADATA}.tracking_uri"] = uri
3134
dbx_host = os.environ.get("DATABRICKS_HOST",None)
3235
if dbx_host is not None:
33-
tags[TAG_PREFIX_EXPORT_IMPORT_METADATA+".DATABRICKS_HOST"] = dbx_host
34-
now = int(time.time()+.5)
36+
tags[f"{TAG_PREFIX_EXPORT_IMPORT_METADATA}.DATABRICKS_HOST"] = dbx_host
37+
now = round(time.time())
3538
snow = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(now))
36-
tags[TAG_PREFIX_EXPORT_IMPORT_METADATA+".timestamp"] = str(now)
37-
tags[TAG_PREFIX_EXPORT_IMPORT_METADATA+".timestamp_nice"] = snow
39+
tags[f"{TAG_PREFIX_EXPORT_IMPORT_METADATA}.timestamp"] = str(now)
40+
tags[f"{TAG_PREFIX_EXPORT_IMPORT_METADATA}.timestamp_nice"] = snow
3841
exp = src_client.get_experiment(run.info.experiment_id)
39-
tags[TAG_PREFIX_EXPORT_IMPORT_METADATA+".experiment_name"] = exp.name
42+
tags[f"{TAG_PREFIX_EXPORT_IMPORT_METADATA}.experiment_name"] = exp.name
43+
4044

41-
def create_tags_for_metadata(src_client, run, export_source_tags):
45+
def create_source_tags(src_client, run, export_source_tags):
4246
""" Create destination tags from source run """
4347
mlflow_system_tags = { k:v for k,v in run.data.tags.items() if k.startswith(TAG_PREFIX_MLFLOW) }
4448
tags = run.data.tags.copy()
45-
if importing_into_databricks():
46-
for k in _databricks_skip_tags:
47-
tags.pop(k, None)
4849
if export_source_tags:
49-
_create_metadata_tags(src_client, tags, run)
50-
50+
_create_source_tags(src_client, tags, run)
5151
for k,v in strip_underscores(run.info).items():
5252
tags[f"{TAG_PREFIX_EXPORT_IMPORT_RUN_INFO}.{k}"] = str(v) # NOTE: tag values must be strings
53-
5453
for k,v in mlflow_system_tags.items():
5554
tags[k.replace(TAG_PREFIX_MLFLOW,TAG_PREFIX_EXPORT_IMPORT_MLFLOW+".")] = v
56-
5755
tags = { k:v for k,v in sorted(tags.items()) }
5856
return tags
5957

58+
6059
def set_dst_user_id(tags, user_id, use_src_user_id):
6160
if importing_into_databricks():
6261
return
@@ -65,35 +64,43 @@ def set_dst_user_id(tags, user_id, use_src_user_id):
6564
user_id = user_id if use_src_user_id else get_user_id()
6665
tags.append(RunTag(MLFLOW_USER,user_id ))
6766

67+
6868
def get_now_nice():
6969
now = int(time.time()+.5)
7070
return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(now))
7171

72+
7273
def strip_underscores(obj):
7374
return { k[1:]:v for (k,v) in obj.__dict__.items() }
7475

76+
7577
def write_json_file(fs, path, dct):
7678
fs.write(path, json.dumps(dct,indent=2)+"\n")
7779

80+
7881
def write_file(path, content):
7982
with open(mk_local_path(path), 'wb') as f:
8083
f.write(content)
8184

85+
8286
def read_json_file(path):
8387
with open(mk_local_path(path), "r") as f:
8488
return json.loads(f.read())
8589

90+
8691
def string_to_list(list_as_string):
8792
if list_as_string == None:
8893
return []
8994
lst = list_as_string.split(",")
9095
if "" in lst: lst.remove("")
9196
return lst
9297

98+
9399
def get_user_id():
94100
from mlflow.tracking.context.default_context import _get_user
95101
return _get_user()
96102

103+
97104
def nested_tags(dst_client, run_ids_mapping):
98105
"""
99106
Set the new parentRunId for new imported child runs.
@@ -105,9 +112,11 @@ def nested_tags(dst_client, run_ids_mapping):
105112
dst_parent_run_id = run_ids_mapping[src_parent_run_id]["dst_run_id"]
106113
dst_client.set_tag(dst_run_id, "mlflow.parentRunId", dst_parent_run_id)
107114

115+
108116
def importing_into_databricks():
109117
return mlflow.tracking.get_tracking_uri().startswith("databricks")
110118

119+
111120
def create_common_manifest(duration):
112121
return {
113122
"info": {
@@ -118,6 +127,7 @@ def create_common_manifest(duration):
118127
}
119128
}
120129

130+
121131
def show_table(title, lst, columns):
122132
print(title)
123133
df = pd.DataFrame(lst, columns = columns)

0 commit comments

Comments
 (0)