Fix crash in RSA public key wrapper.

This commit is contained in:
John Preston 2017-12-07 09:34:11 +04:00 committed by Berkus Decker
parent f74793ca3f
commit 71daae1227
2 changed files with 11 additions and 6 deletions

View File

@ -180,6 +180,9 @@ public:
const BIGNUM *raw() const { const BIGNUM *raw() const {
return _data; return _data;
} }
BIGNUM *takeRaw() {
return base::take(_data);
}
bool failed() const { bool failed() const {
return _failed; return _failed;

View File

@ -35,7 +35,7 @@ namespace {
// This is a key setter for compatibility with OpenSSL 1.0 // This is a key setter for compatibility with OpenSSL 1.0
int RSA_set0_key(RSA *r, BIGNUM *n, BIGNUM *e, BIGNUM *d) { int RSA_set0_key(RSA *r, BIGNUM *n, BIGNUM *e, BIGNUM *d) {
if ((r->n == nullptr && n == nullptr) || (r->e == nullptr && e == nullptr)) { if ((r->n == nullptr && n == nullptr) || (r->e == nullptr && e == nullptr)) {
return false; return 0;
} }
if (n != nullptr) { if (n != nullptr) {
BN_free(r->n); BN_free(r->n);
@ -49,7 +49,7 @@ int RSA_set0_key(RSA *r, BIGNUM *n, BIGNUM *e, BIGNUM *d) {
BN_free(r->d); BN_free(r->d);
r->d = d; r->d = d;
} }
return true; return 1;
} }
// This is a key getter for compatibility with OpenSSL 1.0 // This is a key getter for compatibility with OpenSSL 1.0
@ -79,10 +79,12 @@ public:
} }
Private(base::const_byte_span nBytes, base::const_byte_span eBytes) : _rsa(RSA_new()) { Private(base::const_byte_span nBytes, base::const_byte_span eBytes) : _rsa(RSA_new()) {
if (_rsa) { if (_rsa) {
BIGNUM *n = openssl::BigNum(nBytes).raw(); auto n = openssl::BigNum(nBytes).takeRaw();
BIGNUM *e = openssl::BigNum(eBytes).raw(); auto e = openssl::BigNum(eBytes).takeRaw();
RSA_set0_key(_rsa, n, e, nullptr); auto valid = (n != nullptr) && (e != nullptr);
if (!n || !e) { // We still pass both values to RSA_set0_key() so that even
// if only one of them is valid RSA would take ownership of it.
if (!RSA_set0_key(_rsa, n, e, nullptr) || !valid) {
RSA_free(base::take(_rsa)); RSA_free(base::take(_rsa));
} else { } else {
computeFingerprint(); computeFingerprint();