Skip to content

Commit b3c171f

Browse files
committed
Up
Signed-off-by: Peter Jung <[email protected]>
1 parent 81089da commit b3c171f

File tree

6 files changed

+81
-78
lines changed

6 files changed

+81
-78
lines changed

mlflow_export_import/common/dump_run.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
INDENT = " "
1010
MAX_LEVEL = 1
1111
TS_FORMAT = "%Y-%m-%d_%H:%M:%S"
12-
client = mlflow.tracking.MlflowClient()
13-
print("MLflow Tracking URI:", mlflow.get_tracking_uri())
1412

1513
def dump_run(run, max_level=1, indent=""):
1614
dump_run_info(run.info,indent)
@@ -27,17 +25,19 @@ def dump_run(run, max_level=1, indent=""):
2725
num_bytes, num_artifacts = dump_artifacts(run.info.run_id, "", 0, max_level, indent+INDENT)
2826
print(f"{indent}Total: bytes: {num_bytes} artifacts: {num_artifacts}")
2927
return run, num_bytes, num_artifacts
30-
28+
3129
def dump_run_id(run_id, max_level=1, indent=""):
30+
client = mlflow.tracking.MlflowClient()
3231
run = client.get_run(run_id)
3332
return dump_run(run,max_level,indent)
3433

3534
def dump_run_info(info, indent=""):
3635
print("{}RunInfo:".format(indent))
36+
client = mlflow.tracking.MlflowClient()
3737
exp = client.get_experiment(info.experiment_id)
3838
if exp is None:
3939
print(f"ERROR: Cannot find experiment ID '{info.experiment_id}'")
40-
return
40+
return
4141
print("{} name: {}".format(indent,exp.name))
4242
for k,v in sorted(info.__dict__.items()):
4343
if not k.endswith("_time"):
@@ -58,8 +58,9 @@ def _dump_time(info, k, indent=""):
5858
return v
5959

6060
def dump_artifacts(run_id, path, level, max_level, indent):
61-
if level+1 > max_level:
61+
if level+1 > max_level:
6262
return 0,0
63+
client = mlflow.tracking.MlflowClient()
6364
artifacts = client.list_artifacts(run_id,path)
6465
num_bytes, num_artifacts = (0,0)
6566
for j,art in enumerate(artifacts):

mlflow_export_import/common/find_artifacts.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@
77
import click
88
import mlflow
99

10-
client = mlflow.tracking.MlflowClient()
11-
print("MLflow Tracking URI:", mlflow.get_tracking_uri())
1210

1311
def find_artifacts(run_id, path, target, max_level=sys.maxsize):
1412
return _find_artifacts(run_id, path, target, max_level, 0, [])
1513

1614
def _find_artifacts(run_id, path, target, max_level, level, matches):
17-
if level+1 > max_level:
15+
if level+1 > max_level:
1816
return matches
19-
artifacts = client.list_artifacts(run_id,path)
17+
client = mlflow.tracking.MlflowClient()
18+
artifacts = client.list_artifacts(run_id, path)
2019
for art in artifacts:
2120
#print(f"art_path: {art.path}")
2221
filename = os.path.basename(art.path)
@@ -40,5 +39,5 @@ def main(run_id, path, target, max_level): # pragma: no cover
4039
for x in matches:
4140
print(" ",x)
4241

43-
if __name__ == "__main__":
42+
if __name__ == "__main__":
4443
main()

mlflow_export_import/model/export_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@
1111
from mlflow_export_import import utils, click_doc
1212

1313
class ModelExporter():
14-
def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=None, stages=None, versions=None, export_run=True):
14+
def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=None, stages=None, versions=None, export_run=True, host=None):
1515
"""
1616
:param mlflow_client: MLflow client or if None create default client.
1717
:param export_source_tags: Export source run metadata tags.
1818
:param notebook_formats: List of notebook formats to export. Values are SOURCE, HTML, JUPYTER or DBC.
1919
:param stages: Stages to export. Default is all stages. Values are Production, Staging, Archived and None.
2020
:param versions: Versions to export. Default is all versions. Values are valid integer numbers.
2121
:param export_run: Export the run that generated a registered model's version.
22+
:param host: Pass host to the MlflowHttpClient and RunExporter.
2223
"""
2324
self.mlflow_client = mlflow_client
24-
self.http_client = MlflowHttpClient()
25-
self.run_exporter = RunExporter(self.mlflow_client, export_source_tags=export_source_tags, notebook_formats=notebook_formats)
25+
self.http_client = MlflowHttpClient(host=host)
26+
self.run_exporter = RunExporter(self.mlflow_client, export_source_tags=export_source_tags, notebook_formats=notebook_formats, host=host)
2627
self.stages = self._normalize_stages(stages)
2728
self.versions = self._normalize_versions(versions)
2829
self.export_run = export_run
@@ -66,7 +67,7 @@ def _export_model(self, model_name, output_dir):
6667
run = self.mlflow_client.get_run(run_id)
6768
dct = dict(vr)
6869
dct["_run_artifact_uri"] = run.info.artifact_uri
69-
experiment = mlflow.get_experiment(run.info.experiment_id)
70+
experiment = self.mlflow_client.get_experiment(run.info.experiment_id)
7071
dct["_experiment_name"] = experiment.name
7172
model["registered_model"]["latest_versions"].append(dct)
7273
exported_versions += 1

mlflow_export_import/model/import_model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313

1414
class BaseModelImporter():
1515
""" Base class of ModelImporter subclasses. """
16-
def __init__(self, mlflow_client, run_importer=None, await_creation_for=None):
16+
def __init__(self, mlflow_client, run_importer=None, await_creation_for=None, host=None):
1717
"""
1818
:param mlflow_client: MLflow client or if None create default client.
1919
:param run_importer: RunImporter instance.
2020
:param await_creation_for: Seconds to wait for model version crreation.
21+
:param host: Pass host to the RunImporter.
2122
"""
2223
self.mlflow_client = mlflow_client
23-
self.run_importer = run_importer if run_importer else RunImporter(self.mlflow_client, mlmodel_fix=True)
24+
self.run_importer = run_importer if run_importer else RunImporter(self.mlflow_client, mlmodel_fix=True, host=host)
2425
self.await_creation_for = await_creation_for
2526

2627
def _import_version(self, model_name, src_vr, dst_run_id, dst_source, sleep_time):
@@ -77,27 +78,29 @@ def _import_model(self, model_name, input_dir, delete_model=False, verbose=False
7778

7879
class ModelImporter(BaseModelImporter):
7980
""" Low-level 'point' model importer """
80-
def __init__(self, mlflow_client, run_importer=None, await_creation_for=None):
81-
super().__init__(mlflow_client, run_importer, await_creation_for=await_creation_for)
81+
def __init__(self, mlflow_client, run_importer=None, await_creation_for=None, host=None):
82+
super().__init__(mlflow_client, run_importer, await_creation_for=await_creation_for, host=host)
8283

8384
def import_model(self, model_name, input_dir, experiment_name, delete_model=False, verbose=False, sleep_time=30):
8485
"""
8586
:param model_name: Model name.
8687
:param input_dir: Input directory.
87-
:param experiment_name: The name of the experiment
88+
:param experiment_name: The name of the experiment.
8889
:param delete_model: Delete current model before importing versions.
8990
:param verbose: Verbose.
9091
:param sleep_time: Seconds to wait for model version crreation.
9192
:return: Model import manifest.
9293
"""
9394
model_dct = self._import_model(model_name, input_dir, delete_model, verbose, sleep_time)
94-
mlflow.set_experiment(experiment_name)
9595
print("Importing versions:")
96+
imported_run_ids = []
9697
for vr in model_dct["latest_versions"]:
9798
run_id = self._import_run(input_dir, experiment_name, vr)
99+
imported_run_ids.append(run_id)
98100
self.import_version(model_name, vr, run_id, sleep_time)
99101
if verbose:
100102
model_utils.dump_model_versions(self.mlflow_client, model_name)
103+
return imported_run_ids
101104

102105
def _import_run(self, input_dir, experiment_name, vr):
103106
run_id = vr["run_id"]
@@ -149,7 +152,6 @@ def import_model(self, model_name, input_dir, delete_model=False, verbose=False,
149152
for vr in model_dct["latest_versions"]:
150153
src_run_id = vr["run_id"]
151154
dst_run_id = self.run_info_map[src_run_id].run_id
152-
mlflow.set_experiment(vr["_experiment_name"])
153155
self.import_version(model_name, vr, dst_run_id, sleep_time)
154156
if verbose:
155157
model_utils.dump_model_versions(self.mlflow_client, model_name)

mlflow_export_import/run/export_run.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
Exports a run to a directory.
33
"""
44

@@ -14,20 +14,19 @@
1414
from mlflow_export_import.common import MlflowExportImportException
1515
from mlflow_export_import import utils, click_doc
1616

17-
print("MLflow Version:", mlflow.version.VERSION)
18-
print("MLflow Tracking URI:", mlflow.get_tracking_uri())
1917

2018
class RunExporter:
21-
def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=None):
19+
def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=None, host=None):
2220
"""
2321
:param mlflow_client: MLflow client.
2422
:param export_source_tags: Export source run metadata tags.
2523
:param notebook_formats: List of notebook formats to export. Values are SOURCE, HTML, JUPYTER or DBC.
24+
:param host: Pass host to the DatabricksHttpClient.
2625
"""
2726
if notebook_formats is None:
2827
notebook_formats = []
2928
self.mlflow_client = mlflow_client
30-
self.dbx_client = DatabricksHttpClient()
29+
self.dbx_client = DatabricksHttpClient(host=host)
3130
print("Databricks REST client:", self.dbx_client)
3231
self.export_source_tags = export_source_tags
3332
self.notebook_formats = notebook_formats
@@ -38,7 +37,7 @@ def _get_metrics_with_steps(self, run):
3837
metric_history = self.mlflow_client.get_metric_history(run.info.run_id,metric)
3938
lst = [utils.strip_underscores(m) for m in metric_history]
4039
for x in lst:
41-
del x["key"]
40+
del x["key"]
4241
metrics_with_steps[metric] = lst
4342
return metrics_with_steps
4443

@@ -91,8 +90,8 @@ def _export_notebook(self, output_dir, notebook, tags, fs):
9190
revision_id = tags["mlflow.databricks.notebookRevisionID"]
9291
notebook_path = tags["mlflow.databricks.notebookPath"]
9392
notebook_name = os.path.basename(notebook_path)
94-
manifest = {
95-
"mlflow.databricks.notebookRevisionID": revision_id,
93+
manifest = {
94+
"mlflow.databricks.notebookRevisionID": revision_id,
9695
"mlflow.databricks.notebookPath": notebook_path,
9796
"mlflow.databricks.export-notebook-revision": revision_id }
9897
path = os.path.join(notebook_dir, "manifest.json")
@@ -101,11 +100,11 @@ def _export_notebook(self, output_dir, notebook, tags, fs):
101100
self._export_notebook_format(notebook_dir, notebook, format, format.lower(), notebook_name, revision_id)
102101

103102
def _export_notebook_format(self, notebook_dir, notebook, format, extension, notebook_name, revision_id):
104-
params = {
105-
"path": notebook,
103+
params = {
104+
"path": notebook,
106105
"direct_download": True,
107106
"format": format,
108-
"revision_timestamp": revision_id
107+
"revision_timestamp": revision_id
109108
}
110109
try:
111110
rsp = self.dbx_client._get("workspace/export", params)
@@ -115,26 +114,26 @@ def _export_notebook_format(self, notebook_dir, notebook, format, extension, not
115114
print(f"WARNING: Cannot save notebook '{notebook}'. {e}")
116115

117116
@click.command()
118-
@click.option("--run-id",
119-
help="Run ID.",
117+
@click.option("--run-id",
118+
help="Run ID.",
120119
type=str,
121120
required=True
122121
)
123-
@click.option("--output-dir",
124-
help="Output directory.",
122+
@click.option("--output-dir",
123+
help="Output directory.",
125124
type=str,
126125
required=True
127126
)
128-
@click.option("--export-source-tags",
129-
help=click_doc.export_source_tags,
130-
type=bool,
131-
default=False,
127+
@click.option("--export-source-tags",
128+
help=click_doc.export_source_tags,
129+
type=bool,
130+
default=False,
132131
show_default=True
133132
)
134-
@click.option("--notebook-formats",
135-
help=click_doc.notebook_formats,
133+
@click.option("--notebook-formats",
134+
help=click_doc.notebook_formats,
136135
type=str,
137-
default="",
136+
default="",
138137
show_default=True
139138
)
140139

@@ -145,7 +144,7 @@ def main(run_id, output_dir, export_source_tags, notebook_formats):
145144
client = mlflow.tracking.MlflowClient()
146145
exporter = RunExporter(
147146
client,
148-
export_source_tags=export_source_tags,
147+
export_source_tags=export_source_tags,
149148
notebook_formats=utils.string_to_list(notebook_formats))
150149
exporter.export_run(run_id, output_dir)
151150

0 commit comments

Comments
 (0)