]> pd.if.org Git - zpackage/blob - crypto/parse_client_hello.c
allow partial package ids in packagehash
[zpackage] / crypto / parse_client_hello.c
1 #define _POSIX_C_SOURCE 200809L
2
3 #include <fcntl.h>
4 #include <stdint.h>
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8 #include <sys/stat.h>
9
10 #include "tlse.h"
11 #include "chacha.h"
12 #include "buffer.h"
13
14 static int tls_cipher_is_fs(struct TLSContext *context, unsigned short cipher) {
15         if (!context) {
16                 return 0;
17         }
18
19         if (context->tlsver == TLS_VERSION13) {
20                 switch (cipher) {
21                         case TLS_AES_128_GCM_SHA256:
22                         case TLS_AES_256_GCM_SHA384:
23                         case TLS_CHACHA20_POLY1305_SHA256:
24                                 return 1;
25                 }
26                 return 0;
27         }
28
29         switch (cipher) {
30                 case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
31                 case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
32                 case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
33                         if (context && context->certificates
34                                         && context->certificates_count
35                                         && context->ec_private_key) {
36                                 return 1;
37                         }
38                         return 0;
39                 case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256:
40                 case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384:
41                 case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
42                 case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
43                         if (context->version == TLS_V12
44                                         || context->version == DTLS_V12) {
45                                 if (context && context->certificates
46                                                 && context->certificates_count
47                                                 && context->ec_private_key) {
48                                         return 1;
49                                 }
50                         }
51                 return 0;
52         case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
53         case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
54         case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
55         case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
56                 return 1;
57         case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256:
58         case TLS_DHE_RSA_WITH_AES_256_CBC_SHA256:
59         case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256:
60         case TLS_DHE_RSA_WITH_AES_256_GCM_SHA384:
61         case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
62         case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
63         case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
64         case TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
65         case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
66                 if (context->tlsver == TLS_VERSION12) {
67                         return 1;
68                 }
69                 break;
70         }
71         return 0;
72 }
73
74 static uint16_t get16(const unsigned char *buf) {
75         uint16_t res;
76
77         res = ((*buf) << 8) + (*(buf+1));
78         return res;
79 }
80
81 int tls_choose_cipher(struct TLSContext *context, const unsigned char *buf,
82                       int buf_len, int *scsv_set) {
83         int i;
84         if (scsv_set) {
85                 *scsv_set = 0;
86         }
87         if (!context) {
88                 return 0;
89         }
90         int selected_cipher = TLS_NO_COMMON_CIPHER;
91
92         if (selected_cipher == TLS_NO_COMMON_CIPHER) {
93                 for (i = 0; i < buf_len; i += 2) {
94                         uint16_t cipher = get16(&buf[i]);
95                         if (tls_cipher_is_fs(context, cipher)) {
96                                 selected_cipher = cipher;
97                                 break;
98                         }
99                 }
100         }
101
102         for (i = 0; i < buf_len; i += 2) {
103                 uint16_t cipher = get16(&buf[i]);
104                 if (cipher == TLS_FALLBACK_SCSV) {
105                         if (scsv_set) {
106                                 *scsv_set = 1;
107                         }
108                         if (selected_cipher != TLS_NO_COMMON_CIPHER) {
109                                 break;
110                         }
111                 }
112         }
113         return selected_cipher;
114 }
115
116 int tls_parse_client_hello(struct TLSContext *context,
117                 const unsigned char *buf, int buf_len,
118                 unsigned int *write_packets) {
119         *write_packets = 0;
120         if (context->connection_status != 0 && context->connection_status != 4) {
121                 DEBUG_PRINT("UNEXPECTED HELLO MESSAGE\n");
122                 return TLS_UNEXPECTED_MESSAGE;
123         }
124
125         int res = 0;
126         int downgraded = 0;
127
128         int hello_min_size = TLS_CLIENT_HELLO_MINSIZE;
129
130         if (buf_len < hello_min_size) {
131                 return TLS_NEED_MORE_DATA;
132         }
133
134         /* big endian */
135         int bytes_to_follow = buf[0] * 0x10000 + buf[1] * 0x100 + buf[2];
136         res += 3;
137         if (buf_len - res < bytes_to_follow) {
138                 return TLS_NEED_MORE_DATA;
139         }
140
141         if (buf_len - res < 2) {
142                 return TLS_NEED_MORE_DATA;
143         }
144
145         unsigned short version = get16(&buf[res]);
146
147         res += 2;
148         if (!tls_supported_version(version)) {
149                 return TLS_NOT_SAFE;
150         }
151         DEBUG_PRINT("VERSION REQUIRED BY REMOTE %x, VERSION NOW %x\n",
152                         (int) version, (int) context->version);
153         memcpy(context->remote_random, &buf[res], TLS_CLIENT_RANDOM_SIZE);
154         res += TLS_CLIENT_RANDOM_SIZE;
155
156         unsigned char session_len = buf[res++];
157         if (buf_len - res < session_len) return TLS_NEED_MORE_DATA;
158
159         if (session_len && session_len <= TLS_MAX_SESSION_ID) {
160                 memcpy(context->session, &buf[res], session_len);
161                 context->session_size = session_len;
162                 DEBUG_DUMP_HEX_LABEL("REMOTE SESSION ID: ",
163                                 context->session,
164                                 context->session_size);
165         } else {
166                 context->session_size = 0;
167         }
168         res += session_len;
169
170         const unsigned char *cipher_buffer = NULL;
171         unsigned short cipher_len = 0;
172         int scsv_set = 0;
173         if (context->is_server) {
174                 if (buf_len - res < 2) {
175                         return TLS_NEED_MORE_DATA;
176                 }
177
178                 cipher_len = get16(&buf[res]);
179                 res += 2;
180                 if (buf_len - res < cipher_len) return TLS_NEED_MORE_DATA;
181                 /* faster than cipher_len % 2 */
182                 /* TODO unlikely, let the compiler worry about it */
183                 if (cipher_len & 1) {
184                         return TLS_BROKEN_PACKET;
185                 }
186
187                 cipher_buffer = &buf[res];
188                 res += cipher_len;
189
190                 if (buf_len - res < 1) {
191                         return TLS_NEED_MORE_DATA;
192                 }
193
194                 unsigned char compression_list_size = buf[res++];
195                 if (buf_len - res < compression_list_size) {
196                         return TLS_NEED_MORE_DATA;
197                 }
198                 /* no compression support */
199                 res += compression_list_size;
200         } else {
201                 /* client */
202                 if (buf_len - res < 2) {
203                         return TLS_NEED_MORE_DATA;
204                 }
205
206                 unsigned short cipher = get16(&buf[res]);
207                 res += 2;
208                 context->cipher = cipher;
209                 if (!tls_cipher_supported(context, cipher)) {
210                         context->cipher = 0;
211                         DEBUG_PRINT("NO CIPHER SUPPORTED\n");
212                         return TLS_NO_COMMON_CIPHER;
213                 }
214                 DEBUG_PRINT("CIPHER: %s\n", tls_cipher_name(context));
215                 if (buf_len - res < 1) return TLS_NEED_MORE_DATA;
216                 unsigned char compression = buf[res++];
217                 if (compression != 0) {
218                         DEBUG_PRINT("COMPRESSION NOT SUPPORTED\n");
219                         return TLS_COMPRESSION_NOT_SUPPORTED;
220                 }
221         }
222
223         if (res > 0) {
224                 if (context->is_server) {
225                         *write_packets = 2;
226                 }
227                 if (context->connection_status != 4) {
228                         context->connection_status = 1;
229                 }
230         }
231
232
233         if (res > 2) {
234                 res += 2;
235         }
236         const unsigned char *key_share = NULL;
237         unsigned short key_size = 0;
238         while (buf_len - res >= 4) {
239                 /* have extensions */
240                 unsigned short extension_type = get16(&buf[res]);
241                 res += 2;
242                 unsigned short extension_len = get16(&buf[res]);
243                 res += 2;
244                 DEBUG_PRINT("Extension: 0x0%x (%i), len: %i\n",
245                             (int) extension_type, (int) extension_type,
246                             (int) extension_len);
247                 if (extension_len) {
248                         /* SNI extension */
249                         if (buf_len - res < extension_len) {
250                                 return TLS_NEED_MORE_DATA;
251                         }
252                         if (extension_type == 0x00) {
253                                 /* unsigned char sni_type = buf[res + 2]; */
254                                 uint16_t sni_host_len = get16(&buf[res+3]);
255                                 if (buf_len - res - 5 < sni_host_len) {
256                                         return TLS_NEED_MORE_DATA;
257                                 }
258
259                                 if (sni_host_len) {
260                                         free(context->sni);
261                                         context->sni = malloc(sni_host_len + 1);
262                                         if (context->sni) {
263                                                 memcpy(context->sni,
264                                                                 &buf[res + 5],
265                                                                 sni_host_len);
266                                                 context->sni[sni_host_len] = 0;
267                                                 DEBUG_PRINT("SNI HOST INDICATOR: [%s]\n", context->sni);
268                                         }
269                                 }
270                         } else if (extension_type == 0x0A) {
271                                 /* supported groups */
272                                 if (buf_len - res > 2) {
273                                         uint16_t group_len = get16(&buf[res]);
274                                         if (buf_len - res >= group_len + 2) {
275                                                 DEBUG_DUMP_HEX_LABEL
276                                                     ("SUPPORTED GROUPS",
277                                                      &buf[res + 2],
278                                                      group_len);
279                                                 int i;
280                                                 int selected = 0;
281                                                 for (i = 0; i < group_len;
282                                                      i += 2) {
283                                                         uint16_t iana_n = get16(&buf[res + 2 + i]);
284                                                         switch (iana_n) {
285                                                         case 23:
286                                                                 context->
287                                                                     curve =
288                                                                     &secp256r1;
289                                                                 selected =
290                                                                     1;
291                                                                 break;
292                                                         case 24:
293                                                                 context->
294                                                                     curve =
295                                                                     &secp384r1;
296                                                                 selected =
297                                                                     1;
298                                                                 break;
299 #if 0
300                                                                 /* needs different implementation */
301                                                                 case 29:
302                                                                 context->curve = &curve25519;
303                                                                 selected = 1;
304                                                                 break;
305 #endif
306 #if 0
307                                                                 /* do not use it anymore */
308                                                                 case 25:
309                                                                 context->curve = &secp521r1;
310                                                                 selected = 1;
311                                                                 break;
312 #endif
313                                                         }
314                                                         if (selected) {
315                                                                 DEBUG_PRINT
316                                                                     ("SELECTED CURVE %s\n",
317                                                                      context->
318                                                                      curve->
319                                                                      name);
320                                                                 break;
321                                                         }
322                                                 }
323                                         }
324                                 }
325                         } else
326                                 if (extension_type == 0x10 && context->alpn &&
327                                                 context->alpn_count) {
328                                 if (buf_len - res > 2) {
329                                         uint16_t alpn_len = get16(&buf[res]);
330                                         if (alpn_len && alpn_len <=
331                                                         extension_len - 2) {
332                                                 unsigned char *alpn =
333                                                     (unsigned char *)
334                                                     &buf[res + 2];
335                                                 int alpn_pos = 0;
336                                                 while (alpn_pos < alpn_len) {
337                                                         unsigned char
338                                                          alpn_size =
339                                                             alpn
340                                                             [alpn_pos++];
341                                                         if (alpn_size +
342                                                             alpn_pos >=
343                                                             extension_len)
344                                                                 break;
345                                                         if ((alpn_size)
346                                                             &&
347                                                             (tls_alpn_contains
348                                                              (context,
349                                                               (char *)
350                                                               &alpn
351                                                               [alpn_pos],
352                                                               alpn_size)))
353                                                         {
354                                                                 free
355                                                                     (context->
356                                                                      negotiated_alpn);
357                                                                 context->
358                                                                     negotiated_alpn
359                                                                     =
360                                                                     malloc
361                                                                     (alpn_size
362                                                                      + 1);
363                                                                 if (context->negotiated_alpn) {
364                                                                         memcpy
365                                                                             (context->
366                                                                              negotiated_alpn,
367                                                                              &alpn
368                                                                              [alpn_pos],
369                                                                              alpn_size);
370                                                                         context->
371                                                                             negotiated_alpn
372                                                                             [alpn_size]
373                                                                             =
374                                                                             0;
375                                                                         DEBUG_PRINT
376                                                                             ("NEGOTIATED ALPN: %s\n",
377                                                                              context->
378                                                                              negotiated_alpn);
379                                                                 }
380                                                                 break;
381                                                         }
382                                                         alpn_pos +=
383                                                             alpn_size;
384                                                         /* ServerHello contains just one alpn */
385                                                         if (!context->
386                                                             is_server)
387                                                                 break;
388                                                 }
389                                         }
390                                 }
391                         } else if (extension_type == 0x0D) {
392                                 /* supported signatures */
393                                 DEBUG_DUMP_HEX_LABEL
394                                     ("SUPPORTED SIGNATURES", &buf[res],
395                                      extension_len);
396                         } else if (extension_type == 0x0B) {
397                                 /* supported point formats */
398                                 DEBUG_DUMP_HEX_LABEL
399                                     ("SUPPORTED POINT FORMATS", &buf[res],
400                                      extension_len);
401                         }
402                         else if (extension_type == 0x2B) {
403                                 /* supported versions */
404                                 if ((buf[res] == extension_len - 1)
405                                     && (extension_len > 4)) {
406                                         DEBUG_DUMP_HEX_LABEL
407                                             ("SUPPORTED VERSIONS",
408                                              &buf[res], extension_len);
409                                         /* tls 1.3 draft version 28 */
410                                         int i;
411                                         int limit = (int) buf[res];
412                                         if (limit == extension_len - 1) {
413                                                 limit--;
414                                                 for (i = 1; i < limit;
415                                                      i += 2) {
416                                                         if ((get16(&buf[res + i]) == 0x7F1C)
417                                                             ||
418                                                             (get16(&buf[res + i]) == TLS_V13)) {
419                                                                 context->
420                                                                     version
421                                                                     =
422                                                                     TLS_V13;
423                                                                 context->
424                                                                     tls13_version
425                                                                     =
426                                                                     get16(&buf[res + i]);
427                                                                 DEBUG_PRINT
428                                                                     ("TLS 1.3 SUPPORTED\n");
429                                                                 break;
430                                                         }
431                                                 }
432                                         }
433                                 }
434                         } else if (extension_type == 0x2A) {
435                                 /* early data */
436                                 DEBUG_DUMP_HEX_LABEL
437                                     ("EXTENSION, EARLY DATA", &buf[res],
438                                      extension_len);
439                         } else if (extension_type == 0x29) {
440                                 /* pre shared key */
441                                 DEBUG_DUMP_HEX_LABEL
442                                     ("EXTENSION, PRE SHARED KEY",
443                                      &buf[res], extension_len);
444                         } else if (extension_type == 0x33) {
445                                 /* key share */
446                                 key_size = get16(&buf[res]);
447                                 if (key_size > extension_len - 2) {
448                                         DEBUG_PRINT("BROKEN KEY SHARE\n");
449                                         return TLS_BROKEN_PACKET;
450                                 }
451                                 DEBUG_DUMP_HEX_LABEL
452                                     ("EXTENSION, KEY SHARE", &buf[res],
453                                      extension_len);
454                                 key_share = &buf[res + 2];
455                         } else if (extension_type == 0x0D) {
456                                 /* signature algorithms */
457                                 DEBUG_DUMP_HEX_LABEL
458                                     ("EXTENSION, SIGNATURE ALGORITHMS",
459                                      &buf[res], extension_len);
460                         } else if (extension_type == 0x2D) {
461                                 /* psk key exchange modes */
462                                 DEBUG_DUMP_HEX_LABEL
463                                     ("EXTENSION, PSK KEY EXCHANGE MODES",
464                                      &buf[res], extension_len);
465                         }
466                         res += extension_len;
467                 }
468         }
469
470         if (buf_len != res) {
471                 return TLS_NEED_MORE_DATA;
472         }
473
474         if (context->is_server && cipher_buffer && cipher_len) {
475                 int cipher =
476                     tls_choose_cipher(context, cipher_buffer, cipher_len,
477                                       &scsv_set);
478                 if (cipher < 0) {
479                         DEBUG_PRINT("NO COMMON CIPHERS\n");
480                         return cipher;
481                 }
482                 if (downgraded && scsv_set) {
483                         DEBUG_PRINT("NO DOWNGRADE (SCSV SET)\n");
484                         tls_alert(context, 1, inappropriate_fallback);
485                         context->critical_error = 1;
486                         return TLS_NOT_SAFE;
487                 }
488                 context->cipher = cipher;
489         }
490         if (key_share && key_size && context->tlsver == TLS_VERSION13) {
491                 int key_share_err =
492                     tls_parse_key_share(context, key_share, key_size);
493                 if (key_share_err) {
494                         /* request hello retry */
495                         if (context->connection_status != 4) {
496                                 *write_packets = 5;
497                                 context->hs_messages[1] = 0;
498                                 context->connection_status = 4;
499                                 return res;
500                         } else
501                                 return key_share_err;
502                 }
503                 /* we have key share */
504                 context->connection_status = 3;
505         }
506         return res;
507 }