flink-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From twalthr <...@git.apache.org>
Subject [GitHub] flink pull request #5043: [FLINK-2170] [connectors] Add OrcRowInputFormat an...
Date Wed, 22 Nov 2017 14:25:29 GMT
Github user twalthr commented on a diff in the pull request:

    https://github.com/apache/flink/pull/5043#discussion_r152579106
  
    --- Diff: flink-connectors/flink-orc/src/main/java/org/apache/flink/orc/OrcRowInputFormat.java
---
    @@ -0,0 +1,747 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.orc;
    +
    +import org.apache.flink.annotation.VisibleForTesting;
    +import org.apache.flink.api.common.io.FileInputFormat;
    +import org.apache.flink.api.common.typeinfo.TypeInformation;
    +import org.apache.flink.api.java.tuple.Tuple2;
    +import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
    +import org.apache.flink.api.java.typeutils.RowTypeInfo;
    +import org.apache.flink.core.fs.FileInputSplit;
    +import org.apache.flink.core.fs.Path;
    +import org.apache.flink.types.Row;
    +
    +import org.apache.hadoop.conf.Configuration;
    +import org.apache.hadoop.hive.common.type.HiveDecimal;
    +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
    +
    +import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf;
    +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument;
    +import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory;
    +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
    +import org.apache.orc.OrcConf;
    +import org.apache.orc.OrcFile;
    +import org.apache.orc.Reader;
    +import org.apache.orc.RecordReader;
    +import org.apache.orc.StripeInformation;
    +import org.apache.orc.TypeDescription;
    +
    +import org.slf4j.Logger;
    +import org.slf4j.LoggerFactory;
    +
    +import java.io.IOException;
    +import java.io.ObjectInputStream;
    +import java.io.ObjectOutputStream;
    +import java.io.Serializable;
    +import java.math.BigDecimal;
    +import java.sql.Date;
    +import java.sql.Timestamp;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.List;
    +
    +import static org.apache.flink.orc.OrcUtils.fillRows;
    +
    +/**
    + * InputFormat to read ORC files.
    + */
    +public class OrcRowInputFormat extends FileInputFormat<Row> implements ResultTypeQueryable<Row>
{
    +
    +	private static final Logger LOG = LoggerFactory.getLogger(OrcRowInputFormat.class);
    +	// the number of rows read in a batch
    +	private static final int DEFAULT_BATCH_SIZE = 1000;
    +
    +	// the number of fields rows to read in a batch
    +	private int batchSize;
    +	// the configuration to read with
    +	private Configuration conf;
    +	// the schema of the ORC files to read
    +	private TypeDescription schema;
    +
    +	// the fields of the ORC schema that the returned Rows are composed of.
    +	private int[] selectedFields;
    +	// the type information of the Rows returned by this InputFormat.
    +	private transient RowTypeInfo rowType;
    +
    +	// the ORC reader
    +	private transient RecordReader orcRowsReader;
    +	// the vectorized row data to be read in a batch
    +	private transient VectorizedRowBatch rowBatch;
    +	// the vector of rows that is read in a batch
    +	private transient Row[] rows;
    +
    +	// the number of rows in the current batch
    +	private transient int rowsInBatch;
    +	// the index of the next row to return
    +	private transient int nextRow;
    +
    +	private ArrayList<Predicate> conjunctPredicates = new ArrayList<>();
    +
    +	/**
    +	 * Creates an OrcRowInputFormat.
    +	 *
    +	 * @param path The path to read ORC files from.
    +	 * @param schemaString The schema of the ORC files as String.
    +	 * @param orcConfig The configuration to read the ORC files with.
    +	 */
    +	public OrcRowInputFormat(String path, String schemaString, Configuration orcConfig)
{
    +		this(path, TypeDescription.fromString(schemaString), orcConfig, DEFAULT_BATCH_SIZE);
    +	}
    +
    +	/**
    +	 * Creates an OrcRowInputFormat.
    +	 *
    +	 * @param path The path to read ORC files from.
    +	 * @param schemaString The schema of the ORC files as String.
    +	 * @param orcConfig The configuration to read the ORC files with.
    +	 * @param batchSize The number of Row objects to read in a batch.
    +	 */
    +	public OrcRowInputFormat(String path, String schemaString, Configuration orcConfig,
int batchSize) {
    +		this(path, TypeDescription.fromString(schemaString), orcConfig, batchSize);
    +	}
    +
    +	/**
    +	 * Creates an OrcRowInputFormat.
    +	 *
    +	 * @param path The path to read ORC files from.
    +	 * @param orcSchema The schema of the ORC files as ORC TypeDescription.
    +	 * @param orcConfig The configuration to read the ORC files with.
    +	 * @param batchSize The number of Row objects to read in a batch.
    +	 */
    +	public OrcRowInputFormat(String path, TypeDescription orcSchema, Configuration orcConfig,
int batchSize) {
    +		super(new Path(path));
    +
    +		// configure OrcInputFormat
    +		this.schema = orcSchema;
    +		this.rowType = (RowTypeInfo) OrcUtils.schemaToTypeInfo(schema);
    +		this.conf = orcConfig;
    +		this.batchSize = batchSize;
    +
    +		// set default selection mask, i.e., all fields.
    +		this.selectedFields = new int[this.schema.getChildren().size()];
    +		for (int i = 0; i < selectedFields.length; i++) {
    +			this.selectedFields[i] = i;
    +		}
    +	}
    +
    +	/**
    +	 * Adds a filter predicate to reduce the number of rows to be returned by the input
format.
    +	 * Multiple conjunctive predicates can be added by calling this method multiple times.
    +	 *
    +	 * <p>Note: Predicates can significantly reduce the amount of data that is read.
    +	 * However, the OrcRowInputFormat does not guarantee that all returned rows qualitfy
the
    +	 * predicates. Moreover, predicates are only applied if the referenced field is among
the
    +	 * selected fields.</p>
    +	 *
    +	 * @param predicate The filter predicate.
    +	 */
    +	public void addPredicate(Predicate predicate) {
    +		// validate
    +		validatePredicate(predicate);
    +		// add predicate
    +		this.conjunctPredicates.add(predicate);
    +	}
    +
    +	private void validatePredicate(Predicate pred) {
    +		if (pred instanceof ColumnPredicate) {
    +			// check column name
    +			String colName = ((ColumnPredicate) pred).columnName;
    +			if (!this.schema.getFieldNames().contains(colName)) {
    +				throw new IllegalArgumentException("Predicate cannot be applied. " +
    +					"Column '" + colName + "' does not exist in ORC schema.");
    +			}
    +		} else if (pred instanceof Not) {
    +			validatePredicate(((Not) pred).child());
    +		} else if (pred instanceof Or) {
    +			for (Predicate p : ((Or) pred).children()) {
    +				validatePredicate(p);
    +			}
    +		}
    +	}
    +
    +	/**
    +	 * Selects the fields from the ORC schema that are returned by InputFormat.
    +	 *
    +	 * @param selectedFields The indices of the fields of the ORC schema that are returned
by the InputFormat.
    +	 */
    +	public void selectFields(int... selectedFields) {
    +		// set field mapping
    +		this.selectedFields = selectedFields;
    +		// adapt result type
    +		this.rowType = RowTypeInfo.projectFields(this.rowType, selectedFields);
    +	}
    +
    +	/**
    +	 * Computes the ORC projection mask of the fields to include from the selected fields.rowOrcInputFormat.nextRecord(null).
    +	 *
    +	 * @return The ORC projection mask.
    +	 */
    +	private boolean[] computeProjectionMask() {
    +		// mask with all fields of the schema
    +		boolean[] projectionMask = new boolean[schema.getMaximumId() + 1];
    +		// for each selected field
    +		for (int inIdx : selectedFields) {
    +			// set all nested fields of a selected field to true
    +			TypeDescription fieldSchema = schema.getChildren().get(inIdx);
    +			for (int i = fieldSchema.getId(); i <= fieldSchema.getMaximumId(); i++) {
    +				projectionMask[i] = true;
    +			}
    +		}
    +		return projectionMask;
    +	}
    +
    +	@Override
    +	public void openInputFormat() throws IOException {
    +		super.openInputFormat();
    +		// create and initialize the row batch
    +		this.rows = new Row[batchSize];
    +		for (int i = 0; i < batchSize; i++) {
    +			rows[i] = new Row(selectedFields.length);
    +		}
    +	}
    +
    +	@Override
    +	public void open(FileInputSplit fileSplit) throws IOException {
    +
    +		LOG.debug("Opening ORC file {}", fileSplit.getPath());
    +
    +		// open ORC file and create reader
    +		org.apache.hadoop.fs.Path hPath = new org.apache.hadoop.fs.Path(fileSplit.getPath().getPath());
    +		Reader orcReader = OrcFile.createReader(hPath, OrcFile.readerOptions(conf));
    +
    +		// get offset and length for the stripes that start in the split
    +		Tuple2<Long, Long> offsetAndLength = getOffsetAndLengthForSplit(fileSplit, getStripes(orcReader));
    +
    +		// create ORC row reader configuration
    +		Reader.Options options = getOptions(orcReader)
    +			.schema(schema)
    +			.range(offsetAndLength.f0, offsetAndLength.f1)
    +			.useZeroCopy(OrcConf.USE_ZEROCOPY.getBoolean(conf))
    +			.skipCorruptRecords(OrcConf.SKIP_CORRUPT_DATA.getBoolean(conf))
    +			.tolerateMissingSchema(OrcConf.TOLERATE_MISSING_SCHEMA.getBoolean(conf));
    +
    +		// configure filters
    +		if (!conjunctPredicates.isEmpty()) {
    +			SearchArgument.Builder b = SearchArgumentFactory.newBuilder();
    +			b = b.startAnd();
    +			for (Predicate predicate : conjunctPredicates) {
    +				predicate.add(b);
    +			}
    +			b = b.end();
    +			options.searchArgument(b.build(), new String[]{});
    +		}
    +
    +		// configure selected fields
    +		options.include(computeProjectionMask());
    +
    +		// create ORC row reader
    +		this.orcRowsReader = orcReader.rows(options);
    +
    +		// assign ids
    +		this.schema.getId();
    +		// create row batch
    +		this.rowBatch = schema.createRowBatch(batchSize);
    +		rowsInBatch = 0;
    +		nextRow = 0;
    +	}
    +
    +	@VisibleForTesting
    +	Reader.Options getOptions(Reader orcReader) {
    +		return orcReader.options();
    +	}
    +
    +	@VisibleForTesting
    +	List<StripeInformation> getStripes(Reader orcReader) {
    +		return orcReader.getStripes();
    +	}
    +
    +	private Tuple2<Long, Long> getOffsetAndLengthForSplit(FileInputSplit split, List<StripeInformation>
stripes) {
    +		long splitStart = split.getStart();
    +		long splitEnd = splitStart + split.getLength();
    +
    +		long readStart = Long.MAX_VALUE;
    +		long readEnd = Long.MIN_VALUE;
    +
    +		for (StripeInformation s : stripes) {
    +			if (splitStart <= s.getOffset() && s.getOffset() < splitEnd) {
    +				// stripe starts in split, so it is included
    +				readStart = Math.min(readStart, s.getOffset());
    +				readEnd = Math.max(readEnd, s.getOffset() + s.getLength());
    +			}
    +		}
    +
    +		if (readStart < Long.MAX_VALUE) {
    +			// at least one split is included
    +			return Tuple2.of(readStart, readEnd - readStart);
    +		} else {
    +			return Tuple2.of(0L, 0L);
    +		}
    +	}
    +
    +	@Override
    +	public void close() throws IOException {
    +		if (orcRowsReader != null) {
    +			this.orcRowsReader.close();
    +		}
    +		this.orcRowsReader = null;
    +	}
    +
    +	@Override
    +	public void closeInputFormat() throws IOException {
    +		this.rows = null;
    +		this.rows = null;
    +		this.schema = null;
    +		this.rowBatch = null;
    +	}
    +
    +	@Override
    +	public boolean reachedEnd() throws IOException {
    +		return !ensureBatch();
    +	}
    +
    +	/**
    +	 * Checks if there is at least one row left in the batch to return.
    +	 * If no more row are available, it reads another batch of rows.
    +	 *
    +	 * @return Returns true if there is one more row to return, false otherwise.
    +	 * @throws IOException throw if an exception happens while reading a batch.
    +	 */
    +	private boolean ensureBatch() throws IOException {
    +
    +		if (nextRow >= rowsInBatch) {
    +			// No more rows available in the Rows array.
    +			nextRow = 0;
    +			// Try to read the next batch if rows from the ORC file.
    +			boolean moreRows = orcRowsReader.nextBatch(rowBatch);
    +
    +			if (moreRows) {
    +				// Load the data into the Rows array.
    +				rowsInBatch = fillRows(rows, schema, rowBatch, selectedFields);
    +			}
    +			return moreRows;
    +		}
    +		// there is at least one Row left in the Rows array.
    +		return true;
    +	}
    +
    +	@Override
    +	public Row nextRecord(Row reuse) throws IOException {
    +		// return the next row
    +		return rows[this.nextRow++];
    +	}
    +
    +	@Override
    +	public TypeInformation<Row> getProducedType() {
    +		return rowType;
    +	}
    +
    +	// --------------------------------------------------------------------------------------------
    +	//  Custom serialization methods
    +	// --------------------------------------------------------------------------------------------
    +
    +	private void writeObject(ObjectOutputStream out) throws IOException {
    +		out.writeInt(batchSize);
    +		this.conf.write(out);
    +		out.writeUTF(schema.toString());
    +
    +		out.writeInt(selectedFields.length);
    +		for (int f : selectedFields) {
    +			out.writeInt(f);
    +		}
    +
    +		out.writeInt(conjunctPredicates.size());
    +		for (Predicate p : conjunctPredicates) {
    +			out.writeObject(p);
    +		}
    +	}
    +
    +	@SuppressWarnings("unchecked")
    +	private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException
{
    +		batchSize = in.readInt();
    +		org.apache.hadoop.conf.Configuration configuration = new org.apache.hadoop.conf.Configuration();
    +		configuration.readFields(in);
    +
    +		if (this.conf == null) {
    +			this.conf = configuration;
    +		}
    +		this.schema = TypeDescription.fromString(in.readUTF());
    +
    +		this.selectedFields = new int[in.readInt()];
    +		for (int i = 0; i < selectedFields.length; i++) {
    +			this.selectedFields[i] = in.readInt();
    +		}
    +
    +		this.conjunctPredicates = new ArrayList<>();
    +		int numPreds = in.readInt();
    +		for (int i = 0; i < numPreds; i++) {
    +			conjunctPredicates.add((Predicate) in.readObject());
    +		}
    +	}
    +
    +	// --------------------------------------------------------------------------------------------
    +	//  Classes to define predicates
    +	// --------------------------------------------------------------------------------------------
    +
    +	/**
    +	 * A filter predicate that can be evaluated by the OrcRowInputFormat.
    +	 */
    +	public abstract static class Predicate implements Serializable {
    +		protected abstract SearchArgument.Builder add(SearchArgument.Builder builder);
    +	}
    +
    +	abstract static class ColumnPredicate extends Predicate {
    +		final String columnName;
    +		final PredicateLeaf.Type literalType;
    +
    +		ColumnPredicate(String columnName, PredicateLeaf.Type literalType) {
    +			this.columnName = columnName;
    +			this.literalType = literalType;
    +		}
    +
    +		Object castLiteral(Serializable literal) {
    +
    +			switch (literalType) {
    +				case LONG:
    +					if (literal instanceof Byte) {
    +						return new Long((Byte) literal);
    +					} else if (literal instanceof Short) {
    +						return new Long((Short) literal);
    +					} else if (literal instanceof Integer) {
    +						return new Long((Integer) literal);
    +					} else if (literal instanceof Long) {
    +						return literal;
    +					} else {
    +						throw new IllegalArgumentException("A predicate on a LONG column requires an integer
" +
    +							"literal, i.e., Byte, Short, Integer, or Long.");
    +					}
    +				case FLOAT:
    +					if (literal instanceof Float) {
    +						return new Double((Float) literal);
    +					} else if (literal instanceof Double) {
    +						return literal;
    +					} else if (literal instanceof BigDecimal) {
    +						return ((BigDecimal) literal).doubleValue();
    +					} else {
    +						throw new IllegalArgumentException("A predicate on a FLOAT column requires a floating
" +
    +							"literal, i.e., Float or Double.");
    +					}
    +				case STRING:
    +					if (literal instanceof String) {
    +						return literal;
    +					} else {
    +						throw new IllegalArgumentException("A predicate on a STRING column requires a floating
" +
    +							"literal, i.e., Float or Double.");
    +					}
    +				case BOOLEAN:
    +					if (literal instanceof Boolean) {
    +						return literal;
    +					} else {
    +						throw new IllegalArgumentException("A predicate on a BOOLEAN column requires a
Boolean literal.");
    +					}
    +				case DATE:
    +					if (literal instanceof Date) {
    +						return literal;
    +					} else {
    +						throw new IllegalArgumentException("A predicate on a DATE column requires a java.sql.Date
literal.");
    +					}
    +				case TIMESTAMP:
    +					if (literal instanceof Timestamp) {
    +						return literal;
    +					} else {
    +						throw new IllegalArgumentException("A predicate on a TIMESTAMP column requires
a java.sql.Timestamp literal.");
    +					}
    +				case DECIMAL:
    +					if (literal instanceof BigDecimal) {
    +						return new HiveDecimalWritable(HiveDecimal.create((BigDecimal) literal));
    +					} else {
    +						throw new IllegalArgumentException("A predicate on a DECIMAL column requires a
BigDecimal literal.");
    +					}
    +				default:
    +					throw new IllegalArgumentException("Unknown literal type " + literalType);
    +			}
    +		}
    +	}
    +
    +	abstract static class BinaryPredicate extends ColumnPredicate {
    +		final Serializable literal;
    +
    +		BinaryPredicate(String columnName, PredicateLeaf.Type literalType, Serializable literal)
{
    +			super(columnName, literalType);
    +			this.literal = literal;
    +		}
    +	}
    +
    +	/**
    +	 * An EQUALS predicate that can be evaluated by the OrcRowInputFormat.
    +	 */
    +	public static class Equals extends BinaryPredicate {
    +		/**
    +		 * Creates an EQUALS predicate.
    +		 *
    +		 * @param columnName The column to check.
    +		 * @param literalType The type of the literal.
    +		 * @param literal The literal value to check the column against.
    +		 */
    +		public Equals(String columnName, PredicateLeaf.Type literalType, Serializable literal)
{
    +			super(columnName, literalType, literal);
    +		}
    +
    +		@Override
    +		protected SearchArgument.Builder add(SearchArgument.Builder builder) {
    +			return builder.equals(columnName, literalType, castLiteral(literal));
    +		}
    +
    +		@Override
    +		public String toString() {
    +			return columnName + " = " + literal;
    +		}
    +	}
    +
    +	/**
    +	 * An EQUALS predicate that can be evaluated with Null safety by the OrcRowInputFormat.
    +	 */
    +	public static class NullSafeEquals extends BinaryPredicate {
    +		/**
    +		 * Creates a null-safe EQUALS predicate.
    +		 *
    +		 * @param columnName The column to check.
    +		 * @param literalType The type of the literal.
    +		 * @param literal The literal value to check the column against.
    +		 */
    +		public NullSafeEquals(String columnName, PredicateLeaf.Type literalType, Serializable
literal) {
    --- End diff --
    
    An untested interface class ;-)


---

Mime
View raw message