// Copyright (c) Facebook, Inc. and its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "rsocket/statemachine/RSocketStateMachine.h" #include #include #include #include #include #include #include "rsocket/DuplexConnection.h" #include "rsocket/RSocketConnectionEvents.h" #include "rsocket/RSocketParameters.h" #include "rsocket/RSocketResponder.h" #include "rsocket/RSocketStats.h" #include "rsocket/framing/Frame.h" #include "rsocket/framing/FrameSerializer.h" #include "rsocket/framing/FrameTransportImpl.h" #include "rsocket/internal/ClientResumeStatusCallback.h" #include "rsocket/internal/ScheduledSubscriber.h" #include "rsocket/internal/WarmResumeManager.h" #include "rsocket/statemachine/ChannelRequester.h" #include "rsocket/statemachine/ChannelResponder.h" #include "rsocket/statemachine/FireAndForgetResponder.h" #include "rsocket/statemachine/RequestResponseRequester.h" #include "rsocket/statemachine/RequestResponseResponder.h" #include "rsocket/statemachine/StreamRequester.h" #include "rsocket/statemachine/StreamResponder.h" #include "rsocket/statemachine/StreamStateMachineBase.h" #include "yarpl/flowable/Subscription.h" #include "yarpl/single/SingleSubscriptions.h" namespace rsocket { namespace { void disconnectError( std::shared_ptr> subscriber) { std::runtime_error exn{"RSocket connection is disconnected or closed"}; subscriber->onSubscribe(yarpl::flowable::Subscription::create()); subscriber->onError(std::move(exn)); } void disconnectError( std::shared_ptr> observer) { auto exn = folly::make_exception_wrapper( "RSocket connection is disconnected or closed"); observer->onSubscribe(yarpl::single::SingleSubscriptions::empty()); observer->onError(std::move(exn)); } } // namespace RSocketStateMachine::RSocketStateMachine( std::shared_ptr requestResponder, std::unique_ptr keepaliveTimer, RSocketMode mode, std::shared_ptr stats, std::shared_ptr connectionEvents, std::shared_ptr resumeManager, std::shared_ptr coldResumeHandler) : RSocketStateMachine( std::make_shared( std::move(requestResponder)), std::move(keepaliveTimer), mode, std::move(stats), std::move(connectionEvents), std::move(resumeManager), std::move(coldResumeHandler)) {} RSocketStateMachine::RSocketStateMachine( std::shared_ptr requestResponder, std::unique_ptr keepaliveTimer, RSocketMode mode, std::shared_ptr stats, std::shared_ptr connectionEvents, std::shared_ptr resumeManager, std::shared_ptr coldResumeHandler) : mode_{mode}, stats_{stats ? stats : RSocketStats::noop()}, // Streams initiated by a client MUST use odd-numbered and streams // initiated by the server MUST use even-numbered stream identifiers nextStreamId_(mode == RSocketMode::CLIENT ? 1 : 2), resumeManager_(std::move(resumeManager)), requestResponder_{std::move(requestResponder)}, keepaliveTimer_{std::move(keepaliveTimer)}, coldResumeHandler_{std::move(coldResumeHandler)}, connectionEvents_{connectionEvents} { CHECK(resumeManager_) << "provide ResumeManager::makeEmpty() instead of nullptr"; // We deliberately do not "open" input or output to avoid having c'tor on the // stack when processing any signals from the connection. See ::connect and // ::onSubscribe. CHECK(requestResponder_); stats_->socketCreated(); VLOG(2) << "Creating RSocketStateMachine"; } RSocketStateMachine::~RSocketStateMachine() { // this destructor can be called from a different thread because the stream // automatons destroyed on different threads can be the last ones referencing // this. VLOG(3) << "~RSocketStateMachine"; // We rely on SubscriptionPtr and SubscriberPtr to dispatch appropriate // terminal signals. DCHECK(!resumeCallback_); DCHECK(isDisconnected()); // the instance should be closed by via // close method } void RSocketStateMachine::setResumable(bool resumable) { // We should set this flag before we are connected DCHECK(isDisconnected()); isResumable_ = resumable; } void RSocketStateMachine::connectServer( std::shared_ptr frameTransport, const SetupParameters& setupParams) { setResumable(setupParams.resumable); setProtocolVersionOrThrow(setupParams.protocolVersion, frameTransport); connect(std::move(frameTransport)); sendPendingFrames(); } bool RSocketStateMachine::resumeServer( std::shared_ptr frameTransport, const ResumeParameters& resumeParams) { const folly::Optional clientAvailable = (resumeParams.clientPosition == kUnspecifiedResumePosition) ? folly::none : folly::make_optional( resumeManager_->impliedPosition() - resumeParams.clientPosition); const int64_t serverAvailable = resumeManager_->lastSentPosition() - resumeManager_->firstSentPosition(); const int64_t serverDelta = resumeManager_->lastSentPosition() - resumeParams.serverPosition; if (frameTransport) { stats_->socketDisconnected(); } closeFrameTransport( std::runtime_error{"Connection being resumed, dropping old connection"}); setProtocolVersionOrThrow(resumeParams.protocolVersion, frameTransport); connect(std::move(frameTransport)); const auto result = resumeFromPositionOrClose( resumeParams.serverPosition, resumeParams.clientPosition); stats_->serverResume( clientAvailable, serverAvailable, serverDelta, result ? RSocketStats::ResumeOutcome::SUCCESS : RSocketStats::ResumeOutcome::FAILURE); return result; } void RSocketStateMachine::connectClient( std::shared_ptr transport, SetupParameters params) { auto const version = params.protocolVersion == ProtocolVersion::Unknown ? ProtocolVersion::Latest : params.protocolVersion; setProtocolVersionOrThrow(version, transport); setResumable(params.resumable); Frame_SETUP frame( (params.resumable ? FrameFlags::RESUME_ENABLE : FrameFlags::EMPTY_) | (params.payload.metadata ? FrameFlags::METADATA : FrameFlags::EMPTY_), version.major, version.minor, getKeepaliveTime(), Frame_SETUP::kMaxLifetime, std::move(params.token), std::move(params.metadataMimeType), std::move(params.dataMimeType), std::move(params.payload)); // TODO: when the server returns back that it doesn't support resumability, we // should retry without resumability VLOG(3) << "Out: " << frame; connect(std::move(transport)); // making sure we send setup frame first outputFrame(frameSerializer_->serializeOut(std::move(frame))); // then the rest of the cached frames will be sent sendPendingFrames(); } void RSocketStateMachine::resumeClient( ResumeIdentificationToken token, std::shared_ptr transport, std::unique_ptr resumeCallback, ProtocolVersion version) { // Cold-resumption. Set the serializer. if (!frameSerializer_) { CHECK(coldResumeHandler_); coldResumeInProgress_ = true; } setProtocolVersionOrThrow( version == ProtocolVersion::Unknown ? ProtocolVersion::Latest : version, transport); Frame_RESUME resumeFrame( std::move(token), resumeManager_->impliedPosition(), resumeManager_->firstSentPosition(), frameSerializer_->protocolVersion()); VLOG(3) << "Out: " << resumeFrame; // Disconnect a previous client if there is one. disconnect(std::runtime_error{"Resuming client on a different connection"}); setResumable(true); reconnect(std::move(transport), std::move(resumeCallback)); outputFrame(frameSerializer_->serializeOut(std::move(resumeFrame))); } void RSocketStateMachine::connect(std::shared_ptr transport) { VLOG(2) << "Connecting to transport " << transport.get(); CHECK(isDisconnected()); CHECK(transport); // Keep a reference to the argument, make sure the instance survives until // setFrameProcessor() returns. There can be terminating signals processed in // that call which will nullify frameTransport_. frameTransport_ = transport; CHECK(frameSerializer_); frameSerializer_->preallocateFrameSizeField() = transport->isConnectionFramed(); if (connectionEvents_) { connectionEvents_->onConnected(); } // Keep a reference to stats, as processing frames might close this instance. auto const stats = stats_; frameTransport_->setFrameProcessor(shared_from_this()); stats->socketConnected(); } void RSocketStateMachine::sendPendingFrames() { DCHECK(!resumeCallback_); StreamsWriterImpl::sendPendingFrames(); // TODO: turn on only after setup frame was received if (keepaliveTimer_) { keepaliveTimer_->start(shared_from_this()); } } void RSocketStateMachine::disconnect(folly::exception_wrapper ex) { VLOG(2) << "Disconnecting transport"; if (isDisconnected()) { return; } if (connectionEvents_) { connectionEvents_->onDisconnected(ex); } closeFrameTransport(std::move(ex)); if (connectionEvents_) { connectionEvents_->onStreamsPaused(); } stats_->socketDisconnected(); } void RSocketStateMachine::close( folly::exception_wrapper ex, StreamCompletionSignal signal) { if (isClosed()) { return; } isClosed_ = true; stats_->socketClosed(signal); VLOG(6) << "close"; if (auto resumeCallback = std::move(resumeCallback_)) { resumeCallback->onResumeError( ConnectionException(ex ? ex.get_exception()->what() : "RS closing")); } closeStreams(signal); closeFrameTransport(ex); if (auto connectionEvents = std::move(connectionEvents_)) { connectionEvents->onClosed(std::move(ex)); } if (closeCallback_) { closeCallback_->remove(*this); } } void RSocketStateMachine::closeFrameTransport(folly::exception_wrapper ex) { if (isDisconnected()) { DCHECK(!resumeCallback_); return; } // Stop scheduling keepalives since the socket is now disconnected if (keepaliveTimer_) { keepaliveTimer_->stop(); } if (auto resumeCallback = std::move(resumeCallback_)) { resumeCallback->onResumeError(ConnectionException( ex ? ex.get_exception()->what() : "connection closing")); } // Echo the exception to the frameTransport only if the frameTransport started // closing with error. Otherwise we sent some error frame over the wire and // we are closing the transport cleanly. if (frameTransport_) { frameTransport_->close(); frameTransport_ = nullptr; } } void RSocketStateMachine::disconnectOrCloseWithError(Frame_ERROR&& errorFrame) { if (isResumable_) { std::runtime_error exn{errorFrame.payload_.moveDataToString()}; disconnect(std::move(exn)); } else { closeWithError(std::move(errorFrame)); } } void RSocketStateMachine::closeWithError(Frame_ERROR&& error) { VLOG(3) << "closeWithError " << error.payload_.data->cloneAsValue().moveToFbString(); StreamCompletionSignal signal; switch (error.errorCode_) { case ErrorCode::INVALID_SETUP: signal = StreamCompletionSignal::INVALID_SETUP; break; case ErrorCode::UNSUPPORTED_SETUP: signal = StreamCompletionSignal::UNSUPPORTED_SETUP; break; case ErrorCode::REJECTED_SETUP: signal = StreamCompletionSignal::REJECTED_SETUP; break; case ErrorCode::CONNECTION_ERROR: // StreamCompletionSignal::CONNECTION_ERROR is reserved for // frameTransport errors // ErrorCode::CONNECTION_ERROR is a normal Frame_ERROR error code which has // nothing to do with frameTransport case ErrorCode::APPLICATION_ERROR: case ErrorCode::REJECTED: case ErrorCode::RESERVED: case ErrorCode::CANCELED: case ErrorCode::INVALID: default: signal = StreamCompletionSignal::ERROR; } std::runtime_error exn{error.payload_.cloneDataToString()}; if (frameSerializer_) { outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(error))); } close(std::move(exn), signal); } void RSocketStateMachine::reconnect( std::shared_ptr newFrameTransport, std::unique_ptr resumeCallback) { CHECK(newFrameTransport); CHECK(resumeCallback); CHECK(!resumeCallback_); CHECK(isResumable_); CHECK(mode_ == RSocketMode::CLIENT); // TODO: output frame buffer should not be written to the new connection until // we receive resume ok resumeCallback_ = std::move(resumeCallback); connect(std::move(newFrameTransport)); } void RSocketStateMachine::requestStream( Payload request, std::shared_ptr> responseSink) { if (isDisconnected()) { disconnectError(std::move(responseSink)); return; } auto const streamId = getNextStreamId(); auto stateMachine = std::make_shared( shared_from_this(), streamId, std::move(request)); const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); stateMachine->subscribe(std::move(responseSink)); } std::shared_ptr> RSocketStateMachine::requestChannel( Payload request, bool hasInitialRequest, std::shared_ptr> responseSink) { if (isDisconnected()) { disconnectError(std::move(responseSink)); return nullptr; } auto const streamId = getNextStreamId(); std::shared_ptr stateMachine; if (hasInitialRequest) { stateMachine = std::make_shared( std::move(request), shared_from_this(), streamId); } else { stateMachine = std::make_shared(shared_from_this(), streamId); } const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); stateMachine->subscribe(std::move(responseSink)); return stateMachine; } void RSocketStateMachine::requestResponse( Payload request, std::shared_ptr> responseSink) { if (isDisconnected()) { disconnectError(std::move(responseSink)); return; } auto const streamId = getNextStreamId(); auto stateMachine = std::make_shared( shared_from_this(), streamId, std::move(request)); const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); stateMachine->subscribe(std::move(responseSink)); } void RSocketStateMachine::closeStreams(StreamCompletionSignal signal) { while (!streams_.empty()) { auto it = streams_.begin(); auto streamStateMachine = std::move(it->second); streams_.erase(it); streamStateMachine->endStream(signal); } } void RSocketStateMachine::processFrame(std::unique_ptr frame) { if (isClosed()) { VLOG(4) << "StateMachine has been closed. Discarding incoming frame"; return; } if (!ensureOrAutodetectFrameSerializer(*frame)) { constexpr auto msg = "Cannot detect protocol version"; closeWithError(Frame_ERROR::connectionError(msg)); return; } const auto frameType = frameSerializer_->peekFrameType(*frame); stats_->frameRead(frameType); const auto optStreamId = frameSerializer_->peekStreamId(*frame, false); if (!optStreamId) { constexpr auto msg = "Cannot decode stream ID"; closeWithError(Frame_ERROR::connectionError(msg)); return; } const auto frameLength = frame->computeChainDataLength(); const auto streamId = *optStreamId; handleFrame(streamId, frameType, std::move(frame)); resumeManager_->trackReceivedFrame( frameLength, frameType, streamId, getConsumerAllowance(streamId)); } void RSocketStateMachine::onTerminal(folly::exception_wrapper ex) { if (isResumable_) { disconnect(std::move(ex)); return; } const auto termSignal = ex ? StreamCompletionSignal::CONNECTION_ERROR : StreamCompletionSignal::CONNECTION_END; close(std::move(ex), termSignal); } void RSocketStateMachine::onKeepAliveFrame( ResumePosition resumePosition, std::unique_ptr data, bool keepAliveRespond) { resumeManager_->resetUpToPosition(resumePosition); if (mode_ == RSocketMode::SERVER) { if (keepAliveRespond) { sendKeepalive(FrameFlags::EMPTY_, std::move(data)); } else { closeWithError(Frame_ERROR::connectionError("keepalive without flag")); } } else { if (keepAliveRespond) { closeWithError(Frame_ERROR::connectionError( "client received keepalive with respond flag")); } else if (keepaliveTimer_) { keepaliveTimer_->keepaliveReceived(); } stats_->keepaliveReceived(); } } void RSocketStateMachine::onMetadataPushFrame( std::unique_ptr metadata) { requestResponder_->handleMetadataPush(std::move(metadata)); } void RSocketStateMachine::onResumeOkFrame(ResumePosition resumePosition) { if (!resumeCallback_) { constexpr auto msg = "Received RESUME_OK while not resuming"; closeWithError(Frame_ERROR::connectionError(msg)); return; } if (!resumeManager_->isPositionAvailable(resumePosition)) { auto const msg = folly::sformat( "Client cannot resume, server position {} is not available", resumePosition); closeWithError(Frame_ERROR::connectionError(msg)); return; } if (coldResumeInProgress_) { setNextStreamId(resumeManager_->getLargestUsedStreamId()); for (const auto& it : resumeManager_->getStreamResumeInfos()) { const auto streamId = it.first; const StreamResumeInfo& streamResumeInfo = it.second; if (streamResumeInfo.requester == RequestOriginator::LOCAL && streamResumeInfo.streamType == StreamType::STREAM) { auto subscriber = coldResumeHandler_->handleRequesterResumeStream( streamResumeInfo.streamToken, streamResumeInfo.consumerAllowance); auto stateMachine = std::make_shared( shared_from_this(), streamId, Payload()); // Set requested to true (since cold resumption) stateMachine->setRequested(streamResumeInfo.consumerAllowance); const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); stateMachine->subscribe( std::make_shared>( std::move(subscriber), *folly::EventBaseManager::get()->getEventBase())); } } coldResumeInProgress_ = false; } auto resumeCallback = std::move(resumeCallback_); resumeCallback->onResumeOk(); resumeFromPosition(resumePosition); } void RSocketStateMachine::onErrorFrame( StreamId streamId, ErrorCode errorCode, Payload payload) { if (streamId != 0) { if (!ensureNotInResumption()) { return; } // we ignore messages for streams which don't exist if (auto stateMachine = getStreamStateMachine(streamId)) { if (errorCode != ErrorCode::APPLICATION_ERROR) { // Encapsulate non-user errors with runtime_error, which is more // suitable for LOGging. stateMachine->handleError( std::runtime_error(payload.moveDataToString())); } else { // Don't expose user errors stateMachine->handleError(ErrorWithPayload(std::move(payload))); } } } else { // TODO: handle INVALID_SETUP, UNSUPPORTED_SETUP, REJECTED_SETUP if ((errorCode == ErrorCode::CONNECTION_ERROR || errorCode == ErrorCode::REJECTED_RESUME) && resumeCallback_) { auto resumeCallback = std::move(resumeCallback_); resumeCallback->onResumeError( ResumptionException(payload.cloneDataToString())); // fall through } close( std::runtime_error(payload.moveDataToString()), StreamCompletionSignal::ERROR); } } void RSocketStateMachine::onSetupFrame() { // this should be processed in SetupResumeAcceptor onUnexpectedFrame(0); } void RSocketStateMachine::onResumeFrame() { // this should be processed in SetupResumeAcceptor onUnexpectedFrame(0); } void RSocketStateMachine::onReservedFrame() { onUnexpectedFrame(0); } void RSocketStateMachine::onLeaseFrame() { onUnexpectedFrame(0); } void RSocketStateMachine::onExtFrame() { onUnexpectedFrame(0); } void RSocketStateMachine::onUnexpectedFrame(StreamId streamId) { auto&& msg = folly::sformat("Unexpected frame for stream {}", streamId); closeWithError(Frame_ERROR::connectionError(msg)); } void RSocketStateMachine::handleFrame( StreamId streamId, FrameType frameType, std::unique_ptr payload) { switch (frameType) { case FrameType::KEEPALIVE: { Frame_KEEPALIVE frame; if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; onKeepAliveFrame( frame.position_, std::move(frame.data_), !!(frame.header_.flags & FrameFlags::KEEPALIVE_RESPOND)); return; } case FrameType::METADATA_PUSH: { Frame_METADATA_PUSH frame; if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; onMetadataPushFrame(std::move(frame.metadata_)); return; } case FrameType::RESUME_OK: { Frame_RESUME_OK frame; if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; onResumeOkFrame(frame.position_); return; } case FrameType::ERROR: { Frame_ERROR frame; if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; onErrorFrame(streamId, frame.errorCode_, std::move(frame.payload_)); return; } case FrameType::SETUP: onSetupFrame(); return; case FrameType::RESUME: onResumeFrame(); return; case FrameType::RESERVED: onReservedFrame(); return; case FrameType::LEASE: onLeaseFrame(); return; case FrameType::REQUEST_N: { Frame_REQUEST_N frameRequestN; if (!deserializeFrameOrError(frameRequestN, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frameRequestN; onRequestNFrame(streamId, frameRequestN.requestN_); break; } case FrameType::CANCEL: { VLOG(3) << mode_ << " In: " << Frame_CANCEL(streamId); onCancelFrame(streamId); break; } case FrameType::PAYLOAD: { Frame_PAYLOAD framePayload; if (!deserializeFrameOrError(framePayload, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << framePayload; onPayloadFrame( streamId, std::move(framePayload.payload_), framePayload.header_.flagsFollows(), framePayload.header_.flagsComplete(), framePayload.header_.flagsNext()); break; } case FrameType::REQUEST_CHANNEL: { Frame_REQUEST_CHANNEL frame; if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; onRequestChannelFrame( streamId, frame.requestN_, std::move(frame.payload_), frame.header_.flagsComplete(), frame.header_.flagsNext(), frame.header_.flagsFollows()); break; } case FrameType::REQUEST_STREAM: { Frame_REQUEST_STREAM frame; if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; onRequestStreamFrame( streamId, frame.requestN_, std::move(frame.payload_), frame.header_.flagsFollows()); break; } case FrameType::REQUEST_RESPONSE: { Frame_REQUEST_RESPONSE frame; if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; onRequestResponseFrame( streamId, std::move(frame.payload_), frame.header_.flagsFollows()); break; } case FrameType::REQUEST_FNF: { Frame_REQUEST_FNF frame; if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; onFireAndForgetFrame( streamId, std::move(frame.payload_), frame.header_.flagsFollows()); break; } case FrameType::EXT: onExtFrame(); return; default: { stats_->unknownFrameReceived(); // per rsocket spec, we will ignore any other unknown frames return; } } } std::shared_ptr RSocketStateMachine::getStreamStateMachine(StreamId streamId) { const auto&& it = streams_.find(streamId); if (it == streams_.end()) { return nullptr; } // we are purposely making a copy of the reference here to avoid problems with // lifetime of the stateMachine when a terminating signal is delivered which // will cause the stateMachine to be destroyed while in one of its methods return it->second; } bool RSocketStateMachine::ensureNotInResumption() { if (resumeCallback_) { // during the time when we are resuming we are can't receive any other // than connection level frames which drives the resumption // TODO(lehecka): this assertion should be handled more elegantly using // different state machine constexpr auto msg = "Received stream frame while resuming"; LOG(ERROR) << msg; closeWithError(Frame_ERROR::connectionError(msg)); return false; } return true; } void RSocketStateMachine::onRequestNFrame( StreamId streamId, uint32_t requestN) { if (!ensureNotInResumption()) { return; } // we ignore messages for streams which don't exist if (auto stateMachine = getStreamStateMachine(streamId)) { stateMachine->handleRequestN(requestN); } } void RSocketStateMachine::onCancelFrame(StreamId streamId) { if (!ensureNotInResumption()) { return; } // we ignore messages for streams which don't exist if (auto stateMachine = getStreamStateMachine(streamId)) { stateMachine->handleCancel(); } } void RSocketStateMachine::onPayloadFrame( StreamId streamId, Payload payload, bool flagsFollows, bool flagsComplete, bool flagsNext) { if (!ensureNotInResumption()) { return; } // we ignore messages for streams which don't exist if (auto stateMachine = getStreamStateMachine(streamId)) { stateMachine->handlePayload( std::move(payload), flagsComplete, flagsNext, flagsFollows); } } void RSocketStateMachine::onRequestStreamFrame( StreamId streamId, uint32_t requestN, Payload payload, bool flagsFollows) { if (!ensureNotInResumption() || !isNewStreamId(streamId)) { return; } auto stateMachine = std::make_shared(shared_from_this(), streamId, requestN); const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); // ensured by calling isNewStreamId stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); } void RSocketStateMachine::onRequestChannelFrame( StreamId streamId, uint32_t requestN, Payload payload, bool flagsComplete, bool flagsNext, bool flagsFollows) { if (!ensureNotInResumption() || !isNewStreamId(streamId)) { return; } auto stateMachine = std::make_shared( shared_from_this(), streamId, requestN); const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); // ensured by calling isNewStreamId stateMachine->handlePayload( std::move(payload), flagsComplete, flagsNext, flagsFollows); } void RSocketStateMachine::onRequestResponseFrame( StreamId streamId, Payload payload, bool flagsFollows) { if (!ensureNotInResumption() || !isNewStreamId(streamId)) { return; } auto stateMachine = std::make_shared(shared_from_this(), streamId); const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); // ensured by calling isNewStreamId stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); } void RSocketStateMachine::onFireAndForgetFrame( StreamId streamId, Payload payload, bool flagsFollows) { if (!ensureNotInResumption() || !isNewStreamId(streamId)) { return; } auto stateMachine = std::make_shared(shared_from_this(), streamId); const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); // ensured by calling isNewStreamId stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); } bool RSocketStateMachine::isNewStreamId(StreamId streamId) { if (frameSerializer_->protocolVersion() > ProtocolVersion{0, 0} && !registerNewPeerStreamId(streamId)) { return false; } return true; } std::shared_ptr> RSocketStateMachine::onNewStreamReady( StreamId streamId, StreamType streamType, Payload payload, std::shared_ptr> response) { if (coldResumeHandler_ && streamType != StreamType::FNF) { auto streamToken = coldResumeHandler_->generateStreamToken(payload, streamId, streamType); resumeManager_->onStreamOpen( streamId, RequestOriginator::REMOTE, streamToken, streamType); } switch (streamType) { case StreamType::CHANNEL: return requestResponder_->handleRequestChannel( std::move(payload), streamId, std::move(response)); case StreamType::STREAM: requestResponder_->handleRequestStream( std::move(payload), streamId, std::move(response)); return nullptr; case StreamType::REQUEST_RESPONSE: // the other overload method should be called CHECK(false); folly::assume_unreachable(); case StreamType::FNF: requestResponder_->handleFireAndForget(std::move(payload), streamId); return nullptr; default: CHECK(false) << "unknown value: " << streamType; folly::assume_unreachable(); } } void RSocketStateMachine::onNewStreamReady( StreamId streamId, StreamType streamType, Payload payload, std::shared_ptr> response) { CHECK(streamType == StreamType::REQUEST_RESPONSE); if (coldResumeHandler_) { auto streamToken = coldResumeHandler_->generateStreamToken(payload, streamId, streamType); resumeManager_->onStreamOpen( streamId, RequestOriginator::REMOTE, streamToken, streamType); } requestResponder_->handleRequestResponse( std::move(payload), streamId, std::move(response)); } void RSocketStateMachine::sendKeepalive(std::unique_ptr data) { sendKeepalive(FrameFlags::KEEPALIVE_RESPOND, std::move(data)); } void RSocketStateMachine::sendKeepalive( FrameFlags flags, std::unique_ptr data) { Frame_KEEPALIVE pingFrame( flags, resumeManager_->impliedPosition(), std::move(data)); VLOG(3) << mode_ << " Out: " << pingFrame; outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(pingFrame))); stats_->keepaliveSent(); } bool RSocketStateMachine::isPositionAvailable(ResumePosition position) const { return resumeManager_->isPositionAvailable(position); } bool RSocketStateMachine::resumeFromPositionOrClose( ResumePosition serverPosition, ResumePosition clientPosition) { DCHECK(!resumeCallback_); DCHECK(!isDisconnected()); DCHECK(mode_ == RSocketMode::SERVER); const bool clientPositionExist = (clientPosition == kUnspecifiedResumePosition) || clientPosition <= resumeManager_->impliedPosition(); if (clientPositionExist && resumeManager_->isPositionAvailable(serverPosition)) { Frame_RESUME_OK resumeOkFrame{resumeManager_->impliedPosition()}; VLOG(3) << "Out: " << resumeOkFrame; frameTransport_->outputFrameOrDrop( frameSerializer_->serializeOut(std::move(resumeOkFrame))); resumeFromPosition(serverPosition); return true; } auto const msg = folly::to( "Cannot resume server, client lastServerPosition=", serverPosition, " firstClientPosition=", clientPosition, " is not available. Last reset position is ", resumeManager_->firstSentPosition()); closeWithError(Frame_ERROR::connectionError(msg)); return false; } void RSocketStateMachine::resumeFromPosition(ResumePosition position) { DCHECK(!resumeCallback_); DCHECK(!isDisconnected()); DCHECK(resumeManager_->isPositionAvailable(position)); if (connectionEvents_) { connectionEvents_->onStreamsResumed(); } resumeManager_->sendFramesFromPosition(position, *frameTransport_); auto frames = consumePendingOutputFrames(); for (auto& frame : frames) { outputFrameOrEnqueue(std::move(frame)); } if (!isDisconnected() && keepaliveTimer_) { keepaliveTimer_->start(shared_from_this()); } } bool RSocketStateMachine::shouldQueue() { // if we are resuming we cant send any frames until we receive RESUME_OK return isDisconnected() || resumeCallback_; } void RSocketStateMachine::fireAndForget(Payload request) { auto const streamId = getNextStreamId(); Frame_REQUEST_FNF frame{streamId, FrameFlags::EMPTY_, std::move(request)}; outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(frame))); } void RSocketStateMachine::metadataPush(std::unique_ptr metadata) { Frame_METADATA_PUSH metadataPushFrame{std::move(metadata)}; outputFrameOrEnqueue( frameSerializer_->serializeOut(std::move(metadataPushFrame))); } void RSocketStateMachine::outputFrame(std::unique_ptr frame) { DCHECK(!isDisconnected()); const auto frameType = frameSerializer_->peekFrameType(*frame); stats_->frameWritten(frameType); if (isResumable_) { auto streamIdPtr = frameSerializer_->peekStreamId(*frame, false); CHECK(streamIdPtr) << "Error in serialized frame."; resumeManager_->trackSentFrame( *frame, frameType, *streamIdPtr, getConsumerAllowance(*streamIdPtr)); } frameTransport_->outputFrameOrDrop(std::move(frame)); } uint32_t RSocketStateMachine::getKeepaliveTime() const { return keepaliveTimer_ ? static_cast(keepaliveTimer_->keepaliveTime().count()) : Frame_SETUP::kMaxKeepaliveTime; } bool RSocketStateMachine::isDisconnected() const { return !frameTransport_; } bool RSocketStateMachine::isClosed() const { return isClosed_; } void RSocketStateMachine::writeNewStream( StreamId streamId, StreamType streamType, uint32_t initialRequestN, Payload payload) { if (coldResumeHandler_ && streamType != StreamType::FNF) { const auto streamToken = coldResumeHandler_->generateStreamToken(payload, streamId, streamType); resumeManager_->onStreamOpen( streamId, RequestOriginator::LOCAL, streamToken, streamType); } StreamsWriterImpl::writeNewStream( streamId, streamType, initialRequestN, std::move(payload)); } void RSocketStateMachine::onStreamClosed(StreamId streamId) { streams_.erase(streamId); resumeManager_->onStreamClosed(streamId); } bool RSocketStateMachine::ensureOrAutodetectFrameSerializer( const folly::IOBuf& firstFrame) { if (frameSerializer_) { return true; } if (mode_ != RSocketMode::SERVER) { // this should never happen as clients are initized with FrameSerializer // instance DCHECK(false); return false; } auto serializer = FrameSerializer::createAutodetectedSerializer(firstFrame); if (!serializer) { LOG(ERROR) << "unable to detect protocol version"; return false; } VLOG(2) << "detected protocol version" << serializer->protocolVersion(); frameSerializer_ = std::move(serializer); frameSerializer_->preallocateFrameSizeField() = frameTransport_ && frameTransport_->isConnectionFramed(); return true; } size_t RSocketStateMachine::getConsumerAllowance(StreamId streamId) const { auto const it = streams_.find(streamId); return it != streams_.end() ? it->second->getConsumerAllowance() : 0; } void RSocketStateMachine::registerCloseCallback( RSocketStateMachine::CloseCallback* callback) { closeCallback_ = callback; } DuplexConnection* RSocketStateMachine::getConnection() { return frameTransport_ ? frameTransport_->getConnection() : nullptr; } void RSocketStateMachine::setProtocolVersionOrThrow( ProtocolVersion version, const std::shared_ptr& transport) { CHECK(version != ProtocolVersion::Unknown); // TODO(lehecka): this is a temporary guard to make sure the transport is // explicitly closed when exceptions are thrown. The right solution is to // automatically close duplex connection in the destructor when unique_ptr // is released auto transportGuard = folly::makeGuard([&] { transport->close(); }); if (frameSerializer_) { if (frameSerializer_->protocolVersion() != version) { // serializer is not interchangeable, it would screw up resumability throw std::runtime_error{"Protocol version mismatch"}; } } else { auto frameSerializer = FrameSerializer::createFrameSerializer(version); if (!frameSerializer) { throw std::runtime_error{"Invalid protocol version"}; } frameSerializer_ = std::move(frameSerializer); frameSerializer_->preallocateFrameSizeField() = frameTransport_ && frameTransport_->isConnectionFramed(); } transportGuard.dismiss(); } StreamId RSocketStateMachine::getNextStreamId() { constexpr auto limit = static_cast(std::numeric_limits::max() - 2); auto const streamId = nextStreamId_; if (streamId >= limit) { throw std::runtime_error{"Ran out of stream IDs"}; } CHECK_EQ(0, streams_.count(streamId)) << "Next stream ID already exists in the streams map"; nextStreamId_ += 2; return streamId; } void RSocketStateMachine::setNextStreamId(StreamId streamId) { nextStreamId_ = streamId + 2; } bool RSocketStateMachine::registerNewPeerStreamId(StreamId streamId) { DCHECK_NE(0, streamId); if (nextStreamId_ % 2 == streamId % 2) { // if this is an unknown stream to the socket and this socket is // generating such stream ids, it is an incoming frame on the stream which // no longer exist return false; } if (streamId <= lastPeerStreamId_) { // receiving frame for a stream which no longer exists return false; } lastPeerStreamId_ = streamId; return true; } bool RSocketStateMachine::hasStreams() const { return !streams_.empty(); } } // namespace rsocket