]> pd.if.org Git - zpackage/blobdiff - crypto/handshake.c
commit files needed for zpm-fetchurl
[zpackage] / crypto / handshake.c
diff --git a/crypto/handshake.c b/crypto/handshake.c
new file mode 100644 (file)
index 0000000..30eb788
--- /dev/null
@@ -0,0 +1,1080 @@
+#define _POSIX_C_SOURCE 200809L
+
+#include <arpa/inet.h>
+
+#include "tlse.h"
+#include "buffer.h"
+
+#define TLS12_FLAG 0x01
+#define TLS13_FLAG 0x03
+
+static unsigned char *encrypt_rsa(struct TLSContext *context,
+                                       const unsigned char *buffer,
+                                       unsigned int len,
+                                       unsigned int *size) {
+       *size = 0;
+       if (!len || !context || !context->certificates
+           || !context->certificates_count
+           || !context->certificates[0]
+           || !context->certificates[0]->der_bytes
+           || !context->certificates[0]->der_len) {
+               DEBUG_PRINT("No certificate set\n");
+               return NULL;
+       }
+       rsa_key key;
+       int err;
+       err = rsa_import(context->certificates[0]->der_bytes,
+                       context->certificates[0]->der_len, &key);
+
+       if (err) {
+               DEBUG_PRINT("Error importing RSA certificate (code: %i)\n",
+                           err);
+               return NULL;
+       }
+       unsigned long out_size = TLS_MAX_RSA_KEY;
+       unsigned char *out = malloc(out_size);
+       int hash_idx = find_hash("sha256");
+       int prng_idx = find_prng("sprng");
+       err = rsa_encrypt_key_ex(buffer, len, out, &out_size, (unsigned char *)
+                       "Concept", 7, NULL, prng_idx, hash_idx,
+                       LTC_PKCS_1_V1_5, &key);
+       rsa_free(&key);
+       if (err || !out_size) {
+               free(out);
+               return NULL;
+       }
+       *size = (unsigned int) out_size;
+       return out;
+}
+
+
+void add_supported_versions(struct tls_buffer *buf, int versions) {
+       size_t size;
+       char version12[] = { 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x03 };
+       char version13[] = { 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04 };
+       char both[] = { 0x00, 0x2b, 0x00, 0x05, 0x04, 0x03, 0x04, 0x03, 0x03 };
+       char *use;
+
+       switch (versions) {
+               case 1:
+                       use = version12;
+                       size = sizeof version12;
+                       break;
+               case 2:
+                       use = version13;
+                       size = sizeof version13;
+                       break;
+               case 3:
+                       use = both;
+                       size = sizeof both;
+                       break;
+       }
+
+       tls_buffer_append(buf, use, size);
+}
+
+static void add_sni_extension(struct tls_buffer *buf, char *sni) {
+       size_t len;
+
+       if (!buf || !sni) {
+               return;
+       }
+
+       len = strlen(sni);
+
+       /* server name extension id = 0x00 0x00 */
+       tls_buffer_append16(buf, 0x0000);
+       /* length of server name extension */
+       tls_buffer_append16(buf, len + 5);
+       /* length of first entry */
+       tls_buffer_append16(buf, len + 3);
+       /* DNS hostname */
+       tls_buffer_append_byte(buf, 0x00);
+       /* length of entry */
+       tls_buffer_append16(buf, len);
+       /* actual server name indication */
+       tls_buffer_append(buf, sni, len);
+
+}
+
+
+/*
+00 20 - 0x20 (32) bytes of cipher suite data
+
+
+*/
+
+static void add_cipher_suites(struct tls_buffer *buf, int suites) {
+       /* the five TLS 1.3 cipher suites in B.4 of rfc 8446 */
+       /* chacha20 preferred */
+       unsigned char tls_13_suites[] = {
+               0x13, 0x03, 0x13, 0x01, 0x13, 0x02, 0x13, 0x04, 0x13, 0x04
+       };
+       unsigned char tls_12_suites[] = {
+               0xcc, 0xa9, 0xc0, 0x2b, 0xc0, 0x23, 0xcc, 0xa8, 0xc0, 0x2f,
+               0xc0, 0x27, 0x00, 0x9e, 0x00, 0x6b, 0x00, 0x67, 0xcc, 0xaa
+               };
+
+       size_t len = 0;
+
+       if (suites & 1) {
+               len += sizeof tls_12_suites;
+       }
+       
+       if (suites & 2) {
+               len += sizeof tls_13_suites;
+       }
+
+       tls_buffer_expand(buf, len + 2);
+       tls_buffer_append16(buf, len);
+       /* if we're including 1.3 ciphers, put them first so they're preferred
+        */
+       if (suites & 2) {
+               tls_buffer_append(buf, tls_13_suites, sizeof tls_13_suites);
+       }
+       if (suites & 1) {
+               tls_buffer_append(buf, tls_12_suites, sizeof tls_12_suites);
+       }
+}
+
+void add_signed_certificate_timestamp_extension(struct tls_buffer *buf) {
+       char sct[] = { 0x00, 0x12, 0x00, 0x00 }; /* sct id and zero bytes */
+       tls_buffer_append(buf, sct, sizeof sct);
+}
+
+/*
+ * 00 05 - assigned value for extension "status request"
+ * 00 05 - 0x5 (5) bytes of "status request" extension data follows
+ * 01 - assigned value for "certificate status type: OCSP"
+ * 00 00 - 0x0 (0) bytes of responderID information
+ * 00 00 - 0x0 (0) bytes of request extension information 
+ */
+void add_status_request_extension(struct tls_buffer *buf) {
+       char sr[] = { 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00 };
+       tls_buffer_append(buf, sr, sizeof sr);
+}
+
+void add_supported_groups_extension(struct tls_buffer *buf) {
+       /* supported groups */
+       /* this specifies the curves */
+       unsigned char groups[] = {
+               /* extension id and size in bytes */
+               0x00, 0x0a, 0x00, 0x08,
+               /* six bytes of groups */
+               0x00, 0x06,
+#if 0
+               /* x25519, */
+               0x00, 0x1d,
+#endif
+               /* secp256r1, secp384r1, secp521r1 */
+               0x00, 0x17, 0x00, 0x18, 0x00, 0x19
+       };
+       tls_buffer_append(buf, groups, sizeof groups);
+}
+
+/*
+ * 00 0b - assigned value for extension "EC points format"
+ * 00 02 - 0x2 (2) bytes of "EC points format" extension data follows
+ * 01 - 0x1 (1) bytes of data are in the supported formats list
+ * 00 - assigned value for uncompressed form 
+ */
+void add_ec_point_formats_extension(struct tls_buffer *buf) {
+       char formats[] = { 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00 };
+       tls_buffer_append(buf, formats, sizeof formats);
+}
+
+void add_signature_algorithms_extension(struct tls_buffer *buf) {
+       char algorithms[] = {
+               0x00, 0x0d, 0x00, 0x0e, 0x00, 0x0c, /* id and lengths */
+               0x04, 0x01, /* RSA/PKCS1/SHA256 */
+               0x04, 0x03, /* ECDSA/SECP256r1/SHA256 */
+               0x05, 0x01, /* RSA/PKCS1/SHA386 */
+               0x05, 0x03, /* ECDSA/SECP384r1/SHA384 */
+               0x06, 0x01, /* RSA/PKCS1/SHA512 */
+               0x06, 0x03 /* ECDSA/SECP521r1/SHA512 */
+       };
+#if 0
+       /* TODO x25519 ? */
+       char tls13algs[] = {
+               0x04, 0x03,
+               0x08, 0x04, /* RSA-PSS-RSAE-SHA256 */
+               0x04, 0x01,
+               0x05, 0x03,
+               0x08, 0x05, /* RSA-PSS-RSAE-SHA384 */
+               0x05, 0x01,
+               0x08, 0x06, /* RSA-PSS-RSAE-SHA512 */
+               0x06, 0x01
+                       /* and 0x02, 0x01 for RSA-PKCS1-SHA1 */
+
+       };
+#endif
+
+       tls_buffer_append(buf, algorithms, sizeof algorithms);
+}
+
+static void add_renegotiation_info_extension(struct tls_buffer *buf) {
+               /* two bytes id, and one byte of zero bytes of info */
+       char info[] = { 0xff, 0x01, 0x00, 0x01, 0x00 };
+
+       tls_buffer_append(buf, info, sizeof info);
+}
+
+/*
+ * 00 33 - assigned value for extension "Key Share"
+ * 00 26 - 0x26 (38) bytes of "Key Share" extension data follows
+ * 00 24 - 0x24 (36) bytes of key share data follows
+ * 00 1d - assigned value for x25519 (key exchange via curve25519)
+ * 00 20 - 0x20 (32) bytes of public key follows
+ * 35 80 ... 62 54 - public key from the step "Client Key Exchange Generation" 
+ */
+static void add_key_share_extension(struct tls_buffer *buf, struct TLSContext
+               *ctx) {
+       char kseid[] = { 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00,
+               0x20 };
+       char bogus_key[32] = { 0 };
+
+       if (!ctx) {
+               return;
+       }
+
+       /* TODO figure out where the client key share is */
+       tls_buffer_append(buf, kseid, sizeof kseid);
+       tls_buffer_append(buf, bogus_key, sizeof bogus_key);
+}
+
+/*
+ * 00 2d - assigned value for extension "PSK Key Exchange Modes"
+ * 00 02 - 0x2 (2) bytes of "PSK Key Exchange Modes" extension data follows
+ * 01 - 0x1 (1) bytes of exchange modes follow
+ * 01 - assigned value for "PSK with (EC)DHE key establishment" 
+ *
+ * we don't actually pre-share keys here, so ignored, but we'll send it
+ * anyway
+ */
+/* TODO probably need to get these from the context */
+static void add_pks_key_exchanges_modes_extension(struct tls_buffer *buf) {
+       char psk[] = { 0x00, 0x2d, 0x00, 0x02, 0x01, 0x01 };
+
+       tls_buffer_append(buf, psk, sizeof psk);
+}
+
+static void set_handshake_header(char *buf, int type, size_t length) {
+       buf[0] = type & 0xff;
+       buf[1] = (length >> 16) & 0xff;
+       buf[2] = (length >> 8) & 0xff;
+       buf[3] = (length >> 0) & 0xff;
+}
+
+int tls_client_hello(struct TLSContext *ctx, struct tls_buffer *hello) {
+       size_t hello_offset = hello->len;
+
+       /* make room for the handshake header */
+       tls_buffer_expand(hello, 4);
+       hello->len += 4;
+
+       /* actual client hello structure follows */
+
+       tls_buffer_append(hello, "\x03\x03", 2); /* legacy_version */
+       /* random not set up yet */
+       //tls_random(ctx->local_random, 32);
+       tls_buffer_append(hello, ctx->local_random, 32); /* client random */
+       tls_buffer_append(hello, "\0", 1); /* legacy_session_id */
+       /* alternatively, append a 32 (0x20) and 32 random bytes as a bogus
+        * session id */
+
+       /*
+        * cipher suites
+        * TODO need a way to only use v1.3
+        */
+       int suites = TLS12_FLAG; /* always use v1.2 suites */
+
+       if (ctx->tlsver == TLS_VERSION13) {
+               suites |= TLS13_FLAG;
+       }
+       add_cipher_suites(hello, suites);
+
+               /* legacy_compression_methods */
+       tls_buffer_append(hello, "\1\0", 2);
+       
+       /* 
+        * extensions
+        * TODO I don't think the extension order matters, so the code below
+        * can be simplified by putting all the extensions together by version
+        */
+       size_t extensions_start = hello->len;
+       /* first two bytes are length of extensions, so make room to fill them
+        * in once we know the size
+        */
+       tls_buffer_append(hello, "\0\0", 2);
+
+       /* TODO need to track which extensions we're sending:
+        * "If a client receives an extension type in ServerHello that it did
+        * not request in the associated ClientHello, it MUST abort the
+        * handshake with an unsupported_extension fatal alert."
+        */
+       add_sni_extension(hello, ctx->sni); /* server name indicator */
+
+#if 0
+       /* TODO not sure why 1.3 doesn't need or want this */
+       /* TODO duckduckgo.com seems to fail with this one */
+       if (ctx->tlsver == TLS_VERSION12) {
+               add_status_request_extension(hello);
+       }
+#endif
+
+       add_supported_groups_extension(hello);
+
+       /* v1.2 only, points are fixed in v1.3 */
+       if (ctx->tlsver == TLS_VERSION12) {
+               add_ec_point_formats_extension(hello);
+       }
+
+       add_signature_algorithms_extension(hello);
+
+       if (ctx->tlsver == TLS_VERSION13) {
+               add_key_share_extension(hello, ctx);
+               add_pks_key_exchanges_modes_extension(hello);
+       }
+
+       /* v1.2 only, 1.3 doesn't support renegotiation
+        * and doesn't seem to need cert ts
+        */
+       if (ctx->tlsver == TLS_VERSION12) {
+               add_renegotiation_info_extension(hello);
+               add_signed_certificate_timestamp_extension(hello);
+       }
+
+       if (ctx->tlsver == TLS_VERSION13) {
+               /* supported versions is mandatory in V1.3 */
+               /* 1 v1.2 only, 2 = v1.3 only, 3 = both */
+               /* TODO need a context flag to allow fallback to v1.2 */
+               /* could probably pass this in v1.2 and the server
+                * would ignore it */
+               add_supported_versions(hello, 3);
+       }
+
+       /* set the extensions length */
+       size_t extensions_length = hello->len - extensions_start - 2;
+       tls_buffer_write16(hello, extensions_length, extensions_start);
+
+       /* fill in the handshake header */
+       size_t hello_length = hello->len - hello_offset - 4;
+       set_handshake_header(hello->buffer+hello_offset, client_hello,
+                       hello_length);
+
+       tls_buffer_compact(hello);
+       return hello->error;
+}
+
+#if 0
+void pbytes(unsigned char *b, size_t len, char *label) {
+       size_t i;
+
+       fprintf(stderr, "%s (%zu bytes)\n", label ? label : "dumping", len);
+
+       for (i=0; i<len; i++) {
+               fprintf(stderr, "%s%02x%s",
+                               i % 20 ? " " : "",
+                               b[i],
+                               (i+1) % 20 ? "" : "\n"
+                      );
+       }
+       if (i%20) {
+               fprintf(stderr, "\n");
+       }
+}
+#endif
+
+struct TLSPacket *tls_build_client_hello(struct TLSContext *context) {
+       if (context->connection_status == 4) {
+               unsigned char header[4] = { 0xFE, 0, 0, 0 };
+               unsigned char hash[TLS_MAX_SHA_SIZE];
+               static unsigned char sha256_helloretryrequest[] =
+                   { 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE,
+                       0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2,
+                       0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07,
+                       0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C
+               };
+               fprintf(stderr, "got hello retry request\n");
+               memcpy(context->local_random, sha256_helloretryrequest, 32);
+               int hash_len = tls_done_hash(context, hash);
+               header[3] = (unsigned char) hash_len;
+               tls_update_hash(context, header, sizeof header);
+               tls_update_hash(context, hash, hash_len);
+       } else if (context->tlsver != TLS_VERSION13) {
+               //fprintf(stderr, "creating local_random\n");
+               if (!tls_random(context->local_random, TLS_SERVER_RANDOM_SIZE)) {
+                       return NULL;
+               }
+       }
+
+       struct tls_buffer shadow;
+       char record_header[] = { 0x16, 0x03, 0x03, 0x00, 0x00 };
+
+       tls_buffer_init(&shadow, 106);
+       tls_buffer_append(&shadow, record_header, sizeof record_header);
+       tls_client_hello(context, &shadow);
+       tls_buffer_writebe(&shadow, 3, 6, shadow.len - 9);
+
+       if (shadow.error) {
+               tls_buffer_free(&shadow);
+               return NULL;
+       }
+
+       struct TLSPacket *packet = malloc(sizeof *packet);
+
+       if (!packet) {
+               return NULL;
+       }
+
+       free(packet->buf);
+       packet->buf = shadow.buffer;
+       packet->len = shadow.len;
+       packet->size = shadow.size;
+       packet->payload_pos = 0;
+       packet->broken = 0;
+       packet->context = context;
+
+       tls_packet_update(packet);
+
+       return packet;
+}
+
+int tls_send_client_hello(struct TLSContext *ctx) {
+       return ctx ? 1 : 0;
+}
+
+struct TLSPacket *tls_build_hello(struct TLSContext *context,
+                                 int tls13_downgrade) {
+       if (context->connection_status == 4) {
+               unsigned char header[4] = { 0xFE, 0, 0, 0 };
+               unsigned char hash[TLS_MAX_SHA_SIZE];
+               static unsigned char sha256_helloretryrequest[] =
+                   { 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE,
+                       0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2,
+                       0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07,
+                       0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C
+               };
+               fprintf(stderr, "got hello retry request\n");
+               memcpy(context->local_random, sha256_helloretryrequest, 32);
+               int hash_len = tls_done_hash(context, hash);
+               header[3] = (unsigned char) hash_len;
+               tls_update_hash(context, header, sizeof header);
+               tls_update_hash(context, hash, hash_len);
+       } else if (!context->is_server || context->tlsver != TLS_VERSION13) {
+               fprintf(stderr, "creating local_random\n");
+               if (!tls_random(context->local_random, TLS_SERVER_RANDOM_SIZE)) {
+                       return NULL;
+               }
+       }
+
+       if (context->is_server && tls13_downgrade) {
+               if (tls13_downgrade == TLS_V12 || tls13_downgrade == DTLS_V12)
+               {
+                       memcpy(context->local_random +
+                              TLS_SERVER_RANDOM_SIZE - 8, "DOWNGRD\x01",
+                              8);
+               } else {
+                       memcpy(context->local_random +
+                              TLS_SERVER_RANDOM_SIZE - 8, "DOWNGRD\x00",
+                              8);
+               }
+       }
+
+       if (!context->is_server) {
+               struct tls_buffer shadow;
+               char record_header[] = { 0x16, 0x03, 0x03, 0x00, 0x00 };
+
+               tls_buffer_init(&shadow, 106);
+               tls_buffer_append(&shadow, record_header, sizeof record_header);
+               tls_client_hello(context, &shadow);
+               tls_buffer_writebe(&shadow, 3, 6, shadow.len - 9);
+
+               if (shadow.error) {
+                       tls_buffer_free(&shadow);
+                       return NULL;
+               }
+
+               struct TLSPacket *packet = malloc(sizeof *packet);
+
+               if (!packet) {
+                       return NULL;
+               }
+
+               free(packet->buf);
+               packet->buf = shadow.buffer;
+               packet->len = shadow.len;
+               packet->size = shadow.size;
+               packet->payload_pos = 0;
+               packet->broken = 0;
+               packet->context = context;
+
+               tls_packet_update(packet);
+
+               fprintf(stderr, "returning packet\n");
+               return packet;
+       }
+
+       /* context must be server from here on out */
+
+       unsigned short packet_version = context->version;
+       unsigned short version = context->version;
+
+       if (context->version == TLS_V13) {
+               version = TLS_V12;
+       } else if (context->version == DTLS_V13) {
+               version = DTLS_V12;
+       }
+
+       struct TLSPacket *packet =
+               tls_create_packet(context, TLS_HANDSHAKE, packet_version, 0);
+
+       /* hello */
+       tls_packet_uint8(packet, server_hello);
+
+       tls_packet_uint24(packet, 0);
+
+       int start_len = packet->len;
+       tls_packet_uint16(packet, version);
+
+       tls_packet_append(packet, context->local_random,
+                       TLS_SERVER_RANDOM_SIZE);
+
+       /* session size, always 0, we don't support sessions */
+       tls_packet_uint8(packet, 0);
+
+       int extension_len = 0;
+       int alpn_len = 0;
+       int alpn_negotiated_len = 0;
+       unsigned char shared_key[TLS_MAX_RSA_KEY];
+       unsigned long shared_key_len = TLS_MAX_RSA_KEY;
+       unsigned short shared_key_short = 0;
+       int selected_group = 0;
+       if (context->tlsver == TLS_VERSION13) {
+               if (context->curve == &curve25519) {
+                       extension_len += 8 + 32;
+                       shared_key_short = (unsigned short) 32;
+                       if (context->finished_key) {
+                               memcpy(shared_key,
+                                               context->
+                                               finished_key, 32);
+                               free(context->finished_key);
+                               context->finished_key = NULL;
+                       }
+                       selected_group = context->curve->iana;
+                       /* make context->curve NULL (x25519 is a different implementation) */
+                       context->curve = NULL;
+               } else if (context->ecc_dhe) {
+                       if (ecc_ansi_x963_export
+                                       (context->ecc_dhe, shared_key,
+                                        &shared_key_len)) {
+                               DEBUG_PRINT
+                                       ("Error exporting ECC DHE key\n");
+                               tls_destroy_packet(packet);
+                               tls_alert(context, 1, internal_error);
+                               return NULL;
+                       }
+                       tls_ecc_dhe_free(context);
+                       extension_len += 8 + shared_key_len;
+                       shared_key_short =
+                               (uint16_t)shared_key_len;
+                       if (context->curve) {
+                               selected_group =
+                                       context->curve->iana;
+                       }
+               } else if (context->dhe) {
+                       selected_group = context->dhe->iana;
+                       tls_dh_export_Y(shared_key,
+                                       &shared_key_len,
+                                       context->dhe);
+                       tls_dhe_free(context);
+                       extension_len += 8 + shared_key_len;
+                       shared_key_short = shared_key_len;
+               }
+
+               extension_len += 6;
+       }
+
+       if (context->negotiated_alpn && context->tlsver != TLS_VERSION13) {
+               alpn_negotiated_len = strlen(context->negotiated_alpn);
+               alpn_len = alpn_negotiated_len + 1;
+               extension_len += alpn_len + 6;
+       }
+
+       /* ciphers */
+       /* fallback ... this should never happen */
+       if (!context->cipher) {
+               context->cipher = TLS_DHE_RSA_WITH_AES_128_CBC_SHA;
+       }
+
+       tls_packet_uint16(packet, context->cipher);
+       /* no compression */
+       tls_packet_uint8(packet, 0);
+
+       if (context->tlsver == TLS_VERSION13) {
+               /* supported versions */
+               tls_packet_uint16(packet, 0x2B);
+
+               tls_packet_uint16(packet, 2);
+               if (context->version == TLS_V13) {
+                       tls_packet_uint16(packet,
+                                       context-> tls13_version ?
+                                       context-> tls13_version :
+                                       TLS_V13);
+               } else {
+                       tls_packet_uint16(packet, context->version);
+               }
+
+               if (context->connection_status == 4) {
+                       /* fallback to the mandatory secp256r1 */
+                       tls_packet_uint16(packet, 0x33);
+                       tls_packet_uint16(packet, 2);
+                       tls_packet_uint16(packet, (uint16_t) secp256r1.iana);
+               }
+
+               if (shared_key_short && selected_group) {
+                       /* key share */
+                       tls_packet_uint16(packet, 0x33);
+                       tls_packet_uint16(packet, shared_key_short + 4);
+                       tls_packet_uint16(packet, selected_group);
+                       tls_packet_uint16(packet, shared_key_short);
+                       tls_packet_append(packet, (unsigned char *) shared_key,
+                                       shared_key_short);
+               }
+       }
+
+       if (!packet->broken && packet->buf) {
+               tls_set_packet_length(packet, packet->len - start_len);
+       }
+
+       tls_packet_update(packet);
+
+       return packet;
+}
+
+struct TLSPacket *tls_buffer_packet(struct tls_buffer *b, struct TLSContext *c) {
+       struct TLSPacket *p = 0;
+
+       if (b && c) {
+               p = tls_create_packet(c, TLS_HANDSHAKE, c->version, 0);
+
+               if (p) {
+                       free(p->buf);
+                       p->buf = b->buffer;
+                       p->size = b->size;
+                       p->len = b->len;
+                       p->payload_pos = 0;
+                       p->broken = 0;
+                       p->context = c;
+               } else {
+                       tls_buffer_free(b);
+               }
+       }
+
+       return p;
+}
+
+static void append_dhe(struct TLSContext *ctx, struct tls_buffer *buf) {
+       unsigned char dh_Ys[0xFFF];
+       unsigned char dh_p[0xFFF];
+       unsigned char dh_g[0xFFF];
+       unsigned long dh_p_len = sizeof dh_p;
+       unsigned long dh_g_len = sizeof dh_g;
+       unsigned long dh_Ys_len = sizeof dh_Ys;
+
+       if (tls_dh_export_pqY(dh_p, &dh_p_len, dh_g, &dh_g_len, dh_Ys,
+                               &dh_Ys_len, ctx->dhe)) {
+               DEBUG_PRINT("ERROR EXPORTING DHE KEY %p\n", ctx->dhe);
+               buf->error = 1;
+               tls_dhe_free(ctx);
+               return;
+       }
+
+       tls_dhe_free(ctx);
+
+       DEBUG_DUMP_HEX_LABEL("Yc", dh_Ys, dh_Ys_len);
+
+       tls_buffer_append24(buf, dh_Ys_len + 2);
+
+       tls_buffer_append16(buf, dh_Ys_len);
+       tls_buffer_append(buf, dh_Ys, dh_Ys_len);
+}
+
+static void append_ecdhe(struct TLSContext *ctx, struct tls_buffer *buf) {
+       unsigned char out[TLS_MAX_RSA_KEY];
+       unsigned long out_len = TLS_MAX_RSA_KEY;
+
+       //fprintf(stderr, "ecc dhe\n");
+
+       if (ecc_ansi_x963_export(ctx->ecc_dhe, out, &out_len)) {
+               DEBUG_PRINT("Error exporting ECC key\n");
+               buf->error = 1;
+       }
+
+       tls_ecc_dhe_free(ctx);
+
+       tls_buffer_append_byte(buf, 0x10);
+       tls_buffer_append24(buf, out_len + 1);
+
+       tls_buffer_append_byte(buf, out_len);
+       tls_buffer_append(buf, out, out_len);
+}
+
+static void set_record_size(struct tls_buffer *b) {
+       uint16_t size;
+
+       size = b->len - 5;
+       tls_buffer_write16(b, size, 3);
+}
+
+struct TLSPacket *tls_client_key_exchange(struct TLSContext *context) {
+       struct tls_buffer cke;
+       struct TLSPacket *p;
+
+       tls_buffer_init(&cke, 42);
+       tls_buffer_append_byte(&cke, 0x16);
+       tls_buffer_append16(&cke, 0x0303);
+       tls_buffer_append16(&cke, 0); /* record size placeholder */
+
+       if (context->ecc_dhe) {
+               append_ecdhe(context, &cke);
+       } else {
+               append_dhe(context, &cke);
+       }
+       set_record_size(&cke);
+
+       p = tls_buffer_packet(&cke, context);
+
+       tls_compute_key(context, 48);
+       context->connection_status = 2;
+       tls_packet_update(p);
+
+       return p;
+}
+
+static int tls_build_random(struct TLSPacket *packet) {
+       int res = 0;
+       unsigned char rand_bytes[48];
+       int bytes = 48;
+
+       if (!tls_random(rand_bytes, bytes)) {
+               return TLS_GENERIC_ERROR;
+       }
+
+       /* max supported version */
+       if (packet->context->is_server) {
+               *(unsigned short *) rand_bytes =
+                   htons(packet->context->version);
+       } else {
+               *(unsigned short *) rand_bytes = htons(TLS_V12);
+       }
+
+       /* DEBUG_DUMP_HEX_LABEL("PREMASTER KEY", rand_bytes, bytes); */
+
+       free(packet->context->premaster_key);
+
+       packet->context->premaster_key = malloc(bytes);
+       if (!packet->context->premaster_key) {
+               return TLS_NO_MEMORY;
+       }
+
+       packet->context->premaster_key_len = bytes;
+       memcpy(packet->context->premaster_key, rand_bytes,
+              packet->context->premaster_key_len);
+
+       unsigned int out_len;
+
+       unsigned char *random = encrypt_rsa(packet->context,
+                       packet->context->premaster_key,
+                       packet->context->premaster_key_len, &out_len);
+
+       tls_compute_key(packet->context, bytes);
+       if (random && out_len > 2) {
+               tls_packet_uint24(packet, out_len + 2);
+               tls_packet_uint16(packet, out_len);
+               tls_packet_append(packet, random, out_len);
+       } else {
+               res = TLS_GENERIC_ERROR;
+       }
+
+       free(random);
+
+       if (res) {
+               return res;
+       }
+
+       return out_len + 2;
+}
+
+void tls_send_client_key_exchange(struct TLSContext *context) {
+       struct TLSPacket *packet;
+
+       int ephemeral = tls_cipher_is_ephemeral(context);
+
+       if (ephemeral && context->premaster_key && context->premaster_key_len) {
+               //fprintf(stderr, "YYYY\n");
+               packet = tls_client_key_exchange(context);
+               tls_queue_packet(packet);
+               return;
+               if (ephemeral == 1) {
+                       /* dhe */
+               } else if (context->ecc_dhe) {
+                       /* ecc dhe */
+               }
+       } else {
+               /* TODO should never happen, should always require
+                * either DHE or ECC DHE */
+               fprintf(stderr, "ZZZZ build random\n");
+               return;
+               packet = tls_create_packet(context, TLS_HANDSHAKE, context->version, 0);
+               tls_packet_uint8(packet, 0x10);
+               tls_build_random(packet);
+       }
+       context->connection_status = 2;
+       tls_packet_update(packet);
+       tls_queue_packet(packet);
+       return;
+}
+
+static uint32_t get24(const unsigned char *buf) {
+       return (*buf << 16) + (*(buf+1) << 8) + *(buf+2);
+}
+
+static uint16_t get16(const unsigned char *buf) {
+       return (*(buf) << 8) + *(buf+1);
+}
+
+int tls_hello_complete(const unsigned char *buf, size_t len) {
+       size_t more;
+
+       if (len < 3) {
+               return 0;
+       }
+
+       more = get16(buf);
+       if (more > len - 3) {
+               fprintf(stderr, "%s:%d\n", __func__, __LINE__);
+               fprintf(stderr, "have %zu, want %zu\n", len, more);
+               return 0;
+       }
+       return 1;
+}
+
+int tls_parse_server_hello(struct TLSContext *ctx, const unsigned char *buf, size_t len) {
+       size_t i = 0;
+       size_t more = 0;
+
+       if (ctx->connection_status != 0 && ctx->connection_status != 4) {
+               return TLS_UNEXPECTED_MESSAGE;
+       }
+
+       if (!tls_hello_complete(buf, len)) {
+               return TLS_NEED_MORE_DATA;
+       }
+
+       /* 3 bytes server hello data size */
+       more = get24(buf+i);
+       i+=3;
+       /* TODO check size reported vs actual */
+
+       /* two bytes server version */
+       uint16_t server_ver = get16(buf+i);
+       i+=2;
+       if (server_ver != ctx->version) {
+               /* TODO allow (or not) downgrade to v1.2 */
+               return TLS_UNEXPECTED_MESSAGE;
+       }
+
+       /* 32 bytes server random */
+       memcpy(ctx->remote_random, buf+i, 32);
+       i+=32;
+
+       /* 1 byte of session id length */
+       uint8_t session_len = *(buf+i);
+       i+=1;
+
+       char *session_id;
+       /* possible session id bytes */
+       /* TODO skip? we don't actually use session ids */
+       if (session_len) {
+               session_id = malloc(session_len);
+               if (!session_id) {
+                       return 0;
+               }
+
+               memcpy(session_id, buf+i, session_len);
+       }
+       i+=session_len;
+
+       /* two bytes cipher suite selected */
+       ctx->cipher = get16(buf+i);
+       i+=2;
+       if (!tls_cipher_supported(ctx, ctx->cipher)) {
+               ctx->cipher = 0;
+               DEBUG_PRINT("NO CIPHER SUPPORTED\n");
+               return TLS_NO_COMMON_CIPHER;
+       }
+
+       /* one byte compression method */
+       uint8_t compression_method = *(buf+i);
+       i++;
+       if (compression_method != 0) {
+               return 0;
+       }
+
+       if (i > 0 && ctx->connection_status != 4) {
+               ctx->connection_status = 1;
+       }
+
+       if (i+2 > len) {
+               /* no extensions */
+               return 0;
+       }
+
+       /* two bytes extensions length */
+       uint16_t extensions_length = get16(buf+i);
+       i+=2;
+
+       if (extensions_length + i != len) {
+               /* mismatch */
+               return 0;
+       }
+
+       /* extensions */
+       while (i+5 <= len) {
+               int etype = get16(buf+i);
+               i+=2;
+
+               /* TODO check that extension type is in the list of
+                * extensions we sent, if not, abort with
+                * unsupported_extension fatal alert
+                */
+               uint16_t elen = get16(buf+i);
+               i+=2;
+               if (elen == 0) {
+                       continue;
+               }
+               if (i+elen > len) {
+                       return TLS_BROKEN_PACKET;
+               }
+
+               uint16_t sni_len;
+               uint16_t alpn_len;
+               const unsigned char *alpn;
+               unsigned char alpn_size;
+               int alpn_pos = 0;
+               uint16_t group_len, iana_n;
+               int j = i;
+               int selected = 0;
+
+               switch (etype) {
+                       case 0x0000:
+                               sni_len = get16(buf+i);
+                               if (sni_len) {
+                                       ctx->sni = malloc(sni_len + 1);
+                                       memcpy(ctx->sni, buf+i+2, sni_len);
+                                       ctx->sni[sni_len] = 0;
+                               }
+                               break;
+                       case 0x000a: /* supported groups */
+                               fprintf(stderr, "supported groups\n");
+                               /* supported groups */
+                               if (i+2 > len) {
+                                       return TLS_BROKEN_PACKET;
+                               }
+
+                               group_len = get16(buf+i);
+                               for (j = i; i < i + group_len+2; i+=2) {
+                                       iana_n = get16(buf+j);
+                                       switch (iana_n) {
+                                               case 23:
+                                                       ctx->curve = &secp256r1;
+                                                       selected = 1;
+                                                       break;
+                                               case 24:
+                                                       ctx->curve = &secp384r1;
+                                                       selected = 1;
+                                                       break;
+                                               case 29:
+                                                       ctx->curve = &curve25519;
+                                                       selected = 1;
+                                                       break;
+                                               case 25:
+                                                       ctx->curve = &secp521r1;
+                                                       selected = 1;
+                                                       break;
+                                       }
+                               }
+                               /* if ctx->curve */
+                               if (selected) {
+                                       fprintf(stderr, "SELECTED CURVE %s\n",
+                                                ctx->curve->name);
+                               }
+                       case 0x0010:
+                               if (!ctx->alpn || ctx->alpn_count == 0) {
+                                       break;
+                               }
+                               if (i+2 > len) {
+                                       return TLS_BROKEN_PACKET;
+                               }
+
+                               alpn_len = get16(buf+i);
+
+                               if (alpn_len == 0 || alpn_len > elen - 2) {
+                                       /* TODO broken */
+                                       break;
+                               }
+                               /* a server's alpn list "must contain exactly
+                                * one "ProtocolName"
+                                */
+                               alpn_size = buf[i + 2];
+                               alpn = buf + i + 3;
+                               if (i + alpn_size + 3 < len) {
+                                       break;
+                               }
+
+                               if (!tls_alpn_contains(ctx, (char *)alpn, alpn_size)) {
+                                       break;
+                               }
+                               free(ctx->negotiated_alpn);
+                               ctx->negotiated_alpn = malloc(alpn_size + 1);
+                               if (ctx->negotiated_alpn) {
+                                       memcpy(ctx->negotiated_alpn,
+                                                       &alpn[alpn_pos],
+                                                       alpn_size);
+                                       ctx->negotiated_alpn[alpn_size] = 0;
+                               }
+                               break;
+                       case 0xff01: /* renegotiation info */
+                               //fprintf(stderr, "renegotiation info\n");
+                               /* ignore, we don't support renegotiation */
+                               break;
+                       case 0x0033: /* key share */
+                               /* TODO parse key share */
+                               fprintf(stderr, "key share info\n");
+                               break;
+                       case 0x000b:
+                               /* signature algorithms */
+                               break;
+                       case 0x002b: /* supported versions */
+                               /* should be two bytes of 0x00 0x02
+                                * indicating two bytes of server version
+                                * then 0x03 0x04 for v1.3
+                                */
+                               fprintf(stderr, "supported versions\n");
+                               break;
+                       default:
+                               fprintf(stderr, "unknown extension %04x\n", etype);
+                               break;
+               }
+               i+=elen;
+       }
+
+#if 0
+       if (ctx->connection_status != 4) {
+               ctx->connection_status = 1;
+       }
+#endif
+
+       return 1;
+}