Extract common logic in OpenSSL-specific code
[tinc] / src / openssl / rsa.c
1 /*
2     rsa.c -- RSA key handling
3     Copyright (C) 2007-2022 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "../system.h"
21
22 #include <openssl/pem.h>
23 #include <openssl/rsa.h>
24
25 #define TINC_RSA_INTERNAL
26
27 #if OPENSSL_VERSION_MAJOR < 3
28 typedef RSA rsa_t;
29 #else
30 #include <openssl/encoder.h>
31 #include <openssl/decoder.h>
32 #include <openssl/core_names.h>
33 #include <openssl/param_build.h>
34 #include <assert.h>
35
36 typedef EVP_PKEY rsa_t;
37 #endif
38
39 #include "log.h"
40 #include "../logger.h"
41 #include "../rsa.h"
42
43 // Set RSA keys
44
45 #if OPENSSL_VERSION_MAJOR >= 3
46 static EVP_PKEY *build_rsa_key(int selection, const BIGNUM *bn_n, const BIGNUM *bn_e, const BIGNUM *bn_d) {
47         assert(bn_n);
48         assert(bn_e);
49
50         EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new_from_name(NULL, "RSA", NULL);
51
52         if(!ctx) {
53                 openssl_err("initialize key context");
54                 return NULL;
55         }
56
57         OSSL_PARAM_BLD *bld = OSSL_PARAM_BLD_new();
58         OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_N, bn_n);
59         OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_E, bn_e);
60
61         if(bn_d) {
62                 OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_D, bn_d);
63         }
64
65         OSSL_PARAM *params = OSSL_PARAM_BLD_to_param(bld);
66         EVP_PKEY *key = NULL;
67
68         bool ok = EVP_PKEY_fromdata_init(ctx) > 0
69                   && EVP_PKEY_fromdata(ctx, &key, selection, params) > 0;
70
71         OSSL_PARAM_free(params);
72         OSSL_PARAM_BLD_free(bld);
73         EVP_PKEY_CTX_free(ctx);
74
75         if(ok) {
76                 return key;
77         }
78
79         openssl_err("build key");
80         return NULL;
81 }
82 #endif
83
84 static bool hex_to_bn(BIGNUM **bn, const char *hex) {
85         return (size_t)BN_hex2bn(bn, hex) == strlen(hex);
86 }
87
88 static rsa_t *rsa_set_hex_key(const char *n, const char *e, const char *d) {
89         rsa_t *rsa = NULL;
90         BIGNUM *bn_n = NULL;
91         BIGNUM *bn_e = NULL;
92         BIGNUM *bn_d = NULL;
93
94         if(!hex_to_bn(&bn_n, n) || !hex_to_bn(&bn_e, e) || (d && !hex_to_bn(&bn_d, d))) {
95                 goto exit;
96         }
97
98 #if OPENSSL_VERSION_MAJOR < 3
99         rsa = RSA_new();
100
101         if(rsa) {
102                 RSA_set0_key(rsa, bn_n, bn_e, bn_d);
103         }
104
105 #else
106         int selection = bn_d ? EVP_PKEY_KEYPAIR : EVP_PKEY_PUBLIC_KEY;
107         rsa = build_rsa_key(selection, bn_n, bn_e, bn_d);
108 #endif
109
110 exit:
111 #if OPENSSL_VERSION_MAJOR < 3
112
113         if(!rsa)
114 #endif
115         {
116                 BN_free(bn_d);
117                 BN_free(bn_e);
118                 BN_free(bn_n);
119         }
120
121         return rsa;
122 }
123
124 rsa_t *rsa_set_hex_public_key(const char *n, const char *e) {
125         return rsa_set_hex_key(n, e, NULL);
126 }
127
128 rsa_t *rsa_set_hex_private_key(const char *n, const char *e, const char *d) {
129         return rsa_set_hex_key(n, e, d);
130 }
131
132 // Read PEM RSA keys
133
134 #if OPENSSL_VERSION_MAJOR >= 3
135 static rsa_t *read_key_from_pem(FILE *fp, int selection) {
136         rsa_t *rsa = NULL;
137         OSSL_DECODER_CTX *ctx = OSSL_DECODER_CTX_new_for_pkey(&rsa, "PEM", NULL, "RSA", selection, NULL, NULL);
138
139         if(!ctx) {
140                 openssl_err("initialize decoder");
141                 return NULL;
142         }
143
144         bool ok = OSSL_DECODER_from_fp(ctx, fp);
145         OSSL_DECODER_CTX_free(ctx);
146
147         if(!ok) {
148                 rsa = NULL;
149                 openssl_err("read RSA key from file");
150         }
151
152         return rsa;
153 }
154 #endif
155
156 rsa_t *rsa_read_pem_public_key(FILE *fp) {
157         rsa_t *rsa;
158
159 #if OPENSSL_VERSION_MAJOR < 3
160         rsa = PEM_read_RSAPublicKey(fp, NULL, NULL, NULL);
161
162         if(!rsa) {
163                 rewind(fp);
164                 rsa = PEM_read_RSA_PUBKEY(fp, NULL, NULL, NULL);
165         }
166
167 #else
168         rsa = read_key_from_pem(fp, OSSL_KEYMGMT_SELECT_PUBLIC_KEY);
169 #endif
170
171         if(!rsa) {
172                 openssl_err("read RSA public key");
173         }
174
175         return rsa;
176 }
177
178 rsa_t *rsa_read_pem_private_key(FILE *fp) {
179         rsa_t *rsa;
180
181 #if OPENSSL_VERSION_MAJOR < 3
182         rsa = PEM_read_RSAPrivateKey(fp, NULL, NULL, NULL);
183 #else
184         rsa = read_key_from_pem(fp, OSSL_KEYMGMT_SELECT_PRIVATE_KEY);
185 #endif
186
187         if(!rsa) {
188                 openssl_err("read RSA private key");
189         }
190
191         return rsa;
192 }
193
194 size_t rsa_size(const rsa_t *rsa) {
195 #if OPENSSL_VERSION_MAJOR < 3
196         return RSA_size(rsa);
197 #else
198         return EVP_PKEY_get_size(rsa);
199 #endif
200 }
201
202 #if OPENSSL_VERSION_MAJOR >= 3
203 // Initialize encryption or decryption context. Must return >0 on success, ≤0 on failure.
204 typedef int (enc_init_t)(EVP_PKEY_CTX *ctx);
205
206 // Encrypt or decrypt data. Must return >0 on success, ≤0 on failure.
207 typedef int (enc_process_t)(EVP_PKEY_CTX *ctx, unsigned char *out, size_t *outlen, const unsigned char *in, size_t inlen);
208
209 static bool rsa_encrypt_decrypt(rsa_t *rsa, const void *in, size_t len, void *out,
210                                 enc_init_t init, enc_process_t process) {
211         EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(rsa, NULL);
212
213         if(ctx) {
214                 size_t outlen = len;
215
216                 bool ok = init(ctx) > 0
217                           && EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_NO_PADDING) > 0
218                           && process(ctx, out, &outlen, in, len) > 0
219                           && outlen == len;
220
221                 EVP_PKEY_CTX_free(ctx);
222
223                 if(ok) {
224                         return true;
225                 }
226         }
227
228         return false;
229 }
230 #endif
231
232 bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
233 #if OPENSSL_VERSION_MAJOR < 3
234
235         if((size_t)RSA_public_encrypt((int) len, in, out, rsa, RSA_NO_PADDING) == len) {
236 #else
237
238         if(rsa_encrypt_decrypt(rsa, in, len, out, EVP_PKEY_encrypt_init, EVP_PKEY_encrypt)) {
239 #endif
240                 return true;
241         }
242
243         openssl_err("perform RSA encryption");
244         return false;
245 }
246
247 bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
248 #if OPENSSL_VERSION_MAJOR < 3
249
250         if((size_t)RSA_private_decrypt((int) len, in, out, rsa, RSA_NO_PADDING) == len) {
251 #else
252
253         if(rsa_encrypt_decrypt(rsa, in, len, out, EVP_PKEY_decrypt_init, EVP_PKEY_decrypt)) {
254 #endif
255                 return true;
256         }
257
258         openssl_err("perform RSA decryption");
259         return false;
260 }
261
262 void rsa_free(rsa_t *rsa) {
263         if(rsa) {
264 #if OPENSSL_VERSION_MAJOR < 3
265                 RSA_free(rsa);
266 #else
267                 EVP_PKEY_free(rsa);
268 #endif
269         }
270 }