Revert "HBASE-15572 Adding optional timestamp semantics to HBase-Spark (Weiqing Yang)"
[hbase.git] / hbase-spark / src / main / scala / org / apache / hadoop / hbase / spark / DefaultSource.scala
blob7970816e9bdebdbab205a080cdaa167c8c41ec49
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org.apache.hadoop.hbase.spark
20 import java.util
21 import java.util.concurrent.ConcurrentLinkedQueue
23 import org.apache.hadoop.hbase.client._
24 import org.apache.hadoop.hbase.io.ImmutableBytesWritable
25 import org.apache.hadoop.hbase.mapred.TableOutputFormat
26 import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
27 import org.apache.hadoop.hbase.spark.datasources.HBaseTableScanRDD
28 import org.apache.hadoop.hbase.spark.datasources.SerializableConfiguration
29 import org.apache.hadoop.hbase.types._
30 import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange}
31 import org.apache.hadoop.hbase._
32 import org.apache.hadoop.mapred.JobConf
33 import org.apache.spark.Logging
34 import org.apache.spark.rdd.RDD
35 import org.apache.spark.sql.datasources.hbase.{Utils, Field, HBaseTableCatalog}
36 import org.apache.spark.sql.{DataFrame, SaveMode, Row, SQLContext}
37 import org.apache.spark.sql.sources._
38 import org.apache.spark.sql.types._
40 import scala.collection.mutable
42 /**
43 * DefaultSource for integration with Spark's dataframe datasources.
44 * This class will produce a relationProvider based on input given to it from spark
46 * In all this DefaultSource support the following datasource functionality
47 * - Scan range pruning through filter push down logic based on rowKeys
48 * - Filter push down logic on HBase Cells
49 * - Qualifier filtering based on columns used in the SparkSQL statement
50 * - Type conversions of basic SQL types. All conversions will be
51 * Through the HBase Bytes object commands.
53 class DefaultSource extends RelationProvider with CreatableRelationProvider with Logging {
54 /**
55 * Is given input from SparkSQL to construct a BaseRelation
57 * @param sqlContext SparkSQL context
58 * @param parameters Parameters given to us from SparkSQL
59 * @return A BaseRelation Object
61 override def createRelation(sqlContext: SQLContext,
62 parameters: Map[String, String]):
63 BaseRelation = {
64 new HBaseRelation(parameters, None)(sqlContext)
68 override def createRelation(
69 sqlContext: SQLContext,
70 mode: SaveMode,
71 parameters: Map[String, String],
72 data: DataFrame): BaseRelation = {
73 val relation = HBaseRelation(parameters, Some(data.schema))(sqlContext)
74 relation.createTable()
75 relation.insert(data, false)
76 relation
80 /**
81 * Implementation of Spark BaseRelation that will build up our scan logic
82 * , do the scan pruning, filter push down, and value conversions
84 * @param sqlContext SparkSQL context
86 case class HBaseRelation (
87 @transient parameters: Map[String, String],
88 userSpecifiedSchema: Option[StructType]
89 )(@transient val sqlContext: SQLContext)
90 extends BaseRelation with PrunedFilteredScan with InsertableRelation with Logging {
91 val catalog = HBaseTableCatalog(parameters)
92 def tableName = catalog.name
93 val configResources = parameters.getOrElse(HBaseSparkConf.HBASE_CONFIG_RESOURCES_LOCATIONS, "")
94 val useHBaseContext = parameters.get(HBaseSparkConf.USE_HBASE_CONTEXT).map(_.toBoolean).getOrElse(true)
95 val usePushDownColumnFilter = parameters.get(HBaseSparkConf.PUSH_DOWN_COLUMN_FILTER)
96 .map(_.toBoolean).getOrElse(true)
98 // The user supplied per table parameter will overwrite global ones in SparkConf
99 val blockCacheEnable = parameters.get(HBaseSparkConf.BLOCK_CACHE_ENABLE).map(_.toBoolean)
100 .getOrElse(
101 sqlContext.sparkContext.getConf.getBoolean(
102 HBaseSparkConf.BLOCK_CACHE_ENABLE, HBaseSparkConf.defaultBlockCacheEnable))
103 val cacheSize = parameters.get(HBaseSparkConf.CACHE_SIZE).map(_.toInt)
104 .getOrElse(
105 sqlContext.sparkContext.getConf.getInt(
106 HBaseSparkConf.CACHE_SIZE, HBaseSparkConf.defaultCachingSize))
107 val batchNum = parameters.get(HBaseSparkConf.BATCH_NUM).map(_.toInt)
108 .getOrElse(sqlContext.sparkContext.getConf.getInt(
109 HBaseSparkConf.BATCH_NUM, HBaseSparkConf.defaultBatchNum))
111 val bulkGetSize = parameters.get(HBaseSparkConf.BULKGET_SIZE).map(_.toInt)
112 .getOrElse(sqlContext.sparkContext.getConf.getInt(
113 HBaseSparkConf.BULKGET_SIZE, HBaseSparkConf.defaultBulkGetSize))
115 //create or get latest HBaseContext
116 val hbaseContext:HBaseContext = if (useHBaseContext) {
117 LatestHBaseContextCache.latest
118 } else {
119 val config = HBaseConfiguration.create()
120 configResources.split(",").foreach( r => config.addResource(r))
121 new HBaseContext(sqlContext.sparkContext, config)
124 val wrappedConf = new SerializableConfiguration(hbaseContext.config)
125 def hbaseConf = wrappedConf.value
128 * Generates a Spark SQL schema objeparametersct so Spark SQL knows what is being
129 * provided by this BaseRelation
131 * @return schema generated from the SCHEMA_COLUMNS_MAPPING_KEY value
133 override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType)
137 def createTable() {
138 val numReg = parameters.get(HBaseTableCatalog.newTable).map(x => x.toInt).getOrElse(0)
139 val startKey = Bytes.toBytes(
140 parameters.get(HBaseTableCatalog.regionStart)
141 .getOrElse(HBaseTableCatalog.defaultRegionStart))
142 val endKey = Bytes.toBytes(
143 parameters.get(HBaseTableCatalog.regionEnd)
144 .getOrElse(HBaseTableCatalog.defaultRegionEnd))
145 if (numReg > 3) {
146 val tName = TableName.valueOf(catalog.name)
147 val cfs = catalog.getColumnFamilies
148 val connection = ConnectionFactory.createConnection(hbaseConf)
149 // Initialize hBase table if necessary
150 val admin = connection.getAdmin()
151 try {
152 if (!admin.isTableAvailable(tName)) {
153 val tableDesc = new HTableDescriptor(tName)
154 cfs.foreach { x =>
155 val cf = new HColumnDescriptor(x.getBytes())
156 logDebug(s"add family $x to ${catalog.name}")
157 tableDesc.addFamily(cf)
159 val splitKeys = Bytes.split(startKey, endKey, numReg);
160 admin.createTable(tableDesc, splitKeys)
163 }finally {
164 admin.close()
165 connection.close()
167 } else {
168 logInfo(
169 s"""${HBaseTableCatalog.newTable}
170 |is not defined or no larger than 3, skip the create table""".stripMargin)
176 * @param data
177 * @param overwrite
179 override def insert(data: DataFrame, overwrite: Boolean): Unit = {
180 val jobConfig: JobConf = new JobConf(hbaseConf, this.getClass)
181 jobConfig.setOutputFormat(classOf[TableOutputFormat])
182 jobConfig.set(TableOutputFormat.OUTPUT_TABLE, catalog.name)
183 var count = 0
184 val rkFields = catalog.getRowKey
185 val rkIdxedFields = rkFields.map{ case x =>
186 (schema.fieldIndex(x.colName), x)
188 val colsIdxedFields = schema
189 .fieldNames
190 .partition( x => rkFields.map(_.colName).contains(x))
191 ._2.map(x => (schema.fieldIndex(x), catalog.getField(x)))
192 val rdd = data.rdd
193 def convertToPut(row: Row) = {
194 // construct bytes for row key
195 val rowBytes = rkIdxedFields.map { case (x, y) =>
196 Utils.toBytes(row(x), y)
198 val rLen = rowBytes.foldLeft(0) { case (x, y) =>
199 x + y.length
201 val rBytes = new Array[Byte](rLen)
202 var offset = 0
203 rowBytes.foreach { x =>
204 System.arraycopy(x, 0, rBytes, offset, x.length)
205 offset += x.length
207 val put = new Put(rBytes)
209 colsIdxedFields.foreach { case (x, y) =>
210 val b = Utils.toBytes(row(x), y)
211 put.addColumn(Bytes.toBytes(y.cf), Bytes.toBytes(y.col), b)
213 count += 1
214 (new ImmutableBytesWritable, put)
216 rdd.map(convertToPut(_)).saveAsHadoopDataset(jobConfig)
219 def getIndexedProjections(requiredColumns: Array[String]): Seq[(Field, Int)] = {
220 requiredColumns.map(catalog.sMap.getField(_)).zipWithIndex
225 * Takes a HBase Row object and parses all of the fields from it.
226 * This is independent of which fields were requested from the key
227 * Because we have all the data it's less complex to parse everything.
229 * @param row the retrieved row from hbase.
230 * @param keyFields all of the fields in the row key, ORDERED by their order in the row key.
232 def parseRowKey(row: Array[Byte], keyFields: Seq[Field]): Map[Field, Any] = {
233 keyFields.foldLeft((0, Seq[(Field, Any)]()))((state, field) => {
234 val idx = state._1
235 val parsed = state._2
236 if (field.length != -1) {
237 val value = Utils.hbaseFieldToScalaType(field, row, idx, field.length)
238 // Return the new index and appended value
239 (idx + field.length, parsed ++ Seq((field, value)))
240 } else {
241 field.dt match {
242 case StringType =>
243 val pos = row.indexOf(HBaseTableCatalog.delimiter, idx)
244 if (pos == -1 || pos > row.length) {
245 // this is at the last dimension
246 val value = Utils.hbaseFieldToScalaType(field, row, idx, row.length)
247 (row.length + 1, parsed ++ Seq((field, value)))
248 } else {
249 val value = Utils.hbaseFieldToScalaType(field, row, idx, pos - idx)
250 (pos, parsed ++ Seq((field, value)))
252 // We don't know the length, assume it extends to the end of the rowkey.
253 case _ => (row.length + 1, parsed ++ Seq((field, Utils.hbaseFieldToScalaType(field, row, idx, row.length))))
256 })._2.toMap
259 def buildRow(fields: Seq[Field], result: Result): Row = {
260 val r = result.getRow
261 val keySeq = parseRowKey(r, catalog.getRowKey)
262 val valueSeq = fields.filter(!_.isRowKey).map { x =>
263 val kv = result.getColumnLatestCell(Bytes.toBytes(x.cf), Bytes.toBytes(x.col))
264 if (kv == null || kv.getValueLength == 0) {
265 (x, null)
266 } else {
267 val v = CellUtil.cloneValue(kv)
268 (x, Utils.hbaseFieldToScalaType(x, v, 0, v.length))
270 }.toMap
271 val unionedRow = keySeq ++ valueSeq
272 // Return the row ordered by the requested order
273 Row.fromSeq(fields.map(unionedRow.get(_).getOrElse(null)))
277 * Here we are building the functionality to populate the resulting RDD[Row]
278 * Here is where we will do the following:
279 * - Filter push down
280 * - Scan or GetList pruning
281 * - Executing our scan(s) or/and GetList to generate result
283 * @param requiredColumns The columns that are being requested by the requesting query
284 * @param filters The filters that are being applied by the requesting query
285 * @return RDD will all the results from HBase needed for SparkSQL to
286 * execute the query on
288 override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
290 val pushDownTuple = buildPushDownPredicatesResource(filters)
291 val pushDownRowKeyFilter = pushDownTuple._1
292 var pushDownDynamicLogicExpression = pushDownTuple._2
293 val valueArray = pushDownTuple._3
295 if (!usePushDownColumnFilter) {
296 pushDownDynamicLogicExpression = null
299 logDebug("pushDownRowKeyFilter: " + pushDownRowKeyFilter.ranges)
300 if (pushDownDynamicLogicExpression != null) {
301 logDebug("pushDownDynamicLogicExpression: " +
302 pushDownDynamicLogicExpression.toExpressionString)
304 logDebug("valueArray: " + valueArray.length)
306 val requiredQualifierDefinitionList =
307 new mutable.MutableList[Field]
309 requiredColumns.foreach( c => {
310 val field = catalog.getField(c)
311 requiredQualifierDefinitionList += field
314 //retain the information for unit testing checks
315 DefaultSourceStaticUtils.populateLatestExecutionRules(pushDownRowKeyFilter,
316 pushDownDynamicLogicExpression)
318 val getList = new util.ArrayList[Get]()
319 val rddList = new util.ArrayList[RDD[Row]]()
321 //add points to getList
322 pushDownRowKeyFilter.points.foreach(p => {
323 val get = new Get(p)
324 requiredQualifierDefinitionList.foreach( d => {
325 if (d.isRowKey)
326 get.addColumn(d.cfBytes, d.colBytes)
328 getList.add(get)
331 val pushDownFilterJava = if (usePushDownColumnFilter && pushDownDynamicLogicExpression != null) {
332 Some(new SparkSQLPushDownFilter(pushDownDynamicLogicExpression,
333 valueArray, requiredQualifierDefinitionList))
334 } else {
335 None
337 val hRdd = new HBaseTableScanRDD(this, hbaseContext, pushDownFilterJava, requiredQualifierDefinitionList.seq)
338 pushDownRowKeyFilter.points.foreach(hRdd.addPoint(_))
339 pushDownRowKeyFilter.ranges.foreach(hRdd.addRange(_))
341 var resultRDD: RDD[Row] = {
342 val tmp = hRdd.map{ r =>
343 val indexedFields = getIndexedProjections(requiredColumns).map(_._1)
344 buildRow(indexedFields, r)
347 if (tmp.partitions.size > 0) {
349 } else {
350 null
354 if (resultRDD == null) {
355 val scan = new Scan()
356 scan.setCacheBlocks(blockCacheEnable)
357 scan.setBatch(batchNum)
358 scan.setCaching(cacheSize)
359 requiredQualifierDefinitionList.foreach( d =>
360 scan.addColumn(d.cfBytes, d.colBytes))
362 val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => {
363 val indexedFields = getIndexedProjections(requiredColumns).map(_._1)
364 buildRow(indexedFields, r._2)
366 resultRDD=rdd
368 resultRDD
371 def buildPushDownPredicatesResource(filters: Array[Filter]):
372 (RowKeyFilter, DynamicLogicExpression, Array[Array[Byte]]) = {
373 var superRowKeyFilter:RowKeyFilter = null
374 val queryValueList = new mutable.MutableList[Array[Byte]]
375 var superDynamicLogicExpression: DynamicLogicExpression = null
377 filters.foreach( f => {
378 val rowKeyFilter = new RowKeyFilter()
379 val logicExpression = transverseFilterTree(rowKeyFilter, queryValueList, f)
380 if (superDynamicLogicExpression == null) {
381 superDynamicLogicExpression = logicExpression
382 superRowKeyFilter = rowKeyFilter
383 } else {
384 superDynamicLogicExpression =
385 new AndLogicExpression(superDynamicLogicExpression, logicExpression)
386 superRowKeyFilter.mergeIntersect(rowKeyFilter)
391 val queryValueArray = queryValueList.toArray
393 if (superRowKeyFilter == null) {
394 superRowKeyFilter = new RowKeyFilter
397 (superRowKeyFilter, superDynamicLogicExpression, queryValueArray)
400 def transverseFilterTree(parentRowKeyFilter:RowKeyFilter,
401 valueArray:mutable.MutableList[Array[Byte]],
402 filter:Filter): DynamicLogicExpression = {
403 filter match {
405 case EqualTo(attr, value) =>
406 val field = catalog.getField(attr)
407 if (field != null) {
408 if (field.isRowKey) {
409 parentRowKeyFilter.mergeIntersect(new RowKeyFilter(
410 DefaultSourceStaticUtils.getByteValue(field,
411 value.toString), null))
413 val byteValue =
414 DefaultSourceStaticUtils.getByteValue(field, value.toString)
415 valueArray += byteValue
417 new EqualLogicExpression(attr, valueArray.length - 1, false)
418 case LessThan(attr, value) =>
419 val field = catalog.getField(attr)
420 if (field != null) {
421 if (field.isRowKey) {
422 parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null,
423 new ScanRange(DefaultSourceStaticUtils.getByteValue(field,
424 value.toString), false,
425 new Array[Byte](0), true)))
427 val byteValue =
428 DefaultSourceStaticUtils.getByteValue(catalog.getField(attr),
429 value.toString)
430 valueArray += byteValue
432 new LessThanLogicExpression(attr, valueArray.length - 1)
433 case GreaterThan(attr, value) =>
434 val field = catalog.getField(attr)
435 if (field != null) {
436 if (field.isRowKey) {
437 parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null,
438 new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field,
439 value.toString), false)))
441 val byteValue =
442 DefaultSourceStaticUtils.getByteValue(field,
443 value.toString)
444 valueArray += byteValue
446 new GreaterThanLogicExpression(attr, valueArray.length - 1)
447 case LessThanOrEqual(attr, value) =>
448 val field = catalog.getField(attr)
449 if (field != null) {
450 if (field.isRowKey) {
451 parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null,
452 new ScanRange(DefaultSourceStaticUtils.getByteValue(field,
453 value.toString), true,
454 new Array[Byte](0), true)))
456 val byteValue =
457 DefaultSourceStaticUtils.getByteValue(catalog.getField(attr),
458 value.toString)
459 valueArray += byteValue
461 new LessThanOrEqualLogicExpression(attr, valueArray.length - 1)
462 case GreaterThanOrEqual(attr, value) =>
463 val field = catalog.getField(attr)
464 if (field != null) {
465 if (field.isRowKey) {
466 parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null,
467 new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field,
468 value.toString), true)))
470 val byteValue =
471 DefaultSourceStaticUtils.getByteValue(catalog.getField(attr),
472 value.toString)
473 valueArray += byteValue
476 new GreaterThanOrEqualLogicExpression(attr, valueArray.length - 1)
477 case Or(left, right) =>
478 val leftExpression = transverseFilterTree(parentRowKeyFilter, valueArray, left)
479 val rightSideRowKeyFilter = new RowKeyFilter
480 val rightExpression = transverseFilterTree(rightSideRowKeyFilter, valueArray, right)
482 parentRowKeyFilter.mergeUnion(rightSideRowKeyFilter)
484 new OrLogicExpression(leftExpression, rightExpression)
485 case And(left, right) =>
487 val leftExpression = transverseFilterTree(parentRowKeyFilter, valueArray, left)
488 val rightSideRowKeyFilter = new RowKeyFilter
489 val rightExpression = transverseFilterTree(rightSideRowKeyFilter, valueArray, right)
490 parentRowKeyFilter.mergeIntersect(rightSideRowKeyFilter)
492 new AndLogicExpression(leftExpression, rightExpression)
493 case IsNull(attr) =>
494 new IsNullLogicExpression(attr, false)
495 case IsNotNull(attr) =>
496 new IsNullLogicExpression(attr, true)
497 case _ =>
498 new PassThroughLogicExpression
504 * Construct to contain a single scan ranges information. Also
505 * provide functions to merge with other scan ranges through AND
506 * or OR operators
508 * @param upperBound Upper bound of scan
509 * @param isUpperBoundEqualTo Include upper bound value in the results
510 * @param lowerBound Lower bound of scan
511 * @param isLowerBoundEqualTo Include lower bound value in the results
513 class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
514 var lowerBound:Array[Byte], var isLowerBoundEqualTo:Boolean)
515 extends Serializable {
518 * Function to merge another scan object through a AND operation
520 * @param other Other scan object
522 def mergeIntersect(other:ScanRange): Unit = {
523 val upperBoundCompare = compareRange(upperBound, other.upperBound)
524 val lowerBoundCompare = compareRange(lowerBound, other.lowerBound)
526 upperBound = if (upperBoundCompare <0) upperBound else other.upperBound
527 lowerBound = if (lowerBoundCompare >0) lowerBound else other.lowerBound
529 isLowerBoundEqualTo = if (lowerBoundCompare == 0)
530 isLowerBoundEqualTo && other.isLowerBoundEqualTo
531 else isLowerBoundEqualTo
533 isUpperBoundEqualTo = if (upperBoundCompare == 0)
534 isUpperBoundEqualTo && other.isUpperBoundEqualTo
535 else isUpperBoundEqualTo
539 * Function to merge another scan object through a OR operation
541 * @param other Other scan object
543 def mergeUnion(other:ScanRange): Unit = {
545 val upperBoundCompare = compareRange(upperBound, other.upperBound)
546 val lowerBoundCompare = compareRange(lowerBound, other.lowerBound)
548 upperBound = if (upperBoundCompare >0) upperBound else other.upperBound
549 lowerBound = if (lowerBoundCompare <0) lowerBound else other.lowerBound
551 isLowerBoundEqualTo = if (lowerBoundCompare == 0)
552 isLowerBoundEqualTo || other.isLowerBoundEqualTo
553 else if (lowerBoundCompare < 0) isLowerBoundEqualTo else other.isLowerBoundEqualTo
555 isUpperBoundEqualTo = if (upperBoundCompare == 0)
556 isUpperBoundEqualTo || other.isUpperBoundEqualTo
557 else if (upperBoundCompare < 0) other.isUpperBoundEqualTo else isUpperBoundEqualTo
561 * Common function to see if this scan over laps with another
563 * Reference Visual
565 * A B
566 * |---------------------------|
567 * LL--------------LU
568 * RL--------------RU
570 * A = lowest value is byte[0]
571 * B = highest value is null
572 * LL = Left Lower Bound
573 * LU = Left Upper Bound
574 * RL = Right Lower Bound
575 * RU = Right Upper Bound
577 * @param other Other scan object
578 * @return True is overlap false is not overlap
580 def getOverLapScanRange(other:ScanRange): ScanRange = {
582 var leftRange:ScanRange = null
583 var rightRange:ScanRange = null
585 //First identify the Left range
586 // Also lower bound can't be null
587 if (compareRange(lowerBound, other.lowerBound) < 0 ||
588 compareRange(upperBound, other.upperBound) < 0) {
589 leftRange = this
590 rightRange = other
591 } else {
592 leftRange = other
593 rightRange = this
596 //Then see if leftRange goes to null or if leftRange.upperBound
597 // upper is greater or equals to rightRange.lowerBound
598 if (leftRange.upperBound == null ||
599 Bytes.compareTo(leftRange.upperBound, rightRange.lowerBound) >= 0) {
600 new ScanRange(leftRange.upperBound, leftRange.isUpperBoundEqualTo, rightRange.lowerBound, rightRange.isLowerBoundEqualTo)
601 } else {
602 null
607 * Special compare logic because we can have null values
608 * for left or right bound
610 * @param left Left byte array
611 * @param right Right byte array
612 * @return 0 for equals 1 is left is greater and -1 is right is greater
614 def compareRange(left:Array[Byte], right:Array[Byte]): Int = {
615 if (left == null && right == null) 0
616 else if (left == null && right != null) 1
617 else if (left != null && right == null) -1
618 else Bytes.compareTo(left, right)
623 * @return
625 def containsPoint(point:Array[Byte]): Boolean = {
626 val lowerCompare = compareRange(point, lowerBound)
627 val upperCompare = compareRange(point, upperBound)
629 ((isLowerBoundEqualTo && lowerCompare >= 0) ||
630 (!isLowerBoundEqualTo && lowerCompare > 0)) &&
631 ((isUpperBoundEqualTo && upperCompare <= 0) ||
632 (!isUpperBoundEqualTo && upperCompare < 0))
635 override def toString:String = {
636 "ScanRange:(upperBound:" + Bytes.toString(upperBound) +
637 ",isUpperBoundEqualTo:" + isUpperBoundEqualTo + ",lowerBound:" +
638 Bytes.toString(lowerBound) + ",isLowerBoundEqualTo:" + isLowerBoundEqualTo + ")"
643 * Contains information related to a filters for a given column.
644 * This can contain many ranges or points.
646 * @param currentPoint the initial point when the filter is created
647 * @param currentRange the initial scanRange when the filter is created
649 class ColumnFilter (currentPoint:Array[Byte] = null,
650 currentRange:ScanRange = null,
651 var points:mutable.MutableList[Array[Byte]] =
652 new mutable.MutableList[Array[Byte]](),
653 var ranges:mutable.MutableList[ScanRange] =
654 new mutable.MutableList[ScanRange]() ) extends Serializable {
655 //Collection of ranges
656 if (currentRange != null ) ranges.+=(currentRange)
658 //Collection of points
659 if (currentPoint != null) points.+=(currentPoint)
662 * This will validate a give value through the filter's points and/or ranges
663 * the result will be if the value passed the filter
665 * @param value Value to be validated
666 * @param valueOffSet The offset of the value
667 * @param valueLength The length of the value
668 * @return True is the value passes the filter false if not
670 def validate(value:Array[Byte], valueOffSet:Int, valueLength:Int):Boolean = {
671 var result = false
673 points.foreach( p => {
674 if (Bytes.equals(p, 0, p.length, value, valueOffSet, valueLength)) {
675 result = true
679 ranges.foreach( r => {
680 val upperBoundPass = r.upperBound == null ||
681 (r.isUpperBoundEqualTo &&
682 Bytes.compareTo(r.upperBound, 0, r.upperBound.length,
683 value, valueOffSet, valueLength) >= 0) ||
684 (!r.isUpperBoundEqualTo &&
685 Bytes.compareTo(r.upperBound, 0, r.upperBound.length,
686 value, valueOffSet, valueLength) > 0)
688 val lowerBoundPass = r.lowerBound == null || r.lowerBound.length == 0
689 (r.isLowerBoundEqualTo &&
690 Bytes.compareTo(r.lowerBound, 0, r.lowerBound.length,
691 value, valueOffSet, valueLength) <= 0) ||
692 (!r.isLowerBoundEqualTo &&
693 Bytes.compareTo(r.lowerBound, 0, r.lowerBound.length,
694 value, valueOffSet, valueLength) < 0)
696 result = result || (upperBoundPass && lowerBoundPass)
698 result
702 * This will allow us to merge filter logic that is joined to the existing filter
703 * through a OR operator
705 * @param other Filter to merge
707 def mergeUnion(other:ColumnFilter): Unit = {
708 other.points.foreach( p => points += p)
710 other.ranges.foreach( otherR => {
711 var doesOverLap = false
712 ranges.foreach{ r =>
713 if (r.getOverLapScanRange(otherR) != null) {
714 r.mergeUnion(otherR)
715 doesOverLap = true
717 if (!doesOverLap) ranges.+=(otherR)
722 * This will allow us to merge filter logic that is joined to the existing filter
723 * through a AND operator
725 * @param other Filter to merge
727 def mergeIntersect(other:ColumnFilter): Unit = {
728 val survivingPoints = new mutable.MutableList[Array[Byte]]()
729 points.foreach( p => {
730 other.points.foreach( otherP => {
731 if (Bytes.equals(p, otherP)) {
732 survivingPoints.+=(p)
736 points = survivingPoints
738 val survivingRanges = new mutable.MutableList[ScanRange]()
740 other.ranges.foreach( otherR => {
741 ranges.foreach( r => {
742 if (r.getOverLapScanRange(otherR) != null) {
743 r.mergeIntersect(otherR)
744 survivingRanges += r
748 ranges = survivingRanges
751 override def toString:String = {
752 val strBuilder = new StringBuilder
753 strBuilder.append("(points:(")
754 var isFirst = true
755 points.foreach( p => {
756 if (isFirst) isFirst = false
757 else strBuilder.append(",")
758 strBuilder.append(Bytes.toString(p))
760 strBuilder.append("),ranges:")
761 isFirst = true
762 ranges.foreach( r => {
763 if (isFirst) isFirst = false
764 else strBuilder.append(",")
765 strBuilder.append(r)
767 strBuilder.append("))")
768 strBuilder.toString()
773 * A collection of ColumnFilters indexed by column names.
775 * Also contains merge commends that will consolidate the filters
776 * per column name
778 class ColumnFilterCollection {
779 val columnFilterMap = new mutable.HashMap[String, ColumnFilter]
781 def clear(): Unit = {
782 columnFilterMap.clear()
786 * This will allow us to merge filter logic that is joined to the existing filter
787 * through a OR operator. This will merge a single columns filter
789 * @param column The column to be merged
790 * @param other The other ColumnFilter object to merge
792 def mergeUnion(column:String, other:ColumnFilter): Unit = {
793 val existingFilter = columnFilterMap.get(column)
794 if (existingFilter.isEmpty) {
795 columnFilterMap.+=((column, other))
796 } else {
797 existingFilter.get.mergeUnion(other)
802 * This will allow us to merge all filters in the existing collection
803 * to the filters in the other collection. All merges are done as a result
804 * of a OR operator
806 * @param other The other Column Filter Collection to be merged
808 def mergeUnion(other:ColumnFilterCollection): Unit = {
809 other.columnFilterMap.foreach( e => {
810 mergeUnion(e._1, e._2)
815 * This will allow us to merge all filters in the existing collection
816 * to the filters in the other collection. All merges are done as a result
817 * of a AND operator
819 * @param other The column filter from the other collection
821 def mergeIntersect(other:ColumnFilterCollection): Unit = {
822 other.columnFilterMap.foreach( e => {
823 val existingColumnFilter = columnFilterMap.get(e._1)
824 if (existingColumnFilter.isEmpty) {
825 columnFilterMap += e
826 } else {
827 existingColumnFilter.get.mergeIntersect(e._2)
832 override def toString:String = {
833 val strBuilder = new StringBuilder
834 columnFilterMap.foreach( e => strBuilder.append(e))
835 strBuilder.toString()
840 * Status object to store static functions but also to hold last executed
841 * information that can be used for unit testing.
843 object DefaultSourceStaticUtils {
845 val rawInteger = new RawInteger
846 val rawLong = new RawLong
847 val rawFloat = new RawFloat
848 val rawDouble = new RawDouble
849 val rawString = RawString.ASCENDING
851 val byteRange = new ThreadLocal[PositionedByteRange] {
852 override def initialValue(): PositionedByteRange = {
853 val range = new SimplePositionedMutableByteRange()
854 range.setOffset(0)
855 range.setPosition(0)
859 def getFreshByteRange(bytes: Array[Byte]): PositionedByteRange = {
860 getFreshByteRange(bytes, 0, bytes.length)
863 def getFreshByteRange(bytes: Array[Byte], offset: Int = 0, length: Int):
864 PositionedByteRange = {
865 byteRange.get().set(bytes).setLength(length).setOffset(offset)
868 //This will contain the last 5 filters and required fields used in buildScan
869 // These values can be used in unit testing to make sure we are converting
870 // The Spark SQL input correctly
871 val lastFiveExecutionRules =
872 new ConcurrentLinkedQueue[ExecutionRuleForUnitTesting]()
875 * This method is to populate the lastFiveExecutionRules for unit test perposes
876 * This method is not thread safe.
878 * @param rowKeyFilter The rowKey Filter logic used in the last query
879 * @param dynamicLogicExpression The dynamicLogicExpression used in the last query
881 def populateLatestExecutionRules(rowKeyFilter: RowKeyFilter,
882 dynamicLogicExpression: DynamicLogicExpression): Unit = {
883 lastFiveExecutionRules.add(new ExecutionRuleForUnitTesting(
884 rowKeyFilter, dynamicLogicExpression))
885 while (lastFiveExecutionRules.size() > 5) {
886 lastFiveExecutionRules.poll()
891 * This method will convert the result content from HBase into the
892 * SQL value type that is requested by the Spark SQL schema definition
894 * @param field The structure of the SparkSQL Column
895 * @param r The result object from HBase
896 * @return The converted object type
898 def getValue(field: Field,
899 r: Result): Any = {
900 if (field.isRowKey) {
901 val row = r.getRow
903 field.dt match {
904 case IntegerType => rawInteger.decode(getFreshByteRange(row))
905 case LongType => rawLong.decode(getFreshByteRange(row))
906 case FloatType => rawFloat.decode(getFreshByteRange(row))
907 case DoubleType => rawDouble.decode(getFreshByteRange(row))
908 case StringType => rawString.decode(getFreshByteRange(row))
909 case TimestampType => rawLong.decode(getFreshByteRange(row))
910 case _ => Bytes.toString(row)
912 } else {
913 val cellByteValue =
914 r.getColumnLatestCell(field.cfBytes, field.colBytes)
915 if (cellByteValue == null) null
916 else field.dt match {
917 case IntegerType => rawInteger.decode(getFreshByteRange(cellByteValue.getValueArray,
918 cellByteValue.getValueOffset, cellByteValue.getValueLength))
919 case LongType => rawLong.decode(getFreshByteRange(cellByteValue.getValueArray,
920 cellByteValue.getValueOffset, cellByteValue.getValueLength))
921 case FloatType => rawFloat.decode(getFreshByteRange(cellByteValue.getValueArray,
922 cellByteValue.getValueOffset, cellByteValue.getValueLength))
923 case DoubleType => rawDouble.decode(getFreshByteRange(cellByteValue.getValueArray,
924 cellByteValue.getValueOffset, cellByteValue.getValueLength))
925 case StringType => Bytes.toString(cellByteValue.getValueArray,
926 cellByteValue.getValueOffset, cellByteValue.getValueLength)
927 case TimestampType => rawLong.decode(getFreshByteRange(cellByteValue.getValueArray,
928 cellByteValue.getValueOffset, cellByteValue.getValueLength))
929 case _ => Bytes.toString(cellByteValue.getValueArray,
930 cellByteValue.getValueOffset, cellByteValue.getValueLength)
936 * This will convert the value from SparkSQL to be stored into HBase using the
937 * right byte Type
939 * @param value String value from SparkSQL
940 * @return Returns the byte array to go into HBase
942 def getByteValue(field: Field,
943 value: String): Array[Byte] = {
944 field.dt match {
945 case IntegerType =>
946 val result = new Array[Byte](Bytes.SIZEOF_INT)
947 val localDataRange = getFreshByteRange(result)
948 rawInteger.encode(localDataRange, value.toInt)
949 localDataRange.getBytes
950 case LongType =>
951 val result = new Array[Byte](Bytes.SIZEOF_LONG)
952 val localDataRange = getFreshByteRange(result)
953 rawLong.encode(localDataRange, value.toLong)
954 localDataRange.getBytes
955 case FloatType =>
956 val result = new Array[Byte](Bytes.SIZEOF_FLOAT)
957 val localDataRange = getFreshByteRange(result)
958 rawFloat.encode(localDataRange, value.toFloat)
959 localDataRange.getBytes
960 case DoubleType =>
961 val result = new Array[Byte](Bytes.SIZEOF_DOUBLE)
962 val localDataRange = getFreshByteRange(result)
963 rawDouble.encode(localDataRange, value.toDouble)
964 localDataRange.getBytes
965 case StringType =>
966 Bytes.toBytes(value)
967 case TimestampType =>
968 val result = new Array[Byte](Bytes.SIZEOF_LONG)
969 val localDataRange = getFreshByteRange(result)
970 rawLong.encode(localDataRange, value.toLong)
971 localDataRange.getBytes
973 case _ => Bytes.toBytes(value)
979 * Contains information related to a filters for a given column.
980 * This can contain many ranges or points.
982 * @param currentPoint the initial point when the filter is created
983 * @param currentRange the initial scanRange when the filter is created
985 class RowKeyFilter (currentPoint:Array[Byte] = null,
986 currentRange:ScanRange =
987 new ScanRange(null, true, new Array[Byte](0), true),
988 var points:mutable.MutableList[Array[Byte]] =
989 new mutable.MutableList[Array[Byte]](),
990 var ranges:mutable.MutableList[ScanRange] =
991 new mutable.MutableList[ScanRange]() ) extends Serializable {
992 //Collection of ranges
993 if (currentRange != null ) ranges.+=(currentRange)
995 //Collection of points
996 if (currentPoint != null) points.+=(currentPoint)
999 * This will validate a give value through the filter's points and/or ranges
1000 * the result will be if the value passed the filter
1002 * @param value Value to be validated
1003 * @param valueOffSet The offset of the value
1004 * @param valueLength The length of the value
1005 * @return True is the value passes the filter false if not
1007 def validate(value:Array[Byte], valueOffSet:Int, valueLength:Int):Boolean = {
1008 var result = false
1010 points.foreach( p => {
1011 if (Bytes.equals(p, 0, p.length, value, valueOffSet, valueLength)) {
1012 result = true
1016 ranges.foreach( r => {
1017 val upperBoundPass = r.upperBound == null ||
1018 (r.isUpperBoundEqualTo &&
1019 Bytes.compareTo(r.upperBound, 0, r.upperBound.length,
1020 value, valueOffSet, valueLength) >= 0) ||
1021 (!r.isUpperBoundEqualTo &&
1022 Bytes.compareTo(r.upperBound, 0, r.upperBound.length,
1023 value, valueOffSet, valueLength) > 0)
1025 val lowerBoundPass = r.lowerBound == null || r.lowerBound.length == 0
1026 (r.isLowerBoundEqualTo &&
1027 Bytes.compareTo(r.lowerBound, 0, r.lowerBound.length,
1028 value, valueOffSet, valueLength) <= 0) ||
1029 (!r.isLowerBoundEqualTo &&
1030 Bytes.compareTo(r.lowerBound, 0, r.lowerBound.length,
1031 value, valueOffSet, valueLength) < 0)
1033 result = result || (upperBoundPass && lowerBoundPass)
1035 result
1039 * This will allow us to merge filter logic that is joined to the existing filter
1040 * through a OR operator
1042 * @param other Filter to merge
1044 def mergeUnion(other:RowKeyFilter): Unit = {
1045 other.points.foreach( p => points += p)
1047 other.ranges.foreach( otherR => {
1048 var doesOverLap = false
1049 ranges.foreach{ r =>
1050 if (r.getOverLapScanRange(otherR) != null) {
1051 r.mergeUnion(otherR)
1052 doesOverLap = true
1054 if (!doesOverLap) ranges.+=(otherR)
1059 * This will allow us to merge filter logic that is joined to the existing filter
1060 * through a AND operator
1062 * @param other Filter to merge
1064 def mergeIntersect(other:RowKeyFilter): Unit = {
1065 val survivingPoints = new mutable.MutableList[Array[Byte]]()
1066 val didntSurviveFirstPassPoints = new mutable.MutableList[Array[Byte]]()
1067 if (points == null || points.length == 0) {
1068 other.points.foreach( otherP => {
1069 didntSurviveFirstPassPoints += otherP
1071 } else {
1072 points.foreach(p => {
1073 if (other.points.length == 0) {
1074 didntSurviveFirstPassPoints += p
1075 } else {
1076 other.points.foreach(otherP => {
1077 if (Bytes.equals(p, otherP)) {
1078 survivingPoints += p
1079 } else {
1080 didntSurviveFirstPassPoints += p
1087 val survivingRanges = new mutable.MutableList[ScanRange]()
1089 if (ranges.length == 0) {
1090 didntSurviveFirstPassPoints.foreach(p => {
1091 survivingPoints += p
1093 } else {
1094 ranges.foreach(r => {
1095 other.ranges.foreach(otherR => {
1096 val overLapScanRange = r.getOverLapScanRange(otherR)
1097 if (overLapScanRange != null) {
1098 survivingRanges += overLapScanRange
1101 didntSurviveFirstPassPoints.foreach(p => {
1102 if (r.containsPoint(p)) {
1103 survivingPoints += p
1108 points = survivingPoints
1109 ranges = survivingRanges
1112 override def toString:String = {
1113 val strBuilder = new StringBuilder
1114 strBuilder.append("(points:(")
1115 var isFirst = true
1116 points.foreach( p => {
1117 if (isFirst) isFirst = false
1118 else strBuilder.append(",")
1119 strBuilder.append(Bytes.toString(p))
1121 strBuilder.append("),ranges:")
1122 isFirst = true
1123 ranges.foreach( r => {
1124 if (isFirst) isFirst = false
1125 else strBuilder.append(",")
1126 strBuilder.append(r)
1128 strBuilder.append("))")
1129 strBuilder.toString()
1135 class ExecutionRuleForUnitTesting(val rowKeyFilter: RowKeyFilter,
1136 val dynamicLogicExpression: DynamicLogicExpression)