Skip to content

Commit

Permalink
[94] Support for strings (#97)
Browse files Browse the repository at this point in the history
* support for strings

* removing type tags, they do not work with scala 2.10

* removing type tags, they do not work with scala 2.10

* removing type tags, they do not work with scala 2.10

* cleanup

* revert changes

* cleanup

* cleanups

* cleanup

* comments
  • Loading branch information
thunterdb authored Apr 24, 2017
1 parent e9a31a6 commit b393bf3
Show file tree
Hide file tree
Showing 17 changed files with 270 additions and 111 deletions.
36 changes: 21 additions & 15 deletions src/main/scala/org/tensorframes/ColumnInformation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package org.tensorframes

import org.apache.spark.sql.types._

import org.tensorframes.impl.{ScalarType, SupportedOperations}


class ColumnInformation private (
val field: StructField,
Expand All @@ -15,7 +17,9 @@ class ColumnInformation private (
val b = new MetadataBuilder().withMetadata(field.metadata)
for (info <- stf) {
b.putLongArray(shapeKey, info.shape.dims.toArray)
b.putString(tensorStructType, info.dataType.toString)
// Keep the SQL name, so that we do not leak internal details.
val dt = SupportedOperations.opsFor(info.dataType).sqlType
b.putString(tensorStructType, dt.toString)
}
val meta = b.build()
field.copy(metadata = meta)
Expand Down Expand Up @@ -73,15 +77,15 @@ object ColumnInformation extends Logging {
* @param scalarType the data type
* @param blockShape the shape of the block
*/
def structField(name: String, scalarType: NumericType, blockShape: Shape): StructField = {
def structField(name: String, scalarType: ScalarType, blockShape: Shape): StructField = {
val i = SparkTFColInfo(blockShape, scalarType)
val f = StructField(name, sqlType(scalarType, blockShape.tail), nullable = false)
ColumnInformation(f, i).merged
}

private def sqlType(scalarType: NumericType, shape: Shape): DataType = {
private def sqlType(scalarType: ScalarType, shape: Shape): DataType = {
if (shape.dims.isEmpty) {
scalarType
SupportedOperations.opsFor(scalarType).sqlType
} else {
ArrayType(sqlType(scalarType, shape.tail), containsNull = false)
}
Expand All @@ -102,11 +106,14 @@ object ColumnInformation extends Logging {
for {
s <- shape
t <- tpe
} yield SparkTFColInfo(s, t)
ops <- SupportedOperations.getOps(t)
} yield SparkTFColInfo(s, ops.scalarType)
}

private def getType(s: String): Option[NumericType] = {
supportedTypes.find(_.toString == s)
private def getType(s: String): Option[DataType] = {
val res = supportedTypes.find(_.toString == s)
logInfo(s"getType: $s -> $res")
res
}

/**
Expand All @@ -115,19 +122,18 @@ object ColumnInformation extends Logging {
* @return
*/
private def extractFromRow(dt: DataType): Option[SparkTFColInfo] = dt match {
case x: NumericType if MetadataConstants.supportedTypes.contains(dt) =>
logTrace("numerictype: " + x)
// It is a basic type that we understand
Some(SparkTFColInfo(Shape(Unknown), x))
case x: ArrayType =>
logTrace("arraytype: " + x)
// Look into the array to figure out the type.
extractFromRow(x.elementType).map { info =>
SparkTFColInfo(info.shape.prepend(Unknown), info.dataType)
}
case _ =>
logTrace("not understood: " + dt)
// Not understood.
None
case _ => SupportedOperations.getOps(dt) match {
case Some(ops) =>
logTrace("numerictype: " + ops.scalarType)
// It is a basic type that we understand
Some(SparkTFColInfo(Shape(Unknown), ops.scalarType))
case None => None
}
}
}
7 changes: 4 additions & 3 deletions src/main/scala/org/tensorframes/ExperimentalOperations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package org.tensorframes
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, DataType, NumericType}
import org.tensorframes.impl.SupportedOperations

import org.tensorframes.impl.{ScalarType, SupportedOperations}

/**
* Some useful methods for operating on dataframes that are not part of the official API (and thus may change anytime).
Expand Down Expand Up @@ -109,8 +110,8 @@ private[tensorframes] object ExtraOperations extends ExperimentalOperations with
DataFrameInfo(allInfo)
}

private def extractBasicType(dt: DataType): Option[NumericType] = dt match {
case x: NumericType => Some(x)
private def extractBasicType(dt: DataType): Option[ScalarType] = dt match {
case x: NumericType => Some(SupportedOperations.opsFor(x).scalarType)
case x: ArrayType => extractBasicType(x.elementType)
case _ => None
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/org/tensorframes/MetadataConstants.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.tensorframes

import org.apache.spark.sql.types.NumericType
import org.tensorframes.impl.SupportedOperations
import org.apache.spark.sql.types.{DataType, NumericType}
import org.tensorframes.impl.{ScalarType, SupportedOperations}

/**
* Metadata annotations that get embedded in dataframes to express tensor information.
Expand Down Expand Up @@ -29,5 +29,5 @@ object MetadataConstants {
/**
* All the SQL types supported by SparkTF.
*/
val supportedTypes: Seq[NumericType] = SupportedOperations.sqlTypes
val supportedTypes: Seq[DataType] = SupportedOperations.sqlTypes
}
16 changes: 12 additions & 4 deletions src/main/scala/org/tensorframes/Shape.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package org.tensorframes

import org.apache.spark.sql.types.NumericType
import org.apache.spark.sql.types.{BinaryType, DataType, NumericType}
import org.tensorflow.framework.TensorShapeProto

import scala.collection.JavaConverters._
import org.tensorframes.Shape.DimType
import org.tensorframes.impl.ScalarType
import org.{tensorflow => tf}


Expand Down Expand Up @@ -36,6 +38,11 @@ class Shape private (private val ds: Array[DimType]) extends Serializable {

def prepend(x: Int): Shape = Shape(x.toLong +: ds)

/**
* Drops the most inner dimension of the shape.
*/
def dropInner: Shape = Shape(ds.dropRight(1))

/**
* A shape with the first dimension dropped.
*/
Expand Down Expand Up @@ -104,15 +111,16 @@ object Shape {

/**
* SparkTF information. This is the information generally required to work on a tensor.
* @param shape
* @param dataType
* @param shape the shape of the column (including the number of rows). May contain some unknowns.
* @param dataType the datatype of the scalar. Note that it is either NumericType or BinaryType.
*/
// TODO(tjh) the types supported by TF are much richer (uint8, etc.) but it is not clear
// if they all map to a Catalyst memory representation
// TODO(tjh) support later basic structures for sparse types?
case class SparkTFColInfo(
shape: Shape,
dataType: NumericType) extends Serializable
dataType: ScalarType) extends Serializable {
}

/**
* Exception thrown when the user requests tensors of high order.
Expand Down
11 changes: 7 additions & 4 deletions src/main/scala/org/tensorframes/dsl/DslImpl.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package org.tensorframes.dsl

import javax.annotation.Nullable

import org.tensorflow.framework.{AttrValue, DataType, GraphDef, TensorShapeProto}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.NumericType

import org.tensorframes.{Logging, ColumnInformation, Shape}
import org.tensorframes.impl.DenseTensor
import org.tensorframes.{ColumnInformation, Logging, Shape}
import org.tensorframes.impl.{DenseTensor, SupportedOperations}


/**
Expand Down Expand Up @@ -75,8 +76,9 @@ private[dsl] object DslImpl extends Logging with DefaultConversions {

def build_constant(dt: DenseTensor): Node = {
val a = AttrValue.newBuilder().setTensor(DenseTensor.toTensorProto(dt))
val dt2 = SupportedOperations.opsFor(dt.dtype).sqlType.asInstanceOf[NumericType]
build("Const", isOp = false,
shape = dt.shape, dtype = dt.dtype,
shape = dt.shape, dtype = dt2,
extraAttrs = Map("value" -> a.build()))
}

Expand All @@ -100,7 +102,8 @@ private[dsl] object DslImpl extends Logging with DefaultConversions {
s"tensorframes: $schema")
}
val shape = if (block) { stf.shape } else { stf.shape.tail }
DslImpl.placeholder(stf.dataType, shape).named(tfName)
val dt = SupportedOperations.opsFor(stf.dataType).sqlType.asInstanceOf[NumericType]
DslImpl.placeholder(dt, shape).named(tfName)
}

private def commonShape(shapes: Seq[Shape]): Shape = {
Expand Down
6 changes: 2 additions & 4 deletions src/main/scala/org/tensorframes/dsl/package.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package org.tensorframes

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.IntegerType

import org.apache.spark.sql.types.{IntegerType, NumericType}
import org.tensorframes.impl.SupportedOperations

/**
Expand Down Expand Up @@ -45,7 +43,7 @@ package object dsl {

def placeholder[T : Numeric : TypeTag](shape: Int*): Operation = {
val ops = SupportedOperations.getOps[T]()
DslImpl.placeholder(ops.sqlType, Shape(shape: _*))
DslImpl.placeholder(ops.sqlType.asInstanceOf[NumericType], Shape(shape: _*))
}

def constant[T : ConvertibleToDenseTensor](x: T): Operation = {
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/org/tensorframes/impl/DataOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import scala.reflect.ClassTag

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.{NumericType, StructType}
import org.apache.spark.sql.types.StructType

import org.tensorframes.{Logging, Shape}
import org.tensorframes.Shape.DimType
Expand Down Expand Up @@ -145,7 +145,7 @@ object DataOps extends Logging {

def getColumnFast0(
reshapeShape: Shape,
scalaType: NumericType,
scalaType: ScalarType,
allDataBuffer: mutable.WrappedArray[_]): Iterable[Any] = {
reshapeShape.dims match {
case Seq() =>
Expand Down
37 changes: 19 additions & 18 deletions src/main/scala/org/tensorframes/impl/DebugRowOps.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
package org.tensorframes.impl

import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}

import org.apache.commons.lang3.SerializationUtils
import org.tensorflow.framework.GraphDef
import org.tensorflow.{Session, Tensor}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row}
import org.tensorflow.framework.GraphDef
import org.tensorflow.{Session, Tensor}

import org.tensorframes._
import org.tensorframes.test.DslOperations

import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}

/**
* The different schemas required for the block reduction.
Expand Down Expand Up @@ -322,17 +325,17 @@ class DebugRowOps
throw new Exception(
s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe")
}
if (! stf.shape.checkMorePreciseThan(in.shape)) {
throw new Exception(
s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" +
s" ${in.shape} requested by the TF graph")
}
// We do not support autocasting for now.
if (stf.dataType != in.scalarType) {
throw new Exception(
s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " +
s"of the column (${in.scalarType})")
}
if (! stf.shape.checkMorePreciseThan(in.shape)) {
throw new Exception(
s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" +
s" ${in.shape} requested by the TF graph")
}
// The input has to be either a constant or a placeholder
if (! in.isPlaceholder) {
throw new Exception(
Expand Down Expand Up @@ -414,16 +417,16 @@ class DebugRowOps
val stf = get(ColumnInformation(f).stf,
s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe")

check(stf.dataType == in.scalarType,
s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " +
s"of the column (${in.scalarType})")

val cellShape = stf.shape.tail
// No check for unknowns: we allow unknowns in the first dimension of the cell shape.
check(cellShape.checkMorePreciseThan(in.shape),
s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" +
s" ${in.shape} requested by the TF graph")

check(stf.dataType == in.scalarType,
s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " +
s"of the column (${in.scalarType})")

check(in.isPlaceholder,
s"Invalid type for input node ${in.name}. It has to be a placeholder")
}
Expand Down Expand Up @@ -532,7 +535,8 @@ class DebugRowOps
val f = col.field
builder.append(s"$prefix-- ${f.name}: ${f.dataType.typeName} (nullable = ${f.nullable})")
val stf = col.stf.map { s =>
s" ${s.dataType.typeName}${s.shape}"
val dt = SupportedOperations.opsFor(s.dataType).sqlType
s" ${dt.typeName}${s.shape}"
} .getOrElse(" <no tensor info>")
builder.append(stf)
builder.append("\n")
Expand Down Expand Up @@ -725,9 +729,6 @@ object DebugRowOpsImpl extends Logging {
}
}

// Trying to get around some frequent crashes within TF.
private[this] val tfLock = new Object

private[impl] def reducePair(
schema: StructType,
gbc: Broadcast[SerializedGraph]): (Row, Row) => Row = {
Expand Down
20 changes: 12 additions & 8 deletions src/main/scala/org/tensorframes/impl/DenseTensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,31 @@ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, NumericTy
*/
private[tensorframes] class DenseTensor private(
val shape: Shape,
val dtype: NumericType,
val dtype: ScalarType,
private val data: Array[Byte]) {

override def toString(): String = s"DenseTensor($shape, $dtype, " +
s"${data.length / dtype.defaultSize} elements)"
s"${data.length} bytes)"
}

private[tensorframes] object DenseTensor {
def apply[T](x: T)(implicit ev2: TypeTag[T]): DenseTensor = {
val ops = SupportedOperations.getOps[T]()
new DenseTensor(Shape.empty, ops.sqlType, convert(x))
apply(Shape.empty, ops.sqlType.asInstanceOf[NumericType], convert(x))
}

def apply[T](xs: Seq[T])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = {
val ops = SupportedOperations.getOps[T]()
new DenseTensor(Shape(xs.size), ops.sqlType, convert1(xs))
apply(Shape(xs.size), ops.sqlType.asInstanceOf[NumericType], convert1(xs))
}

def apply(shape: Shape, dtype: NumericType, data: Array[Byte]): DenseTensor = {
new DenseTensor(shape, SupportedOperations.opsFor(dtype).scalarType, data)
}

def matrix[T](xs: Seq[Seq[T]])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = {
val ops = SupportedOperations.getOps[T]()
new DenseTensor(Shape(xs.size, xs.head.size), ops.sqlType, convert2(xs))
apply(Shape(xs.size, xs.head.size), ops.sqlType.asInstanceOf[NumericType], convert2(xs))
}

private def convert[T](x: T)(implicit ev2: TypeTag[T]): Array[Byte] = {
Expand Down Expand Up @@ -98,15 +102,15 @@ private[tensorframes] object DenseTensor {
val shape = Shape.from(proto.getTensorShape)
val data = ops.sqlType match {
case DoubleType =>
val coll = proto.getDoubleValList.asScala.toSeq.map(_.doubleValue())
val coll = proto.getDoubleValList.asScala.map(_.doubleValue())
convert(coll)
case IntegerType =>
val coll = proto.getIntValList.asScala.toSeq.map(_.intValue())
val coll = proto.getIntValList.asScala.map(_.intValue())
convert(coll)
case _ =>
throw new IllegalArgumentException(
s"Cannot convert type ${ops.sqlType}")
}
new DenseTensor(shape, ops.sqlType, data)
new DenseTensor(shape, ops.scalarType, data)
}
}
Loading

0 comments on commit b393bf3

Please sign in to comment.