Spark目前已经内置的函数参见:
Spark 1.5 DataFrame API Highlights: Date/Time/String Handling, Time Intervals, and UDAFs
如果在SPARK函数中使用UDF或UDAF, 详见示例
package cn.com.systex
import scala.reflect.runtime.universe
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.functions.callUDF
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import java.sql.Timestamp
import java.sql.Date
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.DateType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructField
case class DateRange(startDate: Timestamp, endDate: Timestamp) {
def in(targetDate: Date): Boolean = {
targetDate.before(endDate) && targetDate.after(startDate)
}
override def toString(): String = {
startDate.toLocaleString() + " " + endDate.toLocaleString();
}
}
class YearOnYearCompare(current: DateRange) extends UserDefinedAggregateFunction {
val previous: DateRange = DateRange(subtractOneYear(current.startDate), subtractOneYear(current.endDate))
println(current)
println(previous)
def inputSchema: StructType = {
StructType(StructField("metric", DoubleType) :: StructField("timeCategory", DateType) :: Nil)
}
def bufferSchema: StructType = {
StructType(StructField("sumOfCurrent", DoubleType) :: StructField("sumOfPrevious", DoubleType) :: Nil)
}
def dataType: org.apache.spark.sql.types.DataType = DoubleType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0)
buffer.update(1, 0.0)
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (current.in(input.getAs[Date](1))) {
buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)
}
if (previous.in(input.getAs[Date](1))) {
buffer(1) = buffer.getAs[Double](0) + input.getAs[Double](0)
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}
def evaluate(buffer: Row): Any = {
if (buffer.getDouble(1) == 0.0) {
0.0
} else {
(buffer.getDouble(0) - buffer.getDouble(1)) / buffer.getDouble(1) * 100
}
}
private def subtractOneYear(date: Timestamp): Timestamp = {
val prev = new Timestamp(date.getTime)
prev.setYear(prev.getYear - 1)
prev
}
}
object SimpleDemo {
def main(args: Array[String]): Unit = {
val dir = "D:/Program/spark/examples/src/main/resources/";
val sc = new SparkContext(new SparkConf().setMaster("local[4]").setAppName("sqltest"))
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._
val df = sqlContext.createDataFrame(Seq(
(1, "张三峰", "广东 广州 天河", 24),
(2, "李四", "广东 广州 东山", 36),
(3, "王五", "广东 广州 越秀", 48),
(4, "赵六", "广东 广州 海珠", 29))).toDF("id", "name", "addr", "age")
def splitAddrFunc: String => Seq[String] = {
_.toLowerCase.split("\\s")
}
val longLength = udf((str: String, length: Int) => str.length > length)
val len = udf((str: String) => str.length)
val df2 = df.withColumn("addr-ex", callUDF(splitAddrFunc, new ArrayType(StringType, true), df("addr")))
val df3 = df2.withColumn("name-len", len($"name")).filter(longLength($"name", lit(2)))
println("打印DF Schema及数据处理结果")
df.printSchema()
df3.printSchema()
df3.foreach { println }
def slen(str: String): Int = str.length
def slengthLongerThan(str: String, length: Int): Boolean = str.length > length
sqlContext.udf.register("len", slen _)
sqlContext.udf.register("longLength", slengthLongerThan _)
df.registerTempTable("user")
println("打印SQL语句执行结果")
sqlContext.sql("select name,len(name) from user where longLength(name,2)").foreach(println)
println("打印数据过滤结果")
df.filter("longLength(name,2)").foreach(println)
val salesDF = sqlContext.createDataFrame(Seq(
(1, "Widget Co", 1000.00, 0.00, "AZ", "2014-01-02"),
(2, "Acme Widgets", 2000.00, 500.00, "CA", "2014-02-01"),
(3, "Widgetry", 1000.00, 200.00, "CA", "2015-01-11"),
(4, "Widgets R Us", 5000.00, 0.0, "CA", "2015-02-19"),
(5, "Ye Olde Widgete", 4200.00, 0.0, "MA", "2015-02-18"))).toDF("id", "name", "sales", "discount", "state", "saleDate")
salesDF.registerTempTable("sales")
val current = DateRange(Timestamp.valueOf("2015-01-01 00:00:00"), Timestamp.valueOf("2015-12-31 00:00:00"))
val yearOnYear = new YearOnYearCompare(current)
sqlContext.udf.register("yearOnYear", yearOnYear)
val dataFrame = sqlContext.sql("select yearOnYear(sales, saleDate) as yearOnYear from sales")
salesDF.printSchema()
dataFrame.show()
}
}