diff --git a/k8s/airflow/values-prod.yaml b/k8s/airflow/values-prod.yaml index f01e6e714a..1cee1ed482 100644 --- a/k8s/airflow/values-prod.yaml +++ b/k8s/airflow/values-prod.yaml @@ -9,7 +9,7 @@ images: repositories: initContainer: eu.gcr.io/airqo-250220/airqo-apache-airflow-xcom containers: eu.gcr.io/airqo-250220/airqo-apache-airflow - tag: prod-d4165e1e-1695022368 + tag: prod-74273167-1695028772 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/airflow/values-stage.yaml b/k8s/airflow/values-stage.yaml index 20127f0f6d..2cfb0ff91e 100644 --- a/k8s/airflow/values-stage.yaml +++ b/k8s/airflow/values-stage.yaml @@ -9,7 +9,7 @@ images: repositories: initContainer: eu.gcr.io/airqo-250220/airqo-stage-apache-airflow-xcom containers: eu.gcr.io/airqo-250220/airqo-stage-apache-airflow - tag: stage-d1aaf3c2-1694766672 + tag: stage-586026aa-1695189812 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/analytics/values-prod.yaml b/k8s/analytics/values-prod.yaml index 56b7b6b3e5..aa95a02534 100644 --- a/k8s/analytics/values-prod.yaml +++ b/k8s/analytics/values-prod.yaml @@ -8,7 +8,7 @@ images: celeryWorker: eu.gcr.io/airqo-250220/airqo-analytics-celery-worker reportJob: eu.gcr.io/airqo-250220/airqo-analytics-report-job devicesSummaryJob: eu.gcr.io/airqo-250220/airqo-analytics-devices-summary-job - tag: prod-d4165e1e-1695022368 + tag: prod-74273167-1695028772 api: name: airqo-analytics-api label: analytics-api diff --git a/k8s/auth-service/values-prod.yaml b/k8s/auth-service/values-prod.yaml index 620252ed57..b7b8ea2057 100644 --- a/k8s/auth-service/values-prod.yaml +++ b/k8s/auth-service/values-prod.yaml @@ -6,7 +6,7 @@ app: replicaCount: 3 image: repository: eu.gcr.io/airqo-250220/airqo-auth-api - tag: prod-d4165e1e-1695022368 + tag: prod-74273167-1695028772 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/auth-service/values-stage.yaml b/k8s/auth-service/values-stage.yaml index 8a1b64d1b8..341ee6b68a 100644 --- a/k8s/auth-service/values-stage.yaml +++ b/k8s/auth-service/values-stage.yaml @@ -6,7 +6,7 @@ app: replicaCount: 2 image: repository: eu.gcr.io/airqo-250220/airqo-stage-auth-api - tag: stage-b17fbb54-1694524327 + tag: stage-5513f226-1695028756 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/data-mgt/values-stage.yaml b/k8s/data-mgt/values-stage.yaml index 7b2e05ff8a..b700dd54ec 100644 --- a/k8s/data-mgt/values-stage.yaml +++ b/k8s/data-mgt/values-stage.yaml @@ -6,7 +6,7 @@ app: replicaCount: 2 image: repository: eu.gcr.io/airqo-250220/airqo-stage-data-mgt-api - tag: stage-e2c1d558-1691937865 + tag: stage-d808bb92-1695279279 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/device-registry/values-prod.yaml b/k8s/device-registry/values-prod.yaml index ba45a082f1..11dfc9e0c2 100644 --- a/k8s/device-registry/values-prod.yaml +++ b/k8s/device-registry/values-prod.yaml @@ -6,7 +6,7 @@ app: replicaCount: 3 image: repository: eu.gcr.io/airqo-250220/airqo-device-registry-api - tag: prod-80ea615f-1694585638 + tag: prod-74273167-1695028772 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/device-registry/values-stage.yaml b/k8s/device-registry/values-stage.yaml index a6e50f65cd..3af90282ae 100644 --- a/k8s/device-registry/values-stage.yaml +++ b/k8s/device-registry/values-stage.yaml @@ -6,7 +6,7 @@ app: replicaCount: 2 image: repository: eu.gcr.io/airqo-250220/airqo-stage-device-registry-api - tag: stage-5e65174e-1695028544 + tag: stage-33cbc445-1695129549 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/exceedance/values-prod-airqo.yaml b/k8s/exceedance/values-prod-airqo.yaml index 835a06a025..8a9e92895e 100644 --- a/k8s/exceedance/values-prod-airqo.yaml +++ b/k8s/exceedance/values-prod-airqo.yaml @@ -4,6 +4,6 @@ app: configmap: env-exceedance-production image: repository: eu.gcr.io/airqo-250220/airqo-exceedance-job - tag: prod-d4165e1e-1695022368 + tag: prod-74273167-1695028772 nameOverride: '' fullnameOverride: '' diff --git a/k8s/exceedance/values-prod-kcca.yaml b/k8s/exceedance/values-prod-kcca.yaml index d8306e316e..f1b68b74fe 100644 --- a/k8s/exceedance/values-prod-kcca.yaml +++ b/k8s/exceedance/values-prod-kcca.yaml @@ -4,6 +4,6 @@ app: configmap: env-exceedance-production image: repository: eu.gcr.io/airqo-250220/kcca-exceedance-job - tag: prod-d4165e1e-1695022368 + tag: prod-74273167-1695028772 nameOverride: '' fullnameOverride: '' diff --git a/k8s/incentives/values-prod.yaml b/k8s/incentives/values-prod.yaml index 91b285ee99..2526ab40a1 100644 --- a/k8s/incentives/values-prod.yaml +++ b/k8s/incentives/values-prod.yaml @@ -6,7 +6,7 @@ app: replicaCount: 3 image: repository: eu.gcr.io/airqo-250220/airqo-incentives-api - tag: prod-d4165e1e-1695022368 + tag: prod-74273167-1695028772 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/incentives/values-stage.yaml b/k8s/incentives/values-stage.yaml index 004bc0de37..6f124f67d1 100644 --- a/k8s/incentives/values-stage.yaml +++ b/k8s/incentives/values-stage.yaml @@ -6,7 +6,7 @@ app: replicaCount: 2 image: repository: eu.gcr.io/airqo-250220/airqo-stage-incentives-api - tag: stage-f7ce8287-1693130445 + tag: stage-d808bb92-1695279279 nameOverride: '' fullnameOverride: '' podAnnotations: {} diff --git a/k8s/predict/values-prod.yaml b/k8s/predict/values-prod.yaml index 3b64007dd0..cfe7b4056c 100644 --- a/k8s/predict/values-prod.yaml +++ b/k8s/predict/values-prod.yaml @@ -7,7 +7,7 @@ images: predictJob: eu.gcr.io/airqo-250220/airqo-predict-job trainJob: eu.gcr.io/airqo-250220/airqo-train-job predictPlaces: eu.gcr.io/airqo-250220/airqo-predict-places-air-quality - tag: prod-d4165e1e-1695022368 + tag: prod-74273167-1695028772 api: name: airqo-prediction-api label: prediction-api diff --git a/k8s/predict/values-stage.yaml b/k8s/predict/values-stage.yaml index 79f9f7908d..0c1220227a 100644 --- a/k8s/predict/values-stage.yaml +++ b/k8s/predict/values-stage.yaml @@ -7,7 +7,7 @@ images: predictJob: eu.gcr.io/airqo-250220/stage-airqo-predict-job trainJob: eu.gcr.io/airqo-250220/stage-airqo-train-job predictPlaces: eu.gcr.io/airqo-250220/stage-airqo-predict-places-air-quality - tag: stage-84518356-1693167908 + tag: stage-defae719-1695039035 api: name: airqo-stage-prediction-api label: prediction-api diff --git a/src/airflow/airflow-requirements.txt b/src/airflow/airflow-requirements.txt index 977410828f..45dd7c89e4 100644 --- a/src/airflow/airflow-requirements.txt +++ b/src/airflow/airflow-requirements.txt @@ -6,4 +6,5 @@ apache-airflow[sentry] lightgbm mlflow gcsfs -pymongo \ No newline at end of file +pymongo +optuna \ No newline at end of file diff --git a/src/airflow/airqo_etl_utils/air_beam_api.py b/src/airflow/airqo_etl_utils/air_beam_api.py index 57adacfcf9..657d69dbf4 100644 --- a/src/airflow/airqo_etl_utils/air_beam_api.py +++ b/src/airflow/airqo_etl_utils/air_beam_api.py @@ -24,25 +24,25 @@ def get_stream_ids( username: str, pollutant: str, ): - params={ - "q": json.dumps( - { - "time_from": int(start_date_time.timestamp()), - "time_to": int(end_date_time.timestamp()), - "tags": "", - "usernames": username, - "west": 10.581214853439886, - "east": 38.08577769782265, - "south": -36.799337832603314, - "north": -19.260169583742446, - "limit": 100, - "offset": 0, - "sensor_name": f"airbeam3-{pollutant}", - "measurement_type": "Particulate Matter", - "unit_symbol": "µg/m³", - } - ) - } + params = { + "q": json.dumps( + { + "time_from": int(start_date_time.timestamp()), + "time_to": int(end_date_time.timestamp()), + "tags": "", + "usernames": username, + "west": 10.581214853439886, + "east": 38.08577769782265, + "south": -36.799337832603314, + "north": -19.260169583742446, + "limit": 100, + "offset": 0, + "sensor_name": f"airbeam3-{pollutant}", + "measurement_type": "Particulate Matter", + "unit_symbol": "µg/m³", + } + ) + } request = self.__request( endpoint=f"mobile/sessions.json", params=params, @@ -65,32 +65,32 @@ def get_measurements( endpoint=f"measurements.json", params=params, ) - - def __request(self, endpoint, params): + def __request(self, endpoint, params): url = f"{self.AIR_BEAM_BASE_URL}{endpoint}" retry_strategy = Retry( total=5, backoff_factor=5, ) - + http = urllib3.PoolManager(retries=retry_strategy) - + try: response = http.request( - "GET", - url, - fields=params,) - + "GET", + url, + fields=params, + ) + response_data = response.data print(response._request_url) - + if response.status == 200: return json.loads(response_data) else: Utils.handle_api_error(response) return None - + except urllib3.exceptions.HTTPError as e: print(f"HTTPError: {e}") return None diff --git a/src/airflow/airqo_etl_utils/airnow_api.py b/src/airflow/airqo_etl_utils/airnow_api.py index 851afe89cb..a1e65b8189 100644 --- a/src/airflow/airqo_etl_utils/airnow_api.py +++ b/src/airflow/airqo_etl_utils/airnow_api.py @@ -56,20 +56,20 @@ def __request(self, endpoint, params, api_key): total=5, backoff_factor=5, ) - + http = urllib3.PoolManager(retries=retry_strategy) - + try: response = http.request("GET", url, fields=params) response_data = response.data print(response._request_url) - + if response.status == 200: return json.loads(response_data) else: Utils.handle_api_error(response) return None - + except urllib3.exceptions.HTTPError as e: print(f"HTTPError: {e}") return None diff --git a/src/airflow/airqo_etl_utils/airqo_api.py b/src/airflow/airqo_etl_utils/airqo_api.py index 876f630a9d..e80cd7bf39 100644 --- a/src/airflow/airqo_etl_utils/airqo_api.py +++ b/src/airflow/airqo_etl_utils/airqo_api.py @@ -322,7 +322,7 @@ def __request(self, endpoint, params=None, body=None, method=None, base_url=None params.update({"token": self.AIRQO_API_TOKEN}) retry_strategy = Retry( - total=5, + total=5, backoff_factor=5, ) @@ -338,21 +338,21 @@ def __request(self, endpoint, params=None, body=None, method=None, base_url=None encoded_args = urlencode(params) url = url + "?" + encoded_args response = http.request( - "PUT", + "PUT", url, - headers=headers, - body=simplejson.dumps(body, ignore_nan=True) - ) + headers=headers, + body=simplejson.dumps(body, ignore_nan=True), + ) elif method == "post": headers["Content-Type"] = "application/json" encoded_args = urlencode(params) url = url + "?" + encoded_args response = http.request( - "POST", + "POST", url, - headers=headers, - body=simplejson.dumps(body, ignore_nan=True) - ) + headers=headers, + body=simplejson.dumps(body, ignore_nan=True), + ) else: handle_api_error("Invalid") return None @@ -368,4 +368,3 @@ def __request(self, endpoint, params=None, body=None, method=None, base_url=None except urllib3.exceptions.HTTPError as e: print(f"HTTPError: {e}") return None - diff --git a/src/airflow/airqo_etl_utils/bigquery_api.py b/src/airflow/airqo_etl_utils/bigquery_api.py index 4a2ed54181..2e66b9fc10 100644 --- a/src/airflow/airqo_etl_utils/bigquery_api.py +++ b/src/airflow/airqo_etl_utils/bigquery_api.py @@ -615,20 +615,37 @@ def fetch_raw_readings(self) -> pd.DataFrame: except Exception as e: raise e - def fetch_data(self, start_date_time: str, historical: bool = False): - # historical is for the actual jobs, not training + def fetch_data( + self, + start_date_time: str, + ) -> pd.DataFrame: + try: + pd.to_datetime(start_date_time) + except ValueError: + raise ValueError(f"Invalid start date time: {start_date_time}") query = f""" - SELECT DISTINCT timestamp as created_at, {"site_id," if historical else ""} device_number, pm2_5_calibrated_value as pm2_5 - FROM `{configuration.BIGQUERY_HOURLY_EVENTS_TABLE_PROD}` - WHERE DATE(timestamp) >= '{start_date_time}' and device_number IS NOT NULL - ORDER BY created_at, device_number - """ + SELECT DISTINCT + t1.device_id, + t1.timestamp, + t1.site_id, + t1.pm2_5_calibrated_value as pm2_5, + t2.latitude, + t2.longitude, + t3.device_category + FROM `{self.hourly_measurements_table_prod}` t1 + JOIN `{self.sites_table}` t2 on t1.site_id = t2.id + JOIN `{self.devices_table}` t3 on t1.device_id = t3.device_id + WHERE date(t1.timestamp) >= '{start_date_time}' and t1.device_id IS NOT NULL + ORDER BY device_id, timestamp""" job_config = bigquery.QueryJobConfig() job_config.use_query_cache = True - df = self.client.query(f"{query}", job_config).result().to_dataframe() - return df + try: + df = self.client.query(query, job_config).result().to_dataframe() + return df + except Exception as e: + print("Error fetching data from bigquery") @staticmethod def save_forecasts_to_bigquery(df, table): diff --git a/src/airflow/airqo_etl_utils/config.py b/src/airflow/airqo_etl_utils/config.py index 065efecdb5..3f2768bf8a 100644 --- a/src/airflow/airqo_etl_utils/config.py +++ b/src/airflow/airqo_etl_utils/config.py @@ -1,6 +1,7 @@ import os from pathlib import Path +import pymongo as pm import urllib3 from dotenv import load_dotenv @@ -170,6 +171,10 @@ class Config: FORECAST_MODELS_BUCKET = os.getenv("FORECAST_MODELS_BUCKET") MONGO_URI = os.getenv("MONGO_URI") MONGO_DATABASE_NAME = os.getenv("MONGO_DATABASE_NAME") + ENVIRONMENT = os.getenv("ENVIRONMENT") configuration = Config() + +client = pm.MongoClient(configuration.MONGO_URI) +db = client[configuration.MONGO_DATABASE_NAME] diff --git a/src/airflow/airqo_etl_utils/ml_utils.py b/src/airflow/airqo_etl_utils/ml_utils.py index 8a7adb8e5a..53bf7af901 100644 --- a/src/airflow/airqo_etl_utils/ml_utils.py +++ b/src/airflow/airqo_etl_utils/ml_utils.py @@ -1,606 +1,676 @@ -from datetime import datetime +import json +import random +from datetime import datetime, timedelta import gcsfs import joblib import mlflow import numpy as np +import optuna import pandas as pd -import pymongo as pm from lightgbm import LGBMRegressor, early_stopping -from scipy.stats import skew from sklearn.metrics import mean_squared_error -from .config import configuration +from .config import configuration, db -fixed_columns = ["site_id"] project_id = configuration.GOOGLE_CLOUD_PROJECT_ID bucket = configuration.FORECAST_MODELS_BUCKET +environment = configuration.ENVIRONMENT +pd.options.mode.chained_assignment = None -def get_trained_model_from_gcs(project_name, bucket_name, source_blob_name): - fs = gcsfs.GCSFileSystem(project=project_name) - fs.ls(bucket_name) - with fs.open(bucket_name + "/" + source_blob_name, "rb") as handle: - job = joblib.load(handle) - return job +class GCSUtils: + """Utility class for saving and retrieving models from GCS""" -def upload_trained_model_to_gcs( - trained_model, project_name, bucket_name, source_blob_name -): - fs = gcsfs.GCSFileSystem(project=project_name) - - # backup previous model - try: - fs.rename( - f"{bucket_name}/{source_blob_name}", - f"{bucket_name}/{datetime.now()}-{source_blob_name}", - ) - print("Bucket: previous model is backed up") - except: - print("Bucket: No file to updated") - - # store new model - with fs.open(bucket_name + "/" + source_blob_name, "wb") as handle: - job = joblib.dump(trained_model, handle) - + # TODO: In future, save and retrieve models from mlflow instead of GCS + @staticmethod + def get_trained_model_from_gcs(project_name, bucket_name, source_blob_name): + fs = gcsfs.GCSFileSystem(project=project_name) + fs.ls(bucket_name) + with fs.open(bucket_name + "/" + source_blob_name, "rb") as handle: + job = joblib.load(handle) + return job -class ForecastUtils: - ###FORECAST MODEL TRAINING UTILS#### @staticmethod - def preprocess_training_data(data, frequency): - data["created_at"] = pd.to_datetime(data["created_at"]) - data["device_number"] = data["device_number"].astype(str) - data["pm2_5"] = data.groupby("device_number")["pm2_5"].transform( - lambda x: x.interpolate(method="linear", limit_direction="both") - ) - if frequency == "daily": - data = ( - data.groupby(["device_number"]) - .resample("D", on="created_at") - .mean(numeric_only=True) + def upload_trained_model_to_gcs( + trained_model, project_name, bucket_name, source_blob_name + ): + fs = gcsfs.GCSFileSystem(project=project_name) + try: + fs.rename( + f"{bucket_name}/{source_blob_name}", + f"{bucket_name}/{datetime.now()}-{source_blob_name}", ) - data.reset_index(inplace=True) - data["pm2_5"] = data.groupby("device_number")["pm2_5"].transform( - lambda x: x.interpolate(method="linear", limit_direction="both") - ) - data["device_number"] = data["device_number"].astype(int) - data = data.dropna(subset=["pm2_5"]) - return data + print("Bucket: previous model is backed up") + except: + print("Bucket: No file to updated") + + with fs.open(bucket_name + "/" + source_blob_name, "wb") as handle: + job = joblib.dump(trained_model, handle) @staticmethod - def feature_eng_training_data(data, target_column, frequency): - def get_lag_features(df, target_col, freq): - df = df.sort_values(by=["device_number", "created_at"]) - - if freq == "daily": - shifts = [1, 2] - for s in shifts: - df[f"pm2_5_last_{s}_day"] = df.groupby(["device_number"])[ - target_col - ].shift(s) - - shifts = [3, 7, 14, 30] - functions = ["mean", "std", "max", "min"] - for s in shifts: - for f in functions: - df[f"pm2_5_{f}_{s}_day"] = ( - df.groupby(["device_number"])[target_col] - .shift(1) - .rolling(s) - .agg(f) - ) - elif freq == "hourly": - shifts = [ - 1, - 2, - ] # TODO: Review to increase these both in training and the actual job - for s in shifts: - df[f"pm2_5_last_{s}_hour"] = df.groupby(["device_number"])[ - target_col - ].shift(s) - - # lag features - shifts = [6, 12, 24, 48] - functions = ["mean", "std", "median", "skew"] - for s in shifts: - for f in functions: - df[f"pm2_5_{f}_{s}_hour"] = ( - df.groupby(["device_number"])[target_col] - .shift(1) - .rolling(s) - .agg(f) - ) - else: - raise ValueError("Invalid frequency") - - return df - - def get_other_features(df_tmp, freq): - # TODO: Experiment on impact of features - attributes = ["year", "month", "day", "dayofweek"] - if freq == "hourly": - attributes.extend(["hour", "minute"]) - for a in attributes: - df_tmp[a] = df_tmp["created_at"].dt.__getattribute__(a) - df_tmp["week"] = df_tmp["created_at"].dt.isocalendar().week.astype(int) - - print("Additional features added") - return df_tmp - - data["created_at"] = pd.to_datetime(data["created_at"]) - df_tmp = get_other_features(data, frequency) - df_tmp = get_lag_features(df_tmp, target_column, frequency) - - return df_tmp + def upload_mapping_to_gcs( + mapping_dict, project_name, bucket_name, source_blob_name + ): + fs = gcsfs.GCSFileSystem(project=project_name) + mapping_dict = json.dumps(mapping_dict) + with fs.open(bucket_name + "/" + source_blob_name, "w") as f: + f.write(mapping_dict) @staticmethod - def train_and_save_hourly_forecast_model(train): # separate code for hourly model - """ - Perform the actual training for hourly data - """ - train["created_at"] = pd.to_datetime(train["created_at"]) - train = train.sort_values(by=["device_number", "created_at"]) - features = [c for c in train.columns if c not in ["created_at", "pm2_5"]] - print(features) - target_col = "pm2_5" - train_data, test_data = pd.DataFrame(), pd.DataFrame() - for device_number in train["device_number"].unique(): - device_df = train[train["device_number"] == device_number] - device_df = device_df.sort_values(by="created_at") - months = device_df["created_at"].dt.month.unique() - train_months = months[:4] - test_months = months[4:] - train_df = device_df[device_df["created_at"].dt.month.isin(train_months)] - test_df = device_df[device_df["created_at"].dt.month.isin(test_months)] - train_data = pd.concat([train_data, train_df]) - test_data = pd.concat([test_data, test_df]) + def get_mapping_from_gcs(project_name, bucket_name, source_blob_name): + fs = gcsfs.GCSFileSystem(project=project_name) + with fs.open(bucket_name + "/" + source_blob_name, "r") as f: + mapping_dict = json.load(f) + return mapping_dict - train_data.drop(columns=["created_at"], axis=1, inplace=True) - test_data.drop(columns=["created_at"], axis=1, inplace=True) - train_target, test_target = train_data[target_col], test_data[target_col] +class DecodingUtils: + """Utility class for encoding and decoding categorical features""" - with mlflow.start_run(): - print("Model training started.....") - n_estimators = 5000 - learning_rate = 0.05 - colsample_bytree = 0.4 - reg_alpha = 0 - reg_lambda = 1 - max_depth = 1 - random_state = 1 + @staticmethod + def decode_categorical_features_pred(df, frequency): + columns = ["device_id", "site_id", "device_category"] + mapping = {} + for col in columns: + if frequency == "hourly": + mapping = GCSUtils.get_mapping_from_gcs( + project_id, bucket, f"hourly_{col}_mapping.json" + ) + elif frequency == "daily": + mapping = GCSUtils.get_mapping_from_gcs( + project_id, bucket, f"daily_{col}_mapping.json" + ) + df[col] = df[col].map(mapping) + return df - clf = LGBMRegressor( - n_estimators=n_estimators, - learning_rate=learning_rate, - colsample_bytree=colsample_bytree, - reg_alpha=reg_alpha, - reg_lambda=reg_lambda, - max_depth=max_depth, - random_state=random_state, - ) + @staticmethod + def decode_categorical_features_before_save(df, frequency): + columns = ["device_id", "site_id", "device_category"] + mapping = {} + for col in columns: + if frequency == "hourly": + mapping = GCSUtils.get_mapping_from_gcs( + project_id, bucket, f"hourly_{col}_mapping.json" + ) + elif frequency == "daily": + mapping = GCSUtils.get_mapping_from_gcs( + project_id, bucket, f"daily_{col}_mapping.json" + ) + df[col] = df[col].map({v: k for k, v in mapping.items()}) + return df - clf.fit( - train_data[features], - train_target, - eval_set=[(test_data[features], test_target)], - callbacks=[early_stopping(stopping_rounds=150)], - eval_metric="rmse", - ) - print("Model training completed.....") - - # Log parameters - mlflow.log_param("n_estimators", n_estimators) - mlflow.log_param("learning_rate", learning_rate) - mlflow.log_param("colsample_bytree", colsample_bytree) - mlflow.log_param("reg_alpha", reg_alpha) - mlflow.log_param("reg_lamba", reg_lambda) - mlflow.log_param("max_depth", max_depth) - mlflow.log_param("random_state", random_state) - - # Log moder - mlflow.sklearn.log_model( - sk_model=clf, - artifact_path="hourly_forecast_model", - registered_model_name=f"LGBM_hourly_forecast_model_development", + @staticmethod + def encode_categorical_training_features(df, freq): + df["timestamp"] = pd.to_datetime(df["timestamp"]) + df1 = df.copy() + columns = ["device_id", "site_id", "device_category"] + mappings = [] + for col in columns: + mapping = {} + for val in df1[col].unique(): + num = random.randint(0, 1000) + while num in mapping.values(): + num = random.randint(0, 1000) + mapping[val] = num + df1[col] = df1[col].map(mapping) + mappings.append(mapping) + for i, col in enumerate(columns): + GCSUtils.upload_mapping_to_gcs( + mappings[i], + project_id, + bucket, + f"{freq}_{col}_mapping.json", ) + return df1 - print("Being model validation.....") - - val_preds = clf.predict(test_data[features]) - rmse_val = mean_squared_error(test_data[target_col], val_preds) ** 0.5 - - print("Model validation completed.....") - print(f"Validation RMSE is {rmse_val}") - - # Log metrics - mlflow.log_metric("VAL_RMSE", rmse_val) - - best_iter = clf.best_iteration_ - clf = LGBMRegressor( - n_estimators=best_iter, - learning_rate=0.05, - colsample_bytree=0.4, - reg_alpha=2, - reg_lambda=1, - max_depth=-1, - random_state=1, - verbosity=2, - ) - train["device_number"] = train["device_number"].astype(int) - clf.fit(train[features], train[target_col]) - upload_trained_model_to_gcs(clf, project_id, bucket, "hourly_forecast_model") +class ForecastUtils: @staticmethod - def train_and_save_daily_forecast_model(train): # separate code for monthly model - train["created_at"] = pd.to_datetime(train["created_at"]) - train = train.sort_values(by=["device_number", "created_at"]) - features = [c for c in train.columns if c not in ["created_at", "pm2_5"]] - print(features) - target_col = "pm2_5" - train_data, test_data = pd.DataFrame(), pd.DataFrame() - - for device_number in train["device_number"].unique(): - device_df = train[train["device_number"] == device_number] - device_df = device_df.sort_values(by="created_at") - months = device_df["created_at"].dt.month.unique() - train_months = months[:8] - test_months = months[8:] - train_df = device_df[device_df["created_at"].dt.month.isin(train_months)] - test_df = device_df[device_df["created_at"].dt.month.isin(test_months)] - train_data = pd.concat([train_data, train_df]) - test_data = pd.concat([test_data, test_df]) - - train_data.drop(columns=["created_at"], axis=1, inplace=True) - test_data.drop(columns=["created_at"], axis=1, inplace=True) - - train_target, test_target = train_data[target_col], test_data[target_col] - with mlflow.start_run(): - print("Model training started.....") - n_estimators = 5000 - learning_rate = 0.05 - colsample_bytree = 0.4 - reg_alpha = 0 - reg_lambda = 1 - max_depth = 1 - random_state = 1 - - clf = LGBMRegressor( - n_estimators=n_estimators, - learning_rate=learning_rate, - colsample_bytree=colsample_bytree, - reg_alpha=reg_alpha, - reg_lambda=reg_lambda, - max_depth=max_depth, - random_state=random_state, - ) - - clf.fit( - train_data[features], - train_target, - eval_set=[(test_data[features], test_target)], - callbacks=[early_stopping(stopping_rounds=150)], - eval_metric="rmse", + def preprocess_data(data, data_frequency): + required_columns = { + "device_id", + "site_id", + "device_category", + "pm2_5", + "timestamp", + } + if not required_columns.issubset(data.columns): + missing_columns = required_columns.difference(data.columns) + raise ValueError( + f"Provided dataframe missing necessary columns: {', '.join(missing_columns)}" ) - print("Model training completed.....") - - # Log parameters - mlflow.log_param("n_estimators", n_estimators) - mlflow.log_param("learning_rate", learning_rate) - mlflow.log_param("colsample_bytree", colsample_bytree) - mlflow.log_param("reg_alpha", reg_alpha) - mlflow.log_param("reg_lamba", reg_lambda) - mlflow.log_param("max_depth", max_depth) - mlflow.log_param("random_state", random_state) - - # Log model - mlflow.sklearn.log_model( - sk_model=clf, - artifact_path="daily_forecast_model", - registered_model_name=f"LGBM_daily_forecast_model_development", - ) - - # model validation - print("Being model validation.....") - - val_preds = clf.predict(test_data[features]) - rmse_val = mean_squared_error(test_data[target_col], val_preds) ** 0.5 - - print("Model validation completed.....") - print(f"Validation RMSE is {rmse_val}") - - # Log metrics - mlflow.log_metric("VAL_RMSE", rmse_val) - - best_iter = clf.best_iteration_ - clf = LGBMRegressor( - n_estimators=best_iter, - learning_rate=0.05, - colsample_bytree=0.4, - reg_alpha=2, - reg_lambda=1, - max_depth=-1, - random_state=1, + try: + data["timestamp"] = pd.to_datetime(data["timestamp"]) + except ValueError as e: + raise ValueError( + "datetime conversion error, please provide timestamp in valid format" ) - clf.fit(train[features], train[target_col]) - upload_trained_model_to_gcs(clf, project_id, bucket, "daily_forecast_model.pkl") - print("Model saved successfully") - - #### FORECAST JOB UTILS #### - - @staticmethod - def preprocess_historical_data(data, frequency): - data["created_at"] = pd.to_datetime(data["created_at"]) - data["device_number"] = data["device_number"].astype(str) - data["pm2_5"] = data.groupby(fixed_columns + ["device_number"])[ + data["pm2_5"] = data.groupby(["device_id", "site_id", "device_category"])[ "pm2_5" ].transform(lambda x: x.interpolate(method="linear", limit_direction="both")) - if frequency == "hourly": - data.sort_values( - by=fixed_columns + ["device_number", "created_at"], inplace=True - ) - elif frequency == "daily": + if data_frequency == "daily": data = ( - data.groupby(fixed_columns + ["device_number"]) - .resample("D", on="created_at") + data.groupby(["device_id", "site_id", "device_category"]) + .resample("D", on="timestamp") .mean(numeric_only=True) ) data.reset_index(inplace=True) - data["pm2_5"] = data.groupby(fixed_columns + ["device_number"])[ + data["pm2_5"] = data.groupby(["device_id", "site_id", "device_category"])[ "pm2_5" ].transform( lambda x: x.interpolate(method="linear", limit_direction="both") ) - data.sort_values( - by=fixed_columns + ["device_number", "created_at"], inplace=True - ) - else: - raise ValueError("Invalid frequency argument") - data["device_number"] = data["device_number"].astype(int) data = data.dropna(subset=["pm2_5"]) return data @staticmethod - def get_lag_features(df_tmp, TARGET_COL, frequency): - df_tmp["created_at"] = pd.to_datetime(df_tmp["created_at"]) - df_tmp = df_tmp.sort_values(by=fixed_columns + ["device_number", "created_at"]) - if frequency == "hourly": - shifts = [1, 2] + def get_lag_and_roll_features(df, target_col, freq): + if df.empty: + raise ValueError("Empty dataframe provided") + + if ( + target_col not in df.columns + or "timestamp" not in df.columns + or "device_id" not in df.columns + ): + raise ValueError("Required columns missing") + + df["timestamp"] = pd.to_datetime(df["timestamp"]) + + df1 = df.copy() # use copy to prevent terminal warning + if freq == "daily": + shifts = [1, 2, 3, 7] for s in shifts: - df_tmp[f"pm2_5_last_{s}_hour"] = df_tmp.groupby(["device_number"])[ - TARGET_COL + df1[f"pm2_5_last_{s}_day"] = df1.groupby(["device_id"])[ + target_col ].shift(s) - - shifts = [6, 12, 24, 48] - functions = ["mean", "std", "median", "skew"] + shifts = [2, 3, 7] + functions = ["mean", "std", "max", "min"] for s in shifts: for f in functions: - df_tmp[f"pm2_5_{f}_{s}_hour"] = ( - df_tmp.groupby(["device_number"])[TARGET_COL] + df1[f"pm2_5_{f}_{s}_day"] = ( + df1.groupby(["device_id"])[target_col] .shift(1) .rolling(s) .agg(f) ) - elif frequency == "daily": - shifts = [1, 2] + elif freq == "hourly": + shifts = [1, 2, 6, 12] for s in shifts: - df_tmp[f"pm2_5_last_{s}_day"] = df_tmp.groupby(["device_number"])[ - TARGET_COL + df1[f"pm2_5_last_{s}_hour"] = df1.groupby(["device_id"])[ + target_col ].shift(s) - shifts = [3, 7, 14, 30] - functions = ["mean", "std", "max", "min"] + shifts = [3, 6, 12, 24] + functions = ["mean", "std", "median", "skew"] for s in shifts: for f in functions: - df_tmp[f"pm2_5_{f}_{s}_day"] = ( - df_tmp.groupby(["device_number"])[TARGET_COL] + df1[f"pm2_5_{f}_{s}_hour"] = ( + df1.groupby(["device_id"])[target_col] .shift(1) .rolling(s) .agg(f) ) else: - raise ValueError("Invalid frequency argument") - print("Adding lag features") - return df_tmp + raise ValueError("Invalid frequency") + return df1 @staticmethod - def get_time_features(df_tmp, frequency): - df_tmp["created_at"] = pd.to_datetime(df_tmp["created_at"]) + def get_time_and_cyclic_features(df, freq): + if df.empty: + raise ValueError("Empty dataframe provided") + + if "timestamp" not in df.columns: + raise ValueError("Required columns missing") + + df["timestamp"] = pd.to_datetime(df["timestamp"]) + + if freq not in ["daily", "hourly"]: + raise ValueError("Invalid frequency") + df["timestamp"] = pd.to_datetime(df["timestamp"]) + df1 = df.copy() attributes = ["year", "month", "day", "dayofweek"] - if frequency == "hourly": - attributes.extend(["hour", "minute"]) - for a in attributes: - df_tmp[a] = df_tmp["created_at"].dt.__getattribute__(a) + max_vals = [2023, 12, 30, 7] + if freq == "hourly": + attributes.append("hour") + max_vals.append(23) + for a, m in zip(attributes, max_vals): + df1[a] = df1["timestamp"].dt.__getattribute__(a) + df1[a + "_sin"] = np.sin(2 * np.pi * df1[a] / m) + df1[a + "_cos"] = np.cos(2 * np.pi * df1[a] / m) + + df1["week"] = df1["timestamp"].dt.isocalendar().week + df1["week_sin"] = np.sin(2 * np.pi * df1["week"] / 52) + df1["week_cos"] = np.cos(2 * np.pi * df1["week"] / 52) + df1.drop(columns=attributes + ["week"], inplace=True) + return df1 - df_tmp["week"] = df_tmp["created_at"].dt.isocalendar().week - print("Adding other features") - return df_tmp + @staticmethod + def get_location_features(df): + if df.empty: + raise ValueError("Empty dataframe provided") + + for column_name in ["timestamp", "latitude", "longitude"]: + if column_name not in df.columns: + raise ValueError(f"{column_name} column is missing") + + df["timestamp"] = pd.to_datetime(df["timestamp"]) + + df["x_cord"] = np.cos(df["latitude"]) * np.cos(df["longitude"]) + df["y_cord"] = np.cos(df["latitude"]) * np.sin(df["longitude"]) + df["z_cord"] = np.sin(df["latitude"]) + + return df + + # df_tmp = get_lag_features(df_tmp, target_column, data_frequency) + # df_tmp = get_time_and_cyclic_features(df_tmp, data_frequency) + # df_tmp = get_location_cord(df_tmp) + # if job_type == "train": + # df_tmp = DecodingUtils.encode_categorical_training_features( + # df_tmp, data_frequency + # ) + # elif job_type == "predict": + # df_tmp = DecodingUtils.decode_categorical_features_pred( + # df_tmp, data_frequency + # ) + # df_tmp.dropna( + # subset=["device_id", "site_id", "device_category"], inplace=True + # ) # only 1 row, not sure why + # + # df_tmp["device_id"] = df_tmp["device_id"].astype(int) + # df_tmp["site_id"] = df_tmp["site_id"].astype(int) + # df_tmp["device_category"] = df_tmp["device_category"].astype(int) + # + # return df_tmp @staticmethod - def generate_hourly_forecasts(data, project_name, bucket_name, source_blob_name): - data["created_at"] = pd.to_datetime(data["created_at"]) - - def get_new_row(df, device1, model): - last_row = df[df["device_number"] == device1].iloc[-1] - new_row = pd.Series(index=last_row.index, dtype="float64") - for i in fixed_columns: - new_row[i] = last_row[i] - new_row["created_at"] = last_row["created_at"] + pd.Timedelta(hours=1) - new_row["device_number"] = device1 - new_row[f"pm2_5_last_1_hour"] = last_row["pm2_5"] - new_row[f"pm2_5_last_2_hour"] = last_row[f"pm2_5_last_{1}_hour"] - - shifts = [6, 12, 24, 48] - functions = ["mean", "std", "median", "skew"] - for s in shifts: - for f in functions: - if f == "mean": - new_row[f"pm2_5_{f}_{s}_hour"] = ( - last_row["pm2_5"] - + last_row[f"pm2_5_{f}_{s}_hour"] * (s - 1) - ) / s - elif f == "std": - new_row[f"pm2_5_{f}_{s}_hour"] = ( - np.sqrt( - (last_row["pm2_5"] - last_row[f"pm2_5_mean_{s}_hour"]) - ** 2 - + (last_row[f"pm2_5_{f}_{s}_hour"] ** 2 * (s - 1)) - ) - / s - ) - elif f == "median": - new_row[f"pm2_5_{f}_{s}_hour"] = np.median( - np.append( - last_row["pm2_5"], last_row[f"pm2_5_{f}_{s}_hour"] - ) - ) - elif f == "skew": - new_row[f"pm2_5_{f}_{s}_hour"] = skew( - np.append( - last_row["pm2_5"], last_row[f"pm2_5_{f}_{s}_hour"] - ) - ) - - attributes = ["year", "month", "day", "dayofweek", "hour", "minute"] - for a in attributes: - new_row[a] = new_row["created_at"].__getattribute__(a) - new_row["week"] = new_row["created_at"].isocalendar().week - - new_row["pm2_5"] = model.predict( - new_row.drop(fixed_columns + ["created_at", "pm2_5"]).values.reshape( - 1, -1 - ) - )[0] - return new_row + def train_and_save_forecast_models(training_data, frequency): + """ + Perform the actual training for hourly data + """ + training_data.dropna( + subset=["device_id", "site_id", "device_category"], inplace=True + ) - forecasts = pd.DataFrame() - forecast_model = get_trained_model_from_gcs( - project_name, bucket_name, source_blob_name + training_data["device_id"] = training_data["device_id"].astype(int) + training_data["site_id"] = training_data["site_id"].astype(int) + training_data["device_category"] = training_data["device_category"].astype(int) + + training_data["timestamp"] = pd.to_datetime(training_data["timestamp"]) + features = [ + c + for c in training_data.columns + if c not in ["timestamp", "pm2_5", "latitude", "longitude"] + ] + print(features) + target_col = "pm2_5" + train_data = validation_data = test_data = pd.DataFrame() + for device in training_data["device_id"].unique(): + device_df = training_data[training_data["device_id"] == device] + months = device_df["timestamp"].dt.month.unique() + train_months = months[:8] + val_months = months[8:9] + test_months = months[9:] + train_df = device_df[device_df["timestamp"].dt.month.isin(train_months)] + val_df = device_df[device_df["timestamp"].dt.month.isin(val_months)] + test_df = device_df[device_df["timestamp"].dt.month.isin(test_months)] + train_data = pd.concat([train_data, train_df]) + validation_data = pd.concat([validation_data, val_df]) + test_data = pd.concat([test_data, test_df]) + + train_data.drop(columns=["timestamp"], axis=1, inplace=True) + validation_data.drop(columns=["timestamp"], axis=1, inplace=True) + test_data.drop(columns=["timestamp"], axis=1, inplace=True) + + train_target, validation_target, test_target = ( + train_data[target_col], + validation_data[target_col], + test_data[target_col], ) - df_tmp = data.copy() - for device in df_tmp["device_number"].unique(): - test_copy = df_tmp[df_tmp["device_number"] == device] - for i in range(int(configuration.HOURLY_FORECAST_HORIZON)): - new_row = get_new_row(test_copy, device, forecast_model) - test_copy = pd.concat( - [test_copy, new_row.to_frame().T], ignore_index=True + + sampler = optuna.samplers.TPESampler() + pruner = optuna.pruners.SuccessiveHalvingPruner( + min_resource=10, reduction_factor=2, min_early_stopping_rate=0 + ) + study = optuna.create_study( + direction="minimize", study_name="LGBM", sampler=sampler, pruner=pruner + ) + + def objective(trial): + param_grid = { + "colsample_bytree": trial.suggest_float("colsample_bytree", 0.1, 1), + "reg_alpha": trial.suggest_float("reg_alpha", 0, 10), + "reg_lambda": trial.suggest_float("reg_lambda", 0, 10), + "n_estimators": trial.suggest_categorical("n_estimators", [50]), + "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3), + "num_leaves": trial.suggest_int("num_leaves", 20, 50), + "max_depth": trial.suggest_int("max_depth", 4, 7), + } + score = 0 + for step in range(4): + lgb_reg = LGBMRegressor( + objective="regression", + random_state=42, + **param_grid, + verbosity=2, + ) + lgb_reg.fit( + train_data[features], + train_target, + categorical_feature=["device_id", "site_id", "device_category"], + eval_set=[(test_data[features], test_target)], + eval_metric="rmse", + callbacks=[early_stopping(stopping_rounds=150)], ) - forecasts = pd.concat([forecasts, test_copy], ignore_index=True) - forecasts["device_number"] = forecasts["device_number"].astype(int) - forecasts["pm2_5"] = forecasts["pm2_5"].astype(float) - forecasts.rename(columns={"created_at": "time"}, inplace=True) - forecasts["time"] = pd.to_datetime(forecasts["time"], utc=True) - current_time = datetime.utcnow() - current_time_utc = pd.Timestamp(current_time, tz="UTC") - result = forecasts[fixed_columns + ["time", "pm2_5", "device_number"]][ - forecasts["time"] >= current_time_utc - ] + val_preds = lgb_reg.predict(validation_data[features]) + score = mean_squared_error(validation_target, val_preds) + if trial.should_prune(): + raise optuna.TrialPruned() - return result + return score + + study.optimize(objective, n_trials=15) + + mlflow.set_tracking_uri(configuration.MLFLOW_TRACKING_URI) + mlflow.set_experiment(f"{frequency}_forecast_model_{environment}") + registered_model_name = f"{frequency}_forecast_model_{environment}" + + mlflow.lightgbm.autolog( + registered_model_name=registered_model_name, log_datasets=False + ) + with mlflow.start_run(): + best_params = study.best_params + print(f"Best params are {best_params}") + clf = LGBMRegressor( + n_estimators=best_params["n_estimators"], + learning_rate=best_params["learning_rate"], + colsample_bytree=best_params["colsample_bytree"], + reg_alpha=best_params["reg_alpha"], + reg_lambda=best_params["reg_lambda"], + max_depth=best_params["max_depth"], + random_state=42, + verbosity=2, + ) + + clf.fit( + train_data[features], + train_target, + eval_set=[(test_data[features], test_target)], + eval_metric="rmse", + categorical_feature=["device_id", "site_id", "device_category"], + callbacks=[early_stopping(stopping_rounds=150)], + ) + + GCSUtils.upload_trained_model_to_gcs( + clf, project_id, bucket, f"{frequency}_forecast_model.pkl" + ) + + # def create_error_df(data, target, preds): + # error_df = pd.DataFrame( + # { + # "actual_values": target, + # "predicted_values": preds, + # } + # ) + # error_df["errors"] = ( + # error_df["predicted_values"] - error_df["actual_values"] + # ) + # error_df = pd.concat([error_df, data], axis=1) + # error_df.drop(["actual_values", "pm2_5"], axis=1, inplace=True) + # error_df.rename(columns={"predicted_values": "pm2_5"}, inplace=True) + # + # return error_df + # + # error_df1 = create_error_df( + # train_data, train_target, clf.predict(train_data[features]) + # ) + # error_df2 = create_error_df( + # test_data, test_target, clf.predict(test_data[features]) + # ) + # + # error_features1 = [c for c in error_df1.columns if c not in ["errors"]] + # error_features2 = [c for c in error_df2.columns if c not in ["errors"]] + # + # error_target1 = error_df1["errors"] + # error_target2 = error_df2["errors"] + # + # error_clf = LGBMRegressor( + # n_estimators=31, + # colsample_bytree=1, + # learning_rate=0.1, + # metric="rmse", + # max_depth=5, + # random_state=42, + # verbosity=2, + # ) + # + # error_clf.fit( + # error_df1[error_features1], + # error_target1, + # eval_set=[(error_df2[error_features2], error_target2)], + # categorical_feature=["device_id", "site_id", "device_category"], + # callbacks=[early_stopping(stopping_rounds=150)], + # ) + # + # GCSUtils.upload_trained_model_to_gcs( + # error_clf, project_id, bucket, f"{frequency}_error_model.pkl" + # ) + + # TODO: quantile regression approach + # alphas = [0.025, 0.975] + # models = [] + # names = [ + # f"{frequency}_lower_quantile_model", + # f"{frequency}_upper_quantile_model", + # ] + # + # for alpha in alphas: + # clf = LGBMRegressor( + # n_estimators=best_params["n_estimators"], + # learning_rate=best_params["learning_rate"], + # colsample_bytree=best_params["colsample_bytree"], + # reg_alpha=best_params["reg_alpha"], + # reg_lambda=best_params["reg_lambda"], + # max_depth=best_params["max_depth"], + # random_state=42, + # verbosity=2, + # objective="quantile", + # alpha=alpha, + # metric="quantile", + # ) + # clf.fit( + # train_data[features], + # train_target, + # eval_set=[(test_data[features], test_target)], + # categorical_feature=["device_id", "site_id", "device_category"], + # ) + # models.append(clf) + # for n, m in zip(names, models): + # upload_trained_model_to_gcs(m, project_id, bucket, f"{n}.pkl") @staticmethod - def generate_daily_forecasts(data, project_name, bucket_name, source_blob_name): - data["created_at"] = pd.to_datetime(data["created_at"]) - - def get_new_row(df_tmp, device, model): - last_row = df_tmp[df_tmp["device_number"] == device].iloc[-1] - new_row = pd.Series(index=last_row.index, dtype="float64") - for i in fixed_columns: - new_row[i] = last_row[i] - new_row["created_at"] = last_row["created_at"] + pd.Timedelta(days=1) - new_row["device_number"] = device - new_row[f"pm2_5_last_1_day"] = last_row["pm2_5"] - new_row[f"pm2_5_last_2_day"] = last_row[f"pm2_5_last_{1}_day"] - - shifts = [3, 7, 14, 30] - functions = ["mean", "std", "max", "min"] - for s in shifts: - for f in functions: - if f == "mean": - new_row[f"pm2_5_{f}_{s}_day"] = ( - last_row["pm2_5"] + last_row[f"pm2_5_{f}_{s}_day"] * (s - 1) - ) / s - elif f == "std": - new_row[f"pm2_5_{f}_{s}_day"] = ( - np.sqrt( - (last_row["pm2_5"] - last_row[f"pm2_5_mean_{s}_day"]) - ** 2 - + (last_row[f"pm2_5_{f}_{s}_day"] ** 2 * (s - 1)) - ) - / s - ) - elif f == "max": - new_row[f"pm2_5_{f}_{s}_day"] = max( - last_row["pm2_5"], last_row[f"pm2_5_{f}_{s}_day"] - ) - elif f == "min": - new_row[f"pm2_5_{f}_{s}_day"] = min( - last_row["pm2_5"], last_row[f"pm2_5_{f}_{s}_day"] - ) - - # Use the date of the new row to create other features - attributes = ["year", "month", "day", "dayofweek"] - for a in attributes: - new_row[a] = new_row["created_at"].__getattribute__(a) - new_row["week"] = new_row["created_at"].isocalendar().week - - new_row["pm2_5"] = model.predict( - new_row.drop(fixed_columns + ["created_at", "pm2_5"]).values.reshape( - 1, -1 + def generate_forecasts(data, project_name, bucket_name, frequency): + data = data.dropna(subset=["device_id"]) + data["timestamp"] = pd.to_datetime(data["timestamp"]) + data.columns = data.columns.str.strip() + # data["margin_of_error"] = data["adjusted_forecast"] = 0 + + def get_forecasts( + df_tmp, + forecast_model, + frequency, + horizon, + ): + """This method generates forecasts for a given device dataframe basing on horizon provided""" + for i in range(int(horizon)): + df_tmp = pd.concat([df_tmp, df_tmp.iloc[-1:]], ignore_index=True) + df_tmp_no_ts = df_tmp.drop("timestamp", axis=1, inplace=False) + # daily frequency + if frequency == "daily": + df_tmp.tail(1)["timestamp"] += timedelta(days=1) + shifts1 = [1, 2, 3, 7] + for s in shifts1: + df_tmp[f"pm2_5_last_{s}_day"] = df_tmp.shift(s, axis=0)["pm2_5"] + # rolling features + shifts2 = [2, 3, 7] + functions = ["mean", "std", "max", "min"] + for s in shifts2: + for f in functions: + df_tmp[f"pm2_5_{f}_{s}_day"] = ( + df_tmp_no_ts.shift(1, axis=0).rolling(s).agg(f) + )["pm2_5"] + + elif frequency == "hourly": + df_tmp.iloc[-1, df_tmp.columns.get_loc("timestamp")] = df_tmp.iloc[ + -2, df_tmp.columns.get_loc("timestamp") + ] + pd.Timedelta(hours=1) + + # lag features + shifts1 = [1, 2, 6, 12] + for s in shifts1: + df_tmp[f"pm2_5_last_{s}_hour"] = df_tmp.shift(s, axis=0)[ + "pm2_5" + ] + + # rolling features + shifts2 = [3, 6, 12, 24] + functions = ["mean", "std", "median", "skew"] + for s in shifts2: + for f in functions: + df_tmp[f"pm2_5_{f}_{s}_hour"] = ( + df_tmp_no_ts.shift(1, axis=0).rolling(s).agg(f) + )["pm2_5"] + + attributes = ["year", "month", "day", "dayofweek"] + max_vals = [2023, 12, 30, 7] + if frequency == "hourly": + attributes.append("hour") + max_vals.append(23) + for a, m in zip(attributes, max_vals): + df_tmp.tail(1)[f"{a}_sin"] = np.sin( + 2 + * np.pi + * df_tmp.tail(1)["timestamp"].dt.__getattribute__(a) + / m + ) + df_tmp.tail(1)[f"{a}_cos"] = np.cos( + 2 + * np.pi + * df_tmp.tail(1)["timestamp"].dt.__getattribute__(a) + / m + ) + df_tmp.tail(1)["week_sin"] = np.sin( + 2 * np.pi * df_tmp.tail(1)["timestamp"].dt.isocalendar().week / 52 + ) + df_tmp.tail(1)["week_cos"] = np.cos( + 2 * np.pi * df_tmp.tail(1)["timestamp"].dt.isocalendar().week / 52 ) - )[0] - return new_row - forecasts = pd.DataFrame() + excluded_columns = [ + "pm2_5", + "timestamp", + "latitude", + "longitude", + # "margin_of_error", + # "adjusted_forecast", + ] + # excluded_columns_2 = [ + # "timestamp", + # "margin_of_error", + # "adjusted_forecast", + # ] + df_tmp.loc[df_tmp.index[-1], "pm2_5"] = forecast_model.predict( + df_tmp.drop(excluded_columns, axis=1).tail(1).values.reshape(1, -1) + ) + # df_tmp.loc[df_tmp.index[-1], "margin_of_error"] = error_model.predict( + # df_tmp.drop(excluded_columns_2, axis=1) + # .tail(1) + # .values.reshape(1, -1) + # ) + # df_tmp.loc[df_tmp.index[-1], "adjusted_forecast"] = ( + # df_tmp.loc[df_tmp.index[-1], "pm2_5"] + # + df_tmp.loc[df_tmp.index[-1], "margin_of_error"] + # ) + + return df_tmp.iloc[-int(horizon) :, :] - forecast_model = get_trained_model_from_gcs( - project_name, bucket_name, source_blob_name + forecasts = pd.DataFrame() + forecast_model = GCSUtils.get_trained_model_from_gcs( + project_name, bucket_name, f"{frequency}_forecast_model.pkl" ) + # error_model = GCSUtils.get_trained_model_from_gcs( + # project_name, bucket_name, f"{frequency}_error_model.pkl" + # ) df_tmp = data.copy() - for device in df_tmp["device_number"].unique(): - test_copy = df_tmp[df_tmp["device_number"] == device] - for i in range(int(configuration.DAILY_FORECAST_HORIZON)): - new_row = get_new_row( - test_copy, - device, - forecast_model, - ) - test_copy = pd.concat( - [test_copy, new_row.to_frame().T], ignore_index=True - ) - forecasts = pd.concat([forecasts, test_copy], ignore_index=True) - forecasts["device_number"] = forecasts["device_number"].astype(int) + for device in df_tmp["device_id"].unique(): + test_copy = df_tmp[df_tmp["device_id"] == device] + horizon = ( + configuration.HOURLY_FORECAST_HORIZON + if frequency == "hourly" + else configuration.DAILY_FORECAST_HORIZON + ) + device_forecasts = get_forecasts( + test_copy, + forecast_model, + frequency, + horizon, + ) + + forecasts = pd.concat([forecasts, device_forecasts], ignore_index=True) + forecasts["pm2_5"] = forecasts["pm2_5"].astype(float) - forecasts.rename(columns={"created_at": "time"}, inplace=True) - current_time = datetime.utcnow() - current_time_utc = pd.Timestamp(current_time, tz="UTC") - result = forecasts[fixed_columns + ["time", "pm2_5", "device_number"]][ - forecasts["time"] >= current_time_utc + # forecasts["margin_of_error"] = forecasts["margin_of_error"].astype(float) + + DecodingUtils.decode_categorical_features_before_save(forecasts, frequency) + forecasts = forecasts[ + [ + "device_id", + "site_id", + "timestamp", + "pm2_5", + # "margin_of_error", + # "adjusted_forecast", + ] ] - - return result + return forecasts @staticmethod def save_forecasts_to_mongo(data, frequency): + device_ids = data["device_id"].unique() created_at = pd.to_datetime(datetime.now()).isoformat() - device_numbers = data["device_number"].unique() - forecast_results = [ - { - field: data[data["device_number"] == i][field].tolist()[0] - if field != "pm2_5" and field != "time" and field != "health_tips" - else data[data["device_number"] == i][field].tolist() - for field in data.columns + + forecast_results = [] + for i in device_ids: + doc = { + "device_id": i, + "created_at": created_at, + "pm2_5": data[data["device_id"] == i]["pm2_5"].tolist(), + "timestamp": data[data["device_id"] == i]["timestamp"].tolist(), } - | {"created_at": created_at} - for i in device_numbers - ] - client = pm.MongoClient(configuration.MONGO_URI) - db = client[configuration.MONGO_DATABASE_NAME] + forecast_results.append(doc) + if frequency == "hourly": - db.hourly_forecasts.insert_many(forecast_results) + collection = db.hourly_forecasts elif frequency == "daily": - db.daily_forecasts.insert_many(forecast_results) + collection = db.daily_forecasts else: raise ValueError("Invalid frequency argument") + + for doc in forecast_results: + try: + filter_query = {"device_id": doc["device_id"]} + update_query = { + "$set": { + "pm2_5": doc["pm2_5"], + "timestamp": doc["timestamp"], + "created_at": doc["created_at"], + } + } + collection.update_one(filter_query, update_query, upsert=True) + except Exception as e: + print( + f"Failed to update forecast for device {doc['device_id']}: {str(e)}" + ) diff --git a/src/airflow/airqo_etl_utils/plume_labs_api.py b/src/airflow/airqo_etl_utils/plume_labs_api.py index b3ba62e0ac..def07e7114 100644 --- a/src/airflow/airqo_etl_utils/plume_labs_api.py +++ b/src/airflow/airqo_etl_utils/plume_labs_api.py @@ -182,24 +182,25 @@ def __request(self, endpoint, params): total=5, backoff_factor=5, ) - + http = urllib3.PoolManager(retries=retry_strategy) - + try: response = http.request( - "GET", - url, - fields=params,) - + "GET", + url, + fields=params, + ) + response_data = response.data print(response._request_url) - + if response.status == 200: return json.loads(response_data) else: Utils.handle_api_error(response) return None - + except urllib3.exceptions.HTTPError as e: print(f"HTTPError: {e}") return None diff --git a/src/airflow/airqo_etl_utils/purple_air_api.py b/src/airflow/airqo_etl_utils/purple_air_api.py index b6dd0ec4d4..025f26283d 100644 --- a/src/airflow/airqo_etl_utils/purple_air_api.py +++ b/src/airflow/airqo_etl_utils/purple_air_api.py @@ -32,31 +32,31 @@ def get_data( return response if response else {} def __request(self, endpoint, params): - url = f"{self.PURPLE_AIR_BASE_URL}{endpoint}" retry_strategy = Retry( total=5, backoff_factor=5, ) - + http = urllib3.PoolManager(retries=retry_strategy) - + try: response = http.request( - "GET", - url, + "GET", + url, fields=params, - headers={"x-api-key": self.PURPLE_AIR_API_KEY},) - + headers={"x-api-key": self.PURPLE_AIR_API_KEY}, + ) + response_data = response.data print(response._request_url) - + if response.status == 200: return json.loads(response_data) else: Utils.handle_api_error(response) return None - + except urllib3.exceptions.HTTPError as e: print(f"HTTPError: {e}") return None diff --git a/src/airflow/airqo_etl_utils/tahmo_api.py b/src/airflow/airqo_etl_utils/tahmo_api.py index 9f9db8a7f8..07a3e771eb 100644 --- a/src/airflow/airqo_etl_utils/tahmo_api.py +++ b/src/airflow/airqo_etl_utils/tahmo_api.py @@ -54,32 +54,29 @@ def get_measurements(self, start_time, end_time, station_codes=None): return measurements.to_dict(orient="records") def __request(self, endpoint, params): - url = f"{self.BASE_URL}{endpoint}" retry_strategy = Retry( total=5, backoff_factor=5, ) - + http = urllib3.PoolManager(retries=retry_strategy) - + try: - headers = urllib3.util.make_headers(basic_auth=f"{self.API_KEY}:{self.API_SECRET}") - response = http.request( - "GET", - url, - fields=params, - headers=headers) - + headers = urllib3.util.make_headers( + basic_auth=f"{self.API_KEY}:{self.API_SECRET}" + ) + response = http.request("GET", url, fields=params, headers=headers) + response_data = response.data print("Tahmo API request: %s" % response._request_url) - + if response.status == 200: return json.loads(response_data) else: Utils.handle_api_error(response) return None - + except urllib3.exceptions.HTTPError as e: print(f"HTTPError: {e}") return None diff --git a/src/airflow/airqo_etl_utils/tests/airqo_utils_tests.py b/src/airflow/airqo_etl_utils/tests/airqo_utils_tests.py index 6e3792f74e..0d1541df67 100644 --- a/src/airflow/airqo_etl_utils/tests/airqo_utils_tests.py +++ b/src/airflow/airqo_etl_utils/tests/airqo_utils_tests.py @@ -10,7 +10,6 @@ from airqo_etl_utils.tests.conftest import FaultDetectionFixtures -# TODO: Convert to pytest class TestAirQoDataUtils(unittest.TestCase): def test_map_site_ids_to_historical_data(self): logs = pd.DataFrame( diff --git a/src/airflow/airqo_etl_utils/tests/big_query_api_tests.py b/src/airflow/airqo_etl_utils/tests/big_query_api_tests.py index e519ad033d..2be61e9415 100644 --- a/src/airflow/airqo_etl_utils/tests/big_query_api_tests.py +++ b/src/airflow/airqo_etl_utils/tests/big_query_api_tests.py @@ -1,4 +1,6 @@ # Import pytest and other modules as needed +from unittest import mock + import pandas as pd import pytest @@ -6,7 +8,7 @@ @pytest.fixture -def mock_bigquery_client(mocker): +def mock_bigquery_client1(mocker): mock_client = mocker.Mock() mock_client.query.return_value.result.return_value.to_dataframe.return_value = ( pd.DataFrame( @@ -21,66 +23,110 @@ def mock_bigquery_client(mocker): return mock_client +@pytest.fixture +def mock_bigquery_client2(): + """A fixture that mocks the bigquery.Client object.""" + + fake_client = mock.Mock() + + sample_df = pd.DataFrame( + { + "device_id": ["A", "A", "B", "B"], + "timestamp": [ + "2023-01-01 00:00:00", + "2023-01-01 01:00:00", + "2023-01-01 00:00:00", + "2023-01-01 01:00:00", + ], + "site_id": [1, 1, 2, 2], + "pm2_5": [10.0, 12.0, 15.0, 18.0], + "latitude": [10.0, 10.0, 20.0, 20.0], + "longitude": [10.0, 10.0, 20.0, 20.0], + "device_category": ["A", "A", "B", "B"], + } + ) + + fake_data_empty_result = pd.DataFrame() + + fake_error = "Fake error" + + def fake_query(query, job_config): + fake_job = mock.Mock() + + if "2023-01-01" in query: + fake_job.result.return_value.to_dataframe.return_value = sample_df + elif "2023-01-02" in query: + fake_job.result.return_value.to_dataframe.return_value = ( + fake_data_empty_result + ) + elif "2023-01-03" in query: + fake_job.result.side_effect = fake_error + else: + raise ValueError("Invalid date") + + return fake_job + + fake_client.query.side_effect = fake_query + + return fake_client + + @pytest.mark.parametrize( - "method", + "start_date_time, expected_df", [ - BigQueryApi.fetch_hourly_forecast_training_data, - BigQueryApi.fetch_daily_forecast_training_data, + ( + "2023-01-01", + pd.DataFrame( + { + "device_id": ["A", "A", "B", "B"], + "timestamp": [ + "2023-01-01 00:00:00", + "2023-01-01 01:00:00", + "2023-01-01 00:00:00", + "2023-01-01 01:00:00", + ], + "site_id": [1, 1, 2, 2], + "pm2_5": [10.0, 12.0, 15.0, 18.0], + "latitude": [10.0, 10.0, 20.0, 20.0], + "longitude": [10.0, 10.0, 20.0, 20.0], + "device_category": ["A", "A", "B", "B"], + } + ), + ), + ("2023-01-02", pd.DataFrame()), ], ) -def test_fetch_data_columns(method, mock_bigquery_client): - api = BigQueryApi() - api.client = mock_bigquery_client - df = method(api) - assert list(df.columns) == ["created_at", "device_number", "pm2_5"] - assert isinstance(df, pd.DataFrame) - assert not df.empty +def test_fetch_data_correct_se(mock_bigquery_client2, start_date_time, expected_df): + """Tests the fetch_data method for scenarios when correct data is retrieved.""" + bq_api = BigQueryApi() + bq_api.client = mock_bigquery_client2 -def test_fetch_hourly_forecast_training_data_exception(mock_bigquery_client): - api = BigQueryApi() - api.client = mock_bigquery_client - api.client.query.side_effect = Exception("Bigquery error") - with pytest.raises(Exception) as e: - df = api.fetch_hourly_forecast_training_data() - assert "Bigquery error" in str(e.value) + actual_df = bq_api.fetch_data(start_date_time) + pd.testing.assert_frame_equal(actual_df, expected_df) -def test_fetch_hourly_forecast_training_data_null(): - api = BigQueryApi() - api.client = mock_bigquery_client() - api.client.query.return_value.result.return_value.to_dataframe.return_value = ( - pd.DataFrame( - { - "created_at": ["2021-01-01 00:00:00", "2021-01-01 01:00:00"], - "device_number": [1, 2], - "pm2_5": [None, None], - } - ) - ) - with pytest.raises(Exception) as e: - df = api.fetch_hourly_forecast_training_data() - assert "pm2_5 column cannot be null" in str(e.value) +@pytest.mark.parametrize("start_date_time", ["2023-13-01", "2023-01-32", "invalid"]) +def test_fetch_data_invalid_date(mock_bigquery_client2, start_date_time): + """Tests the fetch_data method for the scenario where an invalid date string is passed.""" + bq_api = BigQueryApi() + bq_api.client = mock_bigquery_client2 -def test_fetch_daily_forecast_training_data_date_range(mock_bigquery_client): - api = BigQueryApi() - api.client = mock_bigquery_client - api.client.query.return_value.result.return_value.to_dataframe.return_value = ( - pd.DataFrame( - { - "created_at": [ - "2020-01-01 00:00:00", - "2020-06-01 00:00:00", - "2020-12-01 00:00:00", - ], - "device_number": [1, 2, 3], - "pm2_5": [10, 20, 30], - } - ) - ) - df = api.fetch_daily_forecast_training_data() - assert df["created_at"].min() >= pd.Timestamp.now() - pd.DateOffset(months=12) + with pytest.raises(ValueError): + bq_api.fetch_data(start_date_time) + + +@pytest.mark.parametrize("start_date_time", ["2023-01-03"]) +def test_fetch_data_bigquery_error(mock_bigquery_client2, start_date_time): + """Tests the fetch_data method for the scenario where a bigquery.GoogleAPIError is raised.""" + + # Create an instance of BigQueryApi with the mocked client + bq_api = BigQueryApi() + bq_api.client = mock_bigquery_client2 + + with pytest.raises(Exception): + bq_api.fetch_data(start_date_time) def test_fetch_raw_readings_empty(mock_bigquery_client): diff --git a/src/airflow/airqo_etl_utils/tests/conftest.py b/src/airflow/airqo_etl_utils/tests/conftest.py index 8ec08842b2..cfb5b15bb5 100644 --- a/src/airflow/airqo_etl_utils/tests/conftest.py +++ b/src/airflow/airqo_etl_utils/tests/conftest.py @@ -1,7 +1,11 @@ +from datetime import datetime +from unittest.mock import MagicMock + import numpy as np import pandas as pd import pytest -from datetime import datetime + +from airqo_etl_utils.config import configuration def pytest_configure(config): @@ -13,82 +17,89 @@ def pytest_configure(config): class ForecastFixtures: @staticmethod @pytest.fixture(scope="session") - def hourly_data(): - return pd.DataFrame( + def preprocessing_sample_df(): + data = pd.DataFrame( { - "device_number": [1, 1, 1, 2, 2, 2], - "created_at": [ - "2021-08-01 00:00:00", - "2021-08-01 01:00:00", - "2021-08-01 02:00:00", - "2021-08-01 00:00:00", - "2021-08-01 01:00:00", - "2021-08-01 02:00:00", - ], - "pm2_5": [10.0, np.nan, 12.0, 15.0, np.nan, np.nan], + "device_id": ["A", "B"], + "site_id": ["X", "Y"], + "device_category": ["LOWCOST", "BAM"], + "pm2_5": [1, 2], + "timestamp": ["2023-01-01", "2023-02-01"], } ) + return data @staticmethod - @pytest.fixture(scope="session") - def daily_data(): - return pd.DataFrame( - { - "device_number": [1, 1, 1, 2, 2, 2], - "created_at": [ - "2021-08-01 00:00:00", - "2021-08-02 00:00:00", - "2021-08-03 00:00:00", - "2021-08-01 00:00:00", - "2021-08-02 00:00:00", - "2021-08-03 00:00:00", - ], - "pm2_5": [10.0, np.nan, 12.0, 15.0, np.nan, np.nan], - } - ) + @pytest.fixture + def feat_eng_sample_df_daily(): + data = { + "timestamp": pd.date_range(end=pd.Timestamp.now(), periods=365).tolist(), + "device_id": ["device1"] * 365, + "pm2_5": range(1, 366), + } + return pd.DataFrame(data) @staticmethod - @pytest.fixture(scope="session") - def hourly_output(): + @pytest.fixture + def feat_eng_sample_df_hourly(): + data = { + "timestamp": pd.date_range( + end=pd.Timestamp.now(), periods=24 * 14, freq="H" + ).tolist(), + "device_id": ["device1"] * 24 * 14, + "pm2_5": range(1, 24 * 14 + 1), + } + return pd.DataFrame(data) + + @staticmethod + @pytest.fixture + def sample_dataframe_for_location_features(): + data = { + "timestamp": pd.date_range(end=pd.Timestamp.now(), periods=100).tolist(), + "device_id": ["device1"] * 100, + "latitude": np.random.uniform(-90, 90, 100), + "longitude": np.random.uniform(-180, 180, 100), + } + return pd.DataFrame(data) + + @staticmethod + @pytest.fixture + def sample_hourly_forecast_data(): return pd.DataFrame( { - "device_number": [1, 1, 1, 2, 2, 2], - "created_at": [ - "2021-08-01 00:00:00", - "2021-08-01 01:00:00", - "2021-08-01 02:00:00", - "2021-08-01 00:00:00", - "2021-08-01 01:00:00", - "2021-08-01 02:00:00", + "device_id": ["dev1", "dev1", "dev2"], + "pm2_5": [10, 15, 20], + "timestamp": [ + datetime(2023, 1, 1, 0), + datetime(2023, 1, 1, 1), + datetime(2023, 1, 1, 2), ], - "pm2_5": [10.0, 11.0, 12.0, 15.0, 16.0, 17.0], } ) @staticmethod - @pytest.fixture(scope="session") - def daily_output(): + @pytest.fixture + def sample_daily_forecast_data(): return pd.DataFrame( { - "device_number": [1, 1, 1, 2, 2, 2], - "created_at": [ - "2021-08-01 00:00:00", - "2021-08-02 00:00:00", - "2021-08-03 00:00:00", - "2021-08-01 00:00:00", - "2021-08-02 00:00:00", - "2021-08-03 00:00:00", + "device_id": ["dev1", "dev1", "dev2"], + "pm2_5": [10, 15, 20], + "timestamp": [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 3), ], - "pm2_5": [10.0, 11.0, 12.0, 15.0, 16.0, 17.0], } ) - -@pytest.fixture(scope="session") -def mongo_fixture(): - from airqo_etl_utils.mongo_client import MongoClient - - return MongoClient(uri="mongodb://localhost:27017", db_name="test_db") + @staticmethod + @pytest.fixture + def mock_db(): + mock_client = MagicMock() + mock_db = mock_client[configuration.MONGO_DATABASE_NAME] + mock_db.hourly_forecasts = MagicMock() + mock_db.daily_forecasts = MagicMock() + return mock_db class FaultDetectionFixtures: diff --git a/src/airflow/airqo_etl_utils/tests/ml_utils_tests.py b/src/airflow/airqo_etl_utils/tests/ml_utils_tests.py index e28e592cea..d03b02d7bc 100644 --- a/src/airflow/airqo_etl_utils/tests/ml_utils_tests.py +++ b/src/airflow/airqo_etl_utils/tests/ml_utils_tests.py @@ -1,32 +1,159 @@ -# TODO: Add tests for ml_utils.py - import pandas as pd +import pytest -from airqo_etl_utils.ml_utils import ForecastUtils +from airqo_etl_utils.ml_utils import ForecastUtils as FUtils from airqo_etl_utils.tests.conftest import ForecastFixtures -class ForecastTests(ForecastFixtures): - def test_preprocess_hourly_training_data(self, hourly_data, hourly_output): - assert isinstance( - ForecastUtils.preprocess_hourly_training_data(hourly_data), pd.DataFrame - ) - assert ( - ForecastUtils.preprocess_hourly_training_data(hourly_data).shape[0] - == hourly_output.shape[0] +class TestsForecasts(ForecastFixtures): + # Preprocess data tests + def test_preprocess_data_typical_case(self, preprocessing_sample_df): + result = FUtils.preprocess_data(preprocessing_sample_df, "daily") + assert "pm2_5" in result.columns + + def test_preprocess_data_invalid_input(self, preprocessing_sample_df): + df = preprocessing_sample_df.drop(columns=["device_id"]) + with pytest.raises(ValueError): + FUtils.preprocess_data(df, "daily") + + def test_preprocess_data_invalid_timestamp(self, preprocessing_sample_df): + df = preprocessing_sample_df.copy() + df["timestamp"] = "invalid" + with pytest.raises(ValueError): + FUtils.preprocess_data(df, "daily") + + # Feature engineering tests + # get_lag_and_rolling_features tests + + def test_empty_df(self): + with pytest.raises(ValueError, match="Empty dataframe provided"): + FUtils.get_lag_and_roll_features(pd.DataFrame(), "pm2_5", "daily") + + def test_missing_columns(self, feat_eng_sample_df_daily): + del feat_eng_sample_df_daily[ + "device_id" + ] # Test for case where 'device_id' is missing + with pytest.raises(ValueError, match="Required columns missing"): + FUtils.get_lag_and_roll_features(feat_eng_sample_df_daily, "pm2_5", "daily") + + def test_invalid_frequency(self, feat_eng_sample_df_daily): + with pytest.raises(ValueError, match="Invalid frequency"): + FUtils.get_lag_and_roll_features( + feat_eng_sample_df_daily, "pm2_5", "annually" + ) + + def test_hourly_freq(self, sample_hourly_dataframe): + hourly_df = FUtils.get_lag_and_roll_features( + sample_hourly_dataframe, "pm2_5", "hourly" ) - assert ForecastUtils.preprocess_hourly_training_data(hourly_data)[ - "pm2_5" - ].equals(hourly_output["pm2_5"]) + for s in [1, 2, 6, 12]: + assert f"pm2_5_last_{s}_hour" in hourly_df.columns + for s in [3, 6, 12, 24]: + for f in ["mean", "std", "median", "skew"]: + assert f"pm2_5_{f}_{s}_hour" in hourly_df.columns - def test_preprocess_daily_training_data(self, daily_data, daily_output): - assert isinstance( - ForecastUtils.preprocess_daily_training_data(daily_data), pd.DataFrame + def test_daily_freq(self, feat_eng_sample_df_daily): + daily_df = FUtils.get_lag_and_roll_features( + feat_eng_sample_df_daily, "pm2_5", "daily" ) - assert ( - ForecastUtils.preprocess_daily_training_data(daily_data).shape[0] - == daily_output.shape[0] + for s in [1, 2, 3, 7, 14]: + assert f"pm2_5_last_{s}_day" in daily_df.columns + for s in [2, 3, 7, 14]: + for f in ["mean", "std", "max", "min"]: + assert f"pm2_5_{f}_{s}_day" in daily_df.columns + + def test_empty_df_for_time_and_cyclic_features(self): + with pytest.raises(ValueError, match="Empty dataframe provided"): + FUtils.get_time_and_cyclic_features(pd.DataFrame(), "daily") + + def test_missing_columns_for_time_and_cyclic_features( + self, feat_eng_sample_df_daily + ): + with pytest.raises(ValueError, match="Required columns missing"): + FUtils.get_time_and_cyclic_features(feat_eng_sample_df_daily, "daily") + + def test_invalid_frequency_for_time_and_cyclic_features( + self, feat_eng_sample_df_daily + ): + with pytest.raises(ValueError, match="Invalid frequency"): + FUtils.get_time_and_cyclic_features(feat_eng_sample_df_daily, "annually") + + # For 'daily' frequency + def test_daily_freq_for_time_and_cyclic_features(self, feat_eng_sample_df_daily): + daily_df = FUtils.get_time_and_cyclic_features( + feat_eng_sample_df_daily, "daily" ) - assert ForecastUtils.preprocess_daily_training_data(daily_data)["pm2_5"].equals( - daily_output["pm2_5"] + for a in ["year", "month", "day", "dayofweek", "week"]: + for t in ["_sin", "_cos"]: + assert f"{a}{t}" in daily_df.columns + + # For 'hourly' frequency + def test_hourly_freq_for_time_and_cyclic_features(self, feat_eng_sample_df_hourly): + hourly_df = FUtils.get_time_and_cyclic_features( + feat_eng_sample_df_hourly, "hourly" ) + for a in ["year", "month", "day", "dayofweek", "hour", "week"]: + for t in ["_sin", "_cos"]: + assert f"{a}{t}" in hourly_df.columns + + def test_empty_df_for_location_features( + self, sample_dataframe_for_location_features + ): + with pytest.raises(ValueError, match="Empty dataframe provided"): + FUtils.get_location_features(pd.DataFrame()) + + def test_missing_timestamp_for_location_features( + self, + sample_dataframe_for_location_features, + ): + del sample_dataframe_for_location_features["timestamp"] + with pytest.raises(ValueError, match="timestamp column is missing"): + FUtils.get_location_features(sample_dataframe_for_location_features) + + # For missing 'latitude' column + def test_missing_latitude_for_location_features( + self, sample_dataframe_for_location_features + ): + del sample_dataframe_for_location_features[ + "latitude" + ] # Test for missing 'latitude' + with pytest.raises(ValueError, match="latitude column is missing"): + FUtils.get_location_features(sample_dataframe_for_location_features) + + def test_missing_longitude_for_location_features( + self, sample_dataframe_for_location_features + ): + del sample_dataframe_for_location_features[ + "longitude" + ] # Test for missing 'longitude' + with pytest.raises(ValueError, match="longitude column is missing"): + FUtils.get_location_features(sample_dataframe_for_location_features) + + # Test the normal procedure + def test_get_location_features(self, sample_dataframe_for_location_features): + df = FUtils.get_location_features(sample_dataframe_for_location_features) + for cord in ["x_cord", "y_cord", "z_cord"]: + assert cord in df.columns + + @pytest.mark.xfail + @pytest.mark.parametrize( + "frequency,collection_name", + [ + ("hourly", "hourly_forecasts"), + ("daily", "daily_forecasts"), + # ("invalid", None), + ], + ) + def test_save_forecasts_to_mongo_frequency( + self, mock_db, frequency, collection_name, sample_dataframe_db + ): + if frequency == "invalid": + # Expect a ValueError for an invalid frequency + with pytest.raises(ValueError) as e: + FUtils.save_forecasts_to_mongo(sample_dataframe_db, frequency) + assert str(e.value) == f"Invalid frequency argument: {frequency}" + else: + # Expect no exception for a valid frequency + FUtils.save_forecasts_to_mongo(sample_dataframe_db, frequency) + mock_collection = getattr(mock_db, collection_name) + assert mock_collection.update_one.call_count == 0 diff --git a/src/airflow/dags/data_warehouse.py b/src/airflow/dags/data_warehouse.py index af9c20a70b..09316c73e1 100644 --- a/src/airflow/dags/data_warehouse.py +++ b/src/airflow/dags/data_warehouse.py @@ -142,7 +142,6 @@ def load(data: pd.DataFrame): load(clean_consolidated_data) - @dag( "Historical-Consolidated-Data-ETL", schedule=None, @@ -185,7 +184,7 @@ def extract_hourly_weather_data(**kwargs): from airqo_etl_utils.date import DateUtils start_date_time, end_date_time = DateUtils.get_dag_date_time_values( - historical=True, **kwargs + historical=True, **kwargs ) return DataWarehouseUtils.extract_hourly_weather_data( @@ -238,6 +237,7 @@ def load(data: pd.DataFrame): ) load(merged_data) + @dag( "Historical-Cleanup-Consolidated-Data", schedule=None, @@ -280,6 +280,7 @@ def load(data: pd.DataFrame): clean_consolidated_data = remove_duplicates(consolidated_data) load(clean_consolidated_data) + data_warehouse_consolidated_data() data_warehouse_cleanup_consolidated_data() data_warehouse_historical_consolidated_data() diff --git a/src/airflow/dags/ml_prediction_jobs.py b/src/airflow/dags/ml_prediction_jobs.py index 2f48d19d68..f68a8d4c8d 100644 --- a/src/airflow/dags/ml_prediction_jobs.py +++ b/src/airflow/dags/ml_prediction_jobs.py @@ -3,7 +3,7 @@ from airqo_etl_utils.airflow_custom_utils import AirflowUtils from airqo_etl_utils.bigquery_api import BigQueryApi from airqo_etl_utils.config import configuration -from airqo_etl_utils.ml_utils import ForecastUtils +from airqo_etl_utils.ml_utils import ForecastUtils, DecodingUtils @dag( @@ -27,24 +27,32 @@ def get_historical_data_for_hourly_forecasts(): from airqo_etl_utils.date import date_to_str start_date = date_to_str(start_date, str_format="%Y-%m-%d") - return BigQueryApi().fetch_data(start_date, historical=True) + return BigQueryApi().fetch_data(start_date) @task() def preprocess_historical_data_hourly_forecast(data): - return ForecastUtils.preprocess_historical_data(data, "hourly") + return ForecastUtils.preprocess_data(data, "hourly") + + @task + def generate_lag_and_rolling_features_hourly_forecast(data): + return ForecastUtils.get_lag_and_roll_features(data, "pm2_5", "hourly") @task() - def add_lag_features_historical_data_hourly_forecast(data): - return ForecastUtils.get_lag_features(data, "pm2_5", frequency="hourly") + def get_time_and_cyclic_features_hourly_forecast(data): + return ForecastUtils.get_time_and_cyclic_features(data, "hourly") - @task - def add_timestep_features_historical_data_hourly_forecasts(data): - return ForecastUtils.get_time_features(data, frequency="hourly") + @task() + def get_location_features_hourly_forecast(data): + return ForecastUtils.get_location_features(data) + + @task() + def encode_hourly_categorical_features(data): + return DecodingUtils.decode_categorical_features_pred(data, "hourly") @task() def make_hourly_forecasts(data): - return ForecastUtils.generate_hourly_forecasts( - data, project_id, bucket, "hourly_forecast_model.pkl" + return ForecastUtils.generate_forecasts( + data=data, project_name=project_id, bucket_name=bucket, frequency="hourly" ) @task() @@ -67,24 +75,32 @@ def get_historical_data_for_daily_forecasts(): days=int(configuration.DAILY_FORECAST_PREDICTION_JOB_SCOPE) ) start_date = date_to_str(start_date, str_format="%Y-%m-%d") - return BigQueryApi().fetch_data(start_date, historical=True) + return BigQueryApi().fetch_data(start_date) @task() def preprocess_historical_data_daily_forecast(data): - return ForecastUtils.preprocess_historical_data(data, "daily") + return ForecastUtils.preprocess_data(data, "daily") @task() - def add_lag_features_historical_data_daily_forecast(data): - return ForecastUtils.get_lag_features(data, "pm2_5", frequency="daily") + def generate_lag_and_rolling_features_daily_forecast(data): + return ForecastUtils.get_lag_and_roll_features(data, "pm2_5", "daily") @task() - def add_timestep_features_historical_data_daily_forecast(data): - return ForecastUtils.get_time_features(data, "daily") + def get_time_and_cyclic_features_daily_forecast(data): + return ForecastUtils.get_time_and_cyclic_features(data, "daily") + + @task() + def get_location_features_daily_forecast(data): + return ForecastUtils.get_location_features(data) + + @task() + def encode_daily_categorical_features(data): + return DecodingUtils.decode_categorical_features_pred(data, "daily") @task() def make_daily_forecasts(data): - return ForecastUtils.generate_daily_forecasts( - data, project_id, bucket, "daily_forecast_model.pkl" + return ForecastUtils.generate_forecasts( + data=data, project_name=project_id, bucket_name=bucket, frequency="daily" ) @task() @@ -97,27 +113,39 @@ def save_daily_forecasts_to_bigquery(data): def save_daily_forecasts_to_mongo(data): ForecastUtils.save_forecasts_to_mongo(data, "daily") + # Hourly forecast pipeline hourly_data = get_historical_data_for_hourly_forecasts() - preprocessed_hourly_data = preprocess_historical_data_hourly_forecast(hourly_data) - lagged_hourly_data = add_lag_features_historical_data_hourly_forecast( - preprocessed_hourly_data + hourly_preprocessed_data = preprocess_historical_data_hourly_forecast(hourly_data) + hourly_lag_and_roll_features = generate_lag_and_rolling_features_hourly_forecast( + hourly_preprocessed_data ) - time_features_hourly_data = add_timestep_features_historical_data_hourly_forecasts( - lagged_hourly_data + hourly_time_and_cyclic_features = get_time_and_cyclic_features_hourly_forecast( + hourly_lag_and_roll_features ) - hourly_forecasts = make_hourly_forecasts(time_features_hourly_data) + hourly_location_features = get_location_features_hourly_forecast( + hourly_time_and_cyclic_features + ) + hourly_encoded_features = encode_hourly_categorical_features( + hourly_location_features + ) + hourly_forecasts = make_hourly_forecasts(hourly_encoded_features) save_hourly_forecasts_to_bigquery(hourly_forecasts) save_hourly_forecasts_to_mongo(hourly_forecasts) + # Daily forecast pipeline daily_data = get_historical_data_for_daily_forecasts() - preprocessed_daily_data = preprocess_historical_data_daily_forecast(daily_data) - lagged_daily_data = add_lag_features_historical_data_daily_forecast( - preprocessed_daily_data + daily_preprocessed_data = preprocess_historical_data_daily_forecast(daily_data) + daily_lag_and_roll_features = generate_lag_and_rolling_features_daily_forecast( + daily_preprocessed_data + ) + daily_time_and_cyclic_features = get_time_and_cyclic_features_daily_forecast( + daily_lag_and_roll_features ) - time_features_daily_data = add_timestep_features_historical_data_daily_forecast( - lagged_daily_data + daily_location_features = get_location_features_daily_forecast( + daily_time_and_cyclic_features ) - daily_forecasts = make_daily_forecasts(time_features_daily_data) + daily_encoded_features = encode_daily_categorical_features(daily_location_features) + daily_forecasts = make_daily_forecasts(daily_encoded_features) save_daily_forecasts_to_bigquery(daily_forecasts) save_daily_forecasts_to_mongo(daily_forecasts) diff --git a/src/airflow/dags/ml_training_jobs.py b/src/airflow/dags/ml_training_jobs.py index 180f7f7ef2..40c563e86e 100644 --- a/src/airflow/dags/ml_training_jobs.py +++ b/src/airflow/dags/ml_training_jobs.py @@ -1,10 +1,13 @@ +from datetime import datetime + from airflow.decorators import dag, task +from dateutil.relativedelta import relativedelta from airqo_etl_utils.airflow_custom_utils import AirflowUtils from airqo_etl_utils.bigquery_api import BigQueryApi from airqo_etl_utils.config import configuration from airqo_etl_utils.date import date_to_str -from airqo_etl_utils.ml_utils import ForecastUtils +from airqo_etl_utils.ml_utils import ForecastUtils, DecodingUtils @dag( @@ -15,11 +18,9 @@ tags=["airqo", "hourly-forecast", "daily-forecast", "training-job"], ) def train_forecasting_models(): + # Hourly forecast tasks @task() def fetch_training_data_for_hourly_forecast_model(): - from dateutil.relativedelta import relativedelta - from datetime import datetime - current_date = datetime.today() start_date = current_date - relativedelta( months=int(configuration.HOURLY_FORECAST_TRAINING_JOB_SCOPE) @@ -29,16 +30,31 @@ def fetch_training_data_for_hourly_forecast_model(): @task() def preprocess_training_data_for_hourly_forecast_model(data): - return ForecastUtils.preprocess_training_data(data, "hourly") + return ForecastUtils.preprocess_data(data, "hourly") @task() - def feature_engineer_training_data_for_hourly_forecast_model(data): - return ForecastUtils.feature_eng_training_data(data, "pm2_5", "hourly") + def get_hourly_lag_and_rolling_features(data): + return ForecastUtils.get_lag_and_roll_features(data, "pm2_5", "hourly") + + @task() + def get_hourly_time_and_cyclic_features(data): + return ForecastUtils.get_time_and_cyclic_features(data, "hourly") + + @task() + def get_location_features(data): + return ForecastUtils.get_location_features(data) + + @task() + def encode_categorical_features(data): + return DecodingUtils.encode_categorical_training_features(data, "daily") @task() def train_and_save_hourly_forecast_model(train_data): - return ForecastUtils.train_and_save_hourly_forecast_model(train_data) + return ForecastUtils.train_and_save_forecast_models( + train_data, frequency="hourly" + ) + # Daily forecast tasks @task() def fetch_training_data_for_daily_forecast_model(): from dateutil.relativedelta import relativedelta @@ -53,24 +69,42 @@ def fetch_training_data_for_daily_forecast_model(): @task() def preprocess_training_data_for_daily_forecast_model(data): - return ForecastUtils.preprocess_training_data(data, "daily") + return ForecastUtils.preprocess_data(data, "daily") + + @task() + def get_daily_lag_and_rolling_features(data): + return ForecastUtils.get_lag_and_roll_features(data, "pm2_5", "daily") + + @task() + def get_daily_time_and_cylic_features(data): + return ForecastUtils.get_time_and_cyclic_features(data, "daily") + + @task() + def get_location_features(data): + return ForecastUtils.get_location_features(data) @task() - def feature_engineer_data_for_daily_forecast_model(data): - return ForecastUtils.feature_eng_training_data(data, "pm2_5", "daily") + def encode_categorical_features(data): + return DecodingUtils.encode_categorical_training_features(data, "daily") @task() def train_and_save_daily_model(train_data): - return ForecastUtils.train_and_save_daily_forecast_model(train_data) + return ForecastUtils.train_and_save_forecast_models(train_data, "daily") hourly_data = fetch_training_data_for_hourly_forecast_model() hourly_data = preprocess_training_data_for_hourly_forecast_model(hourly_data) - hourly_data = feature_engineer_training_data_for_hourly_forecast_model(hourly_data) + hourly_data = get_hourly_lag_and_rolling_features(hourly_data) + hourly_data = get_hourly_time_and_cyclic_features(hourly_data) + hourly_data = get_location_features(hourly_data) + hourly_data = encode_categorical_features(hourly_data) train_and_save_hourly_forecast_model(hourly_data) daily_data = fetch_training_data_for_daily_forecast_model() daily_data = preprocess_training_data_for_daily_forecast_model(daily_data) - daily_data = feature_engineer_data_for_daily_forecast_model(daily_data) + daily_data = get_daily_lag_and_rolling_features(daily_data) + daily_data = get_daily_time_and_cylic_features(daily_data) + daily_data = get_location_features(daily_data) + daily_data = encode_categorical_features(daily_data) train_and_save_daily_model(daily_data) diff --git a/src/airflow/dev-requirements.txt b/src/airflow/dev-requirements.txt index 81c23b0562..d37188bf8d 100644 --- a/src/airflow/dev-requirements.txt +++ b/src/airflow/dev-requirements.txt @@ -3,6 +3,7 @@ apache-airflow-providers-slack confluent-avro google-cloud-bigquery google-cloud-storage +optuna pyarrow sentry-sdk pandas @@ -17,5 +18,4 @@ db_dtypes mlflow lightgbm gcsfs -pymongo pytest \ No newline at end of file diff --git a/src/airflow/requirements.txt b/src/airflow/requirements.txt index c79865c3cf..f5df386d68 100644 --- a/src/airflow/requirements.txt +++ b/src/airflow/requirements.txt @@ -7,7 +7,7 @@ kafka-python simplejson~=3.19.1 sentry-sdk geopy -mlflow~=2.5.0 +mlflow lightgbm~=4.0.0 setuptools~=68.0.0 urllib3~=1.26.16 @@ -16,6 +16,6 @@ joblib~=1.3.1 scikit-learn~=1.3.0 gcsfs pymongo~=4.4.1 - +optuna pytest~=7.4.0 scipy~=1.11.1 \ No newline at end of file diff --git a/src/data-mgt/node/utils/log.js b/src/data-mgt/node/utils/log.js index 8814967bcd..9a78b39f6e 100644 --- a/src/data-mgt/node/utils/log.js +++ b/src/data-mgt/node/utils/log.js @@ -29,7 +29,7 @@ const logError = (error) => { // console.error(e); if (process.env.NODE_ENV !== "production") { console.log("an unhandled promise rejection" + ": "); - console.error(e); + console.error(error); } return "log deactivated in prod and stage"; }; diff --git a/src/data-mgt/node/utils/test/ut_date.js b/src/data-mgt/node/utils/test/ut_date.js index e75c7e3996..42b457b961 100644 --- a/src/data-mgt/node/utils/test/ut_date.js +++ b/src/data-mgt/node/utils/test/ut_date.js @@ -1 +1,96 @@ -require("module-alias/register"); +const { expect } = require('chai'); +const DateUtil = require('../date'); + +describe('Date Util', () => { + describe('generateDateFormat', () => { + it('should return a formatted date string with hours', async () => { + const ISODate = '2023-09-21T12:34:56Z'; + const result = await DateUtil.generateDateFormat(ISODate); + expect(result).to.equal('2023-09-21-12'); + }); + }); + + describe('isTimeEmpty', () => { + it('should return false for a valid time', () => { + const dateTime = '2023-09-21T12:34:56Z'; + const result = DateUtil.isTimeEmpty(dateTime); + expect(result).to.be.false; + }); + + it('should return true for an empty time', () => { + const dateTime = '2023-09-21'; + const result = DateUtil.isTimeEmpty(dateTime); + expect(result).to.be.true; + }); + }); + + describe('generateDateFormatWithoutHrs', () => { + it('should return a formatted date string without hours', () => { + const ISODate = '2023-09-21T12:34:56Z'; + const result = DateUtil.generateDateFormatWithoutHrs(ISODate); + expect(result).to.equal('2023-09-21'); + }); + }); + + describe('isDate', () => { + it('should return true for date strings with "-" or "/"', () => { + expect(DateUtil.isDate('2023-09-21')).to.be.true; + expect(DateUtil.isDate('09/21/2023')).to.be.true; + }); + + it('should return false for non-date strings', () => { + expect(DateUtil.isDate('2023')).to.be.false; + expect(DateUtil.isDate('Hello, World!')).to.be.false; + }); + }); + + describe('addMonthsToProvideDateTime', () => { + it('should add months to a provided date/time', () => { + const dateTime = '2023-09-21T12:34:56Z'; + const number = 3; + const result = DateUtil.addMonthsToProvideDateTime(dateTime, number); + expect(result).to.be.a('Date'); + }); + + it('should handle empty time and add months to date', () => { + const date = '2023-09-21'; + const number = 3; + const result = DateUtil.addMonthsToProvideDateTime(date, number); + expect(result).to.be.a('Date'); + }); + }); + + describe('monthsInfront', () => { + it('should return a date in the future with the given number of months', () => { + const number = 3; + const result = DateUtil.monthsInfront(number); + expect(result).to.be.a('Date'); + }); + }); + + describe('addDays', () => { + it('should add days to the current date', () => { + const number = 7; + const result = DateUtil.addDays(number); + expect(result).to.be.a('Date'); + }); + }); + + describe('getDifferenceInMonths', () => { + it('should calculate the difference in months between two dates', () => { + const date1 = '2023-09-21'; + const date2 = '2024-01-15'; + const result = DateUtil.getDifferenceInMonths(date1, date2); + expect(result).to.equal(4); + }); + }); + + describe('threeMonthsFromNow', () => { + it('should return a date three months from the provided date', () => { + const date = '2023-09-21'; + const result = DateUtil.threeMonthsFromNow(date); + expect(result).to.be.a('Date'); + }); + }); + +}); diff --git a/src/data-mgt/node/utils/test/ut_errors.js b/src/data-mgt/node/utils/test/ut_errors.js index e75c7e3996..6c13a790d5 100644 --- a/src/data-mgt/node/utils/test/ut_errors.js +++ b/src/data-mgt/node/utils/test/ut_errors.js @@ -1 +1,83 @@ -require("module-alias/register"); +const chai = require("chai"); +const { expect } = chai; +const sinon = require("sinon"); +const HTTPStatus = require("http-status"); +const errors = require("../errors"); + +describe("Errors Utility Functions", () => { + describe("convertErrorArrayToObject", () => { + it("should convert an array of errors to an object", () => { + const errorArray = [ + { param: "field1", msg: "Field 1 is required" }, + { param: "field2", msg: "Field 2 must be a number" }, + ]; + + const result = errors.convertErrorArrayToObject(errorArray); + + expect(result).to.deep.equal({ + field1: "Field 1 is required", + field2: "Field 2 must be a number", + }); + }); + }); + + describe("errorResponse", () => { + it("should send an error response with default status code", () => { + const res = { + status: sinon.stub().returnsThis(), + json: sinon.spy(), + }; + + errors.errorResponse({ res, message: "An error occurred" }); + + expect(res.status.calledWith(HTTPStatus.INTERNAL_SERVER_ERROR)).to.be.true; + expect(res.json.calledWithMatch({ + success: false, + message: "An error occurred", + error: { + statusCode: HTTPStatus.INTERNAL_SERVER_ERROR, + message: "An error occurred", + error: {}, + }, + })).to.be.true; + }); + + it("should send an error response with a custom status code", () => { + const res = { + status: sinon.stub().returnsThis(), + json: sinon.spy(), + }; + + errors.errorResponse({ res, message: "Bad request", statusCode: HTTPStatus.BAD_REQUEST }); + + expect(res.status.calledWith(HTTPStatus.BAD_REQUEST)).to.be.true; + expect(res.json.calledWithMatch({ + success: false, + message: "Bad request", + error: { + statusCode: HTTPStatus.BAD_REQUEST, + message: "Bad request", + error: {}, + }, + })).to.be.true; + }); + }); + + describe("badRequest", () => { + it("should send a bad request response", () => { + const res = { + status: sinon.stub().returnsThis(), + json: sinon.spy(), + }; + + errors.badRequest(res, "Bad request", { field: "Invalid input" }); + + expect(res.status.calledWith(HTTPStatus.BAD_REQUEST)).to.be.true; + expect(res.json.calledWithMatch({ + success: false, + message: "Bad request", + errors: { field: "Invalid input" }, + })).to.be.true; + }); + }); +}); diff --git a/src/data-mgt/node/utils/test/ut_log.js b/src/data-mgt/node/utils/test/ut_log.js index 74b1a1f804..e21e7d11a8 100644 --- a/src/data-mgt/node/utils/test/ut_log.js +++ b/src/data-mgt/node/utils/test/ut_log.js @@ -1,2 +1,66 @@ -require("module-alias/register"); -s; +const chai = require("chai"); +const { expect } = chai; +const sinon = require("sinon"); +const { logText, logElement, logObject, logError } = require("../log"); + +describe("Logging Utility Functions", () => { + describe("logText", () => { + it("should log a message when not in production", () => { + const consoleLogStub = sinon.stub(console, "log"); + process.env.NODE_ENV = "development"; + const result = logText("Test log message"); + expect(consoleLogStub.calledOnce).to.be.true; + expect(consoleLogStub.calledWith("Test log message")).to.be.true; + consoleLogStub.restore(); + process.env.NODE_ENV = "test"; + }); + + it("should return a log deactivation message in production", () => { + const consoleLogStub = sinon.stub(console, "log"); + process.env.NODE_ENV = "production"; + const result = logText("Test log message"); + expect(consoleLogStub.notCalled).to.be.true; + expect(result).to.equal("log deactivated in prod and stage"); + consoleLogStub.restore(); + process.env.NODE_ENV = "test"; + }); + }); + + describe("logElement", () => { + it("should log an element when not in production", () => { + const consoleLogStub = sinon.stub(console, "log"); + process.env.NODE_ENV = "development"; + const result = logElement("Test", "Element"); + expect(consoleLogStub.calledOnce).to.be.true; + expect(consoleLogStub.calledWith("Test: Element")).to.be.true; + consoleLogStub.restore(); + process.env.NODE_ENV = "test"; + }); + }); + + describe("logObject", () => { + it("should log an object when not in production", () => { + const consoleLogStub = sinon.stub(console, "log"); + process.env.NODE_ENV = "development"; + const result = logObject("Test", { key: "value" }); + expect(consoleLogStub.calledOnce).to.be.true; + expect(consoleLogStub.calledWith("Test: ")).to.be.true; + consoleLogStub.restore(); + process.env.NODE_ENV = "test"; + }); + }); + + describe("logError", () => { + it("should log an error when not in production", () => { + const consoleErrorStub = sinon.stub(console, "error"); + process.env.NODE_ENV = "development"; + const error = new Error("Test error message"); + const result = logError(error); + expect(consoleErrorStub.calledOnce).to.be.true; + expect(consoleErrorStub.calledWith(error)).to.be.true; + consoleErrorStub.restore(); + process.env.NODE_ENV = "test"; + }); + + }); +}); diff --git a/src/device-registry/routes/v2/kya.js b/src/device-registry/routes/v2/kya.js index b1158a28eb..5cfff363cc 100644 --- a/src/device-registry/routes/v2/kya.js +++ b/src/device-registry/routes/v2/kya.js @@ -79,6 +79,16 @@ router.get( }), ], ]), + oneOf([ + [ + query("language") + .optional() + .notEmpty() + .withMessage("the language cannot be empty when provided") + .bail() + .trim() + ], + ]), knowYourAirController.listLessons ); @@ -97,6 +107,16 @@ router.get( .withMessage("the tenant value is not among the expected ones"), ], ]), + oneOf([ + [ + query("language") + .optional() + .notEmpty() + .withMessage("the language cannot be empty when provided") + .bail() + .trim() + ], + ]), oneOf([ [ @@ -1183,6 +1203,16 @@ router.get( }), ], ]), + oneOf([ + [ + query("language") + .optional() + .notEmpty() + .withMessage("the language cannot be empty when provided") + .bail() + .trim() + ], + ]), knowYourAirController.listQuizzes ); @@ -1201,6 +1231,16 @@ router.get( .withMessage("the tenant value is not among the expected ones"), ], ]), + oneOf([ + [ + query("language") + .optional() + .notEmpty() + .withMessage("the language cannot be empty when provided") + .bail() + .trim() + ], + ]), oneOf([ [ diff --git a/src/device-registry/utils/create-event.js b/src/device-registry/utils/create-event.js index 48d8144797..26b76022e1 100644 --- a/src/device-registry/utils/create-event.js +++ b/src/device-registry/utils/create-event.js @@ -595,7 +595,7 @@ const createEvent = { if (language !== undefined && constants.ENVIRONMENT === "STAGING ENVIRONMENT") { let data = responseFromListEvents.data[0].data; for (const event of data) { - let translatedHealthTips = await translateUtil.translate(event.health_tips, language); + let translatedHealthTips = await translateUtil.translateTips(event.health_tips, language); if (translatedHealthTips.success === true) { event.health_tips = translatedHealthTips.data; } diff --git a/src/device-registry/utils/create-health-tips.js b/src/device-registry/utils/create-health-tips.js index 87a198b98b..fdac17a412 100644 --- a/src/device-registry/utils/create-health-tips.js +++ b/src/device-registry/utils/create-health-tips.js @@ -34,7 +34,7 @@ const createHealthTips = { skip, }); if (language !== undefined) { - translatedHealthTips = await translateUtil.translate(responseFromListHealthTips.data, language); + translatedHealthTips = await translateUtil.translateTips(responseFromListHealthTips.data, language); responseFromListHealthTips = translatedHealthTips; } diff --git a/src/device-registry/utils/create-know-your-air.js b/src/device-registry/utils/create-know-your-air.js index 10a02af088..5d8a821862 100644 --- a/src/device-registry/utils/create-know-your-air.js +++ b/src/device-registry/utils/create-know-your-air.js @@ -13,6 +13,7 @@ const { logObject, logElement, logText } = require("./log"); const generateFilter = require("./generate-filter"); const log4js = require("log4js"); const logger = log4js.getLogger(`${constants.ENVIRONMENT} -- create-kya-util`); +const translateUtil = require("./translate"); const mongoose = require("mongoose").set("debug", true); const ObjectId = mongoose.Types.ObjectId; @@ -44,6 +45,7 @@ const createKnowYourAir = { const { user_id } = request.params; const limit = parseInt(request.query.limit, 0); const skip = parseInt(request.query.skip, 0); + const language = request.query.language; const filter = generateFilter.kyalessons(request); if (filter.success && filter.success === false) { return filter; @@ -57,6 +59,12 @@ const createKnowYourAir = { user_id: user_id, } ); + if (language !== undefined) { + const translatedLessons = await translateUtil.translateLessons(responseFromListLessons.data, language); + if (translatedLessons.success === true) { + return translatedLessons; + } + } logObject("responseFromListLessons", responseFromListLessons); return responseFromListLessons; } catch (error) { @@ -938,6 +946,7 @@ const createKnowYourAir = { const { user_id } = request.params; const limit = parseInt(request.query.limit, 0); const skip = parseInt(request.query.skip, 0); + const language = request.query.language; const filter = generateFilter.kyaquizzes(request); if (filter.success && filter.success === false) { return filter; @@ -949,6 +958,12 @@ const createKnowYourAir = { skip, user_id: user_id, }); + if (language !== undefined) { + const translatedQuizzes = await translateUtil.translateQuizzes(responseFromListQuizzes.data, language); + if (translatedQuizzes.success === true) { + return translatedQuizzes; + } + } logObject("responseFromListQuizzes", responseFromListQuizzes); return responseFromListQuizzes; } catch (error) { diff --git a/src/device-registry/utils/test/ut_create-know-your-air.js b/src/device-registry/utils/test/ut_create-know-your-air.js index fcb5499776..2c934b1d44 100644 --- a/src/device-registry/utils/test/ut_create-know-your-air.js +++ b/src/device-registry/utils/test/ut_create-know-your-air.js @@ -42,6 +42,28 @@ describe("createKnowYourAir Utility Functions", () => { listStub.restore(); }); + it("should return a list of translated lessons successfully", async () => { + const request = { + query: { tenant: "your-tenant" }, + params: { user_id: "user-id" }, + query: { limit: 10, skip: 0, language: "fr" }, + }; + + // Stub KnowYourAirLessonModel.list + const listStub = sinon + .stub(KnowYourAirLessonModel("your-tenant"), "list") + .resolves({ success: true, data: [], status: httpStatus.OK }); + + const result = await createKnowYourAir.listLesson(request); + + expect(result.success).to.be.true; + expect(result.data).to.deep.equal([]); + expect(result.status).to.equal(httpStatus.OK); + + // Restore the stub + listStub.restore(); + }); + it("should handle filter failure", async () => { const request = { query: { tenant: "your-tenant" }, @@ -1639,6 +1661,27 @@ describe("createKnowYourAir Utility Functions", () => { KnowYourAirQuizModel("your-tenant").list.restore(); }); + it("should list translated quizzes", async () => { + const request = { + query: { tenant: "your-tenant" }, + params: { user_id: "user-id" }, + query: { limit: 10, skip: 0, language: "fr" }, + }; + + // Stub KnowYourAirQuizModel(tenant).list to return quiz data + const quizListStub = sinon + .stub(KnowYourAirQuizModel("your-tenant"), "list") + .resolves({ success: true /* other response properties */ }); + + const result = await createKnowYourAir.listQuiz(request); + + expect(result.success).to.be.true; + // Your other assertions here + + // Restore the stub + KnowYourAirQuizModel("your-tenant").list.restore(); + }); + it("should handle filter failure", async () => { const request = { query: { tenant: "your-tenant" }, diff --git a/src/device-registry/utils/test/ut_translate.js b/src/device-registry/utils/test/ut_translate.js index 665dff3b4f..3c990a4125 100644 --- a/src/device-registry/utils/test/ut_translate.js +++ b/src/device-registry/utils/test/ut_translate.js @@ -9,45 +9,195 @@ const httpStatus = require("http-status"); const translateUtil = require("@utils/translate"); describe('translateUtil', () => { - it('should translate health tips to the target language', async () => { - const healthTips = [ - { - title: 'Hello', - description: 'World', - }, - { - title: 'Good', - description: 'Morning', - }, - ]; - const targetLanguage = 'fr'; - - const expectedTranslations = [ - { - title: 'Bonjour', - description: 'Monde', - }, - { - title: 'Bien', - description: 'Matin', - }, - ]; - - const result = await translateUtil.translate(healthTips, targetLanguage); - - - expect(result).to.have.property('success', true); - for (let i = 0; i < result.data.length; i++) { - expect(result.data[i].title).to.equal(expectedTranslations[i].title); - expect(result.data[i].description).to.equal(expectedTranslations[i].description); - } - }).timeout(10000); - - it('should handle translation errors gracefully', async () => { - - const healthTips = null; - const targetLanguage = 'fr'; - const result = await translateUtil.translate(healthTips, targetLanguage); + describe("translateTips", () => { + it('should translate health tips to the target language', async () => { + const healthTips = [ + { + title: 'Hello', + description: 'World', + }, + { + title: 'Good', + description: 'Morning', + }, + ]; + const targetLanguage = 'fr'; + + const expectedTranslations = [ + { + title: 'Bonjour', + description: 'Monde', + }, + { + title: 'Bien', + description: 'Matin', + }, + ]; + + const result = await translateUtil.translateTips(healthTips, targetLanguage); + + + expect(result).to.have.property('success', true); + for (let i = 0; i < result.data.length; i++) { + expect(result.data[i].title).to.equal(expectedTranslations[i].title); + expect(result.data[i].description).to.equal(expectedTranslations[i].description); + } + }).timeout(10000); + + it('should handle translation errors gracefully', async () => { + + const healthTips = null; + const targetLanguage = 'fr'; + const result = await translateUtil.translateTips(healthTips, targetLanguage); + + expect(result).to.have.property('success', false); + expect(result).to.have.property('message', 'Internal Server Error'); + expect(result).to.have.property('status', 500); + expect(result).to.have.property('errors'); + expect(result.errors).to.have.property('message'); + }); + }) + + describe("translateLessons", () => { + it('should translate Kya lessons to the target language', async () => { + const kyaLessons = [ + { + "_id": "testId", + "title": "Actions you can take to reduce air pollution", + "completion_message": "You just finished your first Know Your Air Lesson", + "image": "https://testimage", + "tasks": [ + { + "_id": "testId", + "title": "Use public transport", + "content": "Vehicle exhaust is a major source of air pollution. Less cars on the road results in less emissions.", + "image": "https://testimage", + "task_position": 2 + }, + ] + } + ]; + const targetLanguage = 'fr'; + + const expectedTranslations = [ + { + "_id": "testId", + "title": "Mesures que vous pouvez prendre pour réduire la pollution de l’air", + "completion_message": "Vous venez de terminer votre première leçon Know Your Air.", + "image": "https://testimage", + "tasks": [ + { + "_id": "testId", + "title": "Utilisez les transports en commun", + "content": "Les gaz d’échappement des véhicules constituent une source majeure de pollution atmosphérique. Moins de voitures sur la route entraîne moins d’émissions.", + "image": "https://testimage", + "task_position": 2 + }, + ] + } + ]; + + const result = await translateUtil.translateLessons(kyaLessons, targetLanguage); + + + expect(result).to.have.property('success', true); + for (let i = 0; i < result.data.length; i++) { + expect(result.data[i].title).to.equal(expectedTranslations[i].title); + expect(result.data[i].completion_message).to.equal(expectedTranslations[i].completion_message); + expect(result.data[i].tasks).to.deep.equal(expectedTranslations[i].tasks); + } + }).timeout(10000); + + it('should handle translation errors gracefully', async () => { + + const lessons = null; + const targetLanguage = 'fr'; + const result = await translateUtil.translateLessons(lessons, targetLanguage); + + expect(result).to.have.property('success', false); + expect(result).to.have.property('message', 'Internal Server Error'); + expect(result).to.have.property('status', 500); + expect(result).to.have.property('errors'); + expect(result.errors).to.have.property('message'); + }); + }); + describe("translateQuizzes", () => { + it('should translate Kya Quizzes to the target language', async () => { + const kyaQuizzes = [ + { + "_id": "testId", + "title": "Get personalised air quality recommendations", + "description": "Tell us more about Air Quality conditions in your environment & get personalised tips.", + "completion_message": "Way to go🎊. You have unlocked personalised air quality recommendations to empower you on your clean air journey.", + "image": "https//testImage", + "questions": [ + { + "title": "Where is your home environment situated?", + "context": "Home environment", + "question_position": 1, + "answers": [ + { + "content": [ + "Cooking with firewood can emit significant amounts of air pollutants.", + "Cook in a well-ventilated kitchen with good airflow or set up an outdoor kitchen if possible.", + "Use an efficient stove designed to burn firewood more cleanly and with less smoke.", + "Consider switching to improved cookstoves that reduce emissions and increase fuel efficiency." + ], + "title": "Firewood", + } + ] + }, + ], + }, + ]; + + const targetLanguage = 'fr'; + + const expectedTranslations = [ + { + "_id": "testId", + "title": "Obtenez des recommandations personnalisées sur la qualité de l'air", + "description": "Dites-nous en plus sur les conditions de qualité de l'air dans votre environnement et obtenez des conseils personnalisés.", + "completion_message": "Bravo🎊. Vous avez débloqué des recommandations personnalisées sur la qualité de l'air pour vous aider dans votre voyage vers un air pur.", + "image": "https//testImage", + "questions": [ + { + "title": "Où se situe votre environnement domestique ?", + "context": "Environnement de la maison", + "question_position": 1, + "answers": [ + { + "content": [ + "Cuisiner avec du bois de chauffage peut émettre des quantités importantes de polluants atmosphériques.", + "Cuisinez dans une cuisine bien ventilée avec une bonne circulation d’air ou installez une cuisine extérieure si possible.", + "Utilisez un poêle efficace conçu pour brûler du bois de chauffage plus proprement et avec moins de fumée.", + "Envisagez de passer à des cuisinières améliorées qui réduisent les émissions et augmentent le rendement énergétique." + ], + "title": "Bois de chauffage", + } + ] + }, + ], + }, + ]; + + const result = await translateUtil.translateQuizzes(kyaQuizzes, targetLanguage); + + + expect(result).to.have.property('success', true); + for (let i = 0; i < result.data.length; i++) { + expect(result.data[i].title).to.equal(expectedTranslations[i].title); + expect(result.data[i].completion_message).to.equal(expectedTranslations[i].completion_message); + expect(result.data[i].questions).to.deep.equal(expectedTranslations[i].questions); + expect(result.data[i].questions.answers).to.deep.equal(expectedTranslations[i].questions.answers); + } + }).timeout(10000); + + it('should handle translation errors gracefully', async () => { + + const kyaQuizzes = null; + const targetLanguage = 'fr'; + const result = await translateUtil.translateQuizzes(kyaQuizzes, targetLanguage); expect(result).to.have.property('success', false); expect(result).to.have.property('message', 'Internal Server Error'); @@ -55,4 +205,5 @@ describe('translateUtil', () => { expect(result).to.have.property('errors'); expect(result.errors).to.have.property('message'); }); +}); }); \ No newline at end of file diff --git a/src/device-registry/utils/translate.js b/src/device-registry/utils/translate.js index 9f885e78c5..eba61b275e 100644 --- a/src/device-registry/utils/translate.js +++ b/src/device-registry/utils/translate.js @@ -9,7 +9,7 @@ const { Translate } = require('@google-cloud/translate').v2; const translate = new Translate(); const translateUtil = { - translate: async (healthTips, targetLanguage) => { + translateTips: async (healthTips, targetLanguage) => { try { const translatedHealthTips = []; @@ -39,6 +39,101 @@ const translateUtil = { }; } }, + + translateLessons: async (lessons, targetLanguage) => { + try { + const translatedLessons = []; + + for (const lesson of lessons) { + const translatedLesson = { ...lesson }; + translatedLesson.title = await translateText(lesson.title, targetLanguage); + translatedLesson.completion_message = await translateText(lesson.completion_message, targetLanguage); + const translatedTasks = []; + for (const task of lesson.tasks) { + const translatedTask = { ...task }; + translatedTask.title = await translateText(task.title, targetLanguage); + translatedTask.content = await translateText(task.content, targetLanguage); + translatedTasks.push(translatedTask); + } + translatedLesson.tasks = translatedTasks + translatedLessons.push(translatedLesson); + } + + return { + success: true, + message: "Translated KYA returned Successfully", + data: translatedLessons, + status: httpStatus.OK, + }; + } catch (error) { + logger.error(`internal server error -- ${error.message}`); + console.log(`internal server error -- ${error.message}`); + + return { + success: false, + message: "Internal Server Error", + status: httpStatus.INTERNAL_SERVER_ERROR, + errors: { + message: error.message, + }, + }; + } + }, + + translateQuizzes: async (quizzes, targetLanguage) => { + try { + const translatedQuizzes = []; + + for (const quiz of quizzes) { + const translatedQuiz = { ...quiz }; + translatedQuiz.title = await translateText(quiz.title, targetLanguage); + translatedQuiz.description = await translateText(quiz.description, targetLanguage); + translatedQuiz.completion_message = await translateText(quiz.completion_message, targetLanguage); + const translatedQuestions = []; + for (const question of quiz.questions) { + const translatedQuestion = { ...question }; + translatedQuestion.title = await translateText(question.title, targetLanguage); + translatedQuestion.context = await translateText(question.context, targetLanguage); + const translatedAnswers = []; + for (const answer of question.answers) { + const translatedAnswer = { ...answer }; + translatedAnswer.title = await translateText(answer.title, targetLanguage); + const translatedContent = []; + for (const contentItem of answer.content) { + const translatedItem = await translateText(contentItem, targetLanguage); + translatedContent.push(translatedItem); + } + translatedAnswer.content = translatedContent; + + translatedAnswers.push(translatedAnswer); + } + translatedQuestion.answers = translatedAnswers; + translatedQuestions.push(translatedQuestion); + } + translatedQuiz.questions = translatedQuestions + translatedQuizzes.push(translatedQuiz); + } + + return { + success: true, + message: "Translated KYA returned Successfully", + data: translatedQuizzes, + status: httpStatus.OK, + }; + } catch (error) { + logger.error(`internal server error -- ${error.message}`); + return { + success: false, + message: "Internal Server Error", + status: httpStatus.INTERNAL_SERVER_ERROR, + errors: { + message: error.message, + }, + }; + } + }, + + }; async function translateText(text, target) { diff --git a/src/incentives/utils/test/ut_create-transaction.js b/src/incentives/utils/test/ut_create-transaction.js index ee9c90f403..25bd1e4a6b 100644 --- a/src/incentives/utils/test/ut_create-transaction.js +++ b/src/incentives/utils/test/ut_create-transaction.js @@ -4,7 +4,7 @@ const chai = require("chai"); const { expect } = chai; const httpStatus = require("http-status"); -const TransactionModel = require("@models/Transaction"); +const TransactionModel = require("@models/transaction"); const createTransaction = require("@utils/create-transaction"); const axios = require("axios"); @@ -509,7 +509,7 @@ describe("createTransaction", () => { }); // Execute the function - const response = await getTransactionDetails(request); + const response = await createTransaction.getTransactionDetails(request); // Assert the response expect(response).to.deep.equal(expectedResponse); @@ -537,7 +537,7 @@ describe("createTransaction", () => { .rejects(new Error("Network Error")); // Execute the function - const response = await getTransactionDetails(request); + const response = await createTransaction.getTransactionDetails(request); // Assert the response expect(response).to.deep.equal({ @@ -601,7 +601,7 @@ describe("createTransaction", () => { }); // Execute the function - const response = await loadDataBundle(request); + const response = await createTransaction.loadDataBundle(request); // Assert the response expect(response).to.deep.equal(expectedResponse); @@ -660,7 +660,7 @@ describe("createTransaction", () => { .rejects(new Error("Network Error")); // Execute the function - const response = await loadDataBundle(request); + const response = await createTransaction.loadDataBundle(request); // Assert the response expect(response).to.deep.equal({ @@ -721,7 +721,7 @@ describe("createTransaction", () => { }; // Execute the function - const response = await checkRemainingDataBundleBalance(request); + const response = await createTransaction.checkRemainingDataBundleBalance(request); // Assert the response expect(response).to.deep.equal(expectedResponse); @@ -745,7 +745,7 @@ describe("createTransaction", () => { const throwStub = chai.spy.on(errorStub, "throw"); // Execute the function - const response = await checkRemainingDataBundleBalance(request); + const response = await createTransaction.checkRemainingDataBundleBalance(request); // Assert the response expect(response).to.deep.equal({ diff --git a/src/incentives/utils/test/ut_generate-filter.js b/src/incentives/utils/test/ut_generate-filter.js index 3e0da9e6d5..871f37bb43 100644 --- a/src/incentives/utils/test/ut_generate-filter.js +++ b/src/incentives/utils/test/ut_generate-filter.js @@ -5,6 +5,7 @@ const { expect } = chai; const generateFilter = require("@utils/generate-filter"); const mongoose = require("mongoose"); const ObjectId = mongoose.Types.ObjectId; +const httpStatus = require("http-status"); describe("generateFilter", () => { describe("hosts", () => { diff --git a/src/predict/api/helpers.py b/src/predict/api/helpers.py index e8bce1ef23..342fff780b 100644 --- a/src/predict/api/helpers.py +++ b/src/predict/api/helpers.py @@ -85,8 +85,8 @@ def geo_coordinates_cache_key(): def get_health_tips() -> list[dict]: try: response = requests.get( - f"{Config.AIRQO_BASE_URL}/api/v2/devices/tips?token={Config.AIRQO_API_AUTH_TOKEN}", - timeout=3, + f"{Config.AIRQO_BASE_URL}api/v2/devices/tips?token={Config.AIRQO_API_AUTH_TOKEN}", + timeout=10, ) if response.status_code == 200: result = response.json() @@ -200,6 +200,7 @@ def get_predictions_by_geo_coordinates_v2(latitude: float, longitude: float) -> @cache.memoize(timeout=Config.CACHE_TIMEOUT) def get_forecasts( db_name, + device_id=None, site_id=None, site_name=None, parish=None, @@ -210,6 +211,7 @@ def get_forecasts( ): query = {} params = { + "device_id": device_id, "site_id": site_id, "site_name": site_name, "parish": parish, @@ -227,12 +229,18 @@ def get_forecasts( results = [] if site_forecasts: - for time, pm2_5 in zip( - site_forecasts[0]["time"], + for time, pm2_5, margin_of_error, adjusted_forecast in zip( + site_forecasts[0]["timestamp"], site_forecasts[0]["pm2_5"], + site_forecasts[0]["margin_of_error"], + site_forecasts[0]["adjusted_forecast"], ): result = { - key: value for key, value in zip(["time", "pm2_5"], [time, pm2_5]) + key: value + for key, value in zip( + ["time", "pm2_5", "margin_of_error", "adjusted_forecast"], + [time, pm2_5, margin_of_error, adjusted_forecast], + ) } results.append(result) diff --git a/src/predict/api/prediction.py b/src/predict/api/prediction.py index c28673180f..588d477b2e 100644 --- a/src/predict/api/prediction.py +++ b/src/predict/api/prediction.py @@ -78,10 +78,6 @@ def get_next_24hr_forecasts(): """ Get forecasts for the next 24 hours from specified start time. """ - - """ - Get forecasts for the next 1 week from specified start day. - """ params = { name: request.args.get(name, default=None, type=str) for name in [ @@ -129,6 +125,7 @@ def get_next_1_week_forecasts(): params = { name: request.args.get(name, default=None, type=str) for name in [ + "device_id", "site_id", "site_name", "parish",