-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1435cc2
commit 394663a
Showing
3 changed files
with
54 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
115 changes: 50 additions & 65 deletions
115
integrations/model-training/mlflow/mlflow-hello-world/mlflow-hello-world.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,71 @@ | ||
# coding: utf-8 | ||
"""Trains and evaluate a simple MLP | ||
on the Reuters newswire topic classification task. | ||
""" | ||
from __future__ import print_function | ||
import os | ||
|
||
# The following imports are the only additions to code required | ||
# to automatically log metrics and parameters to Comet. | ||
import comet_ml # noqa | ||
import comet_ml | ||
|
||
import mlflow.keras | ||
import numpy as np | ||
# You can use 'tensorflow', 'torch' or 'jax' as backend. Make sure to set the | ||
# environment variable before importing. | ||
os.environ["KERAS_BACKEND"] = "tensorflow" | ||
|
||
import keras | ||
|
||
# The following import and function call are the only additions to code required | ||
# to automatically log metrics and parameters to MLflow. | ||
import mlflow | ||
from keras.datasets import reuters | ||
from keras.layers import Activation, Dense, Dropout | ||
from keras.models import Sequential | ||
from keras.preprocessing.text import Tokenizer | ||
import mlflow.keras # noqa: E402 | ||
import numpy as np # noqa: E402 | ||
|
||
# The sqlite store is needed for the model registry | ||
mlflow.set_tracking_uri("sqlite:///db.sqlite") | ||
import keras # noqa: E402 | ||
|
||
# We need to create a run before calling keras or MLFlow will end the run by itself | ||
mlflow.set_experiment("comet-example-mlflow-hello-world") | ||
mlflow.start_run() | ||
# Login to Comet if necessary | ||
comet_ml.login(project_name="comet-example-mlflow-hello-world") | ||
|
||
mlflow.keras.autolog() | ||
# Load dataset | ||
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() | ||
x_train = np.expand_dims(x_train, axis=3) | ||
x_test = np.expand_dims(x_test, axis=3) | ||
x_train[0].shape | ||
|
||
max_words = 1000 | ||
batch_size = 32 | ||
epochs = 5 | ||
# Build model | ||
NUM_CLASSES = 10 | ||
INPUT_SHAPE = (28, 28, 1) | ||
|
||
|
||
def initialize_model(): | ||
return keras.Sequential( | ||
[ | ||
keras.Input(shape=INPUT_SHAPE), | ||
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), | ||
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), | ||
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), | ||
keras.layers.GlobalAveragePooling2D(), | ||
keras.layers.Dense(NUM_CLASSES, activation="softmax"), | ||
] | ||
) | ||
|
||
print("Loading data...") | ||
(x_train, y_train), (x_test, y_test) = reuters.load_data( | ||
num_words=max_words, test_split=0.2 | ||
) | ||
|
||
print(len(x_train), "train sequences") | ||
print(len(x_test), "test sequences") | ||
model = initialize_model() | ||
model.summary() | ||
|
||
num_classes = np.max(y_train) + 1 | ||
print(num_classes, "classes") | ||
# Train model | ||
|
||
print("Vectorizing sequence data...") | ||
tokenizer = Tokenizer(num_words=max_words) | ||
x_train = tokenizer.sequences_to_matrix(x_train, mode="binary") | ||
x_test = tokenizer.sequences_to_matrix(x_test, mode="binary") | ||
print("x_train shape:", x_train.shape) | ||
print("x_test shape:", x_test.shape) | ||
BATCH_SIZE = 64 # adjust this based on the memory of your machine | ||
EPOCHS = 3 | ||
|
||
print( | ||
"Convert class vector to binary class matrix " | ||
"(for use with categorical_crossentropy)" | ||
model = initialize_model() | ||
|
||
model.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(), | ||
optimizer=keras.optimizers.Adam(), | ||
metrics=["accuracy"], | ||
) | ||
y_train = keras.utils.to_categorical(y_train, num_classes) | ||
y_test = keras.utils.to_categorical(y_test, num_classes) | ||
print("y_train shape:", y_train.shape) | ||
print("y_test shape:", y_test.shape) | ||
|
||
print("Building model...") | ||
model = Sequential() | ||
model.add(Dense(512, input_shape=(max_words,))) | ||
model.add(Activation("relu")) | ||
model.add(Dropout(0.5)) | ||
model.add(Dense(num_classes)) | ||
model.add(Activation("softmax")) | ||
|
||
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) | ||
|
||
history = model.fit( | ||
|
||
run = mlflow.start_run() | ||
model.fit( | ||
x_train, | ||
y_train, | ||
batch_size=batch_size, | ||
epochs=epochs, | ||
verbose=1, | ||
batch_size=BATCH_SIZE, | ||
epochs=EPOCHS, | ||
validation_split=0.1, | ||
callbacks=[mlflow.keras.MlflowCallback(run)], | ||
) | ||
score = model.evaluate(x_test, y_test, batch_size=batch_size, verbose=1) | ||
print("Test score:", score[0]) | ||
print("Test accuracy:", score[1]) | ||
|
||
mlflow.keras.log_model(model, "model", registered_model_name="Test Model") | ||
|
||
mlflow.end_run() |
4 changes: 2 additions & 2 deletions
4
integrations/model-training/mlflow/mlflow-hello-world/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
alembic<1.10.0 # workaround ValueError: Constraint must have a name | ||
comet_ml | ||
comet_ml>=3.44.0 | ||
keras | ||
mlflow | ||
tensorflow |