vn-verdnaturachat/ios/Pods/Flipper-Folly/folly/io/async/SSLContext.cpp

664 lines
19 KiB
C++
Raw Normal View History

/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/async/SSLContext.h>
#include <folly/Format.h>
#include <folly/Memory.h>
#include <folly/Random.h>
#include <folly/SharedMutex.h>
#include <folly/SpinLock.h>
#include <folly/ssl/Init.h>
#include <folly/system/ThreadId.h>
// ---------------------------------------------------------------------
// SSLContext implementation
// ---------------------------------------------------------------------
namespace folly {
//
// For OpenSSL portability API
// SSLContext implementation
SSLContext::SSLContext(SSLVersion version) {
folly::ssl::init();
ctx_ = SSL_CTX_new(SSLv23_method());
if (ctx_ == nullptr) {
throw std::runtime_error("SSL_CTX_new: " + getErrors());
}
int opt = 0;
switch (version) {
case TLSv1:
opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
break;
case SSLv3:
opt = SSL_OP_NO_SSLv2;
break;
case TLSv1_2:
opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
SSL_OP_NO_TLSv1_1;
break;
case SSLv2:
default:
// do nothing
break;
}
// Disable TLS 1.3 by default, for now, if this version of OpenSSL
// supports it. There are some semantic differences (e.g. assumptions
// on getSession() returning a resumable session, SSL_CTX_set_ciphersuites,
// etc.)
//
#if FOLLY_OPENSSL_IS_110
SSL_CTX_set_max_proto_version(ctx_, TLS1_2_VERSION);
#endif
int newOpt = SSL_CTX_set_options(ctx_, opt);
DCHECK((newOpt & opt) == opt);
SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
checkPeerName_ = false;
SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
sslAcceptRunner_ = std::make_unique<SSLAcceptRunner>();
#if FOLLY_OPENSSL_HAS_SNI
SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this);
#endif
}
SSLContext::~SSLContext() {
if (ctx_ != nullptr) {
SSL_CTX_free(ctx_);
ctx_ = nullptr;
}
#if FOLLY_OPENSSL_HAS_ALPN
deleteNextProtocolsStrings();
#endif
}
void SSLContext::ciphers(const std::string& ciphers) {
setCiphersOrThrow(ciphers);
}
void SSLContext::setClientECCurvesList(
const std::vector<std::string>& ecCurves) {
if (ecCurves.empty()) {
return;
}
#if OPENSSL_VERSION_NUMBER >= 0x1000200fL
std::string ecCurvesList;
join(":", ecCurves, ecCurvesList);
int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
if (rc == 0) {
throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
}
#endif
}
void SSLContext::setServerECCurve(const std::string& curveName) {
#if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
EC_KEY* ecdh = nullptr;
int nid;
/*
* Elliptic-Curve Diffie-Hellman parameters are either "named curves"
* from RFC 4492 section 5.1.1, or explicitly described curves over
* binary fields. OpenSSL only supports the "named curves", which provide
* maximum interoperability.
*/
nid = OBJ_sn2nid(curveName.c_str());
if (nid == 0) {
LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
}
ecdh = EC_KEY_new_by_curve_name(nid);
if (ecdh == nullptr) {
LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
}
SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
EC_KEY_free(ecdh);
#else
throw std::runtime_error("Elliptic curve encryption not allowed");
#endif
}
SSLContext::SSLContext(SSL_CTX* ctx) : ctx_(ctx) {
if (SSL_CTX_up_ref(ctx) == 0) {
throw std::runtime_error("Failed to increment SSL_CTX refcount");
}
}
void SSLContext::setX509VerifyParam(
const ssl::X509VerifyParam& x509VerifyParam) {
if (!x509VerifyParam) {
return;
}
if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
}
}
void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
if (rc == 0) {
throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
}
providedCiphersString_ = ciphers;
}
void SSLContext::setVerificationOption(
const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
verifyPeer_ = verifyPeer;
}
int SSLContext::getVerificationMode(
const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
int mode = SSL_VERIFY_NONE;
switch (verifyPeer) {
// case SSLVerifyPeerEnum::USE_CTX: // can't happen
// break;
case SSLVerifyPeerEnum::VERIFY:
mode = SSL_VERIFY_PEER;
break;
case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
break;
case SSLVerifyPeerEnum::NO_VERIFY:
mode = SSL_VERIFY_NONE;
break;
case SSLVerifyPeerEnum::USE_CTX:
default:
break;
}
return mode;
}
int SSLContext::getVerificationMode() {
return getVerificationMode(verifyPeer_);
}
void SSLContext::authenticate(
bool checkPeerCert,
bool checkPeerName,
const std::string& peerName) {
int mode;
if (checkPeerCert) {
mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
SSL_VERIFY_CLIENT_ONCE;
checkPeerName_ = checkPeerName;
peerFixedName_ = peerName;
} else {
mode = SSL_VERIFY_NONE;
checkPeerName_ = false; // can't check name without cert!
peerFixedName_.clear();
}
SSL_CTX_set_verify(ctx_, mode, nullptr);
}
void SSLContext::loadCertificate(const char* path, const char* format) {
if (path == nullptr || format == nullptr) {
throw std::invalid_argument(
"loadCertificateChain: either <path> or <format> is nullptr");
}
if (strcmp(format, "PEM") == 0) {
if (SSL_CTX_use_certificate_chain_file(ctx_, path) != 1) {
int errnoCopy = errno;
std::string reason("SSL_CTX_use_certificate_chain_file: ");
reason.append(path);
reason.append(": ");
reason.append(getErrors(errnoCopy));
throw std::runtime_error(reason);
}
} else {
throw std::runtime_error(
"Unsupported certificate format: " + std::string(format));
}
}
void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
if (cert.data() == nullptr) {
throw std::invalid_argument("loadCertificate: <cert> is nullptr");
}
ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
if (bio == nullptr) {
throw std::runtime_error("BIO_new: " + getErrors());
}
int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
throw std::runtime_error("BIO_write: " + getErrors());
}
ssl::X509UniquePtr x509(
PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
if (x509 == nullptr) {
throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
}
if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
}
// Any further X509 PEM blocks are treated as additional certificates in
// the certificate chain.
constexpr size_t kMaxCertChain = 64;
for (size_t i = 0; i < kMaxCertChain; i++) {
x509.reset(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
if (x509 == nullptr) {
ERR_clear_error();
return;
}
if (SSL_CTX_add1_chain_cert(ctx_, x509.get()) == 0) {
throw std::runtime_error("SSL_CTX_add0_chain_cert: " + getErrors());
}
}
throw std::runtime_error(
"loadCertificateFromBufferPEM(): Too many certificates in chain");
}
void SSLContext::loadPrivateKey(const char* path, const char* format) {
if (path == nullptr || format == nullptr) {
throw std::invalid_argument(
"loadPrivateKey: either <path> or <format> is nullptr");
}
if (strcmp(format, "PEM") == 0) {
if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
}
} else {
throw std::runtime_error(
"Unsupported private key format: " + std::string(format));
}
}
void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
if (pkey.data() == nullptr) {
throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
}
ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
if (bio == nullptr) {
throw std::runtime_error("BIO_new: " + getErrors());
}
int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
throw std::runtime_error("BIO_write: " + getErrors());
}
ssl::EvpPkeyUniquePtr key(
PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
if (key == nullptr) {
throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
}
if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
}
}
void SSLContext::loadCertKeyPairFromBufferPEM(
folly::StringPiece cert,
folly::StringPiece pkey) {
loadCertificateFromBufferPEM(cert);
loadPrivateKeyFromBufferPEM(pkey);
if (!isCertKeyPairValid()) {
throw std::runtime_error("SSL certificate and private key do not match");
}
}
void SSLContext::loadCertKeyPairFromFiles(
const char* certPath,
const char* keyPath,
const char* certFormat,
const char* keyFormat) {
loadCertificate(certPath, certFormat);
loadPrivateKey(keyPath, keyFormat);
if (!isCertKeyPairValid()) {
throw std::runtime_error("SSL certificate and private key do not match");
}
}
bool SSLContext::isCertKeyPairValid() const {
return SSL_CTX_check_private_key(ctx_) == 1;
}
void SSLContext::loadTrustedCertificates(const char* path) {
if (path == nullptr) {
throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
}
if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
}
ERR_clear_error();
}
void SSLContext::loadTrustedCertificates(X509_STORE* store) {
SSL_CTX_set_cert_store(ctx_, store);
}
void SSLContext::loadClientCAList(const char* path) {
auto clientCAs = SSL_load_client_CA_file(path);
if (clientCAs == nullptr) {
LOG(ERROR) << "Unable to load ca file: " << path << " " << getErrors();
return;
}
SSL_CTX_set_client_CA_list(ctx_, clientCAs);
}
void SSLContext::passwordCollector(
std::shared_ptr<PasswordCollector> collector) {
if (collector == nullptr) {
LOG(ERROR) << "passwordCollector: ignore invalid password collector";
return;
}
collector_ = collector;
SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
}
#if FOLLY_OPENSSL_HAS_SNI
void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
serverNameCb_ = cb;
}
void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
clientHelloCbs_.push_back(cb);
}
int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
auto context = (SSLContext*)data;
if (context == nullptr) {
return SSL_TLSEXT_ERR_NOACK;
}
for (auto& cb : context->clientHelloCbs_) {
// Generic callbacks to happen after we receive the Client Hello.
// For example, we use one to switch which cipher we use depending
// on the user's TLS version. Because the primary purpose of
// baseServerNameOpenSSLCallback is for SNI support, and these callbacks
// are side-uses, we ignore any possible failures other than just logging
// them.
cb(ssl);
}
if (!context->serverNameCb_) {
return SSL_TLSEXT_ERR_NOACK;
}
ServerNameCallbackResult ret = context->serverNameCb_(ssl);
switch (ret) {
case SERVER_NAME_FOUND:
return SSL_TLSEXT_ERR_OK;
case SERVER_NAME_NOT_FOUND:
return SSL_TLSEXT_ERR_NOACK;
case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
*al = TLS1_AD_UNRECOGNIZED_NAME;
return SSL_TLSEXT_ERR_ALERT_FATAL;
default:
CHECK(false);
}
return SSL_TLSEXT_ERR_NOACK;
}
#endif // FOLLY_OPENSSL_HAS_SNI
#if FOLLY_OPENSSL_HAS_ALPN
int SSLContext::alpnSelectCallback(
SSL* /* ssl */,
const unsigned char** out,
unsigned char* outlen,
const unsigned char* in,
unsigned int inlen,
void* data) {
auto context = (SSLContext*)data;
CHECK(context);
if (context->advertisedNextProtocols_.empty()) {
*out = nullptr;
*outlen = 0;
} else {
auto i = context->pickNextProtocols();
const auto& item = context->advertisedNextProtocols_[i];
if (SSL_select_next_proto(
(unsigned char**)out,
outlen,
item.protocols,
item.length,
in,
inlen) != OPENSSL_NPN_NEGOTIATED) {
return SSL_TLSEXT_ERR_NOACK;
}
}
return SSL_TLSEXT_ERR_OK;
}
bool SSLContext::setAdvertisedNextProtocols(
const std::list<std::string>& protocols) {
return setRandomizedAdvertisedNextProtocols({{1, protocols}});
}
bool SSLContext::setRandomizedAdvertisedNextProtocols(
const std::list<NextProtocolsItem>& items) {
unsetNextProtocols();
if (items.empty()) {
return false;
}
int total_weight = 0;
for (const auto& item : items) {
if (item.protocols.empty()) {
continue;
}
AdvertisedNextProtocolsItem advertised_item;
advertised_item.length = 0;
for (const auto& proto : item.protocols) {
++advertised_item.length;
auto protoLength = proto.length();
if (protoLength >= 256) {
deleteNextProtocolsStrings();
return false;
}
advertised_item.length += unsigned(protoLength);
}
advertised_item.protocols = new unsigned char[advertised_item.length];
if (!advertised_item.protocols) {
throw std::runtime_error("alloc failure");
}
unsigned char* dst = advertised_item.protocols;
for (auto& proto : item.protocols) {
auto protoLength = uint8_t(proto.length());
*dst++ = (unsigned char)protoLength;
memcpy(dst, proto.data(), protoLength);
dst += protoLength;
}
total_weight += item.weight;
advertisedNextProtocols_.push_back(advertised_item);
advertisedNextProtocolWeights_.push_back(item.weight);
}
if (total_weight == 0) {
deleteNextProtocolsStrings();
return false;
}
nextProtocolDistribution_ = std::discrete_distribution<>(
advertisedNextProtocolWeights_.begin(),
advertisedNextProtocolWeights_.end());
SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
// Client cannot really use randomized alpn
// Note that this function reverses the typical return value convention
// of openssl and returns 0 on success.
return SSL_CTX_set_alpn_protos(
ctx_,
advertisedNextProtocols_[0].protocols,
advertisedNextProtocols_[0].length) == 0;
}
void SSLContext::deleteNextProtocolsStrings() {
for (auto protocols : advertisedNextProtocols_) {
delete[] protocols.protocols;
}
advertisedNextProtocols_.clear();
advertisedNextProtocolWeights_.clear();
}
void SSLContext::unsetNextProtocols() {
deleteNextProtocolsStrings();
SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
// clear the error stack here since openssl internals sometimes add a
// malloc failure when doing a memdup of NULL, 0..
ERR_clear_error();
}
size_t SSLContext::pickNextProtocols() {
CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
auto rng = ThreadLocalPRNG();
return size_t(nextProtocolDistribution_(rng));
}
#endif // FOLLY_OPENSSL_HAS_ALPN
SSL* SSLContext::createSSL() const {
SSL* ssl = SSL_new(ctx_);
if (ssl == nullptr) {
throw std::runtime_error("SSL_new: " + getErrors());
}
return ssl;
}
void SSLContext::setSessionCacheContext(const std::string& context) {
SSL_CTX_set_session_id_context(
ctx_,
reinterpret_cast<const unsigned char*>(context.data()),
std::min<unsigned int>(
static_cast<unsigned int>(context.length()), SSL_MAX_SID_CTX_LENGTH));
}
/**
* Match a name with a pattern. The pattern may include wildcard. A single
* wildcard "*" can match up to one component in the domain name.
*
* @param host Host name, typically the name of the remote host
* @param pattern Name retrieved from certificate
* @param size Size of "pattern"
* @return True, if "host" matches "pattern". False otherwise.
*/
bool SSLContext::matchName(const char* host, const char* pattern, int size) {
bool match = false;
int i = 0, j = 0;
while (i < size && host[j] != '\0') {
if (toupper(pattern[i]) == toupper(host[j])) {
i++;
j++;
continue;
}
if (pattern[i] == '*') {
while (host[j] != '.' && host[j] != '\0') {
j++;
}
i++;
continue;
}
break;
}
if (i == size && host[j] == '\0') {
match = true;
}
return match;
}
int SSLContext::passwordCallback(char* password, int size, int, void* data) {
auto context = (SSLContext*)data;
if (context == nullptr || context->passwordCollector() == nullptr) {
return 0;
}
std::string userPassword;
// call user defined password collector to get password
context->passwordCollector()->getPassword(userPassword, size);
auto const length = std::min(userPassword.size(), size_t(size));
std::memcpy(password, userPassword.data(), length);
return int(length);
}
#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
void SSLContext::enableFalseStart() {
SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
}
#endif
void SSLContext::initializeOpenSSL() {
folly::ssl::init();
}
void SSLContext::setOptions(long options) {
long newOpt = SSL_CTX_set_options(ctx_, options);
if ((newOpt & options) != options) {
throw std::runtime_error("SSL_CTX_set_options failed");
}
}
std::string SSLContext::getErrors(int errnoCopy) {
std::string errors;
unsigned long errorCode;
char message[256];
errors.reserve(512);
while ((errorCode = ERR_get_error()) != 0) {
if (!errors.empty()) {
errors += "; ";
}
const char* reason = ERR_reason_error_string(errorCode);
if (reason == nullptr) {
snprintf(message, sizeof(message) - 1, "SSL error # %08lX", errorCode);
reason = message;
}
errors += reason;
}
if (errors.empty()) {
errors = "error code: " + folly::to<std::string>(errnoCopy);
}
return errors;
}
void SSLContext::enableTLS13() {
#if FOLLY_OPENSSL_IS_110
SSL_CTX_set_max_proto_version(ctx_, 0);
#endif
}
std::ostream& operator<<(std::ostream& os, const PasswordCollector& collector) {
os << collector.describe();
return os;
}
} // namespace folly