flink-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [flink] wuchong commented on a change in pull request #8244: [FLINK-11945] [table-runtime-blink] Support over aggregation for blink streaming runtime
Date Sun, 05 May 2019 13:53:05 GMT
wuchong commented on a change in pull request #8244: [FLINK-11945] [table-runtime-blink] Support
over aggregation for blink streaming runtime
URL: https://github.com/apache/flink/pull/8244#discussion_r281019195
 
 

 ##########
 File path: flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecOverAggregate.scala
 ##########
 @@ -132,8 +143,350 @@ class StreamExecOverAggregate(
     replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode])
   }
 
-  override protected def translateToPlanInternal(
+  override def translateToPlanInternal(
       tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = {
-    throw new TableException("Implements this")
+    val tableConfig = tableEnv.getConfig
+
+    if (logicWindow.groups.size > 1) {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "All aggregates must be computed on the same window."))
+    }
+
+    val overWindow: org.apache.calcite.rel.core.Window.Group = logicWindow.groups.get(0)
+
+    val orderKeys = overWindow.orderKeys.getFieldCollations
+
+    if (orderKeys.size() != 1) {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "The window can only be ordered by a single time column."))
+    }
+    val orderKey = orderKeys.get(0)
+
+    if (!orderKey.direction.equals(ASCENDING)) {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "The window can only be ordered in ASCENDING mode."))
+    }
+
+    val inputDS = getInputNodes.get(0).translateToPlan(tableEnv)
+      .asInstanceOf[StreamTransformation[BaseRow]]
+
+    val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(input)
+
+    if (inputIsAccRetract) {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "Retraction on Over window aggregation is not supported yet. " +
+            "Note: Over window aggregation should not follow a non-windowed GroupBy aggregation."))
+    }
+
+    if (!logicWindow.groups.get(0).keys.isEmpty && tableConfig.getMinIdleStateRetentionTime
< 0) {
+      LOG.warn(
+        "No state retention interval configured for a query which accumulates state. " +
+          "Please provide a query configuration with valid retention interval to prevent
" +
+          "excessive state size. You may specify a retention time of 0 to not clean up the
state.")
+    }
+
+    val timeType = outputRowType.getFieldList.get(orderKey.getFieldIndex).getType
+
+    // check time field
+    if (!FlinkTypeFactory.isRowtimeIndicatorType(timeType)
+      && !FlinkTypeFactory.isProctimeIndicatorType(timeType)) {
+      throw new TableException(
+        "OVER windows' ordering in stream mode must be defined on a time attribute.")
+    }
+
+    // identify window rowtime attribute
+    val rowTimeIdx: Option[Int] = if (FlinkTypeFactory.isRowtimeIndicatorType(timeType))
{
+      Some(orderKey.getFieldIndex)
+    } else if (FlinkTypeFactory.isProctimeIndicatorType(timeType)) {
+      None
+    } else {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "OVER windows can only be applied on time attributes."))
+    }
+
+    val config = tableEnv.getConfig
+    val codeGenCtx = CodeGeneratorContext(config)
+    val aggregateCalls = logicWindow.groups.get(0).getAggregateCalls(logicWindow).asScala
+    val isRowsClause = overWindow.isRows
+    val constants = logicWindow.constants.asScala
+    val constantTypes = constants.map(c => FlinkTypeFactory.toInternalType(c.getType))
+
+    val fieldNames = inputRowType.getFieldNames.asScala
+    val fieldTypes = inputRowType.getFieldList.asScala
+      .map(c => FlinkTypeFactory.toInternalType(c.getType))
+
+    val inRowType = FlinkTypeFactory.toInternalRowType(inputRel.getRowType)
+    val outRowType = FlinkTypeFactory.toInternalRowType(outputRowType)
+
+    val aggInputType = tableEnv.getTypeFactory.buildRelDataType(
+      fieldNames ++ constants.indices.map(i => "TMP" + i),
+      fieldTypes ++ constantTypes)
+
+    val overProcessFunction = if (overWindow.lowerBound.isPreceding
+      && overWindow.lowerBound.isUnbounded
+      && overWindow.upperBound.isCurrentRow) {
+
+      // unbounded OVER window
+      createUnboundedOverProcessFunction(
+        codeGenCtx,
+        aggregateCalls,
+        constants,
+        aggInputType,
+        rowTimeIdx,
+        isRowsClause,
+        tableConfig,
+        tableEnv.getRelBuilder,
+        config.getNullCheck)
+
+    } else if (overWindow.lowerBound.isPreceding
+      && !overWindow.lowerBound.isUnbounded
+      && overWindow.upperBound.isCurrentRow) {
+
+      val boundValue = OverAggregateUtil.getBoundary(logicWindow, overWindow.lowerBound)
+
+      if (boundValue.isInstanceOf[BigDecimal]) {
+        throw new TableException(
+          TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+            "the specific value is decimal which haven not supported yet."))
+      }
+      // bounded OVER window
+      val precedingOffset = -1 * boundValue.asInstanceOf[Long] + (if (isRowsClause) 1 else
0)
+
+      createBoundedOverProcessFunction(
+        codeGenCtx,
+        aggregateCalls,
+        constants,
+        aggInputType,
+        rowTimeIdx,
+        isRowsClause,
+        precedingOffset,
+        tableConfig,
+        tableEnv.getRelBuilder,
+        config.getNullCheck)
+
+    } else {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "OVER RANGE FOLLOWING windows are not supported yet."))
+    }
+
+    val partitionKeys: Array[Int] = overWindow.keys.toArray
+    val inputTypeInfo = inRowType.toTypeInfo
+
+    val selector = KeySelectorUtil.getBaseRowSelector(partitionKeys, inputTypeInfo)
+
+    val returnTypeInfo = outRowType.toTypeInfo
+      .asInstanceOf[BaseRowTypeInfo]
+    // partitioned aggregation
+
+    val operator = new KeyedProcessOperator(overProcessFunction)
+
+    val ret = new OneInputTransformation(
+      inputDS,
+      getOperatorName,
+      operator,
+      returnTypeInfo,
+      inputDS.getParallelism)
+
+    if (partitionKeys.isEmpty) {
+      ret.setParallelism(1)
+      ret.setMaxParallelism(1)
+    }
+
+    // set KeyType and Selector for state
+    ret.setStateKeySelector(selector)
+    ret.setStateKeyType(selector.getProducedType)
+    ret
+  }
+
+  /**
+    * Create an ProcessFunction for unbounded OVER window to evaluate final aggregate value.
+    *
+    * @param ctx            code generator context
+    * @param aggregateCalls physical calls to aggregate functions and their output field
names
+    * @param constants      the constants in aggregates parameters, such as sum(1)
+    * @param aggInputType   physical type of the input row which consist of input and constants.
+    * @param rowTimeIdx     the index of the rowtime field or None in case of processing
time.
+    * @param isRowsClause   it is a tag that indicates whether the OVER clause is ROWS clause
+    */
+  private def createUnboundedOverProcessFunction(
+      ctx: CodeGeneratorContext,
+      aggregateCalls: Seq[AggregateCall],
+      constants: Seq[RexLiteral],
+      aggInputType: RelDataType,
+      rowTimeIdx: Option[Int],
+      isRowsClause: Boolean,
+      tableConfig: TableConfig,
+      relBuilder: RelBuilder,
+      nullCheck: Boolean): KeyedProcessFunction[BaseRow, BaseRow, BaseRow] = {
+
+    val needRetraction = false
+    val aggInfoList = transformToStreamAggregateInfoList(
+      aggregateCalls,
+      // use aggInputType which considers constants as input instead of inputSchema.relDataType
+      aggInputType,
+      Array.fill(aggregateCalls.size)(needRetraction),
+      needInputCount = needRetraction,
+      isStateBackendDataViews = true)
+
+    val fieldTypes = inputRowType.getFieldList.asScala.
+      map(c => FlinkTypeFactory.toInternalType(c.getType)).toArray
+
+    val generator = new AggsHandlerCodeGenerator(
+      ctx,
+      relBuilder,
+      fieldTypes,
+      needRetraction,
+      copyInputField = false)
+
+    val genAggsHandler = generator
+      // over agg code gen must pass the constants
+      .withConstants(constants)
+      .generateAggsHandler("UnboundedOverAggregateHelper", aggInfoList)
+
+    val flattenAccTypes = aggInfoList.getAccTypes.map(
+      TypeConverters.createInternalTypeFromTypeInfo)
+
+    if (rowTimeIdx.isDefined) {
+      if (isRowsClause) {
+        // ROWS unbounded over process function
+        new RowTimeUnboundedRowsOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          rowTimeIdx.get,
+          tableConfig)
+      } else {
+        // RANGE unbounded over process function
+        new RowTimeUnboundedRangeOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          rowTimeIdx.get,
+          tableConfig)
+      }
+    } else {
+      new ProcTimeUnboundedOver(
+        genAggsHandler,
+        flattenAccTypes,
+        tableConfig)
+    }
+  }
+
+  /**
+    * Create an ProcessFunction for ROWS clause bounded OVER window to evaluate final
+    * aggregate value.
+    *
+    * @param ctx            code generator context
+    * @param aggregateCalls physical calls to aggregate functions and their output field
names
+    * @param constants      the constants in aggregates parameters, such as sum(1)
+    * @param aggInputType   physical type of the input row which consist of input and constants.
+    * @param rowTimeIdx     the index of the rowtime field or None in case of processing
time.
+    * @param isRowsClause   it is a tag that indicates whether the OVER clause is ROWS clause
+    */
+  private def createBoundedOverProcessFunction(
+      ctx: CodeGeneratorContext,
+      aggregateCalls: Seq[AggregateCall],
+      constants: Seq[RexLiteral],
+      aggInputType: RelDataType,
+      rowTimeIdx: Option[Int],
+      isRowsClause: Boolean,
+      precedingOffset: Long,
+      tableConfig: TableConfig,
+      relBuilder: RelBuilder,
+      nullCheck: Boolean): KeyedProcessFunction[BaseRow, BaseRow, BaseRow] = {
+
+    val needRetraction = true
+    val aggInfoList = transformToStreamAggregateInfoList(
+      aggregateCalls,
+      // use aggInputType which considers constants as input instead of inputSchema.relDataType
+      aggInputType,
+      Array.fill(aggregateCalls.size)(needRetraction),
+      needInputCount = needRetraction,
+      isStateBackendDataViews = true)
+
+    val fieldTypes = inputRowType.getFieldList.asScala.
+      map(c => FlinkTypeFactory.toInternalType(c.getType)).toArray
+
+    val generator = new AggsHandlerCodeGenerator(
+      ctx,
+      relBuilder,
+      fieldTypes,
+      needRetraction,
+      copyInputField = false)
+
+    val genAggsHandler = generator
+      // over agg code gen must pass the constants
+      .withConstants(constants)
+      .generateAggsHandler("BoundedOverAggregateHelper", aggInfoList)
+
+    val flattenAccTypes = aggInfoList.getAccTypes.map(
+      TypeConverters.createInternalTypeFromTypeInfo)
+
+    if (rowTimeIdx.isDefined) {
+      if (isRowsClause) {
+        new RowTimeBoundedRowsOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          precedingOffset,
+          rowTimeIdx.get,
+          tableConfig)
+      } else {
+        new RowTimeBoundedRangeOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          precedingOffset,
+          rowTimeIdx.get,
+          tableConfig)
+      }
+    } else {
+      if (isRowsClause) {
+        new ProcTimeBoundedRowsOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          precedingOffset,
+          tableConfig)
+      } else {
+        new ProcTimeBoundedRangeOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          precedingOffset,
+          tableConfig)
+      }
+    }
+  }
+
+  private def getOperatorName = {
 
 Review comment:
   The logic of operator name and `explainTerms` is duplicate. We are planning to reuse `explainTerms`
for operator name. You can simply use `"OverAggregate"` as the operator name. So that we can
avoid to introduce this fat method and changes in `OverAggregateUtil`.
   
   In order to avoid duplicate 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message