]> pd.if.org Git - zpackage/blob - crypto/handshake.c
remove stray debug fprintf
[zpackage] / crypto / handshake.c
1 #define _POSIX_C_SOURCE 200809L
2
3 #include <arpa/inet.h>
4
5 #include "tlse.h"
6 #include "buffer.h"
7
8 #define TLS12_FLAG 0x01
9 #define TLS13_FLAG 0x03
10
11 static unsigned char *encrypt_rsa(struct TLSContext *context,
12                                         const unsigned char *buffer,
13                                         unsigned int len,
14                                         unsigned int *size) {
15         *size = 0;
16         if (!len || !context || !context->certificates
17             || !context->certificates_count
18             || !context->certificates[0]
19             || !context->certificates[0]->der_bytes
20             || !context->certificates[0]->der_len) {
21                 DEBUG_PRINT("No certificate set\n");
22                 return NULL;
23         }
24         rsa_key key;
25         int err;
26         err = rsa_import(context->certificates[0]->der_bytes,
27                         context->certificates[0]->der_len, &key);
28
29         if (err) {
30                 DEBUG_PRINT("Error importing RSA certificate (code: %i)\n",
31                             err);
32                 return NULL;
33         }
34         unsigned long out_size = TLS_MAX_RSA_KEY;
35         unsigned char *out = malloc(out_size);
36         int hash_idx = find_hash("sha256");
37         int prng_idx = find_prng("sprng");
38         err = rsa_encrypt_key_ex(buffer, len, out, &out_size, (unsigned char *)
39                         "Concept", 7, NULL, prng_idx, hash_idx,
40                         LTC_PKCS_1_V1_5, &key);
41         rsa_free(&key);
42         if (err || !out_size) {
43                 free(out);
44                 return NULL;
45         }
46         *size = (unsigned int) out_size;
47         return out;
48 }
49
50
51 void add_supported_versions(struct tls_buffer *buf, int versions) {
52         size_t size;
53         char version12[] = { 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x03 };
54         char version13[] = { 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04 };
55         char both[] = { 0x00, 0x2b, 0x00, 0x05, 0x04, 0x03, 0x04, 0x03, 0x03 };
56         char *use;
57
58         switch (versions) {
59                 case 1:
60                         use = version12;
61                         size = sizeof version12;
62                         break;
63                 case 2:
64                         use = version13;
65                         size = sizeof version13;
66                         break;
67                 case 3:
68                         use = both;
69                         size = sizeof both;
70                         break;
71         }
72
73         tls_buffer_append(buf, use, size);
74 }
75
76 static void add_sni_extension(struct tls_buffer *buf, char *sni) {
77         size_t len;
78
79         if (!buf || !sni) {
80                 return;
81         }
82
83         len = strlen(sni);
84
85         /* server name extension id = 0x00 0x00 */
86         tls_buffer_append16(buf, 0x0000);
87         /* length of server name extension */
88         tls_buffer_append16(buf, len + 5);
89         /* length of first entry */
90         tls_buffer_append16(buf, len + 3);
91         /* DNS hostname */
92         tls_buffer_append_byte(buf, 0x00);
93         /* length of entry */
94         tls_buffer_append16(buf, len);
95         /* actual server name indication */
96         tls_buffer_append(buf, sni, len);
97
98 }
99
100
101 /*
102 00 20 - 0x20 (32) bytes of cipher suite data
103
104
105 */
106
107 static void add_cipher_suites(struct tls_buffer *buf, int suites) {
108         /* the five TLS 1.3 cipher suites in B.4 of rfc 8446 */
109         /* chacha20 preferred */
110         unsigned char tls_13_suites[] = {
111                 0x13, 0x03, 0x13, 0x01, 0x13, 0x02, 0x13, 0x04, 0x13, 0x04
112         };
113         unsigned char tls_12_suites[] = {
114                 0xcc, 0xa9, 0xc0, 0x2b, 0xc0, 0x23, 0xcc, 0xa8, 0xc0, 0x2f,
115                 0xc0, 0x27, 0x00, 0x9e, 0x00, 0x6b, 0x00, 0x67, 0xcc, 0xaa
116         };
117
118         size_t len = 0;
119
120         if (suites & 1) {
121                 len += sizeof tls_12_suites;
122         }
123         
124         if (suites & 2) {
125                 len += sizeof tls_13_suites;
126         }
127
128         tls_buffer_expand(buf, len + 2);
129         tls_buffer_append16(buf, len);
130         /* if we're including 1.3 ciphers, put them first so they're preferred
131          */
132         if (suites & 2) {
133                 tls_buffer_append(buf, tls_13_suites, sizeof tls_13_suites);
134         }
135         if (suites & 1) {
136                 tls_buffer_append(buf, tls_12_suites, sizeof tls_12_suites);
137         }
138 }
139
140 void add_signed_certificate_timestamp_extension(struct tls_buffer *buf) {
141         char sct[] = { 0x00, 0x12, 0x00, 0x00 }; /* sct id and zero bytes */
142         tls_buffer_append(buf, sct, sizeof sct);
143 }
144
145 /*
146  * 00 05 - assigned value for extension "status request"
147  * 00 05 - 0x5 (5) bytes of "status request" extension data follows
148  * 01 - assigned value for "certificate status type: OCSP"
149  * 00 00 - 0x0 (0) bytes of responderID information
150  * 00 00 - 0x0 (0) bytes of request extension information 
151  */
152 void add_status_request_extension(struct tls_buffer *buf) {
153         char sr[] = { 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00 };
154         tls_buffer_append(buf, sr, sizeof sr);
155 }
156
157 void add_supported_groups_extension(struct tls_buffer *buf) {
158         /* supported groups */
159         /* this specifies the curves */
160         unsigned char groups[] = {
161                 /* extension id and size in bytes */
162                 0x00, 0x0a, 0x00, 0x08,
163                 /* six bytes of groups */
164                 0x00, 0x06,
165 #if 0
166                 /* x25519, */
167                 0x00, 0x1d,
168 #endif
169                 /* secp256r1, secp384r1, secp521r1 */
170                 0x00, 0x17, 0x00, 0x18, 0x00, 0x19
171         };
172         tls_buffer_append(buf, groups, sizeof groups);
173 }
174
175 /*
176  * 00 0b - assigned value for extension "EC points format"
177  * 00 02 - 0x2 (2) bytes of "EC points format" extension data follows
178  * 01 - 0x1 (1) bytes of data are in the supported formats list
179  * 00 - assigned value for uncompressed form 
180  */
181 void add_ec_point_formats_extension(struct tls_buffer *buf) {
182         char formats[] = { 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00 };
183         tls_buffer_append(buf, formats, sizeof formats);
184 }
185
186 void add_signature_algorithms_extension(struct tls_buffer *buf) {
187         char algorithms[] = {
188                 0x00, 0x0d, 0x00, 0x0e, 0x00, 0x0c, /* id and lengths */
189                 0x04, 0x01, /* RSA/PKCS1/SHA256 */
190                 0x04, 0x03, /* ECDSA/SECP256r1/SHA256 */
191                 0x05, 0x01, /* RSA/PKCS1/SHA386 */
192                 0x05, 0x03, /* ECDSA/SECP384r1/SHA384 */
193                 0x06, 0x01, /* RSA/PKCS1/SHA512 */
194                 0x06, 0x03 /* ECDSA/SECP521r1/SHA512 */
195         };
196 #if 0
197         /* TODO x25519 ? */
198         char tls13algs[] = {
199                 0x04, 0x03,
200                 0x08, 0x04, /* RSA-PSS-RSAE-SHA256 */
201                 0x04, 0x01,
202                 0x05, 0x03,
203                 0x08, 0x05, /* RSA-PSS-RSAE-SHA384 */
204                 0x05, 0x01,
205                 0x08, 0x06, /* RSA-PSS-RSAE-SHA512 */
206                 0x06, 0x01
207                         /* and 0x02, 0x01 for RSA-PKCS1-SHA1 */
208
209         };
210 #endif
211
212         tls_buffer_append(buf, algorithms, sizeof algorithms);
213 }
214
215 static void add_renegotiation_info_extension(struct tls_buffer *buf) {
216         /* two bytes id, and one byte of zero bytes of info */
217         char info[] = { 0xff, 0x01, 0x00, 0x01, 0x00 };
218
219         tls_buffer_append(buf, info, sizeof info);
220 }
221
222 /*
223  * 00 33 - assigned value for extension "Key Share"
224  * 00 26 - 0x26 (38) bytes of "Key Share" extension data follows
225  * 00 24 - 0x24 (36) bytes of key share data follows
226  * 00 1d - assigned value for x25519 (key exchange via curve25519)
227  * 00 20 - 0x20 (32) bytes of public key follows
228  * 35 80 ... 62 54 - public key from the step "Client Key Exchange Generation" 
229  */
230 static void add_key_share_extension(struct tls_buffer *buf, struct TLSContext
231                 *ctx) {
232         char kseid[] = { 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00,
233                 0x20 };
234         char bogus_key[32] = { 0 };
235
236         if (!ctx) {
237                 return;
238         }
239
240         /* TODO figure out where the client key share is */
241         tls_buffer_append(buf, kseid, sizeof kseid);
242         tls_buffer_append(buf, bogus_key, sizeof bogus_key);
243 }
244
245 /*
246  * 00 2d - assigned value for extension "PSK Key Exchange Modes"
247  * 00 02 - 0x2 (2) bytes of "PSK Key Exchange Modes" extension data follows
248  * 01 - 0x1 (1) bytes of exchange modes follow
249  * 01 - assigned value for "PSK with (EC)DHE key establishment" 
250  *
251  * we don't actually pre-share keys here, so ignored, but we'll send it
252  * anyway
253  */
254 /* TODO probably need to get these from the context */
255 static void add_pks_key_exchanges_modes_extension(struct tls_buffer *buf) {
256         char psk[] = { 0x00, 0x2d, 0x00, 0x02, 0x01, 0x01 };
257
258         tls_buffer_append(buf, psk, sizeof psk);
259 }
260
261 static void set_handshake_header(char *buf, int type, size_t length) {
262         buf[0] = type & 0xff;
263         buf[1] = (length >> 16) & 0xff;
264         buf[2] = (length >> 8) & 0xff;
265         buf[3] = (length >> 0) & 0xff;
266 }
267
268 int tls_client_hello(struct TLSContext *ctx, struct tls_buffer *hello) {
269         size_t hello_offset = hello->len;
270
271         /* make room for the handshake header */
272         tls_buffer_expand(hello, 4);
273         hello->len += 4;
274
275         /* actual client hello structure follows */
276
277         tls_buffer_append(hello, "\x03\x03", 2); /* legacy_version */
278         /* random not set up yet */
279         //tls_random(ctx->local_random, 32);
280         tls_buffer_append(hello, ctx->local_random, 32); /* client random */
281         tls_buffer_append(hello, "\0", 1); /* legacy_session_id */
282         /* alternatively, append a 32 (0x20) and 32 random bytes as a bogus
283          * session id */
284
285         /*
286          * cipher suites
287          * TODO need a way to only use v1.3
288          */
289         int suites = TLS12_FLAG; /* always use v1.2 suites */
290
291         if (ctx->tlsver == TLS_VERSION13) {
292                 suites |= TLS13_FLAG;
293         }
294         add_cipher_suites(hello, suites);
295
296         /* legacy_compression_methods */
297         tls_buffer_append(hello, "\1\0", 2);
298         
299         /* 
300          * extensions
301          * TODO I don't think the extension order matters, so the code below
302          * can be simplified by putting all the extensions together by version
303          */
304         size_t extensions_start = hello->len;
305         /* first two bytes are length of extensions, so make room to fill them
306          * in once we know the size
307          */
308         tls_buffer_append(hello, "\0\0", 2);
309
310         /* TODO need to track which extensions we're sending:
311          * "If a client receives an extension type in ServerHello that it did
312          * not request in the associated ClientHello, it MUST abort the
313          * handshake with an unsupported_extension fatal alert."
314          */
315         add_sni_extension(hello, ctx->sni); /* server name indicator */
316
317 #if 0
318         /* TODO not sure why 1.3 doesn't need or want this */
319         /* TODO duckduckgo.com seems to fail with this one */
320         if (ctx->tlsver == TLS_VERSION12) {
321                 add_status_request_extension(hello);
322         }
323 #endif
324
325         add_supported_groups_extension(hello);
326
327         /* v1.2 only, points are fixed in v1.3 */
328         if (ctx->tlsver == TLS_VERSION12) {
329                 add_ec_point_formats_extension(hello);
330         }
331
332         add_signature_algorithms_extension(hello);
333
334         if (ctx->tlsver == TLS_VERSION13) {
335                 add_key_share_extension(hello, ctx);
336                 add_pks_key_exchanges_modes_extension(hello);
337         }
338
339         /* v1.2 only, 1.3 doesn't support renegotiation
340          * and doesn't seem to need cert ts
341          */
342         if (ctx->tlsver == TLS_VERSION12) {
343                 add_renegotiation_info_extension(hello);
344                 add_signed_certificate_timestamp_extension(hello);
345         }
346
347         if (ctx->tlsver == TLS_VERSION13) {
348                 /* supported versions is mandatory in V1.3 */
349                 /* 1 v1.2 only, 2 = v1.3 only, 3 = both */
350                 /* TODO need a context flag to allow fallback to v1.2 */
351                 /* could probably pass this in v1.2 and the server
352                  * would ignore it */
353                 add_supported_versions(hello, 3);
354         }
355
356         /* set the extensions length */
357         size_t extensions_length = hello->len - extensions_start - 2;
358         tls_buffer_write16(hello, extensions_length, extensions_start);
359
360         /* fill in the handshake header */
361         size_t hello_length = hello->len - hello_offset - 4;
362         set_handshake_header(hello->buffer+hello_offset, client_hello,
363                         hello_length);
364
365         tls_buffer_compact(hello);
366         return hello->error;
367 }
368
369 #if 0
370 void pbytes(unsigned char *b, size_t len, char *label) {
371         size_t i;
372
373         fprintf(stderr, "%s (%zu bytes)\n", label ? label : "dumping", len);
374
375         for (i=0; i<len; i++) {
376                 fprintf(stderr, "%s%02x%s",
377                                 i % 20 ? " " : "",
378                                 b[i],
379                                 (i+1) % 20 ? "" : "\n"
380                        );
381         }
382         if (i%20) {
383                 fprintf(stderr, "\n");
384         }
385 }
386 #endif
387
388 struct TLSPacket *tls_build_client_hello(struct TLSContext *context) {
389         if (context->connection_status == 4) {
390                 unsigned char header[4] = { 0xFE, 0, 0, 0 };
391                 unsigned char hash[TLS_MAX_SHA_SIZE];
392                 static unsigned char sha256_helloretryrequest[] =
393                     { 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE,
394                         0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2,
395                         0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07,
396                         0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C
397                 };
398                 fprintf(stderr, "got hello retry request\n");
399                 memcpy(context->local_random, sha256_helloretryrequest, 32);
400                 int hash_len = tls_done_hash(context, hash);
401                 header[3] = (unsigned char) hash_len;
402                 tls_update_hash(context, header, sizeof header);
403                 tls_update_hash(context, hash, hash_len);
404         } else if (context->tlsver != TLS_VERSION13) {
405                 //fprintf(stderr, "creating local_random\n");
406                 if (!tls_random(context->local_random, TLS_SERVER_RANDOM_SIZE)) {
407                         return NULL;
408                 }
409         }
410
411         struct tls_buffer shadow;
412         char record_header[] = { 0x16, 0x03, 0x03, 0x00, 0x00 };
413
414         tls_buffer_init(&shadow, 106);
415         tls_buffer_append(&shadow, record_header, sizeof record_header);
416         tls_client_hello(context, &shadow);
417         tls_buffer_writebe(&shadow, 3, 6, shadow.len - 9);
418
419         if (shadow.error) {
420                 tls_buffer_free(&shadow);
421                 return NULL;
422         }
423
424         struct TLSPacket *packet = malloc(sizeof *packet);
425
426         if (!packet) {
427                 return NULL;
428         }
429
430         packet->buf = shadow.buffer;
431         packet->len = shadow.len;
432         packet->size = shadow.size;
433         packet->payload_pos = 0;
434         packet->broken = 0;
435         packet->context = context;
436
437         tls_packet_update(packet);
438
439         return packet;
440 }
441
442 int tls_send_client_hello(struct TLSContext *ctx) {
443         return ctx ? 1 : 0;
444 }
445
446 struct TLSPacket *tls_build_hello(struct TLSContext *context,
447                                   int tls13_downgrade) {
448         if (context->connection_status == 4) {
449                 unsigned char header[4] = { 0xFE, 0, 0, 0 };
450                 unsigned char hash[TLS_MAX_SHA_SIZE];
451                 static unsigned char sha256_helloretryrequest[] =
452                     { 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE,
453                         0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2,
454                         0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07,
455                         0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C
456                 };
457                 fprintf(stderr, "got hello retry request\n");
458                 memcpy(context->local_random, sha256_helloretryrequest, 32);
459                 int hash_len = tls_done_hash(context, hash);
460                 header[3] = (unsigned char) hash_len;
461                 tls_update_hash(context, header, sizeof header);
462                 tls_update_hash(context, hash, hash_len);
463         } else if (!context->is_server || context->tlsver != TLS_VERSION13) {
464                 fprintf(stderr, "creating local_random\n");
465                 if (!tls_random(context->local_random, TLS_SERVER_RANDOM_SIZE)) {
466                         return NULL;
467                 }
468         }
469
470         if (context->is_server && tls13_downgrade) {
471                 if (tls13_downgrade == TLS_V12 || tls13_downgrade == DTLS_V12)
472                 {
473                         memcpy(context->local_random +
474                                TLS_SERVER_RANDOM_SIZE - 8, "DOWNGRD\x01",
475                                8);
476                 } else {
477                         memcpy(context->local_random +
478                                TLS_SERVER_RANDOM_SIZE - 8, "DOWNGRD\x00",
479                                8);
480                 }
481         }
482
483         if (!context->is_server) {
484                 struct tls_buffer shadow;
485                 char record_header[] = { 0x16, 0x03, 0x03, 0x00, 0x00 };
486
487                 tls_buffer_init(&shadow, 106);
488                 tls_buffer_append(&shadow, record_header, sizeof record_header);
489                 tls_client_hello(context, &shadow);
490                 tls_buffer_writebe(&shadow, 3, 6, shadow.len - 9);
491
492                 if (shadow.error) {
493                         tls_buffer_free(&shadow);
494                         return NULL;
495                 }
496
497                 struct TLSPacket *packet = malloc(sizeof *packet);
498
499                 if (!packet) {
500                         return NULL;
501                 }
502
503                 free(packet->buf);
504                 packet->buf = shadow.buffer;
505                 packet->len = shadow.len;
506                 packet->size = shadow.size;
507                 packet->payload_pos = 0;
508                 packet->broken = 0;
509                 packet->context = context;
510
511                 tls_packet_update(packet);
512
513                 fprintf(stderr, "returning packet\n");
514                 return packet;
515         }
516
517         /* context must be server from here on out */
518
519         unsigned short packet_version = context->version;
520         unsigned short version = context->version;
521
522         if (context->version == TLS_V13) {
523                 version = TLS_V12;
524         } else if (context->version == DTLS_V13) {
525                 version = DTLS_V12;
526         }
527
528         struct TLSPacket *packet =
529                 tls_create_packet(context, TLS_HANDSHAKE, packet_version, 0);
530
531         /* hello */
532         tls_packet_uint8(packet, server_hello);
533
534         tls_packet_uint24(packet, 0);
535
536         int start_len = packet->len;
537         tls_packet_uint16(packet, version);
538
539         tls_packet_append(packet, context->local_random,
540                         TLS_SERVER_RANDOM_SIZE);
541
542         /* session size, always 0, we don't support sessions */
543         tls_packet_uint8(packet, 0);
544
545         int extension_len = 0;
546         int alpn_len = 0;
547         int alpn_negotiated_len = 0;
548         unsigned char shared_key[TLS_MAX_RSA_KEY];
549         unsigned long shared_key_len = TLS_MAX_RSA_KEY;
550         unsigned short shared_key_short = 0;
551         int selected_group = 0;
552         if (context->tlsver == TLS_VERSION13) {
553                 if (context->curve == &curve25519) {
554                         extension_len += 8 + 32;
555                         shared_key_short = (unsigned short) 32;
556                         if (context->finished_key) {
557                                 memcpy(shared_key,
558                                                 context->
559                                                 finished_key, 32);
560                                 free(context->finished_key);
561                                 context->finished_key = NULL;
562                         }
563                         selected_group = context->curve->iana;
564                         /* make context->curve NULL (x25519 is a different implementation) */
565                         context->curve = NULL;
566                 } else if (context->ecc_dhe) {
567                         if (ecc_ansi_x963_export
568                                         (context->ecc_dhe, shared_key,
569                                          &shared_key_len)) {
570                                 DEBUG_PRINT
571                                         ("Error exporting ECC DHE key\n");
572                                 tls_destroy_packet(packet);
573                                 tls_alert(context, 1, internal_error);
574                                 return NULL;
575                         }
576                         tls_ecc_dhe_free(context);
577                         extension_len += 8 + shared_key_len;
578                         shared_key_short =
579                                 (uint16_t)shared_key_len;
580                         if (context->curve) {
581                                 selected_group =
582                                         context->curve->iana;
583                         }
584                 } else if (context->dhe) {
585                         selected_group = context->dhe->iana;
586                         tls_dh_export_Y(shared_key,
587                                         &shared_key_len,
588                                         context->dhe);
589                         tls_dhe_free(context);
590                         extension_len += 8 + shared_key_len;
591                         shared_key_short = shared_key_len;
592                 }
593
594                 extension_len += 6;
595         }
596
597         if (context->negotiated_alpn && context->tlsver != TLS_VERSION13) {
598                 alpn_negotiated_len = strlen(context->negotiated_alpn);
599                 alpn_len = alpn_negotiated_len + 1;
600                 extension_len += alpn_len + 6;
601         }
602
603         /* ciphers */
604         /* fallback ... this should never happen */
605         if (!context->cipher) {
606                 context->cipher = TLS_DHE_RSA_WITH_AES_128_CBC_SHA;
607         }
608
609         tls_packet_uint16(packet, context->cipher);
610         /* no compression */
611         tls_packet_uint8(packet, 0);
612
613         if (context->tlsver == TLS_VERSION13) {
614                 /* supported versions */
615                 tls_packet_uint16(packet, 0x2B);
616
617                 tls_packet_uint16(packet, 2);
618                 if (context->version == TLS_V13) {
619                         tls_packet_uint16(packet,
620                                         context-> tls13_version ?
621                                         context-> tls13_version :
622                                         TLS_V13);
623                 } else {
624                         tls_packet_uint16(packet, context->version);
625                 }
626
627                 if (context->connection_status == 4) {
628                         /* fallback to the mandatory secp256r1 */
629                         tls_packet_uint16(packet, 0x33);
630                         tls_packet_uint16(packet, 2);
631                         tls_packet_uint16(packet, (uint16_t) secp256r1.iana);
632                 }
633
634                 if (shared_key_short && selected_group) {
635                         /* key share */
636                         tls_packet_uint16(packet, 0x33);
637                         tls_packet_uint16(packet, shared_key_short + 4);
638                         tls_packet_uint16(packet, selected_group);
639                         tls_packet_uint16(packet, shared_key_short);
640                         tls_packet_append(packet, (unsigned char *) shared_key,
641                                         shared_key_short);
642                 }
643         }
644
645         if (!packet->broken && packet->buf) {
646                 tls_set_packet_length(packet, packet->len - start_len);
647         }
648
649         tls_packet_update(packet);
650
651         return packet;
652 }
653
654 struct TLSPacket *tls_buffer_packet(struct tls_buffer *b, struct TLSContext *c) {
655         struct TLSPacket *p = 0;
656
657         if (b && c) {
658                 p = tls_create_packet(c, TLS_HANDSHAKE, c->version, 0);
659
660                 if (p) {
661                         free(p->buf);
662                         p->buf = b->buffer;
663                         p->size = b->size;
664                         p->len = b->len;
665                         p->payload_pos = 0;
666                         p->broken = 0;
667                         p->context = c;
668                 } else {
669                         tls_buffer_free(b);
670                 }
671         }
672
673         return p;
674 }
675
676 static void append_dhe(struct TLSContext *ctx, struct tls_buffer *buf) {
677         unsigned char dh_Ys[0xFFF];
678         unsigned char dh_p[0xFFF];
679         unsigned char dh_g[0xFFF];
680         unsigned long dh_p_len = sizeof dh_p;
681         unsigned long dh_g_len = sizeof dh_g;
682         unsigned long dh_Ys_len = sizeof dh_Ys;
683
684         ENTER;
685         if (tls_dh_export_pqY(dh_p, &dh_p_len, dh_g, &dh_g_len, dh_Ys,
686                                 &dh_Ys_len, ctx->dhe)) {
687                 DEBUG_PRINT("ERROR EXPORTING DHE KEY %p\n", ctx->dhe);
688                 buf->error = 1;
689                 tls_dhe_free(ctx);
690                 LEAVE;;
691                 return;
692         }
693         tls_buffer_append_byte(buf, 0x10);
694
695         tls_dhe_free(ctx);
696
697         DEBUG_DUMP_HEX_LABEL("Yc", dh_Ys, dh_Ys_len);
698
699         tls_buffer_append24(buf, dh_Ys_len + 2);
700         tls_buffer_append16(buf, dh_Ys_len);
701         tls_buffer_append(buf, dh_Ys, dh_Ys_len);
702         LEAVE;
703 }
704
705 static void append_ecdhe(struct TLSContext *ctx, struct tls_buffer *buf) {
706         unsigned char out[TLS_MAX_RSA_KEY];
707         unsigned long out_len = TLS_MAX_RSA_KEY;
708
709         ENTER;
710
711         if (ecc_ansi_x963_export(ctx->ecc_dhe, out, &out_len)) {
712                 DEBUG_PRINT("Error exporting ECC key\n");
713                 buf->error = 1;
714                 LEAVE;;
715                 return;
716         }
717
718         tls_ecc_dhe_free(ctx);
719
720         tls_buffer_append_byte(buf, 0x10);
721         tls_buffer_append24(buf, out_len + 1);
722
723         tls_buffer_append_byte(buf, out_len);
724         tls_buffer_append(buf, out, out_len);
725         LEAVE;
726 }
727
728 static void set_record_size(struct tls_buffer *b) {
729         uint16_t size;
730
731         size = b->len - 5;
732         tls_buffer_write16(b, size, 3);
733 }
734
735 struct TLSPacket *tls_client_key_exchange(struct TLSContext *context) {
736         struct tls_buffer cke;
737         struct TLSPacket *p;
738
739         ENTER;
740         tls_buffer_init(&cke, 42);
741         tls_buffer_append_byte(&cke, 0x16);
742         tls_buffer_append16(&cke, 0x0303);
743         tls_buffer_append16(&cke, 0); /* record size placeholder */
744
745         if (context->ecc_dhe) {
746                 append_ecdhe(context, &cke);
747         } else {
748                 append_dhe(context, &cke);
749         }
750         set_record_size(&cke);
751
752         p = tls_buffer_packet(&cke, context);
753
754         tls_compute_key(context, 48);
755         context->connection_status = 2;
756         tls_packet_update(p);
757
758         LEAVE;
759         return p;
760 }
761
762 static int tls_build_random(struct TLSPacket *packet) {
763         int res = 0;
764         unsigned char rand_bytes[48];
765         int bytes = 48;
766
767         if (!tls_random(rand_bytes, bytes)) {
768                 return TLS_GENERIC_ERROR;
769         }
770
771         /* max supported version */
772         if (packet->context->is_server) {
773                 *(unsigned short *) rand_bytes =
774                     htons(packet->context->version);
775         } else {
776                 *(unsigned short *) rand_bytes = htons(TLS_V12);
777         }
778
779         /* DEBUG_DUMP_HEX_LABEL("PREMASTER KEY", rand_bytes, bytes); */
780
781         free(packet->context->premaster_key);
782
783         packet->context->premaster_key = malloc(bytes);
784         if (!packet->context->premaster_key) {
785                 return TLS_NO_MEMORY;
786         }
787
788         packet->context->premaster_key_len = bytes;
789         memcpy(packet->context->premaster_key, rand_bytes,
790                packet->context->premaster_key_len);
791
792         unsigned int out_len;
793
794         unsigned char *random = encrypt_rsa(packet->context,
795                         packet->context->premaster_key,
796                         packet->context->premaster_key_len, &out_len);
797
798         tls_compute_key(packet->context, bytes);
799         if (random && out_len > 2) {
800                 tls_packet_uint24(packet, out_len + 2);
801                 tls_packet_uint16(packet, out_len);
802                 tls_packet_append(packet, random, out_len);
803         } else {
804                 res = TLS_GENERIC_ERROR;
805         }
806
807         free(random);
808
809         if (res) {
810                 return res;
811         }
812
813         return out_len + 2;
814 }
815
816 void tls_send_client_key_exchange(struct TLSContext *context) {
817         struct TLSPacket *packet;
818
819         ENTER;
820         int ephemeral = tls_cipher_is_ephemeral(context);
821
822         if (ephemeral && context->premaster_key && context->premaster_key_len) {
823                 //fprintf(stderr, "YYYY\n");
824                 packet = tls_client_key_exchange(context);
825                 tls_queue_packet(packet);
826                 LEAVE;
827                 return;
828                 if (ephemeral == 1) {
829                         /* dhe */
830                 } else if (context->ecc_dhe) {
831                         /* ecc dhe */
832                 }
833         } else {
834                 /* TODO should never happen, should always require
835                  * either DHE or ECC DHE */
836                 fprintf(stderr, "ZZZZ build random\n");
837                 LEAVE;
838                 return;
839                 packet = tls_create_packet(context, TLS_HANDSHAKE, context->version, 0);
840                 tls_packet_uint8(packet, 0x10);
841                 tls_build_random(packet);
842         }
843         context->connection_status = 2;
844         tls_packet_update(packet);
845         tls_queue_packet(packet);
846         LEAVE;
847         return;
848 }
849
850 static uint32_t get24(const unsigned char *buf) {
851         return (*buf << 16) + (*(buf+1) << 8) + *(buf+2);
852 }
853
854 static uint16_t get16(const unsigned char *buf) {
855         return (*(buf) << 8) + *(buf+1);
856 }
857
858 int tls_hello_complete(const unsigned char *buf, size_t len) {
859         size_t more;
860
861         if (len < 3) {
862                 return 0;
863         }
864
865         more = get16(buf);
866         if (more > len - 3) {
867                 fprintf(stderr, "%s:%d\n", __func__, __LINE__);
868                 fprintf(stderr, "have %zu, want %zu\n", len, more);
869                 return 0;
870         }
871         return 1;
872 }
873
874 int tls_parse_server_hello(struct TLSContext *ctx, const unsigned char *buf, size_t len) {
875         size_t i = 0;
876         size_t more = 0;
877
878         if (ctx->connection_status != 0 && ctx->connection_status != 4) {
879                 return TLS_UNEXPECTED_MESSAGE;
880         }
881
882         if (!tls_hello_complete(buf, len)) {
883                 return TLS_NEED_MORE_DATA;
884         }
885
886         /* 3 bytes server hello data size */
887         more = get24(buf+i);
888         i+=3;
889         /* TODO check size reported vs actual */
890
891         /* two bytes server version */
892         uint16_t server_ver = get16(buf+i);
893         i+=2;
894         DEBUG_PRINTLN("server version = %04x\n", server_ver);
895         if (server_ver != ctx->version) {
896                 /* TODO allow (or not) downgrade to v1.2 */
897                 return TLS_UNEXPECTED_MESSAGE;
898         }
899
900         /* 32 bytes server random */
901         memcpy(ctx->remote_random, buf+i, 32);
902         i+=32;
903
904         /* 1 byte of session id length */
905         uint8_t session_len = *(buf+i);
906         i+=1;
907
908         char *session_id;
909         /* possible session id bytes */
910         /* TODO skip? we don't actually use session ids */
911         if (session_len) {
912                 session_id = malloc(session_len);
913                 if (!session_id) {
914                         return 0;
915                 }
916
917                 memcpy(session_id, buf+i, session_len);
918         }
919         i+=session_len;
920
921         /* two bytes cipher suite selected */
922         ctx->cipher = get16(buf+i);
923         i+=2;
924         DEBUG_PRINTLN("server cipher = %04x\n", ctx->cipher);
925         if (!tls_cipher_supported(ctx, ctx->cipher)) {
926                 ctx->cipher = 0;
927                 DEBUG_PRINT("NO CIPHER SUPPORTED\n");
928                 MARK;
929                 return TLS_NO_COMMON_CIPHER;
930         }
931
932         /* one byte compression method */
933         uint8_t compression_method = *(buf+i);
934         i++;
935         if (compression_method != 0) {
936                 return 0;
937         }
938
939         if (i > 0 && ctx->connection_status != 4) {
940                 ctx->connection_status = 1;
941         }
942
943         if (i+2 > len) {
944                 /* no extensions */
945                 return 0;
946         }
947
948         /* two bytes extensions length */
949         uint16_t extensions_length = get16(buf+i);
950         i+=2;
951
952         if (extensions_length + i != len) {
953                 /* mismatch */
954                 return 0;
955         }
956
957         /* extensions */
958         while (i+5 <= len) {
959                 int etype = get16(buf+i);
960                 i+=2;
961
962                 /* TODO check that extension type is in the list of
963                  * extensions we sent, if not, abort with
964                  * unsupported_extension fatal alert
965                  */
966                 uint16_t elen = get16(buf+i);
967                 i+=2;
968                 if (elen == 0) {
969                         continue;
970                 }
971                 if (i+elen > len) {
972                 MARK;
973                         return TLS_BROKEN_PACKET;
974                 }
975
976                 uint16_t sni_len;
977                 uint16_t alpn_len;
978                 const unsigned char *alpn;
979                 unsigned char alpn_size;
980                 int alpn_pos = 0;
981                 uint16_t group_len, iana_n;
982                 int j = i;
983                 int selected = 0;
984
985                 switch (etype) {
986                         case 0x0000:
987                                 sni_len = get16(buf+i);
988                                 if (sni_len) {
989                                         ctx->sni = malloc(sni_len + 1);
990                                         memcpy(ctx->sni, buf+i+2, sni_len);
991                                         ctx->sni[sni_len] = 0;
992                                 }
993                                 break;
994                         case 0x000a: /* supported groups */
995                                 fprintf(stderr, "supported groups\n");
996                                 /* supported groups */
997                                 if (i+2 > len) {
998                                         return TLS_BROKEN_PACKET;
999                                 }
1000
1001                                 group_len = get16(buf+i);
1002                                 for (j = i; i < i + group_len+2; i+=2) {
1003                                         iana_n = get16(buf+j);
1004                                         switch (iana_n) {
1005                                                 case 23:
1006                                                         ctx->curve = &secp256r1;
1007                                                         selected = 1;
1008                                                         break;
1009                                                 case 24:
1010                                                         ctx->curve = &secp384r1;
1011                                                         selected = 1;
1012                                                         break;
1013                                                 case 29:
1014                                                         ctx->curve = &curve25519;
1015                                                         selected = 1;
1016                                                         break;
1017                                                 case 25:
1018                                                         ctx->curve = &secp521r1;
1019                                                         selected = 1;
1020                                                         break;
1021                                         }
1022                                 }
1023                                 /* if ctx->curve */
1024                                 if (selected) {
1025                                         DEBUG_PRINTLN("SELECTED CURVE %s\n",
1026                                                  ctx->curve->name);
1027                                 }
1028                         case 0x0010:
1029                                 if (!ctx->alpn || ctx->alpn_count == 0) {
1030                                         break;
1031                                 }
1032                                 if (i+2 > len) {
1033                                         return TLS_BROKEN_PACKET;
1034                                 }
1035
1036                                 alpn_len = get16(buf+i);
1037
1038                                 if (alpn_len == 0 || alpn_len > elen - 2) {
1039                                         /* TODO broken */
1040                                         break;
1041                                 }
1042                                 /* a server's alpn list "must contain exactly
1043                                  * one "ProtocolName"
1044                                  */
1045                                 alpn_size = buf[i + 2];
1046                                 alpn = buf + i + 3;
1047                                 if (i + alpn_size + 3 < len) {
1048                                         break;
1049                                 }
1050
1051                                 if (!tls_alpn_contains(ctx, (char *)alpn, alpn_size)) {
1052                                         break;
1053                                 }
1054                                 free(ctx->negotiated_alpn);
1055                                 ctx->negotiated_alpn = malloc(alpn_size + 1);
1056                                 if (ctx->negotiated_alpn) {
1057                                         memcpy(ctx->negotiated_alpn,
1058                                                         &alpn[alpn_pos],
1059                                                         alpn_size);
1060                                         ctx->negotiated_alpn[alpn_size] = 0;
1061                                 }
1062                                 break;
1063                         case 0xff01: /* renegotiation info */
1064                                 //fprintf(stderr, "renegotiation info\n");
1065                                 /* ignore, we don't support renegotiation */
1066                 MARK;
1067                                 break;
1068                         case 0x0033: /* key share */
1069                                 /* TODO parse key share */
1070                                 fprintf(stderr, "key share info\n");
1071                                 break;
1072                         case 0x000b:
1073                                 /* signature algorithms */
1074                 MARK;
1075                                 break;
1076                         case 0x002b: /* supported versions */
1077                                 /* should be two bytes of 0x00 0x02
1078                                  * indicating two bytes of server version
1079                                  * then 0x03 0x04 for v1.3
1080                                  */
1081                                 fprintf(stderr, "supported versions\n");
1082                                 break;
1083                         default:
1084                                 fprintf(stderr, "unknown extension %04x\n", etype);
1085                                 break;
1086                 }
1087                 i+=elen;
1088         }
1089
1090 #if 0
1091         if (ctx->connection_status != 4) {
1092                 ctx->connection_status = 1;
1093         }
1094 #endif
1095
1096                 MARK;
1097         return 1;
1098 }