beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "ASF GitHub Bot (JIRA)" <j...@apache.org>
Subject [jira] [Work logged] (BEAM-5056) [SQL] Nullability of aggregation expressions isn't inferred properly
Date Fri, 03 Aug 2018 21:37:00 GMT

     [ https://issues.apache.org/jira/browse/BEAM-5056?focusedWorklogId=131081&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-131081 ]

ASF GitHub Bot logged work on BEAM-5056:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 03/Aug/18 21:36
            Start Date: 03/Aug/18 21:36
    Worklog Time Spent: 10m 
      Work Description: akedin closed pull request #6118: [BEAM-5056] [SQL] Fix nullability in output schema
URL: https://github.com/apache/beam/pull/6118
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
index 7e7f22a89f0..236e5fa5639 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
@@ -125,7 +125,7 @@ public BeamAggregationRel(
                   "combineBy",
                   Combine.perKey(
                       new BeamAggregationTransforms.AggregationAdaptor(
-                          getNamedAggCalls(), CalciteUtils.toBeamSchema(input.getRowType()))))
+                          getNamedAggCalls(), CalciteUtils.toSchema(input.getRowType()))))
               .setCoder(KvCoder.of(keyCoder, aggCoder));
 
       PCollection<Row> mergedStream =
@@ -133,8 +133,8 @@ public BeamAggregationRel(
               "mergeRecord",
               ParDo.of(
                   new BeamAggregationTransforms.MergeAggregationRecord(
-                      CalciteUtils.toBeamSchema(getRowType()), windowFieldIndex)));
-      mergedStream.setRowSchema(CalciteUtils.toBeamSchema(getRowType()));
+                      CalciteUtils.toSchema(getRowType()), windowFieldIndex)));
+      mergedStream.setRowSchema(CalciteUtils.toSchema(getRowType()));
 
       return mergedStream;
     }
@@ -164,7 +164,7 @@ private void validateWindowIsSupported(PCollection<Row> upstream) {
 
     /** Type of sub-rowrecord used as Group-By keys. */
     private Schema exKeyFieldsSchema(RelDataType relDataType) {
-      Schema inputSchema = CalciteUtils.toBeamSchema(relDataType);
+      Schema inputSchema = CalciteUtils.toSchema(relDataType);
       return groupSet
           .asList()
           .stream()
@@ -183,9 +183,7 @@ private Schema exAggFieldsSchema() {
     }
 
     private Schema.Field newRowField(Pair<AggregateCall, String> namedAggCall) {
-      return Schema.Field.of(
-              namedAggCall.right, CalciteUtils.toFieldType(namedAggCall.left.getType()))
-          .withNullable(namedAggCall.left.getType().isNullable());
+      return CalciteUtils.toField(namedAggCall.right, namedAggCall.left.getType());
     }
   }
 
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java
index 84989f0f528..2d050acdb1f 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java
@@ -72,12 +72,12 @@ public Calc copy(RelTraitSet traitSet, RelNode input, RexProgram program) {
 
       BeamSqlExpressionExecutor executor = new BeamSqlFnExecutor(BeamCalcRel.this.getProgram());
 
-      Schema schema = CalciteUtils.toBeamSchema(rowType);
+      Schema schema = CalciteUtils.toSchema(rowType);
       PCollection<Row> projectStream =
           upstream
-              .apply(ParDo.of(new CalcFn(executor, CalciteUtils.toBeamSchema(rowType))))
+              .apply(ParDo.of(new CalcFn(executor, CalciteUtils.toSchema(rowType))))
               .setRowSchema(schema);
-      projectStream.setRowSchema(CalciteUtils.toBeamSchema(getRowType()));
+      projectStream.setRowSchema(CalciteUtils.toSchema(getRowType()));
 
       return projectStream;
     }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRel.java
index a752e206a42..a36c9eadb4b 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRel.java
@@ -28,6 +28,8 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable;
 import org.apache.beam.sdk.extensions.sql.BeamSqlTable;
@@ -146,10 +148,12 @@ private boolean isSideInputJoin() {
       if (isSideInputJoin()) {
         checkArgument(pinput.size() == 1, "More than one input received for side input join");
         return joinAsLookup(leftRelNode, rightRelNode, pinput.get(0))
-            .setRowSchema(CalciteUtils.toBeamSchema(getRowType()));
+            .setRowSchema(CalciteUtils.toSchema(getRowType()));
       }
 
-      Schema leftSchema = CalciteUtils.toBeamSchema(left.getRowType());
+      Schema leftSchema = CalciteUtils.toSchema(left.getRowType());
+      Schema rightSchema = CalciteUtils.toSchema(right.getRowType());
+
       assert pinput.size() == 2;
       PCollection<Row> leftRows = pinput.get(0);
       PCollection<Row> rightRows = pinput.get(1);
@@ -186,10 +190,6 @@ private boolean isSideInputJoin() {
                   MapElements.via(new BeamJoinTransforms.ExtractJoinFields(false, pairs)))
               .setCoder(KvCoder.of(extractKeyRowCoder, rightRows.getCoder()));
 
-      // prepare the NullRows
-      Row leftNullRow = buildNullRow(leftRelNode);
-      Row rightNullRow = buildNullRow(rightRelNode);
-
       // a regular join
       if ((leftRows.isBounded() == PCollection.IsBounded.BOUNDED
               && rightRows.isBounded() == PCollection.IsBounded.BOUNDED)
@@ -201,7 +201,7 @@ private boolean isSideInputJoin() {
               "WindowFns must match for a bounded-vs-bounded/unbounded-vs-unbounded join.", e);
         }
 
-        return standardJoin(extractedLeftRows, extractedRightRows, leftNullRow, rightNullRow);
+        return standardJoin(extractedLeftRows, extractedRightRows, leftSchema, rightSchema);
       } else if ((leftRows.isBounded() == PCollection.IsBounded.BOUNDED
               && rightRows.isBounded() == UNBOUNDED)
           || (leftRows.isBounded() == UNBOUNDED
@@ -224,7 +224,7 @@ private boolean isSideInputJoin() {
               "LEFT side of an OUTER JOIN must be Unbounded table.");
         }
 
-        return sideInputJoin(extractedLeftRows, extractedRightRows, leftNullRow, rightNullRow);
+        return sideInputJoin(extractedLeftRows, extractedRightRows, leftSchema, rightSchema);
       } else {
         throw new UnsupportedOperationException(
             "The inputs to the JOIN have un-joinnable windowFns: " + leftWinFn + ", " + rightWinFn);
@@ -257,25 +257,52 @@ private boolean triggersOncePerWindow(WindowingStrategy windowingStrategy) {
   private PCollection<Row> standardJoin(
       PCollection<KV<Row, Row>> extractedLeftRows,
       PCollection<KV<Row, Row>> extractedRightRows,
-      Row leftNullRow,
-      Row rightNullRow) {
+      Schema leftSchema,
+      Schema rightSchema) {
     PCollection<KV<Row, KV<Row, Row>>> joinedRows = null;
+
     switch (joinType) {
       case LEFT:
-        joinedRows =
-            org.apache.beam.sdk.extensions.joinlibrary.Join.leftOuterJoin(
-                extractedLeftRows, extractedRightRows, rightNullRow);
-        break;
+        {
+          Schema rigthNullSchema = buildNullSchema(rightSchema);
+          Row rightNullRow = Row.nullRow(rigthNullSchema);
+
+          extractedRightRows = setValueCoder(extractedRightRows, SchemaCoder.of(rigthNullSchema));
+
+          joinedRows =
+              org.apache.beam.sdk.extensions.joinlibrary.Join.leftOuterJoin(
+                  extractedLeftRows, extractedRightRows, rightNullRow);
+
+          break;
+        }
       case RIGHT:
-        joinedRows =
-            org.apache.beam.sdk.extensions.joinlibrary.Join.rightOuterJoin(
-                extractedLeftRows, extractedRightRows, leftNullRow);
-        break;
+        {
+          Schema leftNullSchema = buildNullSchema(leftSchema);
+          Row leftNullRow = Row.nullRow(leftNullSchema);
+
+          extractedLeftRows = setValueCoder(extractedLeftRows, SchemaCoder.of(leftNullSchema));
+
+          joinedRows =
+              org.apache.beam.sdk.extensions.joinlibrary.Join.rightOuterJoin(
+                  extractedLeftRows, extractedRightRows, leftNullRow);
+          break;
+        }
       case FULL:
-        joinedRows =
-            org.apache.beam.sdk.extensions.joinlibrary.Join.fullOuterJoin(
-                extractedLeftRows, extractedRightRows, leftNullRow, rightNullRow);
-        break;
+        {
+          Schema leftNullSchema = buildNullSchema(leftSchema);
+          Schema rightNullSchema = buildNullSchema(rightSchema);
+
+          Row leftNullRow = Row.nullRow(leftNullSchema);
+          Row rightNullRow = Row.nullRow(rightNullSchema);
+
+          extractedLeftRows = setValueCoder(extractedLeftRows, SchemaCoder.of(leftNullSchema));
+          extractedRightRows = setValueCoder(extractedRightRows, SchemaCoder.of(rightNullSchema));
+
+          joinedRows =
+              org.apache.beam.sdk.extensions.joinlibrary.Join.fullOuterJoin(
+                  extractedLeftRows, extractedRightRows, leftNullRow, rightNullRow);
+          break;
+        }
       case INNER:
       default:
         joinedRows =
@@ -288,15 +315,15 @@ private boolean triggersOncePerWindow(WindowingStrategy windowingStrategy) {
         joinedRows
             .apply(
                 "JoinParts2WholeRow", MapElements.via(new BeamJoinTransforms.JoinParts2WholeRow()))
-            .setRowSchema(CalciteUtils.toBeamSchema(getRowType()));
+            .setRowSchema(CalciteUtils.toSchema(getRowType()));
     return ret;
   }
 
   public PCollection<Row> sideInputJoin(
       PCollection<KV<Row, Row>> extractedLeftRows,
       PCollection<KV<Row, Row>> extractedRightRows,
-      Row leftNullRow,
-      Row rightNullRow) {
+      Schema leftSchema,
+      Schema rightSchema) {
     // we always make the Unbounded table on the left to do the sideInput join
     // (will convert the result accordingly before return)
     boolean swapped = (extractedLeftRows.isBounded() == PCollection.IsBounded.BOUNDED);
@@ -305,7 +332,19 @@ private boolean triggersOncePerWindow(WindowingStrategy windowingStrategy) {
 
     PCollection<KV<Row, Row>> realLeftRows = swapped ? extractedRightRows : extractedLeftRows;
     PCollection<KV<Row, Row>> realRightRows = swapped ? extractedLeftRows : extractedRightRows;
-    Row realRightNullRow = swapped ? leftNullRow : rightNullRow;
+
+    Row realRightNullRow;
+    if (swapped) {
+      Schema leftNullSchema = buildNullSchema(leftSchema);
+
+      realRightRows = setValueCoder(realRightRows, SchemaCoder.of(leftNullSchema));
+      realRightNullRow = Row.nullRow(leftNullSchema);
+    } else {
+      Schema rightNullSchema = buildNullSchema(rightSchema);
+
+      realRightRows = setValueCoder(realRightRows, SchemaCoder.of(rightNullSchema));
+      realRightNullRow = Row.nullRow(rightNullSchema);
+    }
 
     // swapped still need to pass down because, we need to swap the result back.
     return sideInputJoinHelper(
@@ -327,14 +366,26 @@ private boolean triggersOncePerWindow(WindowingStrategy windowingStrategy) {
                         new BeamJoinTransforms.SideInputJoinDoFn(
                             joinType, rightNullRow, rowsView, swapped))
                     .withSideInputs(rowsView))
-            .setRowSchema(CalciteUtils.toBeamSchema(getRowType()));
+            .setRowSchema(CalciteUtils.toSchema(getRowType()));
 
     return ret;
   }
 
-  private Row buildNullRow(BeamRelNode relNode) {
-    Schema leftType = CalciteUtils.toBeamSchema(relNode.getRowType());
-    return Row.nullRow(leftType);
+  private Schema buildNullSchema(Schema schema) {
+    Schema.Builder builder = Schema.builder();
+
+    builder.addFields(
+        schema.getFields().stream().map(f -> f.withNullable(true)).collect(Collectors.toList()));
+
+    return builder.build();
+  }
+
+  private static <K, V> PCollection<KV<K, V>> setValueCoder(
+      PCollection<KV<K, V>> kvs, Coder<V> valueCoder) {
+    // safe case because PCollection of KV always has KvCoder
+    KvCoder<K, V> coder = (KvCoder<K, V>) kvs.getCoder();
+
+    return kvs.setCoder(KvCoder.of(coder.getKeyCoder(), valueCoder));
   }
 
   private List<Pair<Integer, Integer>> extractJoinColumns(int leftRowColumnCount) {
@@ -386,8 +437,8 @@ private Row buildNullRow(BeamRelNode relNode) {
         new BeamJoinTransforms.JoinAsLookup(
             condition,
             seekableTable,
-            CalciteUtils.toBeamSchema(rightRelNode.getRowType()),
-            CalciteUtils.toBeamSchema(leftRelNode.getRowType()).getFieldCount()));
+            CalciteUtils.toSchema(rightRelNode.getRowType()),
+            CalciteUtils.toSchema(leftRelNode.getRowType()).getFieldCount()));
   }
 
   /** check if {@code BeamRelNode} implements {@code BeamSeekableTable}. */
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java
index 3d6d8a71612..466543568e3 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java
@@ -167,7 +167,7 @@ public int getCount() {
         return upstream
             .apply(Window.into(new GlobalWindows()))
             .apply(new LimitTransform<>())
-            .setRowSchema(CalciteUtils.toBeamSchema(getRowType()));
+            .setRowSchema(CalciteUtils.toSchema(getRowType()));
       } else {
 
         WindowingStrategy<?, ?> windowingStrategy = upstream.getWindowingStrategy();
@@ -202,7 +202,7 @@ public int getCount() {
         return rawStream
             .apply("flatten", Flatten.iterables())
             .setSchema(
-                CalciteUtils.toBeamSchema(getRowType()),
+                CalciteUtils.toSchema(getRowType()),
                 SerializableFunctions.identity(),
                 SerializableFunctions.identity());
       }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUncollectRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUncollectRel.java
index b80810bb774..8dd17f2cd6e 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUncollectRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUncollectRel.java
@@ -63,7 +63,7 @@ public RelNode copy(RelTraitSet traitSet, RelNode input) {
 
       // Each row of the input contains a single array of things to be emitted; Calcite knows
       // what the row looks like
-      Schema outputSchema = CalciteUtils.toBeamSchema(getRowType());
+      Schema outputSchema = CalciteUtils.toSchema(getRowType());
 
       PCollection<Row> uncollected =
           upstream.apply(ParDo.of(new UncollectDoFn(outputSchema))).setRowSchema(outputSchema);
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
index 829f886b6e6..55fcc1aea0c 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
@@ -93,7 +93,7 @@ public Correlate copy(
 
       // The correlated subquery
       BeamUncollectRel uncollect = (BeamUncollectRel) BeamSqlRelUtils.getBeamRelInput(right);
-      Schema innerSchema = CalciteUtils.toBeamSchema(uncollect.getRowType());
+      Schema innerSchema = CalciteUtils.toSchema(uncollect.getRowType());
       checkArgument(
           innerSchema.getFieldCount() == 1, "Can only UNNEST a single column", getClass());
 
@@ -101,7 +101,7 @@ public Correlate copy(
           new BeamSqlFnExecutor(
               ((BeamCalcRel) BeamSqlRelUtils.getBeamRelInput(uncollect.getInput())).getProgram());
 
-      Schema joinedSchema = CalciteUtils.toBeamSchema(rowType);
+      Schema joinedSchema = CalciteUtils.toSchema(rowType);
 
       return outer
           .apply(
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java
index 9d6775a27b2..45517572201 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java
@@ -78,7 +78,7 @@ public BeamValuesRel(
         throw new IllegalStateException("Values with empty tuples!");
       }
 
-      Schema schema = CalciteUtils.toBeamSchema(getRowType());
+      Schema schema = CalciteUtils.toSchema(getRowType());
 
       List<Row> rows = tuples.stream().map(tuple -> tupleToRow(schema, tuple)).collect(toList());
 
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
index a19d55e2fee..687d85588cf 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
@@ -146,7 +146,7 @@ public AggregationAdaptor(
         String aggName = aggCall.right;
 
         if (call.getArgList().size() == 2) {
-          /**
+          /*
            * handle the case of aggregation function has two parameters and use KV pair to bundle
            * two corresponding expressions.
            */
@@ -171,8 +171,9 @@ public AggregationAdaptor(
           sourceFieldExps.add(sourceExp);
         }
 
-        FieldType typeDescriptor = CalciteUtils.toFieldType(call.type);
-        fields.add(Schema.Field.of(aggName, typeDescriptor));
+        Schema.Field field = CalciteUtils.toField(aggName, call.type);
+        Schema.TypeName fieldTypeName = field.getType().getTypeName();
+        fields.add(field);
 
         switch (call.getAggregation().getName()) {
           case "COUNT":
@@ -193,22 +194,17 @@ public AggregationAdaptor(
             break;
           case "VAR_POP":
             aggregators.add(
-                VarianceFn.newPopulation(
-                    BigDecimalConverter.forSqlType(typeDescriptor.getTypeName())));
+                VarianceFn.newPopulation(BigDecimalConverter.forSqlType(fieldTypeName)));
             break;
           case "VAR_SAMP":
-            aggregators.add(
-                VarianceFn.newSample(BigDecimalConverter.forSqlType(typeDescriptor.getTypeName())));
+            aggregators.add(VarianceFn.newSample(BigDecimalConverter.forSqlType(fieldTypeName)));
             break;
           case "COVAR_POP":
             aggregators.add(
-                CovarianceFn.newPopulation(
-                    BigDecimalConverter.forSqlType(typeDescriptor.getTypeName())));
+                CovarianceFn.newPopulation(BigDecimalConverter.forSqlType(fieldTypeName)));
             break;
           case "COVAR_SAMP":
-            aggregators.add(
-                CovarianceFn.newSample(
-                    BigDecimalConverter.forSqlType(typeDescriptor.getTypeName())));
+            aggregators.add(CovarianceFn.newSample(BigDecimalConverter.forSqlType(fieldTypeName)));
             break;
           default:
             if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index e307b768e08..3bb7b2c16b7 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -18,8 +18,6 @@
 
 package org.apache.beam.sdk.extensions.sql.impl.utils;
 
-import static org.apache.beam.sdk.schemas.Schema.toSchema;
-
 import com.google.common.collect.BiMap;
 import com.google.common.collect.ImmutableBiMap;
 import com.google.common.collect.ImmutableMap;
@@ -90,12 +88,8 @@
           FieldType.STRING, SqlTypeName.VARCHAR);
 
   /** Generate {@link Schema} from {@code RelDataType} which is used to create table. */
-  public static Schema toBeamSchema(RelDataType tableInfo) {
-    return tableInfo
-        .getFieldList()
-        .stream()
-        .map(CalciteUtils::toBeamSchemaField)
-        .collect(toSchema());
+  public static Schema toSchema(RelDataType tableInfo) {
+    return tableInfo.getFieldList().stream().map(CalciteUtils::toField).collect(Schema.toSchema());
   }
 
   public static SqlTypeName toSqlTypeName(FieldType type) {
@@ -134,6 +128,14 @@ public static FieldType toFieldType(SqlTypeName sqlTypeName) {
     }
   }
 
+  public static Schema.Field toField(RelDataTypeField calciteField) {
+    return toField(calciteField.getName(), calciteField.getType());
+  }
+
+  public static Schema.Field toField(String name, RelDataType calciteType) {
+    return Schema.Field.of(name, toFieldType(calciteType)).withNullable(calciteType.isNullable());
+  }
+
   public static FieldType toFieldType(RelDataType calciteType) {
     switch (calciteType.getSqlTypeName()) {
       case ARRAY:
@@ -143,19 +145,13 @@ public static FieldType toFieldType(RelDataType calciteType) {
         return FieldType.map(
             toFieldType(calciteType.getKeyType()), toFieldType(calciteType.getValueType()));
       case ROW:
-        return FieldType.row(toBeamSchema(calciteType));
+        return FieldType.row(toSchema(calciteType));
 
       default:
         return toFieldType(calciteType.getSqlTypeName());
     }
   }
 
-  public static Schema.Field toBeamSchemaField(RelDataTypeField calciteField) {
-    FieldType fieldType = toFieldType(calciteField.getType());
-    // TODO: We should support Calcite's nullable annotations.
-    return Schema.Field.of(calciteField.getName(), fieldType).withNullable(true);
-  }
-
   /** Create an instance of {@code RelDataType} so it can be used to create a table. */
   public static RelDataType toCalciteRowType(Schema schema, RelDataTypeFactory dataTypeFactory) {
     RelDataTypeFactory.Builder builder = new RelDataTypeFactory.Builder(dataTypeFactory);
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationNullableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationNullableTest.java
index 4c61561852e..91e19db5dff 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationNullableTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationNullableTest.java
@@ -19,6 +19,7 @@
 
 import static org.apache.beam.sdk.extensions.sql.utils.RowAsserts.matchesScalar;
 import static org.apache.beam.sdk.transforms.SerializableFunctions.identity;
+import static org.junit.Assert.assertEquals;
 
 import java.util.List;
 import org.apache.beam.sdk.schemas.Schema;
@@ -29,7 +30,6 @@
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
 import org.junit.Before;
-import org.junit.Ignore;
 import org.junit.Rule;
 import org.junit.Test;
 
@@ -122,12 +122,44 @@ public void testAvg() {
     pipeline.run();
   }
 
-  @Ignore
-  // FIXME [BEAM-5056] [SQL] Nullability of aggregation expressions isn't inferred properly
+  @Test
   public void testAvgGroupByNullable() {
-    String sql = "SELECT AVG(f_int1) FROM PCOLLECTION GROUP BY f_int2";
+    String sql = "SELECT AVG(f_int1), f_int2 FROM PCOLLECTION GROUP BY f_int2";
+
+    PCollection<Row> out = boundedInput.apply(SqlTransform.query(sql));
+    Schema schema = out.getSchema();
+
+    PAssert.that(out)
+        .containsInAnyOrder(
+            Row.withSchema(schema).addValues(null, null).build(),
+            Row.withSchema(schema).addValues(2, 1).build(),
+            Row.withSchema(schema).addValues(1, 5).build(),
+            Row.withSchema(schema).addValues(3, 2).build());
+
+    pipeline.run();
+  }
 
-    boundedInput.apply(SqlTransform.query(sql));
+  @Test
+  public void testCountGroupByNullable() {
+    String sql = "SELECT COUNT(f_int1) as c, f_int2 FROM PCOLLECTION GROUP BY f_int2";
+
+    PCollection<Row> out = boundedInput.apply(SqlTransform.query(sql));
+    Schema schema = out.getSchema();
+
+    PAssert.that(out)
+        .containsInAnyOrder(
+            Row.withSchema(schema).addValues(0L, null).build(),
+            Row.withSchema(schema).addValues(1L, 1).build(),
+            Row.withSchema(schema).addValues(1L, 5).build(),
+            Row.withSchema(schema).addValues(1L, 2).build());
+
+    assertEquals(
+        Schema.builder()
+            // COUNT() is never nullable, and calcite knows it
+            .addInt64Field("c")
+            .addNullableField("f_int2", Schema.FieldType.INT32)
+            .build(),
+        schema);
 
     pipeline.run();
   }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/TestUtils.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/TestUtils.java
index 182978a4e33..fbfe28ad810 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/TestUtils.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/TestUtils.java
@@ -245,9 +245,8 @@ public static Schema buildBeamSqlSchema(Object... args) {
   }
 
   // TODO: support nested.
-  // TODO: support nullable.
   private static Schema.Field toRecordField(Object[] args, int i) {
-    return Schema.Field.of((String) args[i + 1], (FieldType) args[i]).withNullable(true);
+    return Schema.Field.of((String) args[i + 1], (FieldType) args[i]);
   }
 
   /**
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutorTestBase.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutorTestBase.java
index cfce911c22c..2d4e9b19c6e 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutorTestBase.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutorTestBase.java
@@ -65,7 +65,7 @@ public static void prepare() {
             .build();
 
     row =
-        Row.withSchema(CalciteUtils.toBeamSchema(relDataType))
+        Row.withSchema(CalciteUtils.toSchema(relDataType))
             .addValues(1234567L, 0, 8.9, 1234567L, "This is an order.")
             .build();
 
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelBoundedVsBoundedTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelBoundedVsBoundedTest.java
index ad32a6c0fed..e5cda16ceb5 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelBoundedVsBoundedTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelBoundedVsBoundedTest.java
@@ -66,12 +66,14 @@ public void testInnerJoin() throws Exception {
     PAssert.that(rows)
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "site_id",
-                    Schema.FieldType.INT32, "price",
-                    Schema.FieldType.INT32, "order_id0",
-                    Schema.FieldType.INT32, "site_id0",
-                    Schema.FieldType.INT32, "price0")
+                    Schema.builder()
+                        .addField("order_id", Schema.FieldType.INT32)
+                        .addField("site_id", Schema.FieldType.INT32)
+                        .addField("price", Schema.FieldType.INT32)
+                        .addField("order_id0", Schema.FieldType.INT32)
+                        .addField("site_id0", Schema.FieldType.INT32)
+                        .addField("price0", Schema.FieldType.INT32)
+                        .build())
                 .addRows(2, 3, 3, 1, 2, 3)
                 .getRows());
     pipeline.run();
@@ -91,12 +93,14 @@ public void testLeftOuterJoin() throws Exception {
     PAssert.that(rows)
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "site_id",
-                    Schema.FieldType.INT32, "price",
-                    Schema.FieldType.INT32, "order_id0",
-                    Schema.FieldType.INT32, "site_id0",
-                    Schema.FieldType.INT32, "price0")
+                    Schema.builder()
+                        .addField("order_id", Schema.FieldType.INT32)
+                        .addField("site_id", Schema.FieldType.INT32)
+                        .addField("price", Schema.FieldType.INT32)
+                        .addNullableField("order_id0", Schema.FieldType.INT32)
+                        .addNullableField("site_id0", Schema.FieldType.INT32)
+                        .addNullableField("price0", Schema.FieldType.INT32)
+                        .build())
                 .addRows(1, 2, 3, null, null, null, 2, 3, 3, 1, 2, 3, 3, 4, 5, null, null, null)
                 .getRows());
     pipeline.run();
@@ -115,12 +119,14 @@ public void testRightOuterJoin() throws Exception {
     PAssert.that(rows)
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "site_id",
-                    Schema.FieldType.INT32, "price",
-                    Schema.FieldType.INT32, "order_id0",
-                    Schema.FieldType.INT32, "site_id0",
-                    Schema.FieldType.INT32, "price0")
+                    Schema.builder()
+                        .addNullableField("order_id", Schema.FieldType.INT32)
+                        .addNullableField("site_id", Schema.FieldType.INT32)
+                        .addNullableField("price", Schema.FieldType.INT32)
+                        .addField("order_id0", Schema.FieldType.INT32)
+                        .addField("site_id0", Schema.FieldType.INT32)
+                        .addField("price0", Schema.FieldType.INT32)
+                        .build())
                 .addRows(2, 3, 3, 1, 2, 3, null, null, null, 2, 3, 3, null, null, null, 3, 4, 5)
                 .getRows());
     pipeline.run();
@@ -139,12 +145,14 @@ public void testFullOuterJoin() throws Exception {
     PAssert.that(rows)
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "site_id",
-                    Schema.FieldType.INT32, "price",
-                    Schema.FieldType.INT32, "order_id0",
-                    Schema.FieldType.INT32, "site_id0",
-                    Schema.FieldType.INT32, "price0")
+                    Schema.builder()
+                        .addNullableField("order_id", Schema.FieldType.INT32)
+                        .addNullableField("site_id", Schema.FieldType.INT32)
+                        .addNullableField("price", Schema.FieldType.INT32)
+                        .addNullableField("order_id0", Schema.FieldType.INT32)
+                        .addNullableField("site_id0", Schema.FieldType.INT32)
+                        .addNullableField("price0", Schema.FieldType.INT32)
+                        .build())
                 .addRows(
                     2, 3, 3, 1, 2, 3, 1, 2, 3, null, null, null, 3, 4, 5, null, null, null, null,
                     null, null, 2, 3, 3, null, null, null, 3, 4, 5)
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsBoundedTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsBoundedTest.java
index 850ddb9b497..149a57efd12 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsBoundedTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsBoundedTest.java
@@ -184,13 +184,17 @@ public void testLeftOuterJoin() throws Exception {
             + " o1.order_id=o2.order_id";
 
     PCollection<Row> rows = compilePipeline(sql, pipeline);
+
     rows.apply(ParDo.of(new BeamSqlOutputToConsoleFn("helloworld")));
+
     PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn())))
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "sum_site_id",
-                    Schema.FieldType.STRING, "buyer")
+                    Schema.builder()
+                        .addField("order_id", Schema.FieldType.INT32)
+                        .addField("sum_site_id", Schema.FieldType.INT32)
+                        .addNullableField("buyer", Schema.FieldType.STRING)
+                        .build())
                 .addRows(1, 3, "james", 2, 5, "bond", 3, 3, null)
                 .getStringRows());
     pipeline.run();
@@ -225,9 +229,11 @@ public void testRightOuterJoin() throws Exception {
     PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn())))
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "sum_site_id",
-                    Schema.FieldType.STRING, "buyer")
+                    Schema.builder()
+                        .addField("order_id", Schema.FieldType.INT32)
+                        .addField("sum_site_id", Schema.FieldType.INT32)
+                        .addNullableField("buyer", Schema.FieldType.STRING)
+                        .build())
                 .addRows(1, 3, "james", 2, 5, "bond", 3, 3, null)
                 .getStringRows());
     pipeline.run();
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsUnboundedTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsUnboundedTest.java
index a23867415c4..c99435ae8cb 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsUnboundedTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsUnboundedTest.java
@@ -92,10 +92,12 @@ public void testInnerJoin() throws Exception {
     PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn())))
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "sum_site_id",
-                    Schema.FieldType.INT32, "order_id0",
-                    Schema.FieldType.INT32, "sum_site_id0")
+                    Schema.builder()
+                        .addField("order_id1", Schema.FieldType.INT32)
+                        .addField("sum_site_id", Schema.FieldType.INT32)
+                        .addField("order_id", Schema.FieldType.INT32)
+                        .addField("sum_site_id0", Schema.FieldType.INT32)
+                        .build())
                 .addRows(1, 3, 1, 3, 2, 5, 2, 5)
                 .getStringRows());
     pipeline.run();
@@ -123,10 +125,12 @@ public void testLeftOuterJoin() throws Exception {
     PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn())))
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "sum_site_id",
-                    Schema.FieldType.INT32, "order_id0",
-                    Schema.FieldType.INT32, "sum_site_id0")
+                    Schema.builder()
+                        .addField("order_id1", Schema.FieldType.INT32)
+                        .addField("sum_site_id", Schema.FieldType.INT32)
+                        .addNullableField("order_id", Schema.FieldType.INT32)
+                        .addNullableField("sum_site_id0", Schema.FieldType.INT32)
+                        .build())
                 .addRows(1, 1, 1, 3, 2, 2, null, null, 2, 2, 2, 5, 3, 3, null, null)
                 .getStringRows());
     pipeline.run();
@@ -148,10 +152,12 @@ public void testRightOuterJoin() throws Exception {
     PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn())))
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "sum_site_id",
-                    Schema.FieldType.INT32, "order_id0",
-                    Schema.FieldType.INT32, "sum_site_id0")
+                    Schema.builder()
+                        .addNullableField("order_id1", Schema.FieldType.INT32)
+                        .addNullableField("sum_site_id", Schema.FieldType.INT32)
+                        .addField("order_id", Schema.FieldType.INT32)
+                        .addField("sum_site_id0", Schema.FieldType.INT32)
+                        .build())
                 .addRows(1, 3, 1, 1, null, null, 2, 2, 2, 5, 2, 2, null, null, 3, 3)
                 .getStringRows());
     pipeline.run();
@@ -174,10 +180,12 @@ public void testFullOuterJoin() throws Exception {
     PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn())))
         .containsInAnyOrder(
             TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT32, "order_id1",
-                    Schema.FieldType.INT32, "sum_site_id",
-                    Schema.FieldType.INT32, "order_id",
-                    Schema.FieldType.INT32, "sum_site_id0")
+                    Schema.builder()
+                        .addNullableField("order_id1", Schema.FieldType.INT32)
+                        .addNullableField("sum_site_id", Schema.FieldType.INT32)
+                        .addNullableField("order_id", Schema.FieldType.INT32)
+                        .addNullableField("sum_site_id0", Schema.FieldType.INT32)
+                        .build())
                 .addRows(
                     1, 1, 1, 3, 6, 2, null, null, 7, 2, null, null, 8, 3, null, null, null, null, 2,
                     5)
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java
index d5116b8cc99..10dace36327 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java
@@ -90,9 +90,11 @@ public void prepare() {
     registerTable(
         "SUB_ORDER_RAM",
         MockedBoundedTable.of(
-            Schema.FieldType.INT64, "order_id",
-            Schema.FieldType.INT32, "site_id",
-            Schema.FieldType.DOUBLE, "price"));
+            Schema.builder()
+                .addField("order_id", Schema.FieldType.INT64)
+                .addField("site_id", Schema.FieldType.INT32)
+                .addNullableField("price", Schema.FieldType.DOUBLE)
+                .build()));
   }
 
   @Test
@@ -153,19 +155,18 @@ public void testOrderBy_timestamp() throws Exception {
 
   @Test
   public void testOrderBy_nullsFirst() throws Exception {
+    Schema schema =
+        Schema.builder()
+            .addField("order_id", Schema.FieldType.INT64)
+            .addNullableField("site_id", Schema.FieldType.INT32)
+            .addField("price", Schema.FieldType.DOUBLE)
+            .build();
+
     registerTable(
         "ORDER_DETAILS",
-        MockedBoundedTable.of(
-                Schema.FieldType.INT64, "order_id",
-                Schema.FieldType.INT32, "site_id",
-                Schema.FieldType.DOUBLE, "price")
+        MockedBoundedTable.of(schema)
             .addRows(1L, 2, 1.0, 1L, null, 2.0, 2L, 1, 3.0, 2L, null, 4.0, 5L, 5, 5.0));
-    registerTable(
-        "SUB_ORDER_RAM",
-        MockedBoundedTable.of(
-            Schema.FieldType.INT64, "order_id",
-            Schema.FieldType.INT32, "site_id",
-            Schema.FieldType.DOUBLE, "price"));
+    registerTable("SUB_ORDER_RAM", MockedBoundedTable.of(schema));
 
     String sql =
         "INSERT INTO SUB_ORDER_RAM(order_id, site_id, price)  SELECT "
@@ -176,10 +177,7 @@ public void testOrderBy_nullsFirst() throws Exception {
     PCollection<Row> rows = compilePipeline(sql, pipeline);
     PAssert.that(rows)
         .containsInAnyOrder(
-            TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT64, "order_id",
-                    Schema.FieldType.INT32, "site_id",
-                    Schema.FieldType.DOUBLE, "price")
+            TestUtils.RowsBuilder.of(schema)
                 .addRows(1L, null, 2.0, 1L, 2, 1.0, 2L, null, 4.0, 2L, 1, 3.0)
                 .getRows());
     pipeline.run().waitUntilFinish();
@@ -187,19 +185,18 @@ public void testOrderBy_nullsFirst() throws Exception {
 
   @Test
   public void testOrderBy_nullsLast() throws Exception {
+    Schema schema =
+        Schema.builder()
+            .addField("order_id", Schema.FieldType.INT64)
+            .addNullableField("site_id", Schema.FieldType.INT32)
+            .addField("price", Schema.FieldType.DOUBLE)
+            .build();
+
     registerTable(
         "ORDER_DETAILS",
-        MockedBoundedTable.of(
-                Schema.FieldType.INT64, "order_id",
-                Schema.FieldType.INT32, "site_id",
-                Schema.FieldType.DOUBLE, "price")
+        MockedBoundedTable.of(schema)
             .addRows(1L, 2, 1.0, 1L, null, 2.0, 2L, 1, 3.0, 2L, null, 4.0, 5L, 5, 5.0));
-    registerTable(
-        "SUB_ORDER_RAM",
-        MockedBoundedTable.of(
-            Schema.FieldType.INT64, "order_id",
-            Schema.FieldType.INT32, "site_id",
-            Schema.FieldType.DOUBLE, "price"));
+    registerTable("SUB_ORDER_RAM", MockedBoundedTable.of(schema));
 
     String sql =
         "INSERT INTO SUB_ORDER_RAM(order_id, site_id, price)  SELECT "
@@ -210,10 +207,7 @@ public void testOrderBy_nullsLast() throws Exception {
     PCollection<Row> rows = compilePipeline(sql, pipeline);
     PAssert.that(rows)
         .containsInAnyOrder(
-            TestUtils.RowsBuilder.of(
-                    Schema.FieldType.INT64, "order_id",
-                    Schema.FieldType.INT32, "site_id",
-                    Schema.FieldType.DOUBLE, "price")
+            TestUtils.RowsBuilder.of(schema)
                 .addRows(1L, 2, 1.0, 1L, null, 2.0, 2L, 1, 3.0, 2L, null, 4.0)
                 .getRows());
     pipeline.run().waitUntilFinish();
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java
index 9c2ec4f7504..1388b98108c 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java
@@ -54,7 +54,7 @@ public void encodeAndDecode() throws Exception {
             .add("col_boolean", SqlTypeName.BOOLEAN)
             .build();
 
-    Schema beamSchema = CalciteUtils.toBeamSchema(relDataType);
+    Schema beamSchema = CalciteUtils.toSchema(relDataType);
 
     Row row =
         Row.withSchema(beamSchema)
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
index 1f65b4bee97..7eea6ece289 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
@@ -124,4 +124,48 @@ public void testToCalciteRowTypeNullable() {
     assertEquals(SqlTypeName.VARBINARY, fields.get("f9").getSqlTypeName());
     assertEquals(SqlTypeName.VARCHAR, fields.get("f10").getSqlTypeName());
   }
+
+  @Test
+  public void testRoundTripBeamSchema() {
+    final Schema schema =
+        Schema.builder()
+            .addField("f1", Schema.FieldType.BYTE)
+            .addField("f2", Schema.FieldType.INT16)
+            .addField("f3", Schema.FieldType.INT32)
+            .addField("f4", Schema.FieldType.INT64)
+            .addField("f5", Schema.FieldType.FLOAT)
+            .addField("f6", Schema.FieldType.DOUBLE)
+            .addField("f7", Schema.FieldType.DECIMAL)
+            .addField("f8", Schema.FieldType.BOOLEAN)
+            .addField("f9", Schema.FieldType.BYTES)
+            .addField("f10", Schema.FieldType.STRING)
+            .build();
+
+    final Schema out =
+        CalciteUtils.toSchema(CalciteUtils.toCalciteRowType(schema, dataTypeFactory));
+
+    assertEquals(schema, out);
+  }
+
+  @Test
+  public void testRoundTripBeamNullableSchema() {
+    final Schema schema =
+        Schema.builder()
+            .addNullableField("f1", Schema.FieldType.BYTE)
+            .addNullableField("f2", Schema.FieldType.INT16)
+            .addNullableField("f3", Schema.FieldType.INT32)
+            .addNullableField("f4", Schema.FieldType.INT64)
+            .addNullableField("f5", Schema.FieldType.FLOAT)
+            .addNullableField("f6", Schema.FieldType.DOUBLE)
+            .addNullableField("f7", Schema.FieldType.DECIMAL)
+            .addNullableField("f8", Schema.FieldType.BOOLEAN)
+            .addNullableField("f9", Schema.FieldType.BYTES)
+            .addNullableField("f10", Schema.FieldType.STRING)
+            .build();
+
+    final Schema out =
+        CalciteUtils.toSchema(CalciteUtils.toCalciteRowType(schema, dataTypeFactory));
+
+    assertEquals(schema, out);
+  }
 }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
index d12b71971d8..61bdc4cbf8b 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
@@ -75,7 +75,7 @@ public void testCsvRecorderEncoder() throws Exception {
 
   private static Schema genSchema() {
     JavaTypeFactory typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
-    return CalciteUtils.toBeamSchema(
+    return CalciteUtils.toSchema(
         typeFactory
             .builder()
             .add("order_id", SqlTypeName.BIGINT)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 131081)
    Time Spent: 3h 20m  (was: 3h 10m)

> [SQL] Nullability of aggregation expressions isn't inferred properly
> --------------------------------------------------------------------
>
>                 Key: BEAM-5056
>                 URL: https://issues.apache.org/jira/browse/BEAM-5056
>             Project: Beam
>          Issue Type: Bug
>          Components: dsl-sql
>            Reporter: Gleb Kanterov
>            Assignee: Xu Mingmin
>            Priority: Major
>          Time Spent: 3h 20m
>  Remaining Estimate: 0h
>
> Given schema and rows:
> {code:java}
> Schema schema =
>     Schema.builder()
>         .addNullableField("f_int1", Schema.FieldType.INT32)
>         .addNullableField("f_int2", Schema.FieldType.INT32)
>         .build();
> List<Row> rows =
>     TestUtils.RowsBuilder.of(schema)
>         .addRows(null, null)
>         .getRows();
> {code}
> Following query fails:
> {code:sql}
> SELECT AVG(f_int1) FROM PCOLLECTION GROUP BY f_int2
> {code}
> {code:java}
> Caused by: java.lang.IllegalArgumentException: Field EXPR$0 is not nullable{code}
>  
>  



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Mime
View raw message