diff --git a/src/lib/net/SecureSocket.cpp b/src/lib/net/SecureSocket.cpp index 39f2a1fd..286394c4 100644 --- a/src/lib/net/SecureSocket.cpp +++ b/src/lib/net/SecureSocket.cpp @@ -661,22 +661,15 @@ bool SecureSocket::verifyCertFingerprint() { // calculate received certificate fingerprint - X509 *cert = cert = SSL_get_peer_certificate(m_ssl->m_ssl); - EVP_MD* tempDigest; - unsigned char tempFingerprint[EVP_MAX_MD_SIZE]; - unsigned int tempFingerprintLen; - tempDigest = (EVP_MD*)EVP_sha1(); - int digestResult = X509_digest(cert, tempDigest, tempFingerprint, &tempFingerprintLen); - - if (digestResult <= 0) { - LOG((CLOG_ERR "failed to calculate fingerprint, digest result: %d", digestResult)); + std::vector fingerprint_raw; + try { + fingerprint_raw = barrier::get_ssl_cert_fingerprint(SSL_get_peer_certificate(m_ssl->m_ssl), + barrier::FingerprintType::SHA1); + } catch (const std::exception& e) { + LOG((CLOG_ERR "%s", e.what())); return false; } - // format fingerprint into hexdecimal format with colon separator - std::vector fingerprint_raw; - fingerprint_raw.assign(reinterpret_cast(tempFingerprint), - reinterpret_cast(tempFingerprint) + tempFingerprintLen); auto fingerprint = barrier::format_ssl_fingerprint(fingerprint_raw); LOG((CLOG_NOTE "server fingerprint: %s", fingerprint.c_str())); diff --git a/src/lib/net/SecureUtils.cpp b/src/lib/net/SecureUtils.cpp index 000c56ed..c9222432 100644 --- a/src/lib/net/SecureUtils.cpp +++ b/src/lib/net/SecureUtils.cpp @@ -18,8 +18,26 @@ #include "SecureUtils.h" #include "base/String.h" +#include +#include +#include +#include + namespace barrier { +namespace { + +const EVP_MD* get_digest_for_type(FingerprintType type) +{ + switch (type) { + case FingerprintType::SHA1: return EVP_sha1(); + case FingerprintType::SHA256: return EVP_sha256(); + } + throw std::runtime_error("Unknown fingerprint type " + std::to_string(static_cast(type))); +} + +} // namespace + std::string format_ssl_fingerprint(const std::vector& fingerprint, bool separator) { std::string result = barrier::string::to_hex(fingerprint, 2); @@ -37,4 +55,25 @@ std::string format_ssl_fingerprint(const std::vector& fingerprint, bool return result; } +std::vector get_ssl_cert_fingerprint(X509* cert, FingerprintType type) +{ + if (!cert) { + throw std::runtime_error("certificate is null"); + } + + unsigned char digest[EVP_MAX_MD_SIZE]; + unsigned int digest_length = 0; + int result = X509_digest(cert, get_digest_for_type(type), digest, &digest_length); + + if (result <= 0) { + throw std::runtime_error("failed to calculate fingerprint, digest result: " + + std::to_string(result)); + } + + std::vector digest_vec; + digest_vec.assign(reinterpret_cast(digest), + reinterpret_cast(digest) + digest_length); + return digest_vec; +} + } // namespace barrier diff --git a/src/lib/net/SecureUtils.h b/src/lib/net/SecureUtils.h index 50e944e1..a35c1db7 100644 --- a/src/lib/net/SecureUtils.h +++ b/src/lib/net/SecureUtils.h @@ -18,14 +18,23 @@ #ifndef BARRIER_LIB_NET_SECUREUTILS_H #define BARRIER_LIB_NET_SECUREUTILS_H +#include +#include #include #include namespace barrier { +enum FingerprintType { + SHA1, // deprecated + SHA256, +}; + std::string format_ssl_fingerprint(const std::vector& fingerprint, bool separator = true); +std::vector get_ssl_cert_fingerprint(X509* cert, FingerprintType type); + } // namespace barrier #endif // BARRIER_LIB_NET_SECUREUTILS_H