// Copyright (c) Facebook, Inc. and its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "rsocket/statemachine/StreamsWriter.h" #include "rsocket/RSocketStats.h" #include "rsocket/framing/FrameSerializer.h" namespace rsocket { void StreamsWriterImpl::outputFrameOrEnqueue( std::unique_ptr frame) { if (shouldQueue()) { enqueuePendingOutputFrame(std::move(frame)); } else { outputFrame(std::move(frame)); } } void StreamsWriterImpl::sendPendingFrames() { // We are free to try to send frames again. Not all frames might be sent if // the connection breaks, the rest of them will queue up again. auto frames = consumePendingOutputFrames(); for (auto& frame : frames) { outputFrameOrEnqueue(std::move(frame)); } } void StreamsWriterImpl::enqueuePendingOutputFrame( std::unique_ptr frame) { auto const length = frame->computeChainDataLength(); stats().streamBufferChanged(1, static_cast(length)); pendingSize_ += length; pendingOutputFrames_.push_back(std::move(frame)); } std::deque> StreamsWriterImpl::consumePendingOutputFrames() { if (auto const numFrames = pendingOutputFrames_.size()) { stats().streamBufferChanged( -static_cast(numFrames), -static_cast(pendingSize_)); pendingSize_ = 0; } return std::move(pendingOutputFrames_); } void StreamsWriterImpl::writeNewStream( StreamId streamId, StreamType streamType, uint32_t initialRequestN, Payload payload) { // for simplicity, require that sent buffers don't consist of chains writeFragmented( [&](Payload p, FrameFlags flags) { switch (streamType) { case StreamType::CHANNEL: outputFrameOrEnqueue( serializer().serializeOut(Frame_REQUEST_CHANNEL( streamId, flags, initialRequestN, std::move(p)))); break; case StreamType::STREAM: outputFrameOrEnqueue(serializer().serializeOut(Frame_REQUEST_STREAM( streamId, flags, initialRequestN, std::move(p)))); break; case StreamType::REQUEST_RESPONSE: outputFrameOrEnqueue(serializer().serializeOut( Frame_REQUEST_RESPONSE(streamId, flags, std::move(p)))); break; case StreamType::FNF: outputFrameOrEnqueue(serializer().serializeOut( Frame_REQUEST_FNF(streamId, flags, std::move(p)))); break; default: CHECK(false) << "invalid stream type " << toString(streamType); } }, streamId, FrameFlags::EMPTY_, std::move(payload)); } void StreamsWriterImpl::writeRequestN(Frame_REQUEST_N&& frame) { outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); } void StreamsWriterImpl::writeCancel(Frame_CANCEL&& frame) { outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); } void StreamsWriterImpl::writePayload(Frame_PAYLOAD&& f) { Frame_PAYLOAD frame = std::move(f); auto const streamId = frame.header_.streamId; auto const initialFlags = frame.header_.flags; writeFragmented( [this, streamId](Payload p, FrameFlags flags) { outputFrameOrEnqueue(serializer().serializeOut( Frame_PAYLOAD(streamId, flags, std::move(p)))); }, streamId, initialFlags, std::move(frame.payload_)); } void StreamsWriterImpl::writeError(Frame_ERROR&& frame) { // TODO: implement fragmentation for writeError as well outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); } // The max amount of user data transmitted per frame - eg the size // of the data and metadata combined, plus the size of the frame header. // This assumes that the frame header will never be more than 512 bytes in // size. A CHECK in FrameTransportImpl enforces this. The idea is that // 16M is so much larger than the ~500 bytes possibly wasted that it won't // be noticeable (0.003% wasted at most) constexpr size_t GENEROUS_MAX_FRAME_SIZE = 0xFFFFFF - 512; // writeFragmented takes a `payload` and splits it up into chunks which // are sent as fragmented requests. The first fragmented payload is // given to writeInitialFrame, which is expected to write the initial // "REQUEST_" or "PAYLOAD" frame of a stream or response. writeFragmented // then writes the rest of the frames as payloads. // // writeInitialFrame // - called with the payload of the first frame to send, and any additional // flags (eg, addFlags with FOLLOWS, if there are more frames to write) // streamId // - The stream ID to write additional fragments with // addFlags // - All flags that writeInitialFrame wants to write the first frame with, // and all flags that subsequent fragmented payloads will be sent with // payload // - The unsplit payload to send, possibly in multiple fragments template void StreamsWriterImpl::writeFragmented( WriteInitialFrame writeInitialFrame, StreamId const streamId, FrameFlags const addFlags, Payload payload) { folly::IOBufQueue metaQueue{folly::IOBufQueue::cacheChainLength()}; folly::IOBufQueue dataQueue{folly::IOBufQueue::cacheChainLength()}; // have to keep track of "did the full payload even have a metadata", because // the rsocket protocol makes a distinction between a zero-length metadata // and a null metadata. bool const haveNonNullMeta = !!payload.metadata; metaQueue.append(std::move(payload.metadata)); dataQueue.append(std::move(payload.data)); bool isFirstFrame = true; while (true) { Payload sendme; // chew off some metadata (splitAtMost will never return a null pointer, // safe to compute length on it always) if (haveNonNullMeta) { sendme.metadata = metaQueue.splitAtMost(GENEROUS_MAX_FRAME_SIZE); DCHECK_GE( GENEROUS_MAX_FRAME_SIZE, sendme.metadata->computeChainDataLength()); } sendme.data = dataQueue.splitAtMost( GENEROUS_MAX_FRAME_SIZE - (haveNonNullMeta ? sendme.metadata->computeChainDataLength() : 0)); auto const metaLeft = metaQueue.chainLength(); auto const dataLeft = dataQueue.chainLength(); auto const moreFragments = metaLeft || dataLeft; auto const flags = (moreFragments ? FrameFlags::FOLLOWS : FrameFlags::EMPTY_) | addFlags; if (isFirstFrame) { isFirstFrame = false; writeInitialFrame(std::move(sendme), flags); } else { outputFrameOrEnqueue(serializer().serializeOut( Frame_PAYLOAD(streamId, flags, std::move(sendme)))); } if (!moreFragments) { break; } } } } // namespace rsocket