Github user dineshjoshi commented on a diff in the pull request:
https://github.com/apache/cassandra/pull/253#discussion_r216087977
--- Diff: src/java/org/apache/cassandra/net/MessageIn.java ---
@@ -231,4 +241,437 @@ public String toString()
sbuf.append("FROM:").append(from).append(" TYPE:").append(getMessageType()).append("
VERB:").append(verb);
return sbuf.toString();
}
+
+ public static MessageInProcessor getProcessor(InetAddressAndPort peer, int messagingVersion)
+ {
+ return getProcessor(peer, messagingVersion, MessageInProcessor.MESSAGING_SERVICE_CONSUMER);
+
+ }
+
+ public static MessageInProcessor getProcessor(InetAddressAndPort peer, int messagingVersion,
BiConsumer<MessageIn, Integer> messageConsumer)
+ {
+ return messagingVersion >= MessagingService.VERSION_40
+ ? new MessageInProcessorAsOf40(peer, messagingVersion, messageConsumer)
+ : new MessageInProcessorPre40(peer, messagingVersion, messageConsumer);
+
+ }
+
+ /**
+ * Implementations contain the mechanics and logic of parsing incoming messages.
Allows for both non-blocking
+ * and blocking styles of interaction via the {@link #process(ByteBuf)} and {@link
#process(RebufferingByteBufDataInputPlus)}
+ * methods, respectively.
+ *
+ * Does not contain the actual deserialization code for message fields nor payload.
That is left to the
+ * {@link MessageIn#read(DataInputPlus, int, int)} family of methods.
+ */
+ public static abstract class MessageInProcessor
+ {
+ /**
+ * The current state of deserializing an incoming message. This enum is only
used in the nonblocking versions.
+ */
+ public enum State
+ {
+ READ_PREFIX,
+ READ_IP_ADDRESS,
+ READ_VERB,
+ READ_PARAMETERS_SIZE,
+ READ_PARAMETERS_DATA,
+ READ_PAYLOAD_SIZE,
+ READ_PAYLOAD
+ }
+
+ static final int VERB_LENGTH = Integer.BYTES;
+
+ /**
+ * The default target for consuming deserialized {@link MessageIn}.
+ */
+ private static final BiConsumer<MessageIn, Integer> MESSAGING_SERVICE_CONSUMER
= (messageIn, id) -> MessagingService.instance().receive(messageIn, id);
+
+ final InetAddressAndPort peer;
+ final int messagingVersion;
+
+ /**
+ * Abstracts out depending directly on {@link MessagingService#receive(MessageIn,
int)}; this makes tests more sane
+ * as they don't require nor trigger the entire message processing circus.
+ */
+ final BiConsumer<MessageIn, Integer> messageConsumer;
+
+ /**
+ * Captures the current {@link State} of processing a message. Primarily useful
in the non-blocking use case.
+ */
+ State state = State.READ_PREFIX;
+
+ /**
+ * Captures the current data we've parsed out of in incoming message. Primarily
useful in the non-blocking use case.
+ */
+ MessageHeader messageHeader;
+
+ /**
+ * Process the buffer in a non-blocking manner. Will try to read out as much
of a message(s) as possible,
+ * and send any fully deserialized messages to {@link #messageConsumer}.
+ */
+ public abstract void process(ByteBuf in) throws IOException;
+
+ /**
+ * Process the buffer in a blocking manner. Will read as many messages as possible,
blocking for more data,
+ * and send any fully deserialized messages to {@link #messageConsumer}.
+ */
+ public abstract void process(RebufferingByteBufDataInputPlus in) throws IOException;
+
+ MessageInProcessor(InetAddressAndPort peer, int messagingVersion, BiConsumer<MessageIn,
Integer> messageConsumer)
+ {
+ this.peer = peer;
+ this.messagingVersion = messagingVersion;
+ this.messageConsumer = messageConsumer;
+ }
+
+ /**
+ * Only applicable in the non-blocking use case, and should ony be used for testing!!!
+ */
+ @VisibleForTesting
+ public MessageHeader getMessageHeader()
+ {
+ return messageHeader;
+ }
+
+ /**
+ * A simple struct to hold the message header data as it is being built up.
+ */
+ public static class MessageHeader
+ {
+ public int messageId;
+ long constructionTime;
+ public InetAddressAndPort from;
+ public MessagingService.Verb verb;
+ int payloadSize;
+
+ Map<ParameterType, Object> parameters = Collections.emptyMap();
+
+ /**
+ * Length of the parameter data. If the message's version is {@link MessagingService#VERSION_40}
or higher,
+ * this value is the total number of header bytes; else, for legacy messaging,
this is the number of
+ * key/value entries in the header.
+ */
+ int parameterLength;
+ }
+
+ MessageHeader readPrefix(DataInputPlus in) throws IOException
+ {
+ MessagingService.validateMagic(in.readInt());
+ MessageHeader messageHeader = new MessageHeader();
+ messageHeader.messageId = in.readInt();
+ int messageTimestamp = in.readInt(); // make sure to read the sent timestamp,
even if DatabaseDescriptor.hasCrossNodeTimeout() is not enabled
+ messageHeader.constructionTime = MessageIn.deriveConstructionTime(peer, messageTimestamp,
ApproximateTime.currentTimeMillis());
+
+ return messageHeader;
+ }
+ }
+
+ /**
+ * Reads the incoming stream of bytes in the 4.0 format.
+ */
+ static class MessageInProcessorAsOf40 extends MessageInProcessor
+ {
+ MessageInProcessorAsOf40(InetAddressAndPort peer, int messagingVersion, BiConsumer<MessageIn,
Integer> messageConsumer)
+ {
+ super(peer, messagingVersion, messageConsumer);
+ assert messagingVersion >= MessagingService.VERSION_40;
+ }
+
+ @SuppressWarnings("resource")
+ public void process(ByteBuf in) throws IOException
+ {
+ ByteBufDataInputPlus inputPlus = new ByteBufDataInputPlus(in);
+ while (true)
+ {
+ switch (state)
+ {
+ case READ_PREFIX:
+ if (in.readableBytes() < MessageOut.MESSAGE_PREFIX_SIZE)
+ return;
+ MessageHeader header = readPrefix(inputPlus);
+ if (header == null)
+ return;
+ header.from = peer;
+ messageHeader = header;
+ state = State.READ_VERB;
+ // fall-through
+ case READ_VERB:
+ if (in.readableBytes() < VERB_LENGTH)
+ return;
+ messageHeader.verb = MessagingService.Verb.fromId(in.readInt());
+ state = State.READ_PARAMETERS_SIZE;
+ // fall-through
+ case READ_PARAMETERS_SIZE:
+ long length = VIntCoding.readUnsignedVInt(in);
+ if (length < 0)
+ return;
+ messageHeader.parameterLength = Ints.checkedCast(length);
+ messageHeader.parameters = messageHeader.parameterLength == 0
? Collections.emptyMap() : new EnumMap<>(ParameterType.class);
+ state = State.READ_PARAMETERS_DATA;
+ // fall-through
+ case READ_PARAMETERS_DATA:
+ if (messageHeader.parameterLength > 0)
+ {
+ if (in.readableBytes() < messageHeader.parameterLength)
+ return;
+ readParameters(inputPlus, messageHeader.parameterLength,
messageHeader.parameters);
+ }
+ state = State.READ_PAYLOAD_SIZE;
+ // fall-through
+ case READ_PAYLOAD_SIZE:
+ length = VIntCoding.readUnsignedVInt(in);
+ if (length < 0)
+ return;
+ messageHeader.payloadSize = (int) length;
+ state = State.READ_PAYLOAD;
+ // fall-through
+ case READ_PAYLOAD:
+ if (in.readableBytes() < messageHeader.payloadSize)
+ return;
+
+ MessageIn<Object> messageIn = MessageIn.read(inputPlus,
messagingVersion,
+ messageHeader.messageId,
messageHeader.constructionTime, messageHeader.from,
+ messageHeader.payloadSize,
messageHeader.verb, messageHeader.parameters);
+
+ if (messageIn != null)
+ messageConsumer.accept(messageIn, messageHeader.messageId);
+
+ state = State.READ_PREFIX;
+ messageHeader = null;
+ break;
+ default:
+ throw new IllegalStateException("unknown/unhandled state: " +
state);
+ }
+ }
+ }
+
+ private void readParameters(DataInputPlus inputPlus, int parameterLength, Map<ParameterType,
Object> parameters) throws IOException
+ {
+ TrackedDataInputPlus inputTracker = new TrackedDataInputPlus(inputPlus);
+
+ while (inputTracker.getBytesRead() < parameterLength)
+ {
+ String key = DataInputStream.readUTF(inputTracker);
+ ParameterType parameterType = ParameterType.byName.get(key);
+ long valueLength = VIntCoding.readUnsignedVInt(inputTracker);
+ parameters.put(parameterType, parameterType.serializer.deserialize(inputTracker,
messagingVersion));
+ }
+ }
+
+ public void process(RebufferingByteBufDataInputPlus in) throws IOException
+ {
+ while (in.isOpen() && !in.isEmpty())
+ {
+ messageHeader = readPrefix(in);
+ messageHeader.from = peer;
+ messageHeader.verb = MessagingService.Verb.fromId(in.readInt());
+ messageHeader.parameterLength = Ints.checkedCast(VIntCoding.readUnsignedVInt(in));
+ messageHeader.parameters = messageHeader.parameterLength == 0 ? Collections.emptyMap()
: new EnumMap<>(ParameterType.class);
+ if (messageHeader.parameterLength > 0)
+ readParameters(in, messageHeader.parameterLength, messageHeader.parameters);
+
+ messageHeader.payloadSize = Ints.checkedCast(VIntCoding.readUnsignedVInt(in));
+ MessageIn<Object> messageIn = MessageIn.read(in, messagingVersion,
+ messageHeader.messageId,
messageHeader.constructionTime, messageHeader.from,
+ messageHeader.payloadSize,
messageHeader.verb, messageHeader.parameters);
+ if (messageIn != null)
+ messageConsumer.accept(messageIn, messageHeader.messageId);
+ }
+ }
+ }
+
+ /**
+ * Reads the incoming stream of bytes in the pre-4.0 format.
+ */
+ static class MessageInProcessorPre40 extends MessageInProcessor
+ {
+ private static final int PARAMETERS_SIZE_LENGTH = Integer.BYTES;
+ private static final int PARAMETERS_VALUE_SIZE_LENGTH = Integer.BYTES;
+ private static final int PAYLOAD_SIZE_LENGTH = Integer.BYTES;
+
+ MessageInProcessorPre40(InetAddressAndPort peer, int messagingVersion, BiConsumer<MessageIn,
Integer> messageConsumer)
+ {
+ super(peer, messagingVersion, messageConsumer);
+ assert messagingVersion < MessagingService.VERSION_40;
+ }
+
+ public void process(ByteBuf in) throws IOException
+ {
+ ByteBufDataInputPlus inputPlus = new ByteBufDataInputPlus(in);
+ while (true)
+ {
+ switch (state)
+ {
+ case READ_PREFIX:
+ if (in.readableBytes() < MessageOut.MESSAGE_PREFIX_SIZE)
+ return;
+ MessageHeader header = readPrefix(inputPlus);
+ if (header == null)
+ return;
+ messageHeader = header;
+ state = State.READ_IP_ADDRESS;
+ // fall-through
+ case READ_IP_ADDRESS:
+ // unfortunately, this assumes knowledge of how CompactEndpointSerializationHelper
serializes data (the first byte is the size).
+ // first, check that we can actually read the size byte, then
check if we can read that number of bytes.
+ // the "+ 1" is to make sure we have the size byte in addition
to the serialized IP addr count of bytes in the buffer.
+ int readableBytes = in.readableBytes();
+ if (readableBytes < 1 || readableBytes < in.getByte(in.readerIndex())
+ 1)
+ return;
+ messageHeader.from = CompactEndpointSerializationHelper.instance.deserialize(inputPlus,
messagingVersion);
+ state = State.READ_VERB;
+ // fall-through
+ case READ_VERB:
+ if (in.readableBytes() < VERB_LENGTH)
+ return;
+ messageHeader.verb = MessagingService.Verb.fromId(in.readInt());
+ state = State.READ_PARAMETERS_SIZE;
+ // fall-through
+ case READ_PARAMETERS_SIZE:
+ if (in.readableBytes() < PARAMETERS_SIZE_LENGTH)
+ return;
+ messageHeader.parameterLength = in.readInt();
+ messageHeader.parameters = messageHeader.parameterLength == 0
? Collections.emptyMap() : new EnumMap<>(ParameterType.class);
+ state = State.READ_PARAMETERS_DATA;
+ // fall-through
+ case READ_PARAMETERS_DATA:
+ if (messageHeader.parameterLength > 0)
+ {
+ if (!readParameters(in, inputPlus, messageHeader.parameterLength,
messageHeader.parameters))
+ return;
+ }
+ state = State.READ_PAYLOAD_SIZE;
+ // fall-through
+ case READ_PAYLOAD_SIZE:
+ if (in.readableBytes() < PAYLOAD_SIZE_LENGTH)
+ return;
+ messageHeader.payloadSize = in.readInt();
+ state = State.READ_PAYLOAD;
+ // fall-through
+ case READ_PAYLOAD:
+ if (in.readableBytes() < messageHeader.payloadSize)
+ return;
+
+ MessageIn<Object> messageIn = MessageIn.read(inputPlus,
messagingVersion,
+ messageHeader.messageId,
messageHeader.constructionTime, messageHeader.from,
+ messageHeader.payloadSize,
messageHeader.verb, messageHeader.parameters);
+
+ if (messageIn != null)
+ messageConsumer.accept(messageIn, messageHeader.messageId);
+
+ state = State.READ_PREFIX;
+ messageHeader = null;
+ break;
+ default:
+ throw new IllegalStateException("unknown/unhandled state: " +
state);
+ }
+ }
+ }
+
+ /**
+ * @return <code>true</code> if all the parameters have been read
from the {@link ByteBuf}; else, <code>false</code>.
+ */
+ private boolean readParameters(ByteBuf in, ByteBufDataInputPlus inputPlus, int
parameterCount, Map<ParameterType, Object> parameters) throws IOException
+ {
+ // makes the assumption that map.size() is a constant time function (HashMap.size()
is)
+ while (parameters.size() < parameterCount)
+ {
+ if (!canReadNextParam(in))
+ return false;
+
+ String key = DataInputStream.readUTF(inputPlus);
+ ParameterType parameterType = ParameterType.byName.get(key);
+ byte[] value = new byte[in.readInt()];
+ in.readBytes(value);
+ try (DataInputBuffer buffer = new DataInputBuffer(value))
+ {
+ parameters.put(parameterType, parameterType.serializer.deserialize(buffer,
messagingVersion));
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * Determine if we can read the next parameter from the {@link ByteBuf}. This
method will *always* set the {@code in}
+ * readIndex back to where it was when this method was invoked.
+ * <p>
+ * NOTE: this function would be sooo much simpler if we included a parameters
length int in the messaging format,
+ * instead of checking the remaining readable bytes for each field as we're parsing
it. c'est la vie ...
+ */
+ @VisibleForTesting
+ boolean canReadNextParam(ByteBuf in)
+ {
+ in.markReaderIndex();
+ // capture the readableBytes value here to avoid all the virtual function
calls.
+ // subtract 6 as we know we'll be reading a short and an int (for the utf
and value lengths).
+ final int minimumBytesRequired = 6;
+ int readableBytes = in.readableBytes() - minimumBytesRequired;
+ if (readableBytes < 0)
+ return false;
+
+ // this is a tad invasive, but since we know the UTF string is prefaced with
a 2-byte length,
+ // read that to make sure we have enough bytes to read the string itself.
+ short strLen = in.readShort();
+ // check if we can read that many bytes for the UTF
+ if (strLen > readableBytes)
+ {
+ in.resetReaderIndex();
+ return false;
+ }
+ in.skipBytes(strLen);
+ readableBytes -= strLen;
+
+ // check if we can read the value length
+ if (readableBytes < PARAMETERS_VALUE_SIZE_LENGTH)
+ {
+ in.resetReaderIndex();
+ return false;
+ }
+ int valueLength = in.readInt();
+ // check if we read that many bytes for the value
+ if (valueLength > readableBytes)
+ {
+ in.resetReaderIndex();
+ return false;
+ }
+
+ in.resetReaderIndex();
+ return true;
+ }
+
+ public void process(RebufferingByteBufDataInputPlus in) throws IOException
+ {
+ while (in.isOpen() && !in.isEmpty())
+ {
+ messageHeader = readPrefix(in);
+ messageHeader.from = CompactEndpointSerializationHelper.instance.deserialize(in,
messagingVersion);
+ messageHeader.verb = MessagingService.Verb.fromId(in.readInt());
+ messageHeader.parameterLength = in.readInt();
+ messageHeader.parameters = messageHeader.parameterLength == 0 ? Collections.emptyMap()
: new EnumMap<>(ParameterType.class);
+ if (messageHeader.parameterLength > 0)
+ readParameters(in, messageHeader.parameterLength, messageHeader.parameters);
+
+ messageHeader.payloadSize = in.readInt();
+ MessageIn<Object> messageIn = MessageIn.read(in, messagingVersion,
+ messageHeader.messageId,
messageHeader.constructionTime, messageHeader.from,
+ messageHeader.payloadSize,
messageHeader.verb, messageHeader.parameters);
+ if (messageIn != null)
+ messageConsumer.accept(messageIn, messageHeader.messageId);
+ }
+ }
+
+ private void readParameters(RebufferingByteBufDataInputPlus in, int parameterCount,
Map<ParameterType, Object> parameters) throws IOException
+ {
+ // makes the assumption that map.size() is a constant time function (HashMap.size()
is)
+ while (parameters.size() < parameterCount)
+ {
+ String key = DataInputStream.readUTF(in);
+ ParameterType parameterType = ParameterType.byName.get(key);
+ int valueLength = in.readInt();
--- End diff --
Unused variable.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: pr-unsubscribe@cassandra.apache.org
For additional commands, e-mail: pr-help@cassandra.apache.org
|