/* * Copyright (c) Facebook, Inc. and its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include namespace folly { namespace detail { // Internal cancellation state object. class CancellationState { public: FOLLY_NODISCARD static CancellationStateSourcePtr create(); private: // Constructed initially with a CancellationSource reference count of 1. CancellationState() noexcept; ~CancellationState(); friend struct CancellationStateTokenDeleter; friend struct CancellationStateSourceDeleter; void removeTokenReference() noexcept; void removeSourceReference() noexcept; public: FOLLY_NODISCARD CancellationStateTokenPtr addTokenReference() noexcept; FOLLY_NODISCARD CancellationStateSourcePtr addSourceReference() noexcept; bool tryAddCallback( CancellationCallback* callback, bool incrementRefCountIfSuccessful) noexcept; void removeCallback(CancellationCallback* callback) noexcept; bool isCancellationRequested() const noexcept; bool canBeCancelled() const noexcept; // Request cancellation. // Return 'true' if cancellation had already been requested. // Return 'false' if this was the first thread to request // cancellation. bool requestCancellation() noexcept; private: void lock() noexcept; void unlock() noexcept; void unlockAndIncrementTokenCount() noexcept; void unlockAndDecrementTokenCount() noexcept; bool tryLockAndCancelUnlessCancelled() noexcept; template bool tryLock(Predicate predicate) noexcept; static bool canBeCancelled(std::uint64_t state) noexcept; static bool isCancellationRequested(std::uint64_t state) noexcept; static bool isLocked(std::uint64_t state) noexcept; static constexpr std::uint64_t kCancellationRequestedFlag = 1; static constexpr std::uint64_t kLockedFlag = 2; static constexpr std::uint64_t kTokenReferenceCountIncrement = 4; static constexpr std::uint64_t kSourceReferenceCountIncrement = std::uint64_t(1) << 33u; static constexpr std::uint64_t kTokenReferenceCountMask = (kSourceReferenceCountIncrement - 1u) - (kTokenReferenceCountIncrement - 1u); static constexpr std::uint64_t kSourceReferenceCountMask = std::numeric_limits::max() - (kSourceReferenceCountIncrement - 1u); // Bit 0 - Cancellation Requested // Bit 1 - Locked Flag // Bits 2-32 - Token reference count (max ~2 billion) // Bits 33-63 - Source reference count (max ~2 billion) std::atomic state_; CancellationCallback* head_; std::thread::id signallingThreadId_; }; inline void CancellationStateTokenDeleter::operator()( CancellationState* state) noexcept { state->removeTokenReference(); } inline void CancellationStateSourceDeleter::operator()( CancellationState* state) noexcept { state->removeSourceReference(); } } // namespace detail inline CancellationToken::CancellationToken( const CancellationToken& other) noexcept : state_() { if (other.state_) { state_ = other.state_->addTokenReference(); } } inline CancellationToken::CancellationToken(CancellationToken&& other) noexcept : state_(std::move(other.state_)) {} inline CancellationToken& CancellationToken::operator=( const CancellationToken& other) noexcept { if (state_ != other.state_) { CancellationToken temp{other}; swap(temp); } return *this; } inline CancellationToken& CancellationToken::operator=( CancellationToken&& other) noexcept { state_ = std::move(other.state_); return *this; } inline bool CancellationToken::isCancellationRequested() const noexcept { return state_ != nullptr && state_->isCancellationRequested(); } inline bool CancellationToken::canBeCancelled() const noexcept { return state_ != nullptr && state_->canBeCancelled(); } inline void CancellationToken::swap(CancellationToken& other) noexcept { std::swap(state_, other.state_); } inline CancellationToken::CancellationToken( detail::CancellationStateTokenPtr state) noexcept : state_(std::move(state)) {} inline bool operator==( const CancellationToken& a, const CancellationToken& b) noexcept { return a.state_ == b.state_; } inline bool operator!=( const CancellationToken& a, const CancellationToken& b) noexcept { return !(a == b); } inline CancellationSource::CancellationSource() : state_(detail::CancellationState::create()) {} inline CancellationSource::CancellationSource( const CancellationSource& other) noexcept : state_() { if (other.state_) { state_ = other.state_->addSourceReference(); } } inline CancellationSource::CancellationSource( CancellationSource&& other) noexcept : state_(std::move(other.state_)) {} inline CancellationSource& CancellationSource::operator=( const CancellationSource& other) noexcept { if (state_ != other.state_) { CancellationSource temp{other}; swap(temp); } return *this; } inline CancellationSource& CancellationSource::operator=( CancellationSource&& other) noexcept { state_ = std::move(other.state_); return *this; } inline CancellationSource CancellationSource::invalid() noexcept { return CancellationSource{detail::CancellationStateSourcePtr{}}; } inline bool CancellationSource::isCancellationRequested() const noexcept { return state_ != nullptr && state_->isCancellationRequested(); } inline bool CancellationSource::canBeCancelled() const noexcept { return state_ != nullptr; } inline CancellationToken CancellationSource::getToken() const noexcept { if (state_ != nullptr) { return CancellationToken{state_->addTokenReference()}; } return CancellationToken{}; } inline bool CancellationSource::requestCancellation() const noexcept { if (state_ != nullptr) { return state_->requestCancellation(); } return false; } inline void CancellationSource::swap(CancellationSource& other) noexcept { std::swap(state_, other.state_); } inline CancellationSource::CancellationSource( detail::CancellationStateSourcePtr&& state) noexcept : state_(std::move(state)) {} template < typename Callable, std::enable_if_t< std::is_constructible:: value, int>> inline CancellationCallback::CancellationCallback( CancellationToken&& ct, Callable&& callable) : next_(nullptr), prevNext_(nullptr), state_(nullptr), callback_(static_cast(callable)), destructorHasRunInsideCallback_(nullptr), callbackCompleted_(false) { if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, false)) { state_ = ct.state_.release(); } } template < typename Callable, std::enable_if_t< std::is_constructible:: value, int>> inline CancellationCallback::CancellationCallback( const CancellationToken& ct, Callable&& callable) : next_(nullptr), prevNext_(nullptr), state_(nullptr), callback_(static_cast(callable)), destructorHasRunInsideCallback_(nullptr), callbackCompleted_(false) { if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, true)) { state_ = ct.state_.get(); } } inline CancellationCallback::~CancellationCallback() { if (state_ != nullptr) { state_->removeCallback(this); } } inline void CancellationCallback::invokeCallback() noexcept { // Invoke within a noexcept context so that we std::terminate() if it throws. callback_(); } namespace detail { inline CancellationStateSourcePtr CancellationState::create() { return CancellationStateSourcePtr{new CancellationState()}; } inline CancellationState::CancellationState() noexcept : state_(kSourceReferenceCountIncrement), head_(nullptr), signallingThreadId_() {} inline CancellationStateTokenPtr CancellationState::addTokenReference() noexcept { state_.fetch_add(kTokenReferenceCountIncrement, std::memory_order_relaxed); return CancellationStateTokenPtr{this}; } inline void CancellationState::removeTokenReference() noexcept { const auto oldState = state_.fetch_sub( kTokenReferenceCountIncrement, std::memory_order_acq_rel); DCHECK( (oldState & kTokenReferenceCountMask) >= kTokenReferenceCountIncrement); if (oldState < (2 * kTokenReferenceCountIncrement)) { delete this; } } inline CancellationStateSourcePtr CancellationState::addSourceReference() noexcept { state_.fetch_add(kSourceReferenceCountIncrement, std::memory_order_relaxed); return CancellationStateSourcePtr{this}; } inline void CancellationState::removeSourceReference() noexcept { const auto oldState = state_.fetch_sub( kSourceReferenceCountIncrement, std::memory_order_acq_rel); DCHECK( (oldState & kSourceReferenceCountMask) >= kSourceReferenceCountIncrement); if (oldState < (kSourceReferenceCountIncrement + kTokenReferenceCountIncrement)) { delete this; } } inline bool CancellationState::isCancellationRequested() const noexcept { return isCancellationRequested(state_.load(std::memory_order_acquire)); } inline bool CancellationState::canBeCancelled() const noexcept { return canBeCancelled(state_.load(std::memory_order_acquire)); } inline bool CancellationState::canBeCancelled(std::uint64_t state) noexcept { // Can be cancelled if there is at least one CancellationSource ref-count // or if cancellation has been requested. return (state >= kSourceReferenceCountIncrement) || isCancellationRequested(state); } inline bool CancellationState::isCancellationRequested( std::uint64_t state) noexcept { return (state & kCancellationRequestedFlag) != 0; } inline bool CancellationState::isLocked(std::uint64_t state) noexcept { return (state & kLockedFlag) != 0; } } // namespace detail } // namespace folly