flink-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ggevay <...@git.apache.org>
Subject [GitHub] flink pull request: [FLINK-3477] [runtime] Add hash-based combine strategy f...
Date Wed, 01 Jun 2016 09:44:18 GMT
Github user ggevay commented on a diff in the pull request:

    https://github.com/apache/flink/pull/1517#discussion_r65331149
  
    --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/operators/hash/ReduceHashTable.java
---
    @@ -0,0 +1,1048 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.runtime.operators.hash;
    +
    +import org.apache.flink.api.common.functions.ReduceFunction;
    +import org.apache.flink.api.common.typeutils.SameTypePairComparator;
    +import org.apache.flink.api.common.typeutils.TypeComparator;
    +import org.apache.flink.api.common.typeutils.TypePairComparator;
    +import org.apache.flink.api.common.typeutils.TypeSerializer;
    +import org.apache.flink.core.memory.DataInputView;
    +import org.apache.flink.core.memory.MemorySegment;
    +import org.apache.flink.runtime.io.disk.RandomAccessInputView;
    +import org.apache.flink.runtime.memory.AbstractPagedOutputView;
    +import org.apache.flink.util.MathUtils;
    +import org.apache.flink.util.Collector;
    +import org.apache.flink.util.MutableObjectIterator;
    +import org.slf4j.Logger;
    +import org.slf4j.LoggerFactory;
    +
    +import java.io.EOFException;
    +import java.io.IOException;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.List;
    +
    +/**
    + * This hash table supports updating elements, and it also has processRecordWithReduce,
    + * which makes one reduce step with the given record.
    + *
    + * The memory is divided into three areas:
    + *  - Bucket area: they contain bucket heads:
    + *    an 8 byte pointer to the first link of a linked list in the record area
    + *  - Record area: this contains the actual data in linked list elements. A linked list
element starts
    + *    with an 8 byte pointer to the next element, and then the record follows.
    + *  - Staging area: This is a small, temporary storage area for writing updated records.
This is needed,
    + *    because before serializing a record, there is no way to know in advance how large
will it be.
    + *    Therefore, we can't serialize directly into the record area when we are doing an
update, because
    + *    if it turns out to be larger then the old record, then it would override some other
record
    + *    that happens to be after the old one in memory. The solution is to serialize to
the staging area first,
    + *    and then copy it to the place of the original if it has the same size, otherwise
allocate a new linked
    + *    list element at the end of the record area, and mark the old one as abandoned.
This creates "holes" in
    + *    the record area, so compactions are eventually needed.
    + *
    + *  Compaction happens by deleting everything in the bucket area, and then reinserting
all elements.
    + *  The reinsertion happens by forgetting the structure (the linked lists) of the record
area, and reading it
    + *  sequentially, and inserting all non-abandoned records, starting from the beginning
of the record area.
    + *  Note, that insertions never override a record that have not been read by the reinsertion
sweep, because
    + *  both the insertions and readings happen sequentially in the record area, and the
insertions obviously
    + *  never overtake the reading sweep.
    + *
    + *  Note: we have to abandon the old linked list element even when the updated record
has a smaller size
    + *  than the original, because otherwise we wouldn't know where the next record starts
during a reinsertion
    + *  sweep.
    + *
    + *  The number of buckets depends on how large are the records. The serializer might
be able to tell us this,
    + *  so in this case, we will calculate the number of buckets upfront, and won't do resizes.
    + *  If the serializer doesn't know the size, then we start with a small number of buckets,
and do resizes as more
    + *  elements are inserted than the number of buckets.
    + *
    + *  The number of memory segments given to the staging area is usually one, because it
just needs to hold
    + *  one record.
    + *
    + * Note: For hashing, we need to use MathUtils.hash because of its avalanche property,
so that
    + * changing only some high bits of the original value shouldn't leave the lower bits
of the hash unaffected.
    + * This is because when choosing the bucket for a record, we mask only the
    + * lower bits (see numBucketsMask). Lots of collisions would occur when, for example,
    + * the original value that is hashed is some bitset, where lots of different values
    + * that are different only in the higher bits will actually occur.
    + */
    +
    +public class ReduceHashTable<T> extends AbstractMutableHashTable<T> {
    +
    +	private static final Logger LOG = LoggerFactory.getLogger(ReduceHashTable.class);
    +
    +	/** The minimum number of memory segments ReduceHashTable needs to be supplied with
in order to work. */
    +	private static final int MIN_NUM_MEMORY_SEGMENTS = 3;
    +
    +	// Note: the following two constants can't be negative, because negative values are
reserved for storing the
    +	// negated size of the record, when it is abandoned (not part of any linked list).
    +
    +	/** The last link in the linked lists will have this as next pointer. */
    +	private static final long END_OF_LIST = Long.MAX_VALUE;
    +
    +	/** This value means that prevElemPtr is "pointing to the bucket head", and not into
the record segments. */
    +	private static final long INVALID_PREV_POINTER = Long.MAX_VALUE - 1;
    +
    +
    +	private static final long RECORD_OFFSET_IN_LINK = 8;
    +
    +
    +	/** this is used by processRecordWithReduce */
    +	private final ReduceFunction<T> reducer;
    +
    +	/** emit() sends data to outputCollector */
    +	private final Collector<T> outputCollector;
    +
    +	private final boolean objectReuseEnabled;
    +
    +	/**
    +	 * This initially contains all the memory we have, and then segments
    +	 * are taken from it by bucketSegments, recordArea, and stagingSegments.
    +	 */
    +	private final ArrayList<MemorySegment> freeMemorySegments;
    +
    +	private final int numAllMemorySegments;
    +
    +	private final int segmentSize;
    +
    +	/**
    +	 * These will contain the bucket heads.
    +	 * The bucket heads are pointers to the linked lists containing the actual records.
    +	 */
    +	private MemorySegment[] bucketSegments;
    +
    +	private static final int bucketSize = 8, bucketSizeBits = 3;
    +
    +	private int numBuckets;
    +	private int numBucketsMask;
    +	private final int numBucketsPerSegment, numBucketsPerSegmentBits, numBucketsPerSegmentMask;
    +
    +	/**
    +	 * The segments where the actual data is stored.
    +	 */
    +	private final RecordArea recordArea;
    +
    +	/**
    +	 * Segments for the staging area.
    +	 * (It should contain at most one record at all times.)
    +	 */
    +	private final ArrayList<MemorySegment> stagingSegments;
    +	private final RandomAccessInputView stagingSegmentsInView;
    +	private final StagingOutputView stagingSegmentsOutView;
    +
    +	private T reuse;
    +
    +	/** This is the internal prober that insertOrReplaceRecord and processRecordWithReduce
use. */
    +	private final HashTableProber<T> prober;
    +
    +	/** The number of elements currently held by the table. */
    +	private long numElements = 0;
    +
    +	/** The number of bytes wasted by updates that couldn't overwrite the old record due
to size change. */
    +	private long holes = 0;
    +
    +	/**
    +	 * If the serializer knows the size of the records, then we can calculate the optimal
number of buckets
    +	 * upfront, so we don't need resizes.
    +	 */
    +	private boolean enableResize;
    +
    +
    +	/**
    +	 * This constructor is for the case when will only call those operations that are also
    +	 * present on CompactingHashTable.
    +	 */
    +	public ReduceHashTable(TypeSerializer<T> serializer, TypeComparator<T> comparator,
List<MemorySegment> memory) {
    +		this(serializer, comparator, memory, null, null, false);
    +	}
    +
    +	public ReduceHashTable(TypeSerializer<T> serializer, TypeComparator<T> comparator,
List<MemorySegment> memory,
    +						ReduceFunction<T> reducer, Collector<T> outputCollector, boolean objectReuseEnabled)
{
    +		super(serializer, comparator);
    +		this.reducer = reducer;
    +		this.numAllMemorySegments = memory.size();
    +		this.freeMemorySegments = new ArrayList<>(memory);
    +		this.outputCollector = outputCollector;
    +		this.objectReuseEnabled = objectReuseEnabled;
    +
    +		// some sanity checks first
    +		if (freeMemorySegments.size() < MIN_NUM_MEMORY_SEGMENTS) {
    +			throw new IllegalArgumentException("Too few memory segments provided. ReduceHashTable
needs at least " +
    +				MIN_NUM_MEMORY_SEGMENTS + " memory segments.");
    +		}
    +
    +		// Get the size of the first memory segment and record it. All further buffers must
have the same size.
    +		// the size must also be a power of 2
    +		segmentSize = freeMemorySegments.get(0).size();
    +		if ( (segmentSize & segmentSize - 1) != 0) {
    +			throw new IllegalArgumentException("Hash Table requires buffers whose size is a power
of 2.");
    +		}
    +
    +		this.numBucketsPerSegment = segmentSize / bucketSize;
    +		this.numBucketsPerSegmentBits = MathUtils.log2strict(this.numBucketsPerSegment);
    +		this.numBucketsPerSegmentMask = (1 << this.numBucketsPerSegmentBits) - 1;
    +
    +		recordArea = new RecordArea(segmentSize);
    +
    +		stagingSegments = new ArrayList<>();
    +		stagingSegmentsInView = new RandomAccessInputView(stagingSegments, segmentSize, false);
    +		stagingSegmentsOutView = new StagingOutputView(stagingSegments, segmentSize);
    +
    +		prober = new HashTableProber<>(buildSideComparator, new SameTypePairComparator<>(buildSideComparator));
    +
    +		enableResize = buildSideSerializer.getLength() == -1;
    +	}
    +
    +	/**
    +	 * Gets the total capacity of this hash table, in bytes.
    +	 *
    +	 * @return The hash table's total capacity.
    +	 */
    +	public long getCapacity() {
    +		return numAllMemorySegments * segmentSize;
    +	}
    +
    +	/**
    +	 * Gets the number of bytes currently occupied in this hash table.
    +	 *
    +	 * @return The number of bytes occupied.
    +	 */
    +	public long getOccupancy() {
    +		return numAllMemorySegments * segmentSize - freeMemorySegments.size() * segmentSize;
    +	}
    +
    +	private void open(int numBucketSegments) {
    +		synchronized (stateLock) {
    +			if (!closed) {
    +				throw new IllegalStateException("currently not closed.");
    +			}
    +			closed = false;
    +		}
    +
    +		allocateBucketSegments(numBucketSegments);
    +
    +		stagingSegments.add(forcedAllocateSegment());
    +
    +		reuse = buildSideSerializer.createInstance();
    +	}
    +
    +	/**
    +	 * Initialize the hash table
    +	 */
    +	@Override
    +	public void open() {
    +		open(calcInitialNumBucketSegments());
    +	}
    +
    +	@Override
    +	public void close() {
    +		// make sure that we close only once
    +		synchronized (stateLock) {
    +			if (closed) {
    +				return;
    +			}
    +			closed = true;
    +		}
    +
    +		LOG.debug("Closing ReduceHashTable and releasing resources.");
    +
    +		releaseBucketSegments();
    +
    +		recordArea.giveBackSegments();
    +
    +		freeMemorySegments.addAll(stagingSegments);
    +		stagingSegments.clear();
    +
    +		numElements = 0;
    +		holes = 0;
    +	}
    +
    +	@Override
    +	public void abort() {
    +		LOG.debug("Aborting ReduceHashTable.");
    +		close();
    +	}
    +
    +	@Override
    +	public List<MemorySegment> getFreeMemory() {
    +		if (!this.closed) {
    +			throw new IllegalStateException("Cannot return memory while ReduceHashTable is open.");
    +		}
    +
    +		return freeMemorySegments;
    +	}
    +
    +	private int calcInitialNumBucketSegments() {
    +		int recordLength = buildSideSerializer.getLength();
    +		double fraction;
    +		if (recordLength == -1) {
    +			// It seems that resizing is quite efficient, so we can err here on the too few bucket
segments side.
    +			// Even with small records, we lose only ~15% speed.
    +			fraction = 0.1;
    +		} else {
    +			fraction = 8.0 / (16 + recordLength);
    +			// note: enableResize is false in this case, so no resizing will happen
    +		}
    +
    +		int ret = Math.max(1, MathUtils.roundDownToPowerOf2((int)(numAllMemorySegments * fraction)));
    +
    +		// We can't handle more than Integer.MAX_VALUE buckets (eg. because hash functions
return int)
    +		if ((long)ret * numBucketsPerSegment > Integer.MAX_VALUE) {
    +			ret = MathUtils.roundDownToPowerOf2(Integer.MAX_VALUE / numBucketsPerSegment);
    +		}
    +		return ret;
    +	}
    +
    +	private void allocateBucketSegments(int numBucketSegments) {
    +		if (numBucketSegments < 1) {
    +			throw new RuntimeException("Bug in ReduceHashTable");
    +		}
    +
    +		bucketSegments = new MemorySegment[numBucketSegments];
    +		for(int i = 0; i < bucketSegments.length; i++) {
    +			bucketSegments[i] = forcedAllocateSegment();
    +			// Init all pointers in all buckets to END_OF_LIST
    +			for(int j = 0; j < numBucketsPerSegment; j++) {
    +				bucketSegments[i].putLong(j << bucketSizeBits, END_OF_LIST);
    +			}
    +		}
    +		numBuckets = numBucketSegments * numBucketsPerSegment;
    +		numBucketsMask = (1 << MathUtils.log2strict(numBuckets)) - 1;
    +	}
    +
    +	private void releaseBucketSegments() {
    +		freeMemorySegments.addAll(Arrays.asList(bucketSegments));
    +		bucketSegments = null;
    +	}
    +
    +	private MemorySegment allocateSegment() {
    +		int s = freeMemorySegments.size();
    +		if (s > 0) {
    +			return freeMemorySegments.remove(s - 1);
    +		} else {
    +			return null;
    +		}
    +	}
    +
    +	private MemorySegment forcedAllocateSegment() {
    +		MemorySegment segment = allocateSegment();
    +		if (segment == null) {
    +			throw new RuntimeException("Bug in ReduceHashTable: A free segment should have been
available.");
    +		}
    +		return segment;
    +	}
    +
    +	/**
    +	 * Searches the hash table for the record with matching key, and updates it (making
one reduce step) if found,
    +	 * otherwise inserts a new entry.
    +	 *
    +	 * (If there are multiple entries with the same key, then it will update one of them.)
    +	 *
    +	 * @param record The record to be processed.
    +	 */
    +	public void processRecordWithReduce(T record) throws Exception {
    +		if (closed) {
    +			return;
    +		}
    +
    +		T match = prober.getMatchFor(record, reuse);
    +		if (match == null) {
    +			prober.insertAfterNoMatch(record);
    +		} else {
    +			// do the reduce step
    +			T res = reducer.reduce(match, record);
    +
    +			// We have given reuse to the reducer UDF, so create new one if object reuse is disabled
    +			if (!objectReuseEnabled) {
    +				reuse = buildSideSerializer.createInstance();
    +			}
    +
    +			prober.updateMatch(res);
    +		}
    +	}
    +
    +	/**
    +	 * Searches the hash table for a record with the given key.
    +	 * If it is found, then it is overridden with the specified record.
    +	 * Otherwise, the specified record is inserted.
    +	 * @param record The record to insert or to replace with.
    +	 * @throws IOException (EOFException specifically, if memory ran out)
    +     */
    +	@Override
    +	public void insertOrReplaceRecord(T record) throws IOException {
    +		if (closed) {
    +			return;
    +		}
    +
    +		T match = prober.getMatchFor(record, reuse);
    +		if (match == null) {
    +			prober.insertAfterNoMatch(record);
    +		} else {
    +			prober.updateMatch(record);
    +		}
    +	}
    +
    +	/**
    +	 * Inserts the given record into the hash table.
    +	 * Note: this method doesn't care about whether a record with the same key is already
present.
    +	 * @param record The record to insert.
    +	 * @throws IOException (EOFException specifically, if memory ran out)
    +     */
    +	@Override
    +	public void insert(T record) throws IOException {
    +		if (closed) {
    +			return;
    +		}
    +
    +		final int hashCode = MathUtils.jenkinsHash(buildSideComparator.hash(record));
    +		final int bucket = hashCode & numBucketsMask;
    +		final int bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which
segment contains the bucket
    +		final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex];
    +		final int bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits;
// offset of the bucket in the segment
    +		final long firstPointer = bucketSegment.getLong(bucketOffset);
    +
    +		try {
    +			final long newFirstPointer = recordArea.appendPointerAndRecord(firstPointer, record);
    +			bucketSegment.putLong(bucketOffset, newFirstPointer);
    +		} catch (EOFException ex) {
    +			compactOrThrow();
    +			insert(record);
    +			return;
    +		}
    +
    +		numElements++;
    +		resizeTableIfNecessary();
    +	}
    +
    +	private void resizeTableIfNecessary() throws IOException {
    +		if (enableResize && numElements > numBuckets) {
    +			final long newNumBucketSegments = 2L * bucketSegments.length;
    +			// Checks:
    +			// - we can't handle more than Integer.MAX_VALUE buckets
    +			// - don't take more memory than the free memory we have left
    +			// - the buckets shouldn't occupy more than half of all our memory
    +			if (newNumBucketSegments * numBucketsPerSegment < Integer.MAX_VALUE &&
    --- End diff --
    
    That wouldn't work because the number of bucket segments has to be a power of 2, so that
taking modulo is fast by bitwise and-ing with a mask, instead of dividing [1].
    
    [1] https://gmplib.org/~tege/x86-timing.pdf


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message