198 lines
7.0 KiB
C++
198 lines
7.0 KiB
C++
// Copyright (c) Facebook, Inc. and its affiliates.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "rsocket/statemachine/StreamsWriter.h"
|
|
|
|
#include "rsocket/RSocketStats.h"
|
|
#include "rsocket/framing/FrameSerializer.h"
|
|
|
|
namespace rsocket {
|
|
|
|
void StreamsWriterImpl::outputFrameOrEnqueue(
|
|
std::unique_ptr<folly::IOBuf> frame) {
|
|
if (shouldQueue()) {
|
|
enqueuePendingOutputFrame(std::move(frame));
|
|
} else {
|
|
outputFrame(std::move(frame));
|
|
}
|
|
}
|
|
|
|
void StreamsWriterImpl::sendPendingFrames() {
|
|
// We are free to try to send frames again. Not all frames might be sent if
|
|
// the connection breaks, the rest of them will queue up again.
|
|
auto frames = consumePendingOutputFrames();
|
|
for (auto& frame : frames) {
|
|
outputFrameOrEnqueue(std::move(frame));
|
|
}
|
|
}
|
|
|
|
void StreamsWriterImpl::enqueuePendingOutputFrame(
|
|
std::unique_ptr<folly::IOBuf> frame) {
|
|
auto const length = frame->computeChainDataLength();
|
|
stats().streamBufferChanged(1, static_cast<int64_t>(length));
|
|
pendingSize_ += length;
|
|
pendingOutputFrames_.push_back(std::move(frame));
|
|
}
|
|
|
|
std::deque<std::unique_ptr<folly::IOBuf>>
|
|
StreamsWriterImpl::consumePendingOutputFrames() {
|
|
if (auto const numFrames = pendingOutputFrames_.size()) {
|
|
stats().streamBufferChanged(
|
|
-static_cast<int64_t>(numFrames), -static_cast<int64_t>(pendingSize_));
|
|
pendingSize_ = 0;
|
|
}
|
|
return std::move(pendingOutputFrames_);
|
|
}
|
|
|
|
void StreamsWriterImpl::writeNewStream(
|
|
StreamId streamId,
|
|
StreamType streamType,
|
|
uint32_t initialRequestN,
|
|
Payload payload) {
|
|
// for simplicity, require that sent buffers don't consist of chains
|
|
writeFragmented(
|
|
[&](Payload p, FrameFlags flags) {
|
|
switch (streamType) {
|
|
case StreamType::CHANNEL:
|
|
outputFrameOrEnqueue(
|
|
serializer().serializeOut(Frame_REQUEST_CHANNEL(
|
|
streamId, flags, initialRequestN, std::move(p))));
|
|
break;
|
|
case StreamType::STREAM:
|
|
outputFrameOrEnqueue(serializer().serializeOut(Frame_REQUEST_STREAM(
|
|
streamId, flags, initialRequestN, std::move(p))));
|
|
break;
|
|
case StreamType::REQUEST_RESPONSE:
|
|
outputFrameOrEnqueue(serializer().serializeOut(
|
|
Frame_REQUEST_RESPONSE(streamId, flags, std::move(p))));
|
|
break;
|
|
case StreamType::FNF:
|
|
outputFrameOrEnqueue(serializer().serializeOut(
|
|
Frame_REQUEST_FNF(streamId, flags, std::move(p))));
|
|
break;
|
|
default:
|
|
CHECK(false) << "invalid stream type " << toString(streamType);
|
|
}
|
|
},
|
|
streamId,
|
|
FrameFlags::EMPTY_,
|
|
std::move(payload));
|
|
}
|
|
|
|
void StreamsWriterImpl::writeRequestN(Frame_REQUEST_N&& frame) {
|
|
outputFrameOrEnqueue(serializer().serializeOut(std::move(frame)));
|
|
}
|
|
|
|
void StreamsWriterImpl::writeCancel(Frame_CANCEL&& frame) {
|
|
outputFrameOrEnqueue(serializer().serializeOut(std::move(frame)));
|
|
}
|
|
|
|
void StreamsWriterImpl::writePayload(Frame_PAYLOAD&& f) {
|
|
Frame_PAYLOAD frame = std::move(f);
|
|
auto const streamId = frame.header_.streamId;
|
|
auto const initialFlags = frame.header_.flags;
|
|
|
|
writeFragmented(
|
|
[this, streamId](Payload p, FrameFlags flags) {
|
|
outputFrameOrEnqueue(serializer().serializeOut(
|
|
Frame_PAYLOAD(streamId, flags, std::move(p))));
|
|
},
|
|
streamId,
|
|
initialFlags,
|
|
std::move(frame.payload_));
|
|
}
|
|
|
|
void StreamsWriterImpl::writeError(Frame_ERROR&& frame) {
|
|
// TODO: implement fragmentation for writeError as well
|
|
outputFrameOrEnqueue(serializer().serializeOut(std::move(frame)));
|
|
}
|
|
|
|
// The max amount of user data transmitted per frame - eg the size
|
|
// of the data and metadata combined, plus the size of the frame header.
|
|
// This assumes that the frame header will never be more than 512 bytes in
|
|
// size. A CHECK in FrameTransportImpl enforces this. The idea is that
|
|
// 16M is so much larger than the ~500 bytes possibly wasted that it won't
|
|
// be noticeable (0.003% wasted at most)
|
|
constexpr size_t GENEROUS_MAX_FRAME_SIZE = 0xFFFFFF - 512;
|
|
|
|
// writeFragmented takes a `payload` and splits it up into chunks which
|
|
// are sent as fragmented requests. The first fragmented payload is
|
|
// given to writeInitialFrame, which is expected to write the initial
|
|
// "REQUEST_" or "PAYLOAD" frame of a stream or response. writeFragmented
|
|
// then writes the rest of the frames as payloads.
|
|
//
|
|
// writeInitialFrame
|
|
// - called with the payload of the first frame to send, and any additional
|
|
// flags (eg, addFlags with FOLLOWS, if there are more frames to write)
|
|
// streamId
|
|
// - The stream ID to write additional fragments with
|
|
// addFlags
|
|
// - All flags that writeInitialFrame wants to write the first frame with,
|
|
// and all flags that subsequent fragmented payloads will be sent with
|
|
// payload
|
|
// - The unsplit payload to send, possibly in multiple fragments
|
|
template <typename WriteInitialFrame>
|
|
void StreamsWriterImpl::writeFragmented(
|
|
WriteInitialFrame writeInitialFrame,
|
|
StreamId const streamId,
|
|
FrameFlags const addFlags,
|
|
Payload payload) {
|
|
folly::IOBufQueue metaQueue{folly::IOBufQueue::cacheChainLength()};
|
|
folly::IOBufQueue dataQueue{folly::IOBufQueue::cacheChainLength()};
|
|
|
|
// have to keep track of "did the full payload even have a metadata", because
|
|
// the rsocket protocol makes a distinction between a zero-length metadata
|
|
// and a null metadata.
|
|
bool const haveNonNullMeta = !!payload.metadata;
|
|
metaQueue.append(std::move(payload.metadata));
|
|
dataQueue.append(std::move(payload.data));
|
|
|
|
bool isFirstFrame = true;
|
|
|
|
while (true) {
|
|
Payload sendme;
|
|
|
|
// chew off some metadata (splitAtMost will never return a null pointer,
|
|
// safe to compute length on it always)
|
|
if (haveNonNullMeta) {
|
|
sendme.metadata = metaQueue.splitAtMost(GENEROUS_MAX_FRAME_SIZE);
|
|
DCHECK_GE(
|
|
GENEROUS_MAX_FRAME_SIZE, sendme.metadata->computeChainDataLength());
|
|
}
|
|
sendme.data = dataQueue.splitAtMost(
|
|
GENEROUS_MAX_FRAME_SIZE -
|
|
(haveNonNullMeta ? sendme.metadata->computeChainDataLength() : 0));
|
|
|
|
auto const metaLeft = metaQueue.chainLength();
|
|
auto const dataLeft = dataQueue.chainLength();
|
|
auto const moreFragments = metaLeft || dataLeft;
|
|
auto const flags =
|
|
(moreFragments ? FrameFlags::FOLLOWS : FrameFlags::EMPTY_) | addFlags;
|
|
|
|
if (isFirstFrame) {
|
|
isFirstFrame = false;
|
|
writeInitialFrame(std::move(sendme), flags);
|
|
} else {
|
|
outputFrameOrEnqueue(serializer().serializeOut(
|
|
Frame_PAYLOAD(streamId, flags, std::move(sendme))));
|
|
}
|
|
|
|
if (!moreFragments) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace rsocket
|