Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change TypeDataset#apply syntax to use a function #110

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ object TypedColumn {
lgen: LabelledGeneric.Aux[T, H],
selector: Selector.Aux[H, K, V]
): Exists[T, K, V] = new Exists[T, K, V] {}

}

implicit class OrderedTypedColumnSyntax[T, U: CatalystOrdered](col: TypedColumn[T, U]) {
Expand All @@ -279,4 +280,4 @@ object TypedColumn {
def >(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped > other.untyped).typed
def >=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped >= other.untyped).typed
}
}
}
6 changes: 2 additions & 4 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
*
* It is statically checked that column with such name exists and has type `A`.
*/
def apply[A](column: Witness.Lt[Symbol])(
implicit
exists: TypedColumn.Exists[T, column.T, A],
def apply[A](selector: T => A)(implicit
encoder: TypedEncoder[A]
): TypedColumn[T, A] = col(column)
): TypedColumn[T, A] = macro frameless.column.ColumnMacros.fromFunction[T, A]

/** Returns `TypedColumn` of type `A` given it's name.
*
Expand Down
70 changes: 70 additions & 0 deletions dataset/src/main/scala/frameless/column/ColumnMacros.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package frameless.column

import frameless.{TypedColumn, TypedEncoder, TypedExpressionEncoder}

import scala.collection.immutable.Queue
import scala.reflect.macros.whitebox

class ColumnMacros(val c: whitebox.Context) {
import c.universe._

// could be used to reintroduce apply('foo)
// $COVERAGE-OFF$ Currently unused
def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we ever want to use this macro instead of the shapeless one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I left this here was in case we wanted to support either ds('a) or ds(_.a) at the same time. We can't do that with overloading, because it will ruin type inference for the function syntax. So if we really wanted to allow both, I thought we could have the macro figure it all out instead.

There are other problems with this, though - I would prefer to just embrace the function syntax because it has better type inference and about 95% smaller bytecode (after implicit expansion is all said and done).

val B = weakTypeOf[B].dealias
val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})")
c.typecheck(q"${c.prefix}.col[$B]($witness)")
}
// $COVERAGE-ON$

def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = {
def fail(tree: Tree) = {
val err =
s"Could not create a column identifier from $tree - try using _.a.b syntax"
c.abort(tree.pos, err)
}

val A = weakTypeOf[A].dealias
val B = weakTypeOf[B].dealias

val selectorStr = selector.tree match {
case Function(List(ValDef(_, ArgName(argName), argTyp, _)), body) => body match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this reads better than

case Function(List(ValDef(_, name, argTyp, _)), body) =>
  NameExtractor(name).unapply(body) match {
    case Some(strs) => strs.mkString(".")
    case None       => fail(other)
  }

case `argName`(strs) => strs.mkString(".")
case other => fail(other)
}
// $COVERAGE-OFF$ - cannot be reached as typechecking will fail in this case before macro is even invoked
case other => fail(other)
// $COVERAGE-ON$
}

val typedCol = appliedType(
weakTypeOf[TypedColumn[_, _]].typeConstructor, A, B
)

val TEEObj = reify(TypedExpressionEncoder)

val datasetCol = c.typecheck(
q"${c.prefix}.dataset.col($selectorStr).as[$B]($TEEObj.apply[$B]($encoder))"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you need an .as here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. it's to go from o.a.s.s.Column to o.a.s.s.TypedColumn. But you're right, it looks like you can make a frameless.TypedColumn from an ordinary Column. Can't remember what I thought the advantage would be in doing this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, if it's Spark's .as then it's not a problem! I thought it was one of ours that triggers implicit search & co, but it's not.

)

c.typecheck(q"new $typedCol($datasetCol)")
}

case class NameExtractor(name: TermName) {
private val This = this
def unapply(tree: Tree): Option[Queue[String]] = {
tree match {
case Ident(`name`) => Some(Queue.empty)
case Select(This(strs), nested) => Some(strs enqueue nested.toString)
// $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work
case Apply(This(strs), List()) => Some(strs)
// $COVERAGE-ON$
case _ => None
}
}
}

object ArgName {
def unapply(name: TermName): Option[NameExtractor] = Some(NameExtractor(name))
}
}
36 changes: 36 additions & 0 deletions dataset/src/test/scala/frameless/ColTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,42 @@ class ColTests extends TypedDatasetSuite {
()
}

test("colApply") {
val x4 = TypedDataset.create[X4[Int, String, Long, Boolean]](Nil)
val t4 = TypedDataset.create[(Int, String, Long, Boolean)](Nil)
val x4x4 = TypedDataset.create[X4X4[Int, String, Long, Boolean]](Nil)

x4(_.a)
t4(_._1)
x4[Int](_.a)
t4[Int](_._1)

illTyped("x4[String](_.a)", "type mismatch;\n found : Int\n required: String")

x4(_.b)
t4(_._2)

x4[String](_.b)
t4[String](_._2)

illTyped("x4[Int](_.b)", "type mismatch;\n found : String\n required: Int")

x4x4(_.xa.a)
x4x4[Int](_.xa.a)
x4x4(_.xa.b)
x4x4[String](_.xa.b)

x4x4(_.xb.a)
x4x4[Int](_.xb.a)
x4x4(_.xb.b)
x4x4[String](_.xb.b)

illTyped("x4x4[String](_.xa.a)", "type mismatch;\n found : Int\n required: String")
illTyped("x4x4(item => item.xa.a * 20)", "Could not create a column identifier from item\\.xa\\.a\\.\\*\\(20\\) - try using _\\.a\\.b syntax")

()
}

test("colMany") {
type X2X2 = X2[X2[Int, String], X2[Long, Boolean]]
val x2x2 = TypedDataset.create[X2X2](Nil)
Expand Down
6 changes: 3 additions & 3 deletions dataset/src/test/scala/frameless/FilterTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class FilterTests extends TypedDatasetSuite {
test("filter with arithmetic expressions: addition") {
check(forAll { (data: Vector[X1[Int]]) =>
val ds = TypedDataset.create(data)
val res = ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector
val res = ds.filter((ds(_.a) + 1) === (ds(_.a) + 1)).collect().run().toVector
res ?= data
})
}
Expand All @@ -31,7 +31,7 @@ class FilterTests extends TypedDatasetSuite {
val t = X1(1) :: X1(2) :: X1(3) :: Nil
val tds: TypedDataset[X1[Int]] = TypedDataset.create(t)

assert(tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1)))
assert(tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1)))
assert(tds.filter(tds(_.a) * 2 === 2).collect().run().toVector === Vector(X1(1)))
assert(tds.filter(tds(_.a) * 3 === 3).collect().run().toVector === Vector(X1(1)))
}
}
16 changes: 8 additions & 8 deletions dataset/src/test/scala/frameless/SelectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class SelectTests extends TypedDatasetSuite {
): Prop = {
val ds = TypedDataset.create(data)

val dataset2 = ds.select(ds('a) + const).collect().run().toVector
val dataset2 = ds.select(ds(_.a) + const).collect().run().toVector
val data2 = data.map { case X1(a) => num.plus(a, const) }

dataset2 ?= data2
Expand All @@ -319,7 +319,7 @@ class SelectTests extends TypedDatasetSuite {
): Prop = {
val ds = TypedDataset.create(data)

val dataset2 = ds.select(ds('a) * const).collect().run().toVector
val dataset2 = ds.select(ds(_.a) * const).collect().run().toVector
val data2 = data.map { case X1(a) => num.times(a, const) }

dataset2 ?= data2
Expand All @@ -341,7 +341,7 @@ class SelectTests extends TypedDatasetSuite {
): Prop = {
val ds = TypedDataset.create(data)

val dataset2 = ds.select(ds('a) - const).collect().run().toVector
val dataset2 = ds.select(ds(_.a) - const).collect().run().toVector
val data2 = data.map { case X1(a) => num.minus(a, const) }

dataset2 ?= data2
Expand All @@ -363,7 +363,7 @@ class SelectTests extends TypedDatasetSuite {
val ds = TypedDataset.create(data)

if (const != 0) {
val dataset2 = ds.select(ds('a) / const).collect().run().toVector.asInstanceOf[Vector[A]]
val dataset2 = ds.select(ds(_.a) / const).collect().run().toVector.asInstanceOf[Vector[A]]
val data2 = data.map { case X1(a) => frac.div(a, const) }
dataset2 ?= data2
} else 0 ?= 0
Expand All @@ -379,17 +379,17 @@ class SelectTests extends TypedDatasetSuite {
assert(t.select(t.col('_1)).collect().run().toList === List(2))
// Issue #54
val fooT = t.select(t.col('_1)).map(x => Tuple1.apply(x)).as[Foo]
assert(fooT.select(fooT('i)).collect().run().toList === List(2))
assert(fooT.select(fooT(_.i)).collect().run().toList === List(2))
}

test("unary - on arithmetic") {
val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil)
assert(e.select(-e('_1)).collect().run().toVector === Vector(-1, -2, -2))
assert(e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector(-3L, -6L, -3L))
assert(e.select(-e(_._1)).collect().run().toVector === Vector(-1, -2, -2))
assert(e.select(-(e(_._1) + e(_._3))).collect().run().toVector === Vector(-3L, -6L, -3L))
}

test("unary - on strings should not type check") {
val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil)
illTyped("""e.select( -e('_2) )""")
illTyped("""e.select( -e(_._2) )""")
}
}
2 changes: 2 additions & 0 deletions dataset/src/test/scala/frameless/XN.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ object X5 {
implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering]: Ordering[X5[A, B, C, D, E]] =
Ordering.Tuple5[A, B, C, D, E].on(x => (x.a, x.b, x.c, x.d, x.e))
}

case class X4X4[A, B, C, D](xa: X4[A, B, C, D], xb: X4[A, B, C, D])
34 changes: 17 additions & 17 deletions docs/src/main/tut/GettingStarted.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ val apartmentsTypedDS2 = spark.createDataset(apartments).typed
This is how we select a particular column from a `TypedDataset`:

```tut:book
val cities: TypedDataset[String] = apartmentsTypedDS.select(apartmentsTypedDS('city))
val cities: TypedDataset[String] = apartmentsTypedDS.select(apartmentsTypedDS(_.city))
```

This is completely safe, for instance suppose we misspell `city`:

```tut:book:fail
apartmentsTypedDS.select(apartmentsTypedDS('citi))
apartmentsTypedDS.select(apartmentsTypedDS(_.citi))
```

This gets caught at compile-time, whereas with traditional Spark `Dataset` the error appears at run-time.
Expand All @@ -81,7 +81,7 @@ apartmentsDS.select('citi)
`select()` supports arbitrary column operations:

```tut:book
apartmentsTypedDS.select(apartmentsTypedDS('surface) * 10, apartmentsTypedDS('surface) + 2).show().run()
apartmentsTypedDS.select(apartmentsTypedDS(_.surface) * 10, apartmentsTypedDS(_.surface) + 2).show().run()
```

*Note that unlike the standard Spark api, here `show()` is lazy. It requires to apply `run()` for the
Expand All @@ -91,14 +91,14 @@ apartmentsTypedDS.select(apartmentsTypedDS('surface) * 10, apartmentsTypedDS('su
Let us now try to compute the price by surface unit:

```tut:book:fail
val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface)) ^
val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface)) ^
```

Argh! Looks like we can't divide a `TypedColumn` of `Double` by `Int`.
Well, we can cast our `Int`s to `Double`s explicitly to proceed with the computation.

```tut:book
val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface).cast[Double])
val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface).cast[Double])
priceBySurfaceUnit.collect().run()
```

Expand All @@ -107,15 +107,15 @@ Alternatively, we can perform the cast implicitly:
```tut:book
import frameless.implicits.widen._

val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS('price)/apartmentsTypedDS('surface))
val priceBySurfaceUnit = apartmentsTypedDS.select(apartmentsTypedDS(_.price)/apartmentsTypedDS(_.surface))
priceBySurfaceUnit.collect.run()
```

Looks like it worked, but that `cast` looks unsafe right? Actually it is safe.
Let's try to cast a `TypedColumn` of `String` to `Double`:

```tut:book:fail
apartmentsTypedDS('city).cast[Double]
apartmentsTypedDS(_.city).cast[Double]
```

The compile-time error tells us that to perform the cast, an evidence (in the form of `CatalystCast[String, Double]`) must be available.
Expand All @@ -136,15 +136,15 @@ The cast is valid and the expression compiles:

```tut:book
case class UpdatedSurface(city: String, surface: Int)
val updated = apartmentsTypedDS.select(apartmentsTypedDS('city), apartmentsTypedDS('surface) + 2).as[UpdatedSurface]
val updated = apartmentsTypedDS.select(apartmentsTypedDS(_.city), apartmentsTypedDS(_.surface) + 2).as[UpdatedSurface]
updated.show(2).run()
```

Next we try to cast a `(String, String)` to an `UpdatedSurface` (which has types `String`, `Int`).
The cast is not valid and the expression does not compile:

```tut:book:fail
apartmentsTypedDS.select(apartmentsTypedDS('city), apartmentsTypedDS('city)).as[UpdatedSurface]
apartmentsTypedDS.select(apartmentsTypedDS(_.city), apartmentsTypedDS(_.city)).as[UpdatedSurface]
```

### Projections
Expand All @@ -161,7 +161,7 @@ import frameless.implicits.widen._
val aptds = apartmentsTypedDS // For shorter expressions

case class ApartmentDetails(city: String, price: Double, surface: Int, ratio: Double)
val aptWithRatio = aptds.select(aptds('city), aptds('price), aptds('surface), aptds('price) / aptds('surface)).as[ApartmentDetails]
val aptWithRatio = aptds.select(aptds(_.city), aptds(_.price), aptds(_.surface), aptds(_.price) / aptds(_.surface)).as[ApartmentDetails]
```

Suppose we only want to work with `city` and `ratio`:
Expand Down Expand Up @@ -222,30 +222,30 @@ val udf = apartmentsTypedDS.makeUDF(priceModifier)

val aptds = apartmentsTypedDS // For shorter expressions

val adjustedPrice = aptds.select(aptds('city), udf(aptds('city), aptds('price)))
val adjustedPrice = aptds.select(aptds(_.city), udf(aptds(_.city), aptds(_.price)))

adjustedPrice.show().run()
```

## GroupBy and Aggregations
Let's suppose we wanted to retrieve the average apartment price in each city
```tut:book
val priceByCity = apartmentsTypedDS.groupBy(apartmentsTypedDS('city)).agg(avg(apartmentsTypedDS('price)))
val priceByCity = apartmentsTypedDS.groupBy(apartmentsTypedDS(_.city)).agg(avg(apartmentsTypedDS(_.price)))
priceByCity.collect().run()
```
Again if we try to aggregate a column that can't be aggregated, we get a compilation error
```tut:book:fail
apartmentsTypedDS.groupBy(apartmentsTypedDS('city)).agg(avg(apartmentsTypedDS('city))) ^
apartmentsTypedDS.groupBy(apartmentsTypedDS(_.city)).agg(avg(apartmentsTypedDS(_.city))) ^
```

Next, we combine `select` and `groupBy` to calculate the average price/surface ratio per city:

```tut:book
val aptds = apartmentsTypedDS // For shorter expressions

val cityPriceRatio = aptds.select(aptds('city), aptds('price) / aptds('surface))
val cityPriceRatio = aptds.select(aptds(_.city), aptds(_.price) / aptds(_.surface))

cityPriceRatio.groupBy(cityPriceRatio('_1)).agg(avg(cityPriceRatio('_2))).show().run()
cityPriceRatio.groupBy(cityPriceRatio(_._1)).agg(avg(cityPriceRatio(_._2))).show().run()
```

## Joins
Expand All @@ -265,7 +265,7 @@ val citiInfoTypedDS = TypedDataset.create(cityInfo)
Here is how to join the population information to the apartment's dataset.

```tut:book
val withCityInfo = apartmentsTypedDS.join(citiInfoTypedDS, apartmentsTypedDS('city), citiInfoTypedDS('name))
val withCityInfo = apartmentsTypedDS.join(citiInfoTypedDS, apartmentsTypedDS(_.city), citiInfoTypedDS(_.name))

withCityInfo.show().run()
```
Expand All @@ -278,7 +278,7 @@ We can then select which information we want to continue to work with:
case class AptPriceCity(city: String, aptPrice: Double, cityPopulation: Int)

withCityInfo.select(
withCityInfo.colMany('_2, 'name), withCityInfo.colMany('_1, 'price), withCityInfo.colMany('_2, 'population)
withCityInfo(_._2.name), withCityInfo(_._1.price), withCityInfo(_._2.population)
).as[AptPriceCity].show().run
```

Expand Down
6 changes: 3 additions & 3 deletions docs/src/main/tut/TypedDatasetVsSparkDataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,19 @@ with a fully optimized query plan.
import frameless.TypedDataset
val fds = TypedDataset.create(ds)

fds.filter( fds('i) === 10 ).select( fds('i) ).show().run()
fds.filter( fds(_.i) === 10 ).select( fds(_.i) ).show().run()
```

And the optimized Physical Plan:

```tut:book
fds.filter( fds('i) === 10 ).select( fds('i) ).explain()
fds.filter( fds(_.i) === 10 ).select( fds(_.i) ).explain()
```

And the compiler is our friend.

```tut:fail
fds.filter( fds('i) === 10 ).select( fds('x) )
fds.filter( fds(_.i) === 10 ).select( fds(_.x) )
```

```tut:invisible
Expand Down