Skip to content
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml

import org.apache.spark.ml.regression.ArimaRegression
import org.apache.spark.sql.SparkSession

object ArimaRegressionExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder.appName("ARIMA Example").getOrCreate()
import spark.implicits._

val tsData = Seq(1.2, 2.3, 3.1, 4.0, 5.5).toDF("y")

val arima = new ArimaRegression().setP(1).setD(0).setQ(1)
val model = arima.fit(tsData)

val result = model.transform(tsData)
result.show()

spark.stop()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.regression

import org.apache.spark.ml.param.{IntParam, Params}

/**
* Shared parameters for ARIMA models.
*/
trait ArimaParams extends Params {

/** The autoregressive order (p). */
final val p: IntParam = new IntParam(this, "p", "Autoregressive order (p)")

/** The differencing order (d). */
final val d: IntParam = new IntParam(this, "d", "Differencing order (d)")

/** The moving average order (q). */
final val q: IntParam = new IntParam(this, "q", "Moving average order (q)")

setDefault(p -> 1, d -> 1, q -> 1)

def getP: Int = $(p)
def getD: Int = $(d)
def getQ: Int = $(q)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.regression

import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

/**
* ARIMA (AutoRegressive Integrated Moving Average) model implementation
* for univariate time series forecasting.
*
* This implementation leverages PySpark Pandas UDF with statsmodels to
* fit ARIMA(p, d, q) models in a distributed fashion.
*
* Input column: "y" (DoubleType)
* Output column: "prediction" (DoubleType)
*/
class ArimaRegression(override val uid: String)
extends Estimator[ArimaRegressionModel]
with ArimaParams
with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("arimaReg"))

def setP(value: Int): this.type = set(p, value)
def setD(value: Int): this.type = set(d, value)
def setQ(value: Int): this.type = set(q, value)

/**
* Fits an ARIMA model using Python statsmodels via Pandas UDF.
* The UDF runs ARIMA(p,d,q) on each time series partition or entire dataset.
*/
override def fit(dataset: Dataset[_]): ArimaRegressionModel = {
val spark = dataset.sparkSession
import spark.implicits._

require(dataset.columns.contains("y"),
"Input dataset must contain a 'y' column of DoubleType representing the time series values.")

// Define the ARIMA Pandas UDF (Python side using statsmodels)
val udfScript =
s"""
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA

@pandas_udf("double")
def arima_forecast_udf(y: pd.Series) -> pd.Series:
try:
model = ARIMA(y, order=(${getOrDefault(p)}, ${getOrDefault(d)}, ${getOrDefault(q)}))
fitted = model.fit()
forecast = fitted.forecast(steps=1)
return pd.Series([forecast.iloc[0]] * len(y))
except Exception:
return pd.Series([float('nan')] * len(y))
"""

// Register the UDF dynamically
spark.udf.registerPython("arima_forecast_udf", udfScript)

// Apply the ARIMA forecast UDF
val predicted = dataset.withColumn("prediction", call_udf("arima_forecast_udf", col("y")))

// Create the model instance
val model = new ArimaRegressionModel(uid)
.setParent(this)
.setP($(p))
.setD($(d))
.setQ($(q))
.setFittedData(predicted)

model
}

override def copy(extra: ParamMap): ArimaRegression = defaultCopy(extra)

override def transformSchema(schema: StructType): StructType = {
require(schema.fieldNames.contains("y"),
"Input schema must contain 'y' column of DoubleType.")
StructType(schema.fields :+ StructField("prediction", DoubleType, nullable = true))
}
}

object ArimaRegression extends DefaultParamsReadable[ArimaRegression] {
override def load(path: String): ArimaRegression = super.load(path)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.regression

import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

class ArimaRegressionModel(override val uid: String)
extends Model[ArimaRegressionModel]
with ArimaParams
with MLWritable {

private var fittedData: DataFrame = _
def setFittedData(df: DataFrame): this.type = { this.fittedData = df; this }

override def copy(extra: ParamMap): ArimaRegressionModel = defaultCopy(extra)

override def transform(dataset: DataFrame): DataFrame = {
require(fittedData != null, "ARIMA model not fitted.")
fittedData
}

override def transformSchema(schema: StructType): StructType = {
schema.add("prediction", org.apache.spark.sql.types.DoubleType, nullable = true)
}

override def write: MLWriter = new DefaultParamsWriter(this)
}

object ArimaRegressionModel extends DefaultParamsReadable[ArimaRegressionModel]
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.regression

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._

/**
* Unit tests for ArimaRegression and ArimaRegressionModel.
*/
class ArimaRegressionSuite extends SparkFunSuite with org.apache.spark.sql.test.SharedSparkSession {

import testImplicits._

test("ARIMA model basic fit and transform") {
val spark = sparkSession
import spark.implicits._

val data = Seq(
(1, 100.0),
(2, 102.0),
(3, 101.0),
(4, 103.0),
(5, 104.0)
).toDF("t", "y")

val arima = new ArimaRegression()
.setP(1)
.setD(1)
.setQ(1)

val model = arima.fit(data)
val transformed = model.transform(data)

assert(transformed.columns.contains("prediction"), "Output should include 'prediction' column.")
assert(transformed.count() == data.count(), "Output row count should match input.")
}

test("ARIMA model schema validation and parameter setting") {
val arima = new ArimaRegression()
.setP(2)
.setD(1)
.setQ(1)

assert(arima.getP == 2)
assert(arima.getD == 1)
assert(arima.getQ == 1)

val schema = org.apache.spark.sql.types.StructType.fromDDL("y DOUBLE")
val outputSchema = arima.transformSchema(schema)
assert(outputSchema.fieldNames.contains("prediction"))
}

test("ARIMA model copy and persistence") {
val spark = sparkSession
import spark.implicits._

val data = Seq(
(1, 10.0),
(2, 12.0),
(3, 11.0)
).toDF("t", "y")

val arima = new ArimaRegression().setP(1).setD(1).setQ(1)
val model = arima.fit(data)

val copied = model.copy(org.apache.spark.ml.param.ParamMap.empty)
assert(copied.getP == model.getP)
assert(copied.getD == model.getD)
assert(copied.getQ == model.getQ)
}

test("ARIMA model handles missing y column gracefully") {
val spark = sparkSession
import spark.implicits._
val invalidDF = Seq((1, 2.0)).toDF("t", "value")
val arima = new ArimaRegression()

intercept[IllegalArgumentException] {
arima.fit(invalidDF)
}
}
}
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.ml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ Regression
RandomForestRegressionModel
FMRegressor
FMRegressionModel
ArimaRegression
ArimaRegressionModel


Statistics
Expand Down
Loading