Skip to content

Commit 1d6e7b8

Browse files
committed
SAITS mejora a Transformer y Baseline en Periodo 1 y Periodo 2
1 parent c8195a3 commit 1d6e7b8

13 files changed

+9844
-701
lines changed

memoria/main.pdf

-543 KB
Binary file not shown.

notebooks/.ipynb_checkpoints/Notebook_Imputation-checkpoint.ipynb

Lines changed: 653 additions & 225 deletions
Large diffs are not rendered by default.

notebooks/.ipynb_checkpoints/Pipeline3-checkpoint.ipynb

Lines changed: 8249 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/Notebook_Imputation.ipynb

Lines changed: 321 additions & 416 deletions
Large diffs are not rendered by default.

pampaneira_imputation/.ipynb_checkpoints/config-checkpoint.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
TIMEZONE = "UTC"
6868

6969
# Fechas del Periodo 1 para los datos de camiones
70-
PERIOD_1_START = pd.to_datetime("2023-01-17 17:00:00+00:00", utc=True)
71-
PERIOD_1_END = pd.to_datetime("2023-03-14 11:00:00+00:00", utc=True)
70+
PERIOD_1_START = pd.to_datetime("2023-01-17 00:00:00+00:00", utc=True)
71+
PERIOD_1_END = pd.to_datetime("2023-03-14 23:00:00+00:00", utc=True)
7272
PERIOD_1_PADDING_START = pd.to_datetime("2023-01-17 00:00:00", utc=True)
7373
PERIOD_1_PADDING_END = pd.to_datetime("2023-03-14 23:00:00", utc=True)
7474

@@ -106,19 +106,19 @@
106106
SAITS_PARAMS = {
107107
"n_steps": N_STEPS,
108108
# n_features se establecerá dinámicamente
109-
"n_layers": 3,
110-
"d_model": 128,
109+
"n_layers": 4,
110+
"d_model": 192,
111111
"d_ffn": 256,
112112
"n_heads": 4,
113113
"d_k": 32,
114114
"d_v": 32,
115115
"dropout": 0.3,
116116
"attn_dropout": 0.2,
117117
"diagonal_attention_mask": True,
118-
"ORT_weight": 1,
119-
"MIT_weight": 1,
118+
"ORT_weight": 0.8,
119+
"MIT_weight": 1.2,
120120
"batch_size": 64,
121-
"epochs": 10, # Considera reducir para pruebas/depuración más rápidas
121+
"epochs": 50, # Considera reducir para pruebas/depuración más rápidas
122122
"patience": 5,
123123
"num_workers": 0,
124124
"device": None, # Autodetecta (CPU o GPU si disponible)
@@ -149,7 +149,7 @@
149149
'ORT_weight': 1.0, # Weight for ORT (Observed Reconstruction Term)
150150
'MIT_weight': 1.0, # Weight for MIT (Missing Imputation Term)
151151
'batch_size': 64, # Batch size for training
152-
'epochs': 10, # Maximum epochs for training
152+
'epochs': 50, # Maximum epochs for training
153153
'patience': 5, # Early stopping patience
154154
'num_workers': 0, # Number of workers for data loading
155155
'device': None, # Device to use (None for auto-detection)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# pampaneira_imputation/data_loader.py
2+
import pandas as pd
3+
from typing import List
4+
from . import config
5+
6+
pd.options.mode.chained_assignment = None
7+
8+
def load_traffic_data(filepath: str = config.TRAFFIC_FILE,
9+
columns_to_use: List[str] = config.PAM_BUB_TRAFFIC_COLS,
10+
date_col: str = config.DATE_COL,
11+
timezone: str = config.TIMEZONE) -> pd.DataFrame:
12+
"""
13+
Carga datos de tráfico generales, selecciona columnas relevantes, convierte la fecha.
14+
15+
Args:
16+
filepath (str, optional): Ruta al archivo CSV de tráfico (por defecto: config.TRAFFIC_FILE).
17+
columns_to_use (List[str], optional): Lista de columnas de tráfico a usar (por defecto: config.PAM_BUB_TRAFFIC_COLS).
18+
date_col (str, optional): Nombre de la columna de fecha (por defecto: config.DATE_COL).
19+
timezone (str, optional): Zona horaria para las fechas (por defecto: config.TIMEZONE).
20+
21+
Returns:
22+
pd.DataFrame: DataFrame con datos de tráfico cargados y preprocesados.
23+
24+
Raises:
25+
FileNotFoundError: Si el archivo especificado no se encuentra.
26+
KeyError: Si una columna especificada no se encuentra en el archivo.
27+
"""
28+
try:
29+
df = pd.read_csv(filepath)
30+
df[date_col] = pd.to_datetime(df[date_col])
31+
# Asegura que la zona horaria UTC sea consciente si aún no lo es
32+
if df[date_col].dt.tz is None:
33+
df[date_col] = df[date_col].dt.tz_localize(timezone)
34+
elif df[date_col].dt.tz.zone != timezone:
35+
df[date_col] = df[date_col].dt.tz_convert(timezone)
36+
37+
# Selecciona solo las columnas requeridas más la columna de fecha
38+
df_filtered = df[[date_col] + columns_to_use]
39+
40+
# Convierte columnas enteras a float64 como en el script original
41+
int_cols = df_filtered.select_dtypes(include='int64').columns
42+
df_filtered[int_cols] = df_filtered[int_cols].astype('float64')
43+
44+
return df_filtered
45+
46+
except FileNotFoundError:
47+
print(f"Error: No se encontró el archivo en {filepath}")
48+
raise
49+
except KeyError as e:
50+
print(f"Error: No se encontró la columna {e} en {filepath}.")
51+
raise
52+
53+
54+
def load_intersection_data(filepath: str = config.INTERSECTION_FILE,
55+
date_col_original: str = "Date", # Nombre original en CSV
56+
date_col_target: str = config.DATE_COL,
57+
truck_pos_col: str = config.TRUCK_POS_COL,
58+
target_truck_pos: str = config.TARGET_TRUCK_POS,
59+
timezone: str = config.TIMEZONE) -> pd.DataFrame:
60+
"""
61+
Carga datos de intersección, filtra por posición de camión, convierte la fecha.
62+
63+
Args:
64+
filepath (str, optional): Ruta al archivo CSV de intersección (por defecto: config.INTERSECTION_FILE).
65+
date_col_original (str, optional): Nombre original de la columna de fecha en el CSV (por defecto: "Date").
66+
date_col_target (str, optional): Nombre objetivo de la columna de fecha (por defecto: config.DATE_COL).
67+
truck_pos_col (str, optional): Nombre de la columna de posición del camión (por defecto: config.TRUCK_POS_COL).
68+
target_truck_pos (str, optional): Posición objetivo del camión para filtrar (por defecto: config.TARGET_TRUCK_POS).
69+
timezone (str, optional): Zona horaria para las fechas (por defecto: config.TIMEZONE).
70+
71+
Returns:
72+
pd.DataFrame: DataFrame con datos de intersección cargados, filtrados y preprocesados.
73+
74+
Raises:
75+
FileNotFoundError: Si el archivo especificado no se encuentra.
76+
KeyError: Si una columna especificada no se encuentra o falla el cambio de nombre en el archivo.
77+
"""
78+
try:
79+
df = pd.read_csv(filepath)
80+
df.rename(columns={date_col_original: date_col_target}, inplace=True)
81+
df[date_col_target] = pd.to_datetime(df[date_col_target])
82+
83+
# Asegura que la zona horaria UTC sea consciente si aún no lo es
84+
if df[date_col_target].dt.tz is None:
85+
df[date_col_target] = df[date_col_target].dt.tz_localize(timezone)
86+
elif df[date_col_target].dt.tz.zone != timezone:
87+
df[date_col_target] = df[date_col_target].dt.tz_convert(timezone)
88+
89+
# Filtra por posición de camión
90+
df_filtered = df[df[truck_pos_col] == target_truck_pos].copy() # Usa .copy()
91+
92+
# Convierte columnas enteras a float64
93+
int_cols = df_filtered.select_dtypes(include='int64').columns
94+
df_filtered[int_cols] = df_filtered[int_cols].astype('float64')
95+
96+
# Selecciona solo las columnas de características finales + fecha (definidas en config)
97+
# Esto asume que el archivo de intersección contiene todas las FEATURE_COLUMNS
98+
# Si no, ajusta config.FEATURE_COLUMNS o esta lógica de selección
99+
cols_to_keep = [date_col_target] + config.FEATURE_COLUMNS
100+
# Asegura que solo mantenemos las columnas presentes en el dataframe
101+
cols_present = [col for col in cols_to_keep if col in df_filtered.columns]
102+
df_final = df_filtered[cols_present]
103+
104+
105+
return df_final
106+
107+
except FileNotFoundError:
108+
print(f"Error: No se encontró el archivo en {filepath}")
109+
raise
110+
except KeyError as e:
111+
print(f"Error: No se encontró la columna {e} o falló el renombrado en {filepath}.")
112+
raise

pampaneira_imputation/.ipynb_checkpoints/data_preprocessor-checkpoint.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,22 @@ def fill_missing_timestamps(
1616
) -> pd.DataFrame:
1717
"""
1818
Rellena las marcas de tiempo horarias faltantes en un DataFrame con NaNs.
19-
20-
Args:
21-
df (pd.DataFrame): DataFrame de entrada con una columna de fecha.
22-
start_date (pd.Timestamp): Fecha de inicio del rango completo.
23-
end_date (pd.Timestamp): Fecha de fin del rango completo.
24-
freq (str, optional): Frecuencia para el rango de fechas (por defecto: 'h').
25-
date_col (str, optional): Nombre de la columna de fecha (por defecto: config.DATE_COL).
26-
27-
Returns:
28-
pd.DataFrame: DataFrame con marcas de tiempo horarias completas y NaNs
29-
para los datos faltantes.
3019
"""
31-
if not pd.api.types.is_datetime64_any_dtype(df[date_col]):
32-
df[date_col] = pd.to_datetime(df[date_col])
33-
if df[date_col].dt.tz is None:
34-
df[date_col] = df[date_col].dt.tz_localize(
35-
config.TIMEZONE
36-
) # Asegura la zona horaria
37-
df = df.set_index(date_col)
20+
3821
full_date_range = pd.date_range(
3922
start=start_date, end=end_date, freq=freq, tz=config.TIMEZONE
4023
)
41-
df_reindexed = df.reindex(full_date_range)
42-
# No reinicies el índice si quieres preservar el DatetimeIndex
43-
return df_reindexed
24+
df_reindexed = df.set_index(date_col).reindex(full_date_range) # Removed reset_index()
25+
df_reindexed.index.name = None # Set index name to None
4426

27+
return df_reindexed
4528

4629
def split_by_period(
4730
df: pd.DataFrame,
48-
period_1_start: pd.Timestamp = config.PERIOD_1_START,
49-
period_1_end: pd.Timestamp = config.PERIOD_1_END,
50-
period_2_start: pd.Timestamp = config.PERIOD_2_START,
51-
period_2_end: pd.Timestamp = config.PERIOD_2_END,
31+
period_1_start: pd.Timestamp = config.PERIOD_1_PADDING_START,
32+
period_1_end: pd.Timestamp = config.PERIOD_1_PADDING_END,
33+
period_2_start: pd.Timestamp = config.PERIOD_2_PADDING_START,
34+
period_2_end: pd.Timestamp = config.PERIOD_2_PADDING_END,
5235
date_col: str = config.DATE_COL,
5336
) -> Tuple[pd.DataFrame, pd.DataFrame]:
5437
"""
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# pampaneira_imputation/evaluation.py
2+
import numpy as np
3+
import pandas as pd
4+
from pypots.nn.functional import calc_mae, calc_mse, calc_rmse, calc_mre
5+
from typing import Dict, Tuple, List
6+
from . import config
7+
8+
def calculate_imputation_metrics(y_true: np.ndarray,
9+
y_pred: np.ndarray,
10+
indicating_mask: np.ndarray) -> Dict[str, float]:
11+
"""
12+
Calcula MAE, MSE, RMSE, MRE para valores imputados donde la máscara es 1.
13+
14+
Args:
15+
y_true (np.ndarray): Datos verdaderos (potencialmente con NaNs donde faltaban originalmente).
16+
y_pred (np.ndarray): Datos imputados.
17+
indicating_mask (np.ndarray): Máscara donde 1 indica un valor faltante que fue imputado,
18+
0 indica un valor observado.
19+
20+
Returns:
21+
Dict[str, float]: Diccionario que contiene 'mae', 'mse', 'rmse', 'mre'.
22+
"""
23+
# Asegura que las entradas sean arrays numpy
24+
y_true = np.asarray(y_true)
25+
y_pred = np.asarray(y_pred)
26+
indicating_mask = np.asarray(indicating_mask)
27+
28+
# Reemplaza NaNs en la verdad fundamental con 0 para el cálculo donde la máscara es 1
29+
# Esto es necesario porque las funciones pypots esperan una verdad fundamental sin NaN
30+
# Solo evaluamos donde indicating_mask es 1, por lo que este reemplazo es seguro.
31+
y_true_filled = np.nan_to_num(y_true, nan=0.0)
32+
33+
if y_true.shape != y_pred.shape or y_true.shape != indicating_mask.shape:
34+
raise ValueError(f"Desajuste de forma: y_true={y_true.shape}, "
35+
f"y_pred={y_pred.shape}, mask={indicating_mask.shape}")
36+
37+
# Verifica si la suma de la máscara es cero (no hay valores para evaluar)
38+
if indicating_mask.sum() == 0:
39+
print("Advertencia: La suma de la máscara indicadora es 0. No hay valores imputados para evaluar.")
40+
return {'mae': np.nan, 'mse': np.nan, 'rmse': np.nan, 'mre': np.nan}
41+
42+
try:
43+
mae = calc_mae(y_pred, y_true_filled, indicating_mask)
44+
mse = calc_mse(y_pred, y_true_filled, indicating_mask)
45+
rmse = calc_rmse(y_pred, y_true_filled, indicating_mask)
46+
mre = calc_mre(y_pred, y_true_filled, indicating_mask) # Precaución con MRE si los valores verdaderos están cerca de cero
47+
48+
return {'mae': mae, 'mse': mse, 'rmse': rmse, 'mre': mre}
49+
except Exception as e:
50+
print(f"Error durante el cálculo de métricas: {e}")
51+
# Añade más información de depuración si es necesario
52+
print(f"Formas: y_pred={y_pred.shape}, y_true_filled={y_true_filled.shape}, mask={indicating_mask.shape}")
53+
print(f"Suma de máscara: {indicating_mask.sum()}")
54+
print(f"Conteo de NaN: pred={np.isnan(y_pred).sum()}, true_filled={np.isnan(y_true_filled).sum()}, mask={np.isnan(indicating_mask).sum()}")
55+
# Considera verificar también por infinitos
56+
return {'mae': np.nan, 'mse': np.nan, 'rmse': np.nan, 'mre': np.nan}
57+
58+
59+
def evaluate_all_methods(preprocessed_data: Dict,
60+
imputed_results: Dict[str, np.ndarray],
61+
methods_to_evaluate: list = ['median', 'mean', 'linear', 'ffill_bfill', 'bfill_ffill', 'saits']) -> pd.DataFrame:
62+
"""
63+
Evalúa múltiples métodos de imputación usando los resultados del conjunto de prueba.
64+
65+
Args:
66+
preprocessed_data (Dict): Diccionario de preprocess_for_imputation.
67+
imputed_results (Dict[str, np.ndarray]): Diccionario que mapea nombres de métodos a arrays NumPy imputados (conjunto de prueba).
68+
methods_to_evaluate (list, optional): Lista de claves en imputed_results para evaluar.
69+
(por defecto: ['median', 'mean', 'linear', 'ffill_bfill', 'bfill_ffill', 'saits'])
70+
71+
Returns:
72+
pd.DataFrame: DataFrame de Pandas que resume MAE, MSE, RMSE, MRE para cada método.
73+
"""
74+
results = []
75+
y_true = preprocessed_data['test_X_ori']
76+
indicating_mask = preprocessed_data['test_indicating_mask']
77+
78+
# Maneja la posible eliminación de columnas para ffill/bfill si es necesario
79+
# Esta lógica asume que WS/WD se eliminaron *antes* de la imputación para ffill/bfill
80+
cols_to_drop_indices = [config.FEATURE_COLUMNS.index(col) for col in config.COLS_TO_DROP_FOR_BASELINE if col in config.FEATURE_COLUMNS]
81+
82+
for method_name in methods_to_evaluate:
83+
if method_name not in imputed_results:
84+
print(f"Advertencia: No se encontraron resultados imputados para el método '{method_name}'. Saltando.")
85+
continue
86+
87+
y_pred = imputed_results[method_name]
88+
current_y_true = y_true
89+
current_mask = indicating_mask
90+
91+
# Manejo específico para métodos donde las columnas podrían haberse eliminado
92+
if method_name in ['ffill_bfill', 'bfill_ffill'] and cols_to_drop_indices:
93+
print(f"Ajustando datos verdaderos y máscara para {method_name} debido a columnas eliminadas.")
94+
current_y_true = np.delete(y_true, cols_to_drop_indices, axis=2)
95+
current_mask = np.delete(indicating_mask, cols_to_drop_indices, axis=2)
96+
# y_pred para estos métodos ya debería tener las columnas eliminadas
97+
98+
print(f"\nCalculando métricas para: {method_name}")
99+
metrics = calculate_imputation_metrics(current_y_true, y_pred, current_mask)
100+
101+
results.append({
102+
"Method": method_name.replace('_', ' ').title(), # Nombre más bonito
103+
"RMSE": metrics.get('rmse', np.nan),
104+
"MSE": metrics.get('mse', np.nan),
105+
"MAE": metrics.get('mae', np.nan),
106+
"MRE": metrics.get('mre', np.nan)
107+
})
108+
109+
error_table = pd.DataFrame.from_records(results)
110+
error_table = error_table.set_index("Method").round(4)
111+
return error_table

0 commit comments

Comments
 (0)