#define _POSIX_C_SOURCE 200809L #include #include #include #include #include #include #include #include #include #include #include #include #include #include "tomcrypt.h" #define mp_init(a) ltc_mp.init(a) #define mp_init_multi ltc_init_multi #define mp_clear(a) ltc_mp.deinit(a) #define mp_clear_multi ltc_deinit_multi #define mp_count_bits(a) ltc_mp.count_bits(a) #define mp_read_radix(a, b, c) ltc_mp.read_radix(a, b, c) #define mp_unsigned_bin_size(a) ltc_mp.unsigned_size(a) #define mp_to_unsigned_bin(a, b) ltc_mp.unsigned_write(a, b) #define mp_read_unsigned_bin(a, b, c) ltc_mp.unsigned_read(a, b, c) #define mp_exptmod(a, b, c, d) ltc_mp.exptmod(a, b, c, d) #define mp_add(a, b, c) ltc_mp.add(a, b, c) #define mp_mul(a, b, c) ltc_mp.mul(a, b, c) #define mp_cmp(a, b) ltc_mp.compare(a, b) #define mp_cmp_d(a, b) ltc_mp.compare_d(a, b) #define mp_sqr(a, b) ltc_mp.sqr(a, b) #define mp_mod(a, b, c) ltc_mp.mpdiv(a, b, NULL, c) #define mp_sub(a, b, c) ltc_mp.sub(a, b, c) #define mp_set(a, b) ltc_mp.set_int(a, b) #if (CRYPT <= 0x0117) #define LTC_PKCS_1_EMSA LTC_LTC_PKCS_1_EMSA #define LTC_PKCS_1_V1_5 LTC_LTC_PKCS_1_V1_5 #define LTC_PKCS_1_PSS LTC_LTC_PKCS_1_PSS #endif #include "tlse.h" #include "chacha.h" #include "buffer.h" #define TLS_DH_DEFAULT_P "87A8E61DB4B6663CFFBBD19C651959998CEEF608660DD0F25D2CEED4435E3B00E00DF8F1D61957D4FAF7DF4561B2AA3016C3D91134096FAA3BF4296D830E9A7C209E0C6497517ABD5A8A9D306BCF67ED91F9E6725B4758C022E0B1EF4275BF7B6C5BFC11D45F9088B941F54EB1E59BB8BC39A0BF12307F5C4FDB70C581B23F76B63ACAE1CAA6B7902D52526735488A0EF13C6D9A51BFA4AB3AD8347796524D8EF6A167B5A41825D967E144E5140564251CCACB83E6B486F6B3CA3F7971506026C0B857F689962856DED4010ABD0BE621C3A3960A54E710C375F26375D7014103A4B54330C198AF126116D2276E11715F693877FAD7EF09CADB094AE91E1A1597" #define TLS_DH_DEFAULT_G "3FB32C9B73134D0B2E77506660EDBD484CA7B18F21EF205407F4793A1A0BA12510DBC15077BE463FFF4FED4AAC0BB555BE3A6C1B0C6B47B1BC3773BF7E8C6F62901228F8C28CBB18A55AE31341000A650196F931C77A57F2DDF463E5E9EC144B777DE62AAAB8A8628AC376D282D6ED3864E67982428EBC831D14348F6F2F9193B5045AF2767164E1DFC967C1FB3F2E55A4BD1BFFE83B9C80D052B985D182EA0ADB2A3B7313D3FE14C8484B1E052588B9B7D2BBD2DF016199ECD06E1557CD0915B3353BBB64E0EC377FD028370DF92B52C7891428CDC67EB6184B523D1DB246C32F63078490F00EF8D647D148D47954515E2327CFEF98C582664B4C0F6CC41659" #define TLS_DHE_KEY_SIZE 2048 #ifndef htonll #define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32)) #endif #ifndef ntohll #define ntohll(x) ((1==ntohl(1)) ? (x) : ((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32)) #endif #define CHECK_HANDSHAKE_STATE(context, n, limit) { if (context->hs_messages[n] >= limit) { DEBUG_PRINT("* UNEXPECTED MESSAGE (%i)\n", (int)n); payload_res = TLS_UNEXPECTED_MESSAGE; break; } context->hs_messages[n]++; } #ifdef DEBUG int tls_indent = 0; int tls_indent_i = 0; #endif //#define MARK typedef enum { KEA_dhe_dss, KEA_dhe_rsa, KEA_dh_anon, KEA_rsa, KEA_dh_dss, KEA_dh_rsa, KEA_ec_diffie_hellman } KeyExchangeAlgorithm; typedef enum { rsa_sign = 1, dss_sign = 2, rsa_fixed_dh = 3, dss_fixed_dh = 4, rsa_ephemeral_dh_RESERVED = 5, dss_ephemeral_dh_RESERVED = 6, fortezza_dms_RESERVED = 20, ecdsa_sign = 64, rsa_fixed_ecdh = 65, ecdsa_fixed_ecdh = 66 } TLSClientCertificateType; typedef enum { none = 0, md5 = 1, sha1 = 2, sha224 = 3, sha256 = 4, sha384 = 5, sha512 = 6, _md5_sha1 = 255 } TLSHashAlgorithm; typedef enum { anonymous = 0, rsa = 1, dsa = 2, ecdsa = 3 } TLSSignatureAlgorithm; struct OID_chain { void *top; unsigned char *oid; }; typedef ssize_t (*tls_recv_func)(int sockfd, void *buf, size_t len, int flags); typedef ssize_t (*tls_send_func)(int sockfd, const void *buf, size_t len, int flags); static const unsigned int version_id[] = { 1, 1, 1, 0 }; static const unsigned int pk_id[] = { 1, 1, 7, 0 }; static const unsigned int serial_id[] = { 1, 1, 2, 1, 0 }; static const unsigned int issurer_id[] = { 1, 1, 4, 0 }; static const unsigned int owner_id[] = { 1, 1, 6, 0 }; static const unsigned int validity_id[] = { 1, 1, 5, 0 }; static const unsigned int algorithm_id[] = { 1, 1, 3, 0 }; static const unsigned int sign_id[] = { 1, 3, 2, 1, 0 }; static const unsigned int sign_id2[] = { 1, 3, 2, 2, 0 }; static const unsigned int priv_id[] = { 1, 4, 0 }; static const unsigned int priv_der_id[] = { 1, 3, 1, 0 }; static const unsigned int ecc_priv_id[] = { 1, 2, 0 }; static const unsigned char country_oid[] = { 0x55, 0x04, 0x06, 0x00 }; static const unsigned char state_oid[] = { 0x55, 0x04, 0x08, 0x00 }; static const unsigned char location_oid[] = { 0x55, 0x04, 0x07, 0x00 }; static const unsigned char entity_oid[] = { 0x55, 0x04, 0x0A, 0x00 }; static const unsigned char subject_oid[] = { 0x55, 0x04, 0x03, 0x00 }; static const unsigned char san_oid[] = { 0x55, 0x1D, 0x11, 0x00 }; static const unsigned char ocsp_oid[] = { 0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x00 }; static const unsigned char TLS_RSA_SIGN_RSA_OID[] = { 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01, 0x00 }; static const unsigned char TLS_RSA_SIGN_MD5_OID[] = { 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x04, 0x00 }; static const unsigned char TLS_RSA_SIGN_SHA1_OID[] = { 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05, 0x00 }; static const unsigned char TLS_RSA_SIGN_SHA256_OID[] = { 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B, 0x00 }; static const unsigned char TLS_RSA_SIGN_SHA384_OID[] = { 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C, 0x00 }; static const unsigned char TLS_RSA_SIGN_SHA512_OID[] = { 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D, 0x00 }; #if 0 static const unsigned char TLS_ECDSA_SIGN_SHA1_OID[] = {0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x01, 0x05, 0x00, 0x00}; static const unsigned char TLS_ECDSA_SIGN_SHA224_OID[] = {0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01, 0x05, 0x00, 0x00}; static const unsigned char TLS_ECDSA_SIGN_SHA256_OID[] = {0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02, 0x05, 0x00, 0x00}; static const unsigned char TLS_ECDSA_SIGN_SHA384_OID[] = {0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03, 0x05, 0x00, 0x00}; static const unsigned char TLS_ECDSA_SIGN_SHA512_OID[] = {0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04, 0x05, 0x00, 0x00}; #endif static const unsigned char TLS_EC_PUBLIC_KEY_OID[] = { 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01, 0x00 }; static const unsigned char TLS_EC_prime192v1_OID[] = { 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x01, 0x00 }; static const unsigned char TLS_EC_prime192v2_OID[] = { 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x02, 0x00 }; static const unsigned char TLS_EC_prime192v3_OID[] = { 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x03, 0x00 }; static const unsigned char TLS_EC_prime239v1_OID[] = { 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x04, 0x00 }; static const unsigned char TLS_EC_prime239v2_OID[] = { 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x05, 0x00 }; static const unsigned char TLS_EC_prime239v3_OID[] = { 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x06, 0x00 }; static const unsigned char TLS_EC_prime256v1_OID[] = { 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07, 0x00 }; #define TLS_EC_secp256r1_OID TLS_EC_prime256v1_OID static const unsigned char TLS_EC_secp224r1_OID[] = { 0x2B, 0x81, 0x04, 0x00, 0x21, 0x00 }; static const unsigned char TLS_EC_secp384r1_OID[] = { 0x2B, 0x81, 0x04, 0x00, 0x22, 0x00 }; static const unsigned char TLS_EC_secp521r1_OID[] = { 0x2B, 0x81, 0x04, 0x00, 0x23, 0x00 }; int tls_random(unsigned char *key, int len); void tls_destroy_packet(struct TLSPacket *packet); struct TLSPacket *tls_build_hello(struct TLSContext *context, int tls13_downgrade); /* not supported */ #if 0 static unsigned char TLS_DSA_SIGN_SHA1_OID[] = {0x2A, 0x86, 0x52, 0xCE, 0x38, 0x04, 0x03, 0x00}; #endif static uint16_t get16(const unsigned char *buf) { uint16_t res; res = ((*buf) << 8) + (*(buf+1)); return res; } static uint32_t get24(const unsigned char *buf) { uint32_t res; res = (buf[0] << 16) + (buf[1] << 8) + buf[2]; return res; } #ifdef DEBUG static char *packet_content_type(int type) { switch (type) { case 20: return "change_cipher_spec"; break; case 21: return "alert"; break; case 22: return "handshake"; break; case 23: return "application_data"; break; default: break; } return "unknown content type"; } static char *packet_handshake_type(int type) { switch (type) { case 0: return "hello_request"; break; case 1: return "client_hello"; break; case 2: return "server_hello"; break; case 11: return "certificate"; break; case 12: return "server_key_exchange"; break; case 13: return "certificate_request"; break; case 14: return "server_hello_done"; break; case 15: return "certificate_verify"; break; case 16: return "client_key_exchange"; break; case 20: return "finished"; break; default: break; } return "unknown handshake type"; } #endif size_t tls_queue_packet(struct TLSPacket *packet) { ENTER; if (!packet) { LEAVE; return -1; } struct TLSContext *context = packet->context; if (!context) { LEAVE; return -1; } DEBUG_PRINTLN("sending packet type %d %s\n", (int)packet->buf[0], packet_content_type(packet->buf[0])); if (packet->buf[0] == 22) { DEBUG_PRINTLN("handshake type %d %s\n", (int)packet->buf[5], packet_handshake_type(packet->buf[5]) ); } tls_buffer_append(&context->output_buffer, packet->buf, packet->len); tls_destroy_packet(packet); LEAVE; return context->output_buffer.len; } static void tls_send_change_cipher_spec(struct TLSContext *context) { ENTER; struct TLSPacket *packet = tls_create_packet(context, TLS_CHANGE_CIPHER, context->version, 64); tls_packet_uint8(packet, 1); tls_packet_update(packet); context->local_sequence_number = 0; tls_queue_packet(packet); LEAVE; return; } static void tls_send_encrypted_extensions(struct TLSContext *context) { struct TLSPacket *packet = tls_create_packet(context, TLS_HANDSHAKE, context->version, 3); tls_packet_uint8(packet, 0x08); if (context->negotiated_alpn) { int alpn_negotiated_len = strlen(context->negotiated_alpn); int alpn_len = alpn_negotiated_len + 1; tls_packet_uint24(packet, alpn_len + 8); tls_packet_uint16(packet, alpn_len + 6); tls_packet_uint16(packet, 0x10); tls_packet_uint16(packet, alpn_len + 2); tls_packet_uint16(packet, alpn_len); tls_packet_uint8(packet, alpn_negotiated_len); tls_packet_append(packet, (unsigned char *) context-> negotiated_alpn, alpn_negotiated_len); } else { tls_packet_uint24(packet, 2); tls_packet_uint16(packet, 0); } tls_packet_update(packet); tls_queue_packet(packet); return; } static void tls_send_done(struct TLSContext *context) { struct TLSPacket *packet = tls_create_packet(context, TLS_HANDSHAKE, context->version, 0); tls_packet_uint8(packet, 0x0E); tls_packet_uint24(packet, 0); tls_packet_update(packet); tls_queue_packet(packet); return; } static void tls_send_certificate(struct TLSContext *context) { int i; unsigned int all_certificate_size = 0; int certificates_count; struct TLSCertificate **certificates; ENTER; if (context->is_server) { certificates_count = context->certificates_count; certificates = context->certificates; } else { certificates_count = context->client_certificates_count; certificates = context->client_certificates; } int delta = 3; if (context->tlsver == TLS_VERSION13) { delta = 5; } int is_ecdsa = tls_is_ecdsa(context); /* TODO can do one loop and test for ecdsa inside loop */ if (is_ecdsa) { for (i = 0; i < certificates_count; i++) { struct TLSCertificate *cert = certificates[i]; if (cert && cert->der_len && cert->ec_algorithm) { all_certificate_size += cert->der_len + delta; } } } else { for (i = 0; i < certificates_count; i++) { struct TLSCertificate *cert = certificates[i]; if (cert && cert->der_len && !cert->ec_algorithm) { all_certificate_size += cert->der_len + delta; } } } for (i = 0; i < certificates_count; i++) { struct TLSCertificate *cert = certificates[i]; if (cert && cert->der_len) { all_certificate_size += cert->der_len + delta; } } if (!all_certificate_size) { DEBUG_PRINT("NO CERTIFICATE SET\n"); } struct TLSPacket *packet = tls_create_packet(context, TLS_HANDSHAKE, context->version, 0); tls_packet_uint8(packet, 0x0B); if (all_certificate_size) { /* context */ if (context->tlsver == TLS_VERSION13) { tls_packet_uint24(packet, all_certificate_size + 4); tls_packet_uint8(packet, 0); } else { tls_packet_uint24(packet, all_certificate_size + 3); } tls_packet_uint24(packet, all_certificate_size); for (i = 0; i < certificates_count; i++) { struct TLSCertificate *cert = certificates[i]; if (cert && cert->der_len) { /* is RSA certificate ? */ if (is_ecdsa && !cert->ec_algorithm) { continue; } /* is ECC certificate ? */ if (!is_ecdsa && cert->ec_algorithm) { continue; } /* 2 times -> one certificate */ tls_packet_uint24(packet, cert->der_len); tls_packet_append(packet, cert->der_bytes, cert->der_len); /* extension */ if (context->tlsver == TLS_VERSION13) { tls_packet_uint16(packet, 0); } } } } else { tls_packet_uint24(packet, all_certificate_size); if (context->tlsver == TLS_VERSION13) { tls_packet_uint8(packet, 0); } } tls_packet_update(packet); tls_queue_packet(packet); LEAVE; return; } int tls_supported_version(uint16_t ver) { switch (ver) { case TLS_V12: case TLS_V13: break; default: DEBUG_PRINT("UNSUPPORTED TLS VERSION %x\n", (int)ver); return 0; } return 1; } void tls_set_packet_length(struct TLSPacket *packet, uint32_t length) { int offset = packet->payload_pos; packet->buf[offset] = (length >> 16) & 0xff; packet->buf[offset+1] = (length >> 8) & 0xff; packet->buf[offset+2] = (length >> 0) & 0xff; } static void tls_init() { static int loaded = 0; if (loaded) { return; } DEBUG_PRINT("Initializing dependencies\n"); loaded = 1; #ifdef LTM_DESC ltc_mp = ltm_desc; #else #ifdef TFM_DESC ltc_mp = tfm_desc; #endif #endif /* TODO remove these */ #if 0 register_hash(&md5_desc); register_hash(&sha1_desc); #endif register_hash(&sha256_desc); register_hash(&sha384_desc); register_hash(&sha512_desc); register_prng(&sprng_desc); register_cipher(&aes_desc); tls_ecc_init_curves(); } static unsigned char *decrypt_rsa(struct TLSContext *context, const unsigned char *buffer, unsigned int len, unsigned int *size) { *size = 0; if (!len || !context || !context->private_key || !context->private_key->der_bytes || !context->private_key->der_len) { DEBUG_PRINT("No private key set\n"); return NULL; } rsa_key key; int err; err = rsa_import(context->private_key->der_bytes, context->private_key->der_len, &key); if (err) { DEBUG_PRINT("Error importing RSA key (code: %i)\n", err); return NULL; } unsigned char *out = malloc(len); unsigned long out_size = len; int hash_idx = find_hash("sha256"); int res = 0; err = rsa_decrypt_key_ex(buffer, len, out, &out_size, (unsigned char *) "Concept", 7, hash_idx, LTC_PKCS_1_V1_5, &res, &key); rsa_free(&key); if (err || !out_size) { DEBUG_PRINT("RSA DECRYPT ERROR\n"); free(out); return NULL; } *size = (unsigned int) out_size; return out; } static int verify_rsa(struct TLSContext *context, unsigned int hash_type, const unsigned char *buffer, unsigned int len, const unsigned char *message, unsigned long message_len) { rsa_key key; int err; if (len == 0) { return TLS_GENERIC_ERROR; } struct TLSCertificate **cert; int count; if (context->is_server) { cert = context->client_certificates; count = context->client_certificates_count; } else { cert = context->certificates; count = context->certificates_count; } if (count == 0 || !cert) { return TLS_GENERIC_ERROR; } err = rsa_import(cert[0]->der_bytes, cert[0]->der_len, &key); if (err) { DEBUG_PRINT("Error importing RSA certificate (code: %i)\n", err); return TLS_GENERIC_ERROR; } int hash_idx = -1; unsigned char hash[TLS_MAX_HASH_LEN]; unsigned long hash_len; hash_len = (unsigned long)sizeof hash; switch (hash_type) { case md5: hash_idx = find_hash("md5"); break; case sha1: hash_idx = find_hash("sha1"); break; case sha256: hash_idx = find_hash("sha256"); break; case sha384: hash_idx = find_hash("sha384"); break; case sha512: hash_idx = find_hash("sha512"); break; } err = hash_memory(hash_idx, message, message_len, hash, &hash_len); if (hash_idx < 0 || err) { DEBUG_PRINT("Unsupported hash type: %i\n", hash_type); return TLS_GENERIC_ERROR; } int rsa_stat = 0; if (context->tlsver == TLS_VERSION13) { err = rsa_verify_hash_ex(buffer, len, hash, hash_len, LTC_PKCS_1_PSS, hash_idx, 0, &rsa_stat, &key); } else { err = rsa_verify_hash_ex(buffer, len, hash, hash_len, LTC_PKCS_1_V1_5, hash_idx, 0, &rsa_stat, &key); } rsa_free(&key); if (err) { return 0; } return rsa_stat; } static int sign_rsa(struct TLSContext *context, unsigned int hash_type, const unsigned char *message, unsigned int message_len, unsigned char *out, unsigned long *outlen) { rsa_key key; int err; int hash_index = -1; unsigned char hash[TLS_MAX_HASH_LEN]; unsigned long hash_len = 0; //hash_state state; if (!outlen || !context || !out || !context->private_key || !context->private_key->der_bytes || !context->private_key->der_len) { DEBUG_PRINT("No private key set\n"); return TLS_GENERIC_ERROR; } err = rsa_import(context->private_key->der_bytes, context->private_key->der_len, &key); if (err) { DEBUG_PRINT("Error %d importing RSA certificate", err); return TLS_GENERIC_ERROR; } switch (hash_type) { case sha1: hash_index = find_hash("sha1"); hash_len = 20; break; case sha256: hash_index = find_hash("sha256"); hash_len = 32; break; case sha384: hash_index = find_hash("sha384"); hash_len = 48; break; case sha512: hash_index = find_hash("sha512"); hash_len = 64; break; case md5: case _md5_sha1: hash_index = find_hash("md5"); hash_len = 16; break; } if (hash_index < 0 || err) { DEBUG_PRINT("Unsupported hash type: %i\n", hash_type); return TLS_GENERIC_ERROR; } hash_memory(hash_index, message, message_len, hash, &hash_len); if (hash_type == _md5_sha1) { unsigned long hlen = 20; hash_index = find_hash("sha1"); hash_memory(hash_index, message, message_len, hash+16, &hlen); hash_len += hlen; } //err = hash_memory(hash_idx, message, message_len, hash, &hash_len); if (context->tlsver == TLS_VERSION13) { err = rsa_sign_hash_ex(hash, hash_len, out, outlen, LTC_PKCS_1_PSS, NULL, find_prng("sprng"), hash_index, hash_type == sha256 ? 32 : 48, &key); } else { err = rsa_sign_hash_ex(hash, hash_len, out, outlen, LTC_PKCS_1_V1_5, NULL, find_prng("sprng"), hash_index, 0, &key); } rsa_free(&key); if (err) { return 0; } return 1; } static int tls_is_point(ecc_key * key) { void *prime, *b, *t1, *t2; int err; if ((err = mp_init_multi(&prime, &b, &t1, &t2, NULL)) != CRYPT_OK) { return err; } /* load prime and b */ if ((err = mp_read_radix(prime, key->dp->prime, 16)) != CRYPT_OK) { goto error; } if ((err = mp_read_radix(b, key->dp->B, 16)) != CRYPT_OK) { goto error; } /* compute y^2 */ if ((err = mp_sqr(key->pubkey.y, t1)) != CRYPT_OK) { goto error; } /* compute x^3 */ if ((err = mp_sqr(key->pubkey.x, t2)) != CRYPT_OK) { goto error; } if ((err = mp_mod(t2, prime, t2)) != CRYPT_OK) { goto error; } if ((err = mp_mul(key->pubkey.x, t2, t2)) != CRYPT_OK) { goto error; } /* compute y^2 - x^3 */ if ((err = mp_sub(t1, t2, t1)) != CRYPT_OK) { goto error; } /* compute y^2 - x^3 + 3x */ if ((err = mp_add(t1, key->pubkey.x, t1)) != CRYPT_OK) { goto error; } if ((err = mp_add(t1, key->pubkey.x, t1)) != CRYPT_OK) { goto error; } if ((err = mp_add(t1, key->pubkey.x, t1)) != CRYPT_OK) { goto error; } if ((err = mp_mod(t1, prime, t1)) != CRYPT_OK) { goto error; } while (mp_cmp_d(t1, 0) == LTC_MP_LT) { if ((err = mp_add(t1, prime, t1)) != CRYPT_OK) { goto error; } } while (mp_cmp(t1, prime) != LTC_MP_LT) { if ((err = mp_sub(t1, prime, t1)) != CRYPT_OK) { goto error; } } /* compare to b */ if (mp_cmp(t1, b) != LTC_MP_EQ) { err = CRYPT_INVALID_PACKET; } else { err = CRYPT_OK; } error: mp_clear_multi(prime, b, t1, t2, NULL); return err; } static int tls_ecc_import_key(const unsigned char *private_key, int private_len, const unsigned char *public_key, int public_len, ecc_key *key, const ltc_ecc_set_type *dp) { int err; if (!key || !ltc_mp.name) { return CRYPT_MEM; } key->type = PK_PRIVATE; if (mp_init_multi (&key->pubkey.x, &key->pubkey.y, &key->pubkey.z, &key->k, NULL) != CRYPT_OK) return CRYPT_MEM; if (public_len && !public_key[0]) { public_key++; public_len--; } if ((err = mp_read_unsigned_bin(key->pubkey.x, (unsigned char *) public_key + 1, (public_len - 1) >> 1)) != CRYPT_OK) { mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } if ((err = mp_read_unsigned_bin(key->pubkey.y, (unsigned char *) public_key + 1 + ((public_len - 1) >> 1), (public_len - 1) >> 1)) != CRYPT_OK) { mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } if ((err = mp_read_unsigned_bin(key->k, (unsigned char *) private_key, private_len)) != CRYPT_OK) { mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } key->idx = -1; key->dp = dp; /* set z */ if ((err = mp_set(key->pubkey.z, 1)) != CRYPT_OK) { mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } /* is it a point on the curve? */ if ((err = tls_is_point(key)) != CRYPT_OK) { DEBUG_PRINT("KEY IS NOT ON CURVE\n"); mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } /* we're good */ return CRYPT_OK; } static int sign_ecdsa(struct TLSContext *context, unsigned int hash_type, const unsigned char *message, unsigned int message_len, unsigned char *out, unsigned long *outlen) { if (!outlen || !context || !out || !outlen || !context->ec_private_key || !context->ec_private_key->priv || !context->ec_private_key->priv_len || !context->ec_private_key->pk || !context->ec_private_key->pk_len) { DEBUG_PRINT("No private ECDSA key set\n"); return TLS_GENERIC_ERROR; } const struct ECCCurveParameters *curve = NULL; switch (context->ec_private_key->ec_algorithm) { case 19: curve = &secp192r1; break; case 20: curve = &secp224k1; break; case 21: curve = &secp224r1; break; case 22: curve = &secp256k1; break; case 23: curve = &secp256r1; break; case 24: curve = &secp384r1; break; case 25: curve = &secp521r1; break; default: DEBUG_PRINT("UNSUPPORTED CURVE\n"); } if (!curve) { return TLS_GENERIC_ERROR; } ecc_key key; int err; ltc_ecc_set_type *dp = (ltc_ecc_set_type *)&curve->dp; /* broken ... fix this */ err = tls_ecc_import_key(context->ec_private_key->priv, context->ec_private_key->priv_len, context->ec_private_key->pk, context->ec_private_key->pk_len, &key, dp); if (err) { DEBUG_PRINT("Error importing ECC certificate (code: %i)\n", (int) err); return TLS_GENERIC_ERROR; } unsigned char hash[TLS_MAX_HASH_LEN]; unsigned long hash_len = 0; int hash_index; switch (hash_type) { case sha1: hash_index = find_hash("sha1"); hash_len = 20; break; case sha256: hash_index = find_hash("sha256"); hash_len = 32; break; case sha384: hash_index = find_hash("sha384"); hash_len = 48; break; case sha512: hash_index = find_hash("sha512"); hash_len = 64; break; case md5: case _md5_sha1: hash_index = find_hash("md5"); hash_len = 16; break; } hash_memory(hash_index, message, message_len, hash, &hash_len); if (hash_type == _md5_sha1) { unsigned long hlen = 20; hash_index = find_hash("sha1"); hash_memory(hash_index, message, message_len, hash+16, &hlen); hash_len += hlen; } if (err) { DEBUG_PRINT("Unsupported hash type: %i\n", hash_type); return TLS_GENERIC_ERROR; } /* "Let z be the Ln leftmost bits of e, where Ln is the bit length of * the group order n." */ if ((int)hash_len > curve->size) { hash_len = curve->size; } err = ecc_sign_hash(hash, hash_len, out, outlen, NULL, find_prng("sprng"), &key); DEBUG_DUMP_HEX_LABEL("ECC SIGNATURE", out, *outlen); ecc_free(&key); return err ? 0 : 1; } static void tls_send_certificate_verify(struct TLSContext *context) { struct TLSPacket *packet = tls_create_packet(context, TLS_HANDSHAKE, context->version, 0); /* certificate verify */ tls_packet_uint8(packet, 0x0F); tls_packet_uint24(packet, 0); unsigned char out[TLS_MAX_RSA_KEY]; unsigned long out_len = TLS_MAX_RSA_KEY; unsigned char signing_data[TLS_MAX_HASH_SIZE + 98]; int signing_data_len; /* first 64 bytes to 0x20 (32) */ memset(signing_data, 0x20, 64); /* context string 33 bytes */ if (context->is_server) { memcpy(signing_data + 64, "TLS 1.3, server CertificateVerify", 33); } else { memcpy(signing_data + 64, "TLS 1.3, client CertificateVerify", 33); } /* a single 0 byte separator */ signing_data[97] = 0; signing_data_len = 98; signing_data_len += tls_get_hash(context, signing_data + 98); DEBUG_DUMP_HEX_LABEL("verify data", signing_data, signing_data_len); int hash_algorithm = sha256; if (tls_is_ecdsa(context)) { switch (context->ec_private_key->ec_algorithm) { case 23: /* secp256r1 + sha256 */ tls_packet_uint16(packet, 0x0403); break; case 24: /* secp384r1 + sha384 */ tls_packet_uint16(packet, 0x0503); hash_algorithm = sha384; break; case 25: /* secp521r1 + sha512 */ tls_packet_uint16(packet, 0x0603); hash_algorithm = sha512; break; default: DEBUG_PRINT("UNSUPPORTED CURVE (SIGNING)\n"); packet->broken = 1; /* TODO error */ return; } } else { tls_packet_uint16(packet, 0x0804); } int packet_size = 2; if (tls_is_ecdsa(context)) { if (sign_ecdsa(context, hash_algorithm, signing_data, signing_data_len, out, &out_len) == 1) { DEBUG_PRINT ("ECDSA signing OK! (ECDSA, length %lu)\n", out_len); tls_packet_uint16(packet, out_len); tls_packet_append(packet, out, out_len); packet_size += out_len + 2; } } else if (sign_rsa(context, hash_algorithm, signing_data, signing_data_len, out, &out_len) == 1) { DEBUG_PRINT("RSA signing OK! (length %lu)\n", out_len); tls_packet_uint16(packet, out_len); tls_packet_append(packet, out, out_len); packet_size += out_len + 2; } tls_set_packet_length(packet, packet_size); tls_packet_update(packet); tls_queue_packet(packet); return; } static int tls_ecc_import_pk(const unsigned char *public_key, int public_len, ecc_key * key, const ltc_ecc_set_type * dp) { int err; if (!key || !ltc_mp.name) { return CRYPT_MEM; } key->type = PK_PUBLIC; if (mp_init_multi(&key->pubkey.x, &key->pubkey.y, &key->pubkey.z, &key->k, NULL) != CRYPT_OK) { return CRYPT_MEM; } if (public_len && !public_key[0]) { public_key++; public_len--; } if ((err = mp_read_unsigned_bin(key->pubkey.x, (unsigned char *) public_key + 1, (public_len - 1) >> 1)) != CRYPT_OK) { mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } if ((err = mp_read_unsigned_bin(key->pubkey.y, (unsigned char *) public_key + 1 + ((public_len - 1) >> 1), (public_len - 1) >> 1)) != CRYPT_OK) { mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } key->idx = -1; key->dp = dp; /* set z */ if ((err = mp_set(key->pubkey.z, 1)) != CRYPT_OK) { mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } /* is it a point on the curve? */ if ((err = tls_is_point(key)) != CRYPT_OK) { DEBUG_PRINT("KEY IS NOT ON CURVE\n"); mp_clear_multi(key->pubkey.x, key->pubkey.y, key->pubkey.z, key->k, NULL); return err; } /* we're good */ return CRYPT_OK; } static int tls_verify_ecdsa(struct TLSContext *context, unsigned int hash_type, const unsigned char *buffer, unsigned int len, const unsigned char *message, unsigned int message_len, const struct ECCCurveParameters *curve_hint) { ecc_key key; int err; if (!curve_hint) { curve_hint = context->curve; } if (len == 0) { return TLS_GENERIC_ERROR; } struct TLSCertificate **cert; int count; if (context->is_server) { cert = context->client_certificates; count = context->client_certificates_count; } else { cert = context->certificates; count = context->certificates_count; } if (count == 0 || !cert || !cert[0] || !cert[0]->pk || !cert[0]->pk_len) { return TLS_GENERIC_ERROR; } err = tls_ecc_import_pk(cert[0]->pk, cert[0]->pk_len, &key, (ltc_ecc_set_type *)&curve_hint->dp); if (err) { DEBUG_PRINT("Error importing ECC certificate (code: %i)", err); return TLS_GENERIC_ERROR; } int hash_idx = -1; unsigned char hash[TLS_MAX_HASH_LEN]; unsigned long hash_len = 0; switch (hash_type) { case md5: hash_idx = find_hash("md5"); hash_len = 16; break; case sha1: hash_idx = find_hash("sha1"); hash_len = 20; break; case sha256: hash_idx = find_hash("sha256"); hash_len = 32; break; case sha384: hash_idx = find_hash("sha384"); hash_len = 48; break; case sha512: hash_idx = find_hash("sha512"); hash_len = 64; break; } err = hash_memory(hash_idx, message, message_len, hash, &hash_len); if (hash_idx < 0 || err) { DEBUG_PRINT("Unsupported hash type: %i\n", hash_type); return TLS_GENERIC_ERROR; } int ecc_stat = 0; err = ecc_verify_hash(buffer, len, hash, hash_len, &ecc_stat, &key); ecc_free(&key); if (err) { return 0; } return ecc_stat; } static void prf_helper(int hash_idx, unsigned long dlen, unsigned char *output, unsigned int outlen, const unsigned char *secret, const unsigned int secret_len, const unsigned char *label, unsigned int label_len, unsigned char *seed, unsigned int seed_len, unsigned char *seed_b, unsigned int seed_b_len) { unsigned char digest_out0[TLS_MAX_HASH_LEN]; unsigned char digest_out1[TLS_MAX_HASH_LEN]; unsigned int i; hmac_state hmac; hmac_init(&hmac, hash_idx, secret, secret_len); hmac_process(&hmac, label, label_len); hmac_process(&hmac, seed, seed_len); if (seed_b && seed_b_len) { hmac_process(&hmac, seed_b, seed_b_len); } hmac_done(&hmac, digest_out0, &dlen); int idx = 0; while (outlen) { hmac_init(&hmac, hash_idx, secret, secret_len); hmac_process(&hmac, digest_out0, dlen); hmac_process(&hmac, label, label_len); hmac_process(&hmac, seed, seed_len); if (seed_b && seed_b_len) { hmac_process(&hmac, seed_b, seed_b_len); } hmac_done(&hmac, digest_out1, &dlen); unsigned int copylen = outlen; if (copylen > dlen) { copylen = dlen; } for (i = 0; i < copylen; i++) { output[idx++] ^= digest_out1[i]; outlen--; } if (!outlen) { break; } hmac_init(&hmac, hash_idx, secret, secret_len); hmac_process(&hmac, digest_out0, dlen); hmac_done(&hmac, digest_out0, &dlen); } } static void tls_prf(struct TLSContext *context, unsigned char *output, unsigned int outlen, const unsigned char *secret, const unsigned int secret_len, const unsigned char *label, unsigned int label_len, unsigned char *seed, unsigned int seed_len, unsigned char *seed_b, unsigned int seed_b_len) { if (!secret || !secret_len) { DEBUG_PRINT("NULL SECRET\n"); return; } /* TODO I don't think this is right, wouldn't use md5 for tls v1.3 */ if (context->version != TLS_V12) { int md5_hash_idx = find_hash("md5"); int sha1_hash_idx = find_hash("sha1"); int half_secret = (secret_len + 1) / 2; memset(output, 0, outlen); prf_helper(md5_hash_idx, 16, output, outlen, secret, half_secret, label, label_len, seed, seed_len, seed_b, seed_b_len); prf_helper(sha1_hash_idx, 20, output, outlen, secret + (secret_len - half_secret), secret_len - half_secret, label, label_len, seed, seed_len, seed_b, seed_b_len); } else { /* sha256_hmac */ unsigned char digest_out0[TLS_MAX_HASH_LEN]; unsigned char digest_out1[TLS_MAX_HASH_LEN]; unsigned long dlen = 32; int hash_idx; unsigned int mac_length = tls_mac_length(context); if (mac_length == TLS_SHA384_MAC_SIZE) { hash_idx = find_hash("sha384"); dlen = mac_length; } else { hash_idx = find_hash("sha256"); } unsigned int i; hmac_state hmac; hmac_init(&hmac, hash_idx, secret, secret_len); hmac_process(&hmac, label, label_len); hmac_process(&hmac, seed, seed_len); if (seed_b && seed_b_len) { hmac_process(&hmac, seed_b, seed_b_len); } hmac_done(&hmac, digest_out0, &dlen); int idx = 0; while (outlen) { hmac_init(&hmac, hash_idx, secret, secret_len); hmac_process(&hmac, digest_out0, dlen); hmac_process(&hmac, label, label_len); hmac_process(&hmac, seed, seed_len); if (seed_b && seed_b_len) { hmac_process(&hmac, seed_b, seed_b_len); } hmac_done(&hmac, digest_out1, &dlen); unsigned int copylen = outlen; if (copylen > dlen) { copylen = (unsigned int) dlen; } for (i = 0; i < copylen; i++) { output[idx++] = digest_out1[i]; outlen--; } if (!outlen) { break; } hmac_init(&hmac, hash_idx, secret, secret_len); hmac_process(&hmac, digest_out0, dlen); hmac_done(&hmac, digest_out0, &dlen); } } } static void tls_send_finished(struct TLSContext *context) { ENTER; struct TLSPacket *packet = tls_create_packet(context, TLS_HANDSHAKE, context->version, TLS_MIN_FINISHED_OPAQUE_LEN + 64); tls_packet_uint8(packet, 20); if (context->tlsver == TLS_VERSION13) { tls_packet_uint24(packet, tls_mac_length(context)); } else { tls_packet_uint24(packet, TLS_MIN_FINISHED_OPAQUE_LEN); } /* verify */ unsigned char hash[TLS_MAX_HASH_SIZE]; unsigned long out_size = TLS_MIN_FINISHED_OPAQUE_LEN; unsigned char out[TLS_MAX_HASH_SIZE]; unsigned int hash_len; int context_is_v13 = 0; if (packet->context->tlsver == TLS_VERSION13) { context_is_v13 = 1; } /* server verifies client's message */ if (context->is_server) { if (context_is_v13) { hash_len = tls_get_hash(context, hash); if (!context->finished_key || !hash_len) { DEBUG_PRINT ("NO FINISHED KEY COMPUTED OR NO HANDSHAKE HASH\n"); /* TODO probably need to terminate */ tls_destroy_packet(packet); LEAVE; return; } DEBUG_DUMP_HEX_LABEL("HS HASH", hash, hash_len); DEBUG_DUMP_HEX_LABEL("HS FINISH", context->finished_key, hash_len); out_size = hash_len; hmac_state hmac; hmac_init(&hmac, tls_get_hash_idx(context), context->finished_key, hash_len); hmac_process(&hmac, hash, hash_len); hmac_done(&hmac, out, &out_size); } else { hash_len = tls_done_hash(context, hash); tls_prf(context, out, TLS_MIN_FINISHED_OPAQUE_LEN, context->master_key, context->master_key_len, (unsigned char *) "server finished", 15, hash, hash_len, NULL, 0); tls_destroy_hash(context); } } else { /* client */ hash_len = tls_get_hash(context, hash); tls_prf(context, out, TLS_MIN_FINISHED_OPAQUE_LEN, context->master_key, context->master_key_len, (unsigned char *) "client finished", 15, hash, hash_len, NULL, 0); } tls_packet_append(packet, out, out_size); tls_packet_update(packet); DEBUG_DUMP_HEX_LABEL("VERIFY DATA", out, out_size); tls_queue_packet(packet); LEAVE; return; } static int tls_key_length(struct TLSContext *context) { switch (context->cipher) { case TLS_RSA_WITH_AES_128_CBC_SHA: case TLS_RSA_WITH_AES_128_CBC_SHA256: case TLS_RSA_WITH_AES_128_GCM_SHA256: case TLS_DHE_RSA_WITH_AES_128_CBC_SHA: case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256: case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: case TLS_AES_128_GCM_SHA256: return 16; case TLS_RSA_WITH_AES_256_CBC_SHA: case TLS_RSA_WITH_AES_256_CBC_SHA256: case TLS_RSA_WITH_AES_256_GCM_SHA384: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA256: case TLS_DHE_RSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_AES_256_GCM_SHA384: case TLS_CHACHA20_POLY1305_SHA256: return 32; } return 0; } /* 0 is none, 1 is GCM?, 2 is chacha */ int tls_is_aead(struct TLSContext *context) { switch (context->cipher) { case TLS_RSA_WITH_AES_128_GCM_SHA256: case TLS_RSA_WITH_AES_256_GCM_SHA384: case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_DHE_RSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: case TLS_AES_128_GCM_SHA256: case TLS_AES_256_GCM_SHA384: return 1; case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_CHACHA20_POLY1305_SHA256: return 2; } return 0; } int tls_mac_length(struct TLSContext *context) { switch (context->cipher) { case TLS_RSA_WITH_AES_128_CBC_SHA: case TLS_RSA_WITH_AES_256_CBC_SHA: case TLS_DHE_RSA_WITH_AES_128_CBC_SHA: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: return TLS_SHA1_MAC_SIZE; case TLS_RSA_WITH_AES_128_CBC_SHA256: case TLS_RSA_WITH_AES_256_CBC_SHA256: case TLS_RSA_WITH_AES_128_GCM_SHA256: case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA256: case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_AES_128_GCM_SHA256: case TLS_CHACHA20_POLY1305_SHA256: case TLS_AES_128_CCM_SHA256: case TLS_AES_128_CCM_8_SHA256: return TLS_SHA256_MAC_SIZE; case TLS_RSA_WITH_AES_256_GCM_SHA384: case TLS_DHE_RSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: case TLS_AES_256_GCM_SHA384: return TLS_SHA384_MAC_SIZE; } return 0; } int _private_tls13_key(struct TLSContext *context, int handshake) { int key_length = tls_key_length(context); int mac_length = tls_mac_length(context); if (!context->premaster_key || !context->premaster_key_len) { return 0; } if (!key_length || !mac_length) { DEBUG_PRINT ("KEY EXPANSION FAILED, KEY LENGTH: %i, MAC LENGTH: %i\n", key_length, mac_length); return 0; } unsigned char *clientkey = NULL; unsigned char *serverkey = NULL; unsigned char *clientiv = NULL; unsigned char *serveriv = NULL; int is_aead = tls_is_aead(context); unsigned char local_keybuffer[TLS_V13_MAX_KEY_SIZE]; unsigned char local_ivbuffer[TLS_V13_MAX_IV_SIZE]; unsigned char remote_keybuffer[TLS_V13_MAX_KEY_SIZE]; unsigned char remote_ivbuffer[TLS_V13_MAX_IV_SIZE]; unsigned char prk[TLS_MAX_HASH_SIZE]; unsigned char hash[TLS_MAX_HASH_SIZE]; static unsigned char earlysecret[TLS_MAX_HASH_SIZE]; const char *server_key = "s ap traffic"; const char *client_key = "c ap traffic"; if (handshake) { server_key = "s hs traffic"; client_key = "c hs traffic"; } unsigned char salt[TLS_MAX_HASH_SIZE]; hash_state md; /* TODO what is the point of this ? */ if (mac_length == TLS_SHA384_MAC_SIZE) { sha384_init(&md); sha384_done(&md, hash); } else { sha256_init(&md); sha256_done(&md, hash); } /* extract secret "early" */ if (context->master_key && context->master_key_len && !handshake) { DEBUG_DUMP_HEX_LABEL("USING PREVIOUS SECRET", context->master_key, context->master_key_len); tls_hkdf_expand_label(mac_length, salt, mac_length, context->master_key, context->master_key_len, "derived", 7, hash, mac_length); DEBUG_DUMP_HEX_LABEL("salt", salt, mac_length); tls_hkdf_extract(mac_length, prk, mac_length, salt, mac_length, earlysecret, mac_length); } else { tls_hkdf_extract(mac_length, prk, mac_length, NULL, 0, earlysecret, mac_length); /* derive secret for handshake "tls13 derived": */ DEBUG_DUMP_HEX_LABEL("null hash", hash, mac_length); tls_hkdf_expand_label(mac_length, salt, mac_length, prk, mac_length, "derived", 7, hash, mac_length); /* extract secret "handshake": */ DEBUG_DUMP_HEX_LABEL("salt", salt, mac_length); tls_hkdf_extract(mac_length, prk, mac_length, salt, mac_length, context->premaster_key, context->premaster_key_len); } if (!is_aead) { DEBUG_PRINT("KEY EXPANSION FAILED, NON AEAD CIPHER\n"); return 0; } unsigned char secret[TLS_MAX_MAC_SIZE]; unsigned char hs_secret[TLS_MAX_HASH_SIZE]; int hash_size; if (handshake) { hash_size = tls_get_hash(context, hash); } else { hash_size = tls_done_hash(context, hash); } DEBUG_DUMP_HEX_LABEL("messages hash", hash, hash_size); if (context->is_server) { tls_hkdf_expand_label(mac_length, hs_secret, mac_length, prk, mac_length, server_key, 12, context-> server_finished_hash ? context-> server_finished_hash : hash, hash_size); DEBUG_DUMP_HEX_LABEL(server_key, hs_secret, mac_length); serverkey = local_keybuffer; serveriv = local_ivbuffer; clientkey = remote_keybuffer; clientiv = remote_ivbuffer; } else { tls_hkdf_expand_label(mac_length, hs_secret, mac_length, prk, mac_length, client_key, 12, context-> server_finished_hash ? context-> server_finished_hash : hash, hash_size); serverkey = remote_keybuffer; serveriv = remote_ivbuffer; clientkey = local_keybuffer; clientiv = local_ivbuffer; } int iv_length = TLS_13_AES_GCM_IV_LENGTH; if (is_aead == 2) { iv_length = TLS_CHACHA20_IV_LENGTH; } tls_hkdf_expand_label(mac_length, local_keybuffer, key_length, hs_secret, mac_length, "key", 3, NULL, 0); tls_hkdf_expand_label(mac_length, local_ivbuffer, iv_length, hs_secret, mac_length, "iv", 2, NULL, 0); tls_hkdf_expand_label(mac_length, secret, mac_length, prk, mac_length, context->is_server ? client_key : server_key, 12, context->server_finished_hash ? context->server_finished_hash : hash, hash_size); tls_hkdf_expand_label(mac_length, remote_keybuffer, key_length, secret, mac_length, "key", 3, NULL, 0); tls_hkdf_expand_label(mac_length, remote_ivbuffer, iv_length, secret, mac_length, "iv", 2, NULL, 0); DEBUG_DUMP_HEX_LABEL("CLIENT KEY", clientkey, key_length); DEBUG_DUMP_HEX_LABEL("CLIENT IV", clientiv, iv_length); DEBUG_DUMP_HEX_LABEL("SERVER KEY", serverkey, key_length); DEBUG_DUMP_HEX_LABEL("SERVER IV", serveriv, iv_length); free(context->finished_key); free(context->remote_finished_key); if (handshake) { context->finished_key = malloc(mac_length); context->remote_finished_key = malloc(mac_length); if (context->finished_key) { tls_hkdf_expand_label(mac_length, context->finished_key, mac_length, hs_secret, mac_length, "finished", 8, NULL, 0); DEBUG_DUMP_HEX_LABEL("FINISHED", context->finished_key, mac_length); } if (context->remote_finished_key) { tls_hkdf_expand_label(mac_length, context->remote_finished_key, mac_length, secret, mac_length, "finished", 8, NULL, 0); DEBUG_DUMP_HEX_LABEL("FINISHED", context->finished_key, mac_length); } } else { context->finished_key = NULL; context->remote_finished_key = NULL; free(context->server_finished_hash); context->server_finished_hash = NULL; } if (context->is_server) { if (is_aead == 2) { memcpy(context->crypto.ctx_remote_mac.remote_nonce, clientiv, iv_length); memcpy(context->crypto.ctx_local_mac.local_nonce, serveriv, iv_length); } else if (is_aead) { memcpy(context->crypto.ctx_remote_mac.remote_iv, clientiv, iv_length); memcpy(context->crypto.ctx_local_mac.local_iv, serveriv, iv_length); } if (tls_crypto_create(context, key_length, serverkey, serveriv, clientkey, clientiv)) { return 0; } } else { if (is_aead == 2) { memcpy(context->crypto.ctx_local_mac.local_nonce, clientiv, iv_length); memcpy(context->crypto.ctx_remote_mac.remote_nonce, serveriv, iv_length); } else if (is_aead) { memcpy(context->crypto.ctx_local_mac.local_iv, clientiv, iv_length); memcpy(context->crypto.ctx_remote_mac.remote_iv, serveriv, iv_length); } if (tls_crypto_create(context, key_length, clientkey, clientiv, serverkey, serveriv)) { return 0; } } context->crypto.created = 1 + is_aead; free(context->master_key); context->master_key = malloc(mac_length); if (context->master_key) { memcpy(context->master_key, prk, mac_length); context->master_key_len = mac_length; } context->local_sequence_number = 0; context->remote_sequence_number = 0; /* * extract client_mac_key(mac_key_length) * extract server_mac_key(mac_key_length) * extract client_key(enc_key_length) * extract server_key(enc_key_length) * extract client_iv(fixed_iv_lengh) * extract server_iv(fixed_iv_length) */ return 1; } static int tls_expand_key(struct TLSContext *context) { unsigned char key[TLS_MAX_KEY_EXPANSION_SIZE]; if (context->tlsver == TLS_VERSION13) { return 0; } if (!context->master_key || !context->master_key_len) { return 0; } int key_length = tls_key_length(context); int mac_length = tls_mac_length(context); if (!key_length || !mac_length) { DEBUG_PRINT ("KEY EXPANSION FAILED, KEY LENGTH: %i, MAC LENGTH: %i\n", key_length, mac_length); return 0; } unsigned char *clientkey = NULL; unsigned char *serverkey = NULL; unsigned char *clientiv = NULL; unsigned char *serveriv = NULL; int iv_length = TLS_AES_IV_LENGTH; int is_aead = tls_is_aead(context); if (context->is_server) { tls_prf(context, key, sizeof(key), context->master_key, context->master_key_len, (unsigned char *) "key expansion", 13, context->local_random, TLS_SERVER_RANDOM_SIZE, context->remote_random, TLS_CLIENT_RANDOM_SIZE); } else { tls_prf(context, key, sizeof(key), context->master_key, context->master_key_len, (unsigned char *) "key expansion", 13, context->remote_random, TLS_SERVER_RANDOM_SIZE, context->local_random, TLS_CLIENT_RANDOM_SIZE); } DEBUG_DUMP_HEX_LABEL("LOCAL RANDOM ", context->local_random, TLS_SERVER_RANDOM_SIZE); DEBUG_DUMP_HEX_LABEL("REMOTE RANDOM", context->remote_random, TLS_CLIENT_RANDOM_SIZE); DEBUG_PRINT("\n=========== EXPANSION ===========\n"); DEBUG_DUMP_HEX(key, TLS_MAX_KEY_EXPANSION_SIZE); DEBUG_PRINT("\n"); int pos = 0; if (is_aead == 2) { iv_length = TLS_CHACHA20_IV_LENGTH; } else { if (is_aead) { iv_length = TLS_AES_GCM_IV_LENGTH; } else { if (context->is_server) { memcpy(context->crypto.ctx_remote_mac.remote_mac, &key[pos], mac_length); pos += mac_length; memcpy(context->crypto.ctx_local_mac.local_mac, &key[pos], mac_length); pos += mac_length; } else { memcpy(context->crypto.ctx_local_mac.local_mac, &key[pos], mac_length); pos += mac_length; memcpy(context->crypto.ctx_remote_mac.remote_mac, &key[pos], mac_length); pos += mac_length; } } } clientkey = &key[pos]; pos += key_length; serverkey = &key[pos]; pos += key_length; clientiv = &key[pos]; pos += iv_length; serveriv = &key[pos]; pos += iv_length; DEBUG_PRINT("EXPANSION %i/%i\n", (int) pos, (int) TLS_MAX_KEY_EXPANSION_SIZE); DEBUG_DUMP_HEX_LABEL("CLIENT KEY", clientkey, key_length); DEBUG_DUMP_HEX_LABEL("CLIENT IV", clientiv, iv_length); #if 0 DEBUG_DUMP_HEX_LABEL("CLIENT MAC KEY", context->is_server ? context->crypto. ctx_remote_mac.remote_mac : context-> crypto.ctx_local_mac.local_mac, mac_length); #endif DEBUG_DUMP_HEX_LABEL("SERVER KEY", serverkey, key_length); DEBUG_DUMP_HEX_LABEL("SERVER IV", serveriv, iv_length); #if 0 DEBUG_DUMP_HEX_LABEL("SERVER MAC KEY", context->is_server ? context->crypto. ctx_local_mac.local_mac : context->crypto. ctx_remote_mac.remote_mac, mac_length); #endif if (context->is_server) { if (is_aead == 2) { memcpy(context->crypto.ctx_remote_mac.remote_nonce, clientiv, iv_length); memcpy(context->crypto.ctx_local_mac.local_nonce, serveriv, iv_length); } else { if (is_aead) { memcpy(context->crypto.ctx_remote_mac. remote_aead_iv, clientiv, iv_length); memcpy(context->crypto.ctx_local_mac.local_aead_iv, serveriv, iv_length); } } if (tls_crypto_create(context, key_length, serverkey, serveriv, clientkey, clientiv)) { return 0; } } else { if (is_aead == 2) { memcpy(context->crypto.ctx_local_mac.local_nonce, clientiv, iv_length); memcpy(context->crypto.ctx_remote_mac.remote_nonce, serveriv, iv_length); } else { if (is_aead) { memcpy(context->crypto.ctx_local_mac.local_aead_iv, clientiv, iv_length); memcpy(context->crypto.ctx_remote_mac. remote_aead_iv, serveriv, iv_length); } } if (tls_crypto_create(context, key_length, clientkey, clientiv, serverkey, serveriv)) { return 0; } } /* * extract client_mac_key(mac_key_length) * extract server_mac_key(mac_key_length) * extract client_key(enc_key_length) * extract server_key(enc_key_length) * extract client_iv(fixed_iv_lengh) * extract server_iv(fixed_iv_length) */ return 1; } int tls_compute_key(struct TLSContext *context, unsigned int key_len) { if (context->tlsver == TLS_VERSION13) { return 0; } if (!context->premaster_key || !context->premaster_key_len || key_len < 48) { DEBUG_PRINT("CANNOT COMPUTE MASTER SECRET\n"); return 0; } unsigned char master_secret_label[] = "master secret"; #ifdef TLS_CHECK_PREMASTER_KEY if (!tls_cipher_is_ephemeral(context)) { uint16_t version = get16(context->premaster_key); /* this check is not true for DHE/ECDHE ciphers */ if (context->version > version) { DEBUG_PRINT("Mismatch protocol version 0x(%x)\n", version); return 0; } } #endif free(context->master_key); context->master_key_len = 0; context->master_key = NULL; context->master_key = malloc(key_len); if (!context->master_key) { return 0; } context->master_key_len = key_len; if (context->is_server) { tls_prf(context, context->master_key, context->master_key_len, context->premaster_key, context->premaster_key_len, master_secret_label, 13, context->remote_random, TLS_CLIENT_RANDOM_SIZE, context->local_random, TLS_SERVER_RANDOM_SIZE); } else { tls_prf(context, context->master_key, context->master_key_len, context->premaster_key, context->premaster_key_len, master_secret_label, 13, context->local_random, TLS_CLIENT_RANDOM_SIZE, context->remote_random, TLS_SERVER_RANDOM_SIZE); } free(context->premaster_key); context->premaster_key = NULL; context->premaster_key_len = 0; DEBUG_PRINT("\n=========== Master key ===========\n"); DEBUG_DUMP_HEX(context->master_key, context->master_key_len); DEBUG_PRINT("\n"); tls_expand_key(context); return 1; } int _is_oid(const unsigned char *oid, const unsigned char *compare_to, int compare_to_len) { int i = 0; while ((oid[i]) && (i < compare_to_len)) { if (oid[i] != compare_to[i]) return 0; i++; } return 1; } int _is_oid2(const unsigned char *oid, const unsigned char *compare_to, int compare_to_len, int oid_len) { int i = 0; if (oid_len < compare_to_len) { compare_to_len = oid_len; } while (i < compare_to_len) { if (oid[i] != compare_to[i]) { return 0; } i++; } return 1; } struct TLSCertificate *tls_create_certificate() { struct TLSCertificate zero = { 0 }; struct TLSCertificate *cert = malloc(sizeof *cert); if (cert) { *cert = zero; } cert->not_before[0] = 0; cert->not_after[0] = 0; return cert; } int tls_certificate_valid_subject_name(const unsigned char *cert_subject, const char *subject) { /* no subjects ... */ if ((!cert_subject || !cert_subject[0]) && (!subject || !subject[0])) { return 0; } if (!subject || !subject[0]) { return bad_certificate; } if (!cert_subject || !cert_subject[0]) { return bad_certificate; } /* exact match */ if (!strcmp((const char *) cert_subject, subject)) { return 0; } const char *wildcard = strchr((const char *) cert_subject, '*'); if (wildcard) { /* 6.4.3 (1) The client SHOULD NOT attempt to match a presented * identifier in which the wildcard character comprises a label * other than the left-most label */ if (!wildcard[1]) { /* subject is [*] * or * subject is [something*] .. invalid */ return bad_certificate; } wildcard++; const char *match = strstr(subject, wildcard); if ((!match) && (wildcard[0] == '.')) { /* check *.domain.com against domain.com */ wildcard++; if (!strcasecmp(subject, wildcard)) return 0; } if (match) { unsigned long offset = (unsigned long) match - (unsigned long) subject; if (offset) { /* check for foo.*.domain.com against *.domain.com (invalid) */ if (memchr(subject, '.', offset)) return bad_certificate; } /* check if exact match */ if (!strcasecmp(match, wildcard)) { return 0; } } } return bad_certificate; } int tls_certificate_valid_subject(struct TLSCertificate *cert, const char *subject) { int i; if (!cert) { return certificate_unknown; } int err = tls_certificate_valid_subject_name(cert->subject, subject); if (err && cert->san) { for (i = 0; i < cert->san_length; i++) { err = tls_certificate_valid_subject_name(cert->san[i], subject); if (!err) { return err; } } } return err; } int tls_certificate_is_valid(struct TLSCertificate *cert) { if (!cert) { return certificate_unknown; } char ts[16]; /* YYYYMMDDHHMMSSZ */ time_t t = time(NULL); struct tm *utc = gmtime(&t); if (utc) { strftime(ts, sizeof ts, "%Y%m%d%H%M%SZ", utc); if (strcmp(cert->not_before, ts) > 0) { DEBUG_PRINT ("Certificate is not yet valid, now: %s (validity: %s - %s)\n", ts, cert->not_before, cert->not_after); return certificate_expired; } if (strcmp(cert->not_after, ts) < 0) { DEBUG_PRINT ("Expired certificate, now: %s (validity: %s - %s)\n", ts, cert->not_before, cert->not_after); return certificate_expired; } DEBUG_PRINT("Valid certificate, now: %s (validity: %s - %s)\n", ts, cert->not_before, cert->not_after); } return 0; } void tls_certificate_set_copy(unsigned char **member, const unsigned char *val, int len) { if (!member) { return; } free(*member); if (len) { *member = malloc(len + 1); if (*member) { memcpy(*member, val, len); (*member)[len] = 0; } } else { *member = NULL; } } void tls_certificate_set_copy_date(unsigned char *member, const unsigned char *val, int len) { if (len > 4) { if (val[0] >= '5') { member[0] = '1'; member[1] = '9'; } else { member[0] = '2'; member[1] = '0'; } memcpy(member + 2, val, len); member[len] = 0; } else { member[0] = 0; } } void tls_certificate_set_key(struct TLSCertificate *cert, const unsigned char *val, int len) { if (!val[0] && len % 2) { val++; len--; } tls_certificate_set_copy(&cert->pk, val, len); if (cert->pk) { cert->pk_len = len; } } void tls_certificate_set_priv(struct TLSCertificate *cert, const unsigned char *val, int len) { tls_certificate_set_copy(&cert->priv, val, len); if (cert->priv) { cert->priv_len = len; } } void tls_certificate_set_sign_key(struct TLSCertificate *cert, const unsigned char *val, int len) { if (!val[0] && len % 2) { val++; len--; } tls_certificate_set_copy(&cert->sign_key, val, len); if (cert->sign_key) { cert->sign_len = len; } } void tls_certificate_set_exponent(struct TLSCertificate *cert, const unsigned char *val, int len) { tls_certificate_set_copy(&cert->exponent, val, len); if (cert->exponent) { cert->exponent_len = len; } } void tls_certificate_set_serial(struct TLSCertificate *cert, const unsigned char *val, int len) { tls_certificate_set_copy(&cert->serial_number, val, len); if (cert->serial_number) { cert->serial_len = len; } } void tls_certificate_set_algorithm(unsigned int *algorithm, const unsigned char *val, int len) { if (len == 7 && _is_oid(val, TLS_EC_PUBLIC_KEY_OID, 7)) { *algorithm = TLS_EC_PUBLIC_KEY; return; } if (len == 8) { if (_is_oid(val, TLS_EC_prime192v1_OID, len)) { *algorithm = TLS_EC_prime192v1; return; } if (_is_oid(val, TLS_EC_prime192v2_OID, len)) { *algorithm = TLS_EC_prime192v2; return; } if (_is_oid(val, TLS_EC_prime192v3_OID, len)) { *algorithm = TLS_EC_prime192v3; return; } if (_is_oid(val, TLS_EC_prime239v1_OID, len)) { *algorithm = TLS_EC_prime239v1; return; } if (_is_oid(val, TLS_EC_prime239v2_OID, len)) { *algorithm = TLS_EC_prime239v2; return; } if (_is_oid(val, TLS_EC_prime239v3_OID, len)) { *algorithm = TLS_EC_prime239v3; return; } if (_is_oid(val, TLS_EC_prime256v1_OID, len)) { *algorithm = TLS_EC_prime256v1; return; } } if (len == 5) { if (_is_oid2 (val, TLS_EC_secp224r1_OID, len, sizeof(TLS_EC_secp224r1_OID) - 1)) { *algorithm = TLS_EC_secp224r1; return; } if (_is_oid2 (val, TLS_EC_secp384r1_OID, len, sizeof(TLS_EC_secp384r1_OID) - 1)) { *algorithm = TLS_EC_secp384r1; return; } if (_is_oid2 (val, TLS_EC_secp521r1_OID, len, sizeof(TLS_EC_secp521r1_OID) - 1)) { *algorithm = TLS_EC_secp521r1; return; } } if (len != 9) { return; } if (_is_oid(val, TLS_RSA_SIGN_SHA256_OID, 9)) { *algorithm = TLS_RSA_SIGN_SHA256; return; } if (_is_oid(val, TLS_RSA_SIGN_RSA_OID, 9)) { *algorithm = TLS_RSA_SIGN_RSA; return; } if (_is_oid(val, TLS_RSA_SIGN_SHA1_OID, 9)) { *algorithm = TLS_RSA_SIGN_SHA1; return; } if (_is_oid(val, TLS_RSA_SIGN_SHA512_OID, 9)) { *algorithm = TLS_RSA_SIGN_SHA512; return; } if (_is_oid(val, TLS_RSA_SIGN_SHA384_OID, 9)) { *algorithm = TLS_RSA_SIGN_SHA384; return; } if (_is_oid(val, TLS_RSA_SIGN_MD5_OID, 9)) { *algorithm = TLS_RSA_SIGN_MD5; return; } } void tls_destroy_certificate(struct TLSCertificate *cert) { if (cert) { int i; free(cert->exponent); free(cert->pk); free(cert->issuer_country); free(cert->issuer_state); free(cert->issuer_location); free(cert->issuer_entity); free(cert->issuer_subject); free(cert->country); free(cert->state); free(cert->location); free(cert->subject); for (i = 0; i < cert->san_length; i++) { free(cert->san[i]); } free(cert->san); free(cert->ocsp); free(cert->serial_number); free(cert->entity); cert->not_before[0] = 0; cert->not_after[0] = 0; free(cert->sign_key); free(cert->priv); free(cert->der_bytes); free(cert->bytes); free(cert->fingerprint); free(cert); } } struct TLSPacket *tls_create_packet(struct TLSContext *context, unsigned char type, unsigned short version, int payload_size_hint) { struct TLSPacket *packet = malloc(sizeof *packet); if (!packet) { return NULL; } packet->broken = 0; if (payload_size_hint > 0) { packet->size = payload_size_hint + 10; } else { packet->size = TLS_BLOB_INCREMENT; } packet->buf = malloc(packet->size); memset(packet->buf, 0, packet->size); packet->context = context; if (!packet->buf) { free(packet); return NULL; } if (context) { packet->payload_pos = 6; packet->len = packet->payload_pos - 1; } else { packet->len = 5; } packet->buf[0] = type; /* big endian protocol version */ packet->buf[1] = version >> 8; packet->buf[2] = version & 0xff; if (version == TLS_V13) { packet->buf[2] = 0x04; } return packet; } void tls_destroy_packet(struct TLSPacket *packet) { if (packet) { if (packet->buf) { free(packet->buf); } free(packet); } } int tls_crypto_create(struct TLSContext *context, int key_length, unsigned char *localkey, unsigned char *localiv, unsigned char *remotekey, unsigned char *remoteiv) { if (context->crypto.created) { if (context->crypto.created == 1) { cbc_done(&context->crypto.ctx_remote.aes_remote); cbc_done(&context->crypto.ctx_local.aes_local); } else { if (context->crypto.created == 2) { unsigned char dummy_buffer[32]; unsigned long tag_len = 0; gcm_done(&context->crypto.ctx_remote. aes_gcm_remote, dummy_buffer, &tag_len); gcm_done(&context->crypto.ctx_local. aes_gcm_local, dummy_buffer, &tag_len); } } context->crypto.created = 0; } int is_aead = tls_is_aead(context); int cipherID = find_cipher("aes"); DEBUG_PRINT("Using cipher ID: %x\n", (int) context->cipher); if (is_aead == 2) { unsigned int counter = 1; chacha_keysetup(&context->crypto.ctx_local.chacha_local, localkey, key_length * 8); chacha_ivsetup_96bitnonce(&context->crypto.ctx_local. chacha_local, localiv, (unsigned char *) &counter); chacha_keysetup(&context->crypto.ctx_remote.chacha_remote, remotekey, key_length * 8); chacha_ivsetup_96bitnonce(&context->crypto.ctx_remote. chacha_remote, remoteiv, (unsigned char *) &counter); context->crypto.created = 3; } else { if (is_aead) { int res1 = gcm_init(&context->crypto.ctx_local.aes_gcm_local, cipherID, localkey, key_length); int res2 = gcm_init(&context->crypto.ctx_remote.aes_gcm_remote, cipherID, remotekey, key_length); if (res1 || res2) { return TLS_GENERIC_ERROR; } context->crypto.created = 2; } else { int res1 = cbc_start(cipherID, localiv, localkey, key_length, 0, &context->crypto.ctx_local.aes_local); int res2 = cbc_start(cipherID, remoteiv, remotekey, key_length, 0, &context->crypto.ctx_remote.aes_remote); if (res1 || res2) { return TLS_GENERIC_ERROR; } context->crypto.created = 1; } } return 0; } static void tls_crypto_done(struct TLSContext *context) { unsigned char dummy_buffer[32]; unsigned long tag_len = 0; switch (context->crypto.created) { case 1: cbc_done(&context->crypto.ctx_remote.aes_remote); cbc_done(&context->crypto.ctx_local.aes_local); break; case 2: gcm_done(&context->crypto.ctx_remote.aes_gcm_remote, dummy_buffer, &tag_len); gcm_done(&context->crypto.ctx_local.aes_gcm_local, dummy_buffer, &tag_len); break; } context->crypto.created = 0; } int tls_packet_append(struct TLSPacket *packet, const unsigned char *buf, unsigned int len) { void *new; if (!packet || packet->broken) { return -1; } if (!len) { return 0; } unsigned int new_len = packet->len + len; if (new_len > packet->size) { packet->size = (new_len / TLS_BLOB_INCREMENT + 1) * TLS_BLOB_INCREMENT; new = TLS_REALLOC(packet->buf, packet->size); if (new) { packet->buf = new; } else { free(packet->buf); packet->size = 0; packet->len = 0; packet->broken = 1; return -1; } } memcpy(packet->buf + packet->len, buf, len); packet->len = new_len; return new_len; } int tls_packet_uint8(struct TLSPacket *packet, unsigned char i) { return tls_packet_append(packet, &i, 1); } int tls_packet_uint16(struct TLSPacket *packet, unsigned short i) { unsigned short ni = htons(i); return tls_packet_append(packet, (unsigned char *) &ni, 2); } int tls_packet_uint32(struct TLSPacket *packet, unsigned int i) { unsigned int ni = htonl(i); return tls_packet_append(packet, (unsigned char *) &ni, 4); } int tls_packet_uint24(struct TLSPacket *packet, unsigned int i) { unsigned char buf[3]; buf[0] = (i >> 16) & 0xff; buf[1] = (i >> 8) & 0xff; buf[2] = (i >> 0) & 0xff; return tls_packet_append(packet, buf, 3); } int tls_random(unsigned char *key, int len) { #ifdef __APPLE__ for (int i = 0; i < len; i++) { unsigned int v = arc4random() % 0x100; key[i] = (char) v; } return 1; #else /* TODO use open and read */ FILE *fp = fopen("/dev/urandom", "r"); if (fp) { int key_len = fread(key, 1, len, fp); fclose(fp); if (key_len == len) return 1; } #endif return 0; } int tls_established(struct TLSContext *context) { return context && context->connection_status == TLS_CONNECTED; } void tls_read_clear(struct TLSContext *context) { if (context) { tls_buffer_free(&context->application_buffer); } } struct TLSContext *tls_create_context(int is_server, unsigned short version) { struct TLSContext zero = {0}; uint16_t ver = 0; struct TLSContext *context = 0; if (version == TLS_V13 && !is_server) { /* TLS 1.3 clients not supported */ return NULL; } tls_init(); switch (version) { case TLS_V13: case TLS_V12: context = malloc(sizeof *context); break; default: return NULL; } ver = version - 0x0201; if (context) { *context = zero; context->is_server = is_server; context->version = version; context->tlsver = ver; context->hs_index = -1; /* set up output buffer */ tls_buffer_init(&context->output_buffer, 0); tls_buffer_init(&context->input_buffer, 0); tls_buffer_init(&context->cached_handshake, 0); tls_buffer_init(&context->application_buffer, 0); } return context; } struct TLSContext *tls_accept(struct TLSContext *context) { if (!context || !context->is_server) { return NULL; } struct TLSContext *child = malloc(sizeof *child); if (child) { memset(child, 0, sizeof(struct TLSContext)); child->is_server = 1; child->is_child = 1; child->version = context->version; child->certificates = context->certificates; child->certificates_count = context->certificates_count; child->private_key = context->private_key; child->ec_private_key = context->ec_private_key; child->root_certificates = context->root_certificates; child->root_count = context->root_count; child->default_dhe_p = context->default_dhe_p; child->default_dhe_g = context->default_dhe_g; child->curve = context->curve; child->alpn = context->alpn; child->alpn_count = context->alpn_count; } return child; } int tls_add_alpn(struct TLSContext *context, const char *alpn) { void *new; if (!context || !alpn || !alpn[0] || (context->is_server && context->is_child)) { return TLS_GENERIC_ERROR; } int len = strlen(alpn); if (tls_alpn_contains(context, alpn, len)) { return 0; } new = TLS_REALLOC(context->alpn, (context->alpn_count + 1) * sizeof(char *)); if (new) { context->alpn = new; } else { free(context->alpn); context->alpn = 0; context->alpn_count = 0; return TLS_NO_MEMORY; } char *alpn_ref = malloc(len + 1); context->alpn[context->alpn_count] = alpn_ref; if (alpn_ref) { memcpy(alpn_ref, alpn, len); alpn_ref[len] = 0; context->alpn_count++; } else { return TLS_NO_MEMORY; } return 0; } int tls_alpn_contains(struct TLSContext *context, const char *alpn, unsigned char alpn_size) { int i; if (!context || !alpn || !alpn_size || !context->alpn) { return 0; } for (i = 0; i < context->alpn_count; i++) { const char *alpn_local = context->alpn[i]; if (alpn_local) { int len = strlen(alpn_local); if (alpn_size == len) { if (!memcmp(alpn_local, alpn, alpn_size)) { return 1; } } } } return 0; } void tls_destroy_context(struct TLSContext *context) { int i; if (!context) { return; } if (!context->is_child) { if (context->certificates) { for (i = 0; i < context->certificates_count; i++) { tls_destroy_certificate(context-> certificates[i]); } } if (context->root_certificates) { for (i = 0; i < context->root_count; i++) { tls_destroy_certificate(context-> root_certificates[i]); } free(context->root_certificates); context->root_certificates = NULL; } if (context->private_key) { tls_destroy_certificate(context->private_key); } if (context->ec_private_key) { tls_destroy_certificate(context->ec_private_key); } free(context->certificates); free(context->default_dhe_p); free(context->default_dhe_g); if (context->alpn) { for (i = 0; i < context->alpn_count; i++) { free(context->alpn[i]); } free(context->alpn); } } if (context->client_certificates) { for (i = 0; i < context->client_certificates_count; i++) { tls_destroy_certificate(context-> client_certificates[i]); } free(context->client_certificates); } context->client_certificates = NULL; free(context->master_key); free(context->premaster_key); if (context->crypto.created) { tls_crypto_done(context); } tls_done_hash(context, NULL); tls_destroy_hash(context); tls_buffer_free(&context->output_buffer); tls_buffer_free(&context->input_buffer); tls_buffer_free(&context->application_buffer); tls_buffer_free(&context->cached_handshake); //free(context->cached_handshake); free(context->sni); tls_dhe_free(context); tls_ecc_dhe_free(context); free(context->negotiated_alpn); free(context->finished_key); free(context->remote_finished_key); free(context->server_finished_hash); free(context); } int tls_cipher_is_ephemeral(struct TLSContext *context) { if (!context) { return 0; } switch (context->cipher) { case TLS_DHE_RSA_WITH_AES_128_CBC_SHA: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA: case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA256: case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_DHE_RSA_WITH_AES_256_GCM_SHA384: case TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256: return 1; case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: return 2; case TLS_AES_128_GCM_SHA256: case TLS_CHACHA20_POLY1305_SHA256: case TLS_AES_128_CCM_SHA256: case TLS_AES_128_CCM_8_SHA256: case TLS_AES_256_GCM_SHA384: if (context->dhe) { return 1; } return 2; } return 0; } int tls_is_ecdsa(struct TLSContext *context) { if (!context) { return 0; } switch (context->cipher) { case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: return 1; } if (context->ec_private_key) { return 1; } return 0; } static void tls_send_server_key_exchange(struct TLSContext *context, int method) { if (!context->is_server) { DEBUG_PRINT ("CANNOT BUILD SERVER KEY EXCHANGE MESSAGE FOR CLIENTS\n"); return; } struct TLSPacket *packet = tls_create_packet(context, TLS_HANDSHAKE, context->version, 0); tls_packet_uint8(packet, 0x0C); unsigned char dummy[3]; tls_packet_append(packet, dummy, 3); int start_len = packet->len; if (method == KEA_dhe_rsa) { tls_dhe_create(context); const char *default_dhe_p = context->default_dhe_p; const char *default_dhe_g = context->default_dhe_g; int key_size; if (!default_dhe_p || !default_dhe_g) { default_dhe_p = TLS_DH_DEFAULT_P; default_dhe_g = TLS_DH_DEFAULT_G; key_size = TLS_DHE_KEY_SIZE / 8; } else { key_size = strlen(default_dhe_p); } if (tls_dh_make_key(key_size, context->dhe, default_dhe_p, default_dhe_g, 0, 0)) { DEBUG_PRINT("ERROR CREATING DHE KEY\n"); free(packet); free(context->dhe); context->dhe = NULL; /* TODO set error */ return; } unsigned char dh_Ys[0xFFF]; unsigned char dh_p[0xFFF]; unsigned char dh_g[0xFFF]; unsigned long dh_p_len = sizeof(dh_p); unsigned long dh_g_len = sizeof(dh_g); unsigned long dh_Ys_len = sizeof(dh_Ys); if (tls_dh_export_pqY (dh_p, &dh_p_len, dh_g, &dh_g_len, dh_Ys, &dh_Ys_len, context->dhe)) { DEBUG_PRINT("ERROR EXPORTING DHE KEY\n"); free(packet); /* TODO set error */ return; } DEBUG_PRINT("LEN: %lu (%lu, %lu)\n", dh_Ys_len, dh_p_len, dh_g_len); DEBUG_DUMP_HEX_LABEL("DHE PK", dh_Ys, dh_Ys_len); DEBUG_DUMP_HEX_LABEL("DHE P", dh_p, dh_p_len); DEBUG_DUMP_HEX_LABEL("DHE G", dh_g, dh_g_len); tls_packet_uint16(packet, dh_p_len); tls_packet_append(packet, dh_p, dh_p_len); tls_packet_uint16(packet, dh_g_len); tls_packet_append(packet, dh_g, dh_g_len); tls_packet_uint16(packet, dh_Ys_len); tls_packet_append(packet, dh_Ys, dh_Ys_len); /* dh_p */ /* dh_g */ /* dh_Ys */ } else if (method == KEA_ec_diffie_hellman) { /* 3 = named curve */ if (!context->curve) { context->curve = tls_ecc_default_curve; } tls_packet_uint8(packet, 3); tls_packet_uint16(packet, context->curve->iana); tls_ecc_dhe_create(context); ltc_ecc_set_type *dp = (ltc_ecc_set_type *) & context->curve->dp; if (ecc_make_key_ex(NULL, find_prng("sprng"), context->ecc_dhe, dp)) { free(context->ecc_dhe); context->ecc_dhe = NULL; DEBUG_PRINT("Error generating ECC key\n"); free(packet); /* TODO set error */ return; } unsigned char out[TLS_MAX_RSA_KEY]; unsigned long out_len = TLS_MAX_RSA_KEY; if (ecc_ansi_x963_export(context->ecc_dhe, out, &out_len)) { DEBUG_PRINT("Error exporting ECC key\n"); free(packet); /* TODO abort */ return; } tls_packet_uint8(packet, out_len); tls_packet_append(packet, out, out_len); } else { /* TODO abort */ free(packet); DEBUG_PRINT("Unsupported ephemeral method: %i\n", method); return; } /* signature */ unsigned int params_len = packet->len - start_len; unsigned int message_len = params_len + TLS_CLIENT_RANDOM_SIZE + TLS_SERVER_RANDOM_SIZE; unsigned char *message = malloc(message_len); if (message) { unsigned char out[TLS_MAX_RSA_KEY]; unsigned long out_len = TLS_MAX_RSA_KEY; int hash_algorithm; hash_algorithm = sha256; if (tls_is_ecdsa(context)) { hash_algorithm = sha512; tls_packet_uint8(packet, hash_algorithm); tls_packet_uint8(packet, ecdsa); } else { tls_packet_uint8(packet, hash_algorithm); tls_packet_uint8(packet, rsa_sign); } memcpy(message, context->remote_random, TLS_CLIENT_RANDOM_SIZE); memcpy(message + TLS_CLIENT_RANDOM_SIZE, context->local_random, TLS_SERVER_RANDOM_SIZE); memcpy(message + TLS_CLIENT_RANDOM_SIZE + TLS_SERVER_RANDOM_SIZE, packet->buf + start_len, params_len); if (tls_is_ecdsa(context)) { if (sign_ecdsa(context, hash_algorithm, message, message_len, out, &out_len) == 1) { DEBUG_PRINT ("Signing OK! (ECDSA, length %lu)\n", out_len); tls_packet_uint16(packet, out_len); tls_packet_append(packet, out, out_len); } } else if (sign_rsa(context, hash_algorithm, message, message_len, out, &out_len) == 1) { DEBUG_PRINT("Signing OK! (length %lu)\n", out_len); tls_packet_uint16(packet, out_len); tls_packet_append(packet, out, out_len); } free(message); } if (!packet->broken && packet->buf) { tls_set_packet_length(packet, packet->len - start_len); } tls_packet_update(packet); tls_queue_packet(packet); return; } #if 0 void _private_tls_set_session_id(struct TLSContext *context) { if (context->tlsver == TLS_VERSION13 && context->session_size == TLS_MAX_SESSION_ID) { return; } if (tls_random(context->session, TLS_MAX_SESSION_ID)) { context->session_size = TLS_MAX_SESSION_ID; } else { context->session_size = 0; } } #endif struct TLSPacket *tls_certificate_request(struct TLSContext *context) { if (!context || !context->is_server) { return NULL; } unsigned short packet_version = context->version; struct TLSPacket *packet = tls_create_packet(context, TLS_HANDSHAKE, packet_version, 0); if (!packet) { return NULL; } /* cert request and size placeholder */ unsigned char dummy[] = { 0x0d, 0, 0, 0 }; tls_packet_append(packet, dummy, sizeof dummy); int start_len = packet->len; if (context->tlsver == TLS_VERSION13) { /* certificate request context */ tls_packet_uint8(packet, 0); /* extensions */ tls_packet_uint16(packet, 18); /* signature algorithms */ tls_packet_uint16(packet, 0x0D); tls_packet_uint16(packet, 14); tls_packet_uint16(packet, 12); #if 0 rsa_pkcs1_sha256 tls_packet_uint16(packet, 0x0401); rsa_pkcs1_sha384 tls_packet_uint16(packet, 0x0501); rsa_pkcs1_sha512 tls_packet_uint16(packet, 0x0601); #endif /* ecdsa_secp256r1_sha256 */ tls_packet_uint16(packet, 0x0403); /* ecdsa_secp384r1_sha384 */ tls_packet_uint16(packet, 0x0503); /* ecdsa_secp521r1_sha512 */ tls_packet_uint16(packet, 0x0604); /* rsa_pss_rsae_sha256 */ tls_packet_uint16(packet, 0x0804); /* rsa_pss_rsae_sha384 */ tls_packet_uint16(packet, 0x0805); /* rsa_pss_rsae_sha512 */ tls_packet_uint16(packet, 0x0806); } else { tls_packet_uint8(packet, 1); tls_packet_uint8(packet, rsa_sign); if (context->version == TLS_V12) { /* 10 pairs or 2 bytes */ tls_packet_uint16(packet, 10); tls_packet_uint8(packet, sha256); tls_packet_uint8(packet, rsa); tls_packet_uint8(packet, sha1); tls_packet_uint8(packet, rsa); tls_packet_uint8(packet, sha384); tls_packet_uint8(packet, rsa); tls_packet_uint8(packet, sha512); tls_packet_uint8(packet, rsa); tls_packet_uint8(packet, md5); tls_packet_uint8(packet, rsa); } /* no DistinguishedName yet */ tls_packet_uint16(packet, 0); } if (!packet->broken) { tls_set_packet_length(packet, packet->len - start_len); } tls_packet_update(packet); return packet; } int tls_parse_key_share(struct TLSContext *context, const unsigned char *buf, int buf_len) { int i = 0; struct ECCCurveParameters *curve = 0; struct DHKey *dhkey = 0; int dhe_key_size = 0; const unsigned char *buffer = NULL; unsigned char *out2; unsigned long out_size; uint16_t key_size; while (buf_len >= 4) { uint16_t named_group = get16(&buf[i]); i += 2; buf_len -= 2; key_size = get16(&buf[i]); i += 2; buf_len -= 2; if (key_size > buf_len) { return TLS_BROKEN_PACKET; } switch (named_group) { case 0x0017: curve = &secp256r1; buffer = &buf[i]; DEBUG_PRINT("KEY SHARE => secp256r1\n"); buf_len = 0; continue; case 0x0018: /* secp384r1 */ curve = &secp384r1; buffer = &buf[i]; DEBUG_PRINT("KEY SHARE => secp384r1\n"); buf_len = 0; continue; case 0x0019: /* secp521r1 */ break; case 0x001D: /* x25519 */ if (key_size != 32) { DEBUG_PRINT ("INVALID x25519 KEY SIZE (%i)\n", key_size); continue; } curve = &curve25519; buffer = &buf[i]; DEBUG_PRINT("KEY SHARE => x25519\n"); buf_len = 0; continue; break; case 0x001E: /* x448 */ break; case 0x0100: dhkey = &ffdhe2048; dhe_key_size = 2048; break; case 0x0101: dhkey = &ffdhe3072; dhe_key_size = 3072; break; case 0x0102: dhkey = &ffdhe4096; dhe_key_size = 4096; break; case 0x0103: dhkey = &ffdhe6144; dhe_key_size = 6144; break; case 0x0104: dhkey = &ffdhe8192; dhe_key_size = 8192; break; } i += key_size; buf_len -= key_size; } if (curve) { context->curve = curve; if (curve == &curve25519) { if (!tls_random(context->local_random, TLS_SERVER_RANDOM_SIZE)) { return TLS_GENERIC_ERROR; } unsigned char secret[32]; static const unsigned char basepoint[32] = { 9 }; tls_random(secret, 32); secret[0] &= 248; secret[31] &= 127; secret[31] |= 64; /* use finished key to store public key */ free(context->finished_key); context->finished_key = malloc(32); if (!context->finished_key) { return TLS_GENERIC_ERROR; } x25519(context->finished_key, secret, basepoint); free(context->premaster_key); context->premaster_key = malloc(32); if (!context->premaster_key) { return TLS_GENERIC_ERROR; } x25519(context->premaster_key, secret, buffer); context->premaster_key_len = 32; return 0; } tls_ecc_dhe_create(context); ltc_ecc_set_type *dp = (ltc_ecc_set_type *)&context->curve->dp; if (ecc_make_key_ex(NULL, find_prng("sprng"), context->ecc_dhe, dp)) { free(context->ecc_dhe); context->ecc_dhe = NULL; DEBUG_PRINT("Error generating ECC DHE key\n"); return TLS_GENERIC_ERROR; } if (!tls_random(context->local_random, TLS_SERVER_RANDOM_SIZE)) { return TLS_GENERIC_ERROR; } ecc_key client_key; memset(&client_key, 0, sizeof client_key); if (ecc_ansi_x963_import_ex (buffer, key_size, &client_key, dp)) { DEBUG_PRINT("Error importing ECC DHE key\n"); return TLS_GENERIC_ERROR; } out2 = malloc(key_size); out_size = key_size; int err = ecc_shared_secret(context->ecc_dhe, &client_key, out2, &out_size); ecc_free(&client_key); if (err) { DEBUG_PRINT("ECC DHE DECRYPT ERROR %i\n", err); free(out2); return TLS_GENERIC_ERROR; } DEBUG_PRINT("OUT_SIZE: %lu\n", out_size); DEBUG_DUMP_HEX_LABEL("ECC DHE", out2, out_size); free(context->premaster_key); context->premaster_key = out2; context->premaster_key_len = out_size; return 0; } else if (dhkey) { tls_dhe_create(context); if (!tls_random(context->local_random, TLS_SERVER_RANDOM_SIZE)) { return TLS_GENERIC_ERROR; } if (tls_dh_make_key(dhe_key_size / 8, context->dhe, (const char *)dhkey->p, (const char *)dhkey->g, 0, 0)) { free(context->dhe); context->dhe = NULL; DEBUG_PRINT("Error generating DHE key\n"); return TLS_GENERIC_ERROR; } unsigned int dhe_out_size; out2 = tls_decrypt_dhe(context, buffer, key_size, &dhe_out_size, 0); if (!out2) { DEBUG_PRINT("Error generating DHE shared key\n"); return TLS_GENERIC_ERROR; } free(context->premaster_key); context->premaster_key = out2; context->premaster_key_len = dhe_out_size; if (context->dhe) { context->dhe->iana = dhkey->iana; } return 0; } DEBUG_PRINT("NO COMMON KEY SHARE SUPPORTED\n"); return TLS_NO_COMMON_CIPHER; } int tls_parse_certificate(struct TLSContext *context, const unsigned char *buf, int buf_len, int is_client) { int res = 0; if (buf_len < 3) { return TLS_NEED_MORE_DATA; } int size = get24(buf); /* not enough data, so just consume all of it */ if (size <= 4) { return 3 + size; } res += 3; /* skip over the size field */ if (context->tlsver == TLS_VERSION13) { int context_size = buf[res]; res++; /* must be 0 */ if (context_size) { res += context_size; } } if (buf_len - res < size) { return TLS_NEED_MORE_DATA; } int idx = 0; int valid_certificate = 0; while (size > 0) { idx++; if (buf_len - res < 3) { return TLS_NEED_MORE_DATA; } int certificate_size = get24(buf+res); res += 3; if (buf_len - res < certificate_size) { return TLS_NEED_MORE_DATA; } /* load chain */ int certificates_in_chain = 0; int res2 = res; unsigned int remaining = certificate_size; do { if (remaining <= 3) { break; } certificates_in_chain++; unsigned int certificate_size2 = get24(buf+res2); res2 += 3; remaining -= 3; if (certificate_size2 > remaining) { DEBUG_PRINT ("Invalid certificate size (%i from %i bytes remaining)\n", certificate_size2, remaining); break; } remaining -= certificate_size2; struct TLSCertificate *cert = asn1_parse(context, &buf [res2], certificate_size2, is_client); if (cert) { if (certificate_size2) { cert->bytes = malloc(certificate_size2); if (cert->bytes) { cert->len = certificate_size2; memcpy(cert->bytes, &buf[res2], certificate_size2); } } /* valid certificate */ if (is_client) { void *new; valid_certificate = 1; new = TLS_REALLOC(context-> client_certificates, (context-> client_certificates_count + 1) * sizeof(struct TLSCertificate *)); if (!new) { free(context-> client_certificates); context-> client_certificates = 0; return TLS_NO_MEMORY; } context->client_certificates = new; context-> client_certificates [context->client_certificates_count] = cert; context-> client_certificates_count++; } else { void *new; new = TLS_REALLOC(context-> certificates, (context-> certificates_count + 1) * sizeof(struct TLSCertificate *)); if (!new) { free(context-> certificates); context->certificates = 0; return TLS_NO_MEMORY; } context->certificates = new; context->certificates[context-> certificates_count] = cert; context->certificates_count++; if ((cert->pk) || (cert->priv)) valid_certificate = 1; else if (!context->is_server) valid_certificate = 1; } } res2 += certificate_size2; /* extension */ if (context->tlsver == TLS_VERSION13) { if (remaining >= 2) { /* ignore extensions */ remaining -= 2; uint16_t size = get16(&buf[res2]); if (size && size >= remaining) { res2 += size; remaining -= size; } } } } while (remaining > 0); if (remaining) { DEBUG_PRINT("Extra %i bytes after certificate\n", remaining); } size -= certificate_size + 3; res += certificate_size; } if (!valid_certificate) { MARK; return TLS_UNSUPPORTED_CERTIFICATE; } if (res != buf_len) { DEBUG_PRINT("Warning: %i bytes read from %i byte buffer\n", (int) res, (int) buf_len); } return res; } static int parse_dh( const unsigned char *buf, int buf_len, const unsigned char **out, int *out_size) { int res = 0; *out = NULL; *out_size = 0; if (buf_len < 2) { return TLS_NEED_MORE_DATA; } uint16_t size = get16(buf); res += 2; if (buf_len - res < size) { return TLS_NEED_MORE_DATA; } DEBUG_DUMP_HEX(&buf[res], size); *out = &buf[res]; *out_size = size; res += size; return res; } static int tls_parse_random(struct TLSContext *context, const unsigned char *buf, int buf_len) { int res = 0; int ephemeral = tls_cipher_is_ephemeral(context); uint16_t size; if (ephemeral == 2) { if (buf_len < 1) { return TLS_NEED_MORE_DATA; } size = buf[0]; res += 1; } else { if (buf_len < 2) { return TLS_NEED_MORE_DATA; } size = get16(buf); res += 2; } if (buf_len - res < size) { return TLS_NEED_MORE_DATA; } unsigned int out_len = 0; unsigned char *random = NULL; switch (ephemeral) { case 1: random = tls_decrypt_dhe(context, &buf[res], size, &out_len, 1); break; case 2: random = tls_decrypt_ecc_dhe(context, &buf[res], size, &out_len, 1); break; default: random = decrypt_rsa(context, &buf[res], size, &out_len); } if (random && out_len > 2) { /* *(unsigned short *)&random[0] = htons(context->version); */ DEBUG_DUMP_HEX_LABEL("PRE MASTER KEY", random, out_len); free(context->premaster_key); context->premaster_key = random; context->premaster_key_len = out_len; tls_compute_key(context, 48); } else { free(random); return 0; } res += size; return res; } static const unsigned char *parse_signature(const unsigned char *buf, int buf_len, int *hash_algorithm, int *sign_algorithm, int *sig_size, int *offset) { int res = 0; if (buf_len < 2) { return NULL; } *hash_algorithm = _md5_sha1; *sign_algorithm = rsa_sign; *sig_size = 0; *hash_algorithm = buf[res]; res++; *sign_algorithm = buf[res]; res++; uint16_t size = get16(&buf[res]); res += 2; if (buf_len - res < size) { return NULL; } DEBUG_DUMP_HEX(&buf[res], size); *sig_size = size; *offset = res + size; return &buf[res]; } int tls_parse_server_key_exchange(struct TLSContext *context, const unsigned char *buf, int buf_len) { int res = 0; int dh_res = 0; if (buf_len < 3) { return TLS_NEED_MORE_DATA; } int size = get24(buf); res += 3; const unsigned char *packet_ref = buf + res; if (buf_len - res < size) { return TLS_NEED_MORE_DATA; } if (!size) { return res; } unsigned char has_ds_params = 0; int key_size = 0; const struct ECCCurveParameters *curve = NULL; const unsigned char *pk_key = NULL; int ephemeral = tls_cipher_is_ephemeral(context); if (ephemeral) { if (ephemeral == 1) { has_ds_params = 1; } else { if (buf[res++] != 3) { /* named curve */ /* any other method is not supported */ return 0; } if (buf_len - res < 3) { return TLS_NEED_MORE_DATA; } int iana_n = get16(&buf[res]); res += 2; key_size = buf[res]; res++; if (buf_len - res < key_size) { return TLS_NEED_MORE_DATA; } DEBUG_PRINT("IANA CURVE NUMBER: %i\n", iana_n); switch (iana_n) { case 19: curve = &secp192r1; break; case 20: curve = &secp224k1; break; case 21: curve = &secp224r1; break; case 22: curve = &secp256k1; break; case 23: curve = &secp256r1; break; case 24: curve = &secp384r1; break; case 25: curve = &secp521r1; break; default: DEBUG_PRINT("UNSUPPORTED CURVE\n"); return TLS_GENERIC_ERROR; } pk_key = &buf[res]; res += key_size; context->curve = curve; } } const unsigned char *dh_p = NULL; int dh_p_len = 0; const unsigned char *dh_g = NULL; int dh_g_len = 0; const unsigned char *dh_Ys = NULL; int dh_Ys_len = 0; if (has_ds_params) { DEBUG_PRINT(" dh_p: "); dh_res = parse_dh(&buf[res], buf_len - res, &dh_p, &dh_p_len); if (dh_res <= 0) { return TLS_BROKEN_PACKET; } res += dh_res; DEBUG_PRINT("\n"); DEBUG_PRINT(" dh_q: "); dh_res = parse_dh(&buf[res], buf_len - res, &dh_g, &dh_g_len); if (dh_res <= 0) { MARK; return TLS_BROKEN_PACKET; } res += dh_res; DEBUG_PRINT("\n"); DEBUG_PRINT(" dh_Ys: "); dh_res = parse_dh(&buf[res], buf_len - res, &dh_Ys, &dh_Ys_len); if (dh_res <= 0) { MARK; return TLS_BROKEN_PACKET; } res += dh_res; DEBUG_PRINT("\n"); } int sign_size; int hash_algorithm; int sign_algorithm; int packet_size = res - 3; int offset = 0; DEBUG_PRINT(" SIGNATURE (%i/%i/%i): ", packet_size, dh_res, key_size); const unsigned char *signature = parse_signature(&buf[res], buf_len - res, &hash_algorithm, &sign_algorithm, &sign_size, &offset); DEBUG_PRINT("\n"); if (sign_size <= 0 || !signature) { return TLS_BROKEN_PACKET; } res += offset; /* check signature */ unsigned int message_len = packet_size + TLS_CLIENT_RANDOM_SIZE + TLS_SERVER_RANDOM_SIZE; unsigned char *message = malloc(message_len); if (message) { memcpy(message, context->local_random, TLS_CLIENT_RANDOM_SIZE); memcpy(message + TLS_CLIENT_RANDOM_SIZE, context->remote_random, TLS_SERVER_RANDOM_SIZE); memcpy(message + TLS_CLIENT_RANDOM_SIZE + TLS_SERVER_RANDOM_SIZE, packet_ref, packet_size); if (tls_is_ecdsa(context)) { if (tls_verify_ecdsa (context, hash_algorithm, signature, sign_size, message, message_len, NULL) != 1) { DEBUG_PRINT ("ECC Server signature FAILED!\n"); free(message); return TLS_BROKEN_PACKET; } } else { if (verify_rsa(context, hash_algorithm, signature, sign_size, message, message_len) != 1) { DEBUG_PRINT("Server signature FAILED!\n"); free(message); return TLS_BROKEN_PACKET; } } free(message); } if (buf_len - res) { DEBUG_PRINT("EXTRA %i BYTES AT THE END OF MESSAGE\n", buf_len - res); DEBUG_DUMP_HEX(&buf[res], buf_len - res); DEBUG_PRINT("\n"); } if (ephemeral == 1) { tls_dhe_create(context); DEBUG_DUMP_HEX_LABEL("DHP", dh_p, dh_p_len); DEBUG_DUMP_HEX_LABEL("DHG", dh_g, dh_g_len); int dhe_key_size = dh_p_len; if (dh_g_len > dh_p_len) { dhe_key_size = dh_g_len; } if (tls_dh_make_key(dhe_key_size, context->dhe, (const char *) dh_p, (const char *) dh_g, dh_p_len, dh_g_len)) { DEBUG_PRINT("ERROR CREATING DHE KEY\n"); free(context->dhe); context->dhe = NULL; return TLS_GENERIC_ERROR; } unsigned int dh_key_size = 0; unsigned char *key = tls_decrypt_dhe(context, dh_Ys, dh_Ys_len, &dh_key_size, 0); DEBUG_DUMP_HEX_LABEL("DH COMMON SECRET", key, dh_key_size); if (key && dh_key_size) { free(context->premaster_key); context->premaster_key = key; context->premaster_key_len = dh_key_size; } } else if (ephemeral == 2 && curve && pk_key && key_size) { tls_ecc_dhe_create(context); ltc_ecc_set_type *dp = (ltc_ecc_set_type *) & curve->dp; if (ecc_make_key_ex (NULL, find_prng("sprng"), context->ecc_dhe, dp)) { free(context->ecc_dhe); context->ecc_dhe = NULL; DEBUG_PRINT("Error generating ECC key\n"); return TLS_GENERIC_ERROR; } free(context->premaster_key); context->premaster_key_len = 0; unsigned int out_len = 0; context->premaster_key = tls_decrypt_ecc_dhe(context, pk_key, key_size, &out_len, 0); if (context->premaster_key) { context->premaster_key_len = out_len; } } return res; } int tls_parse_client_key_exchange(struct TLSContext *context, const unsigned char *buf, int buf_len) { if (context->connection_status != 1) { DEBUG_PRINT ("UNEXPECTED CLIENT KEY EXCHANGE MESSAGE (connections status: %i)\n", (int) context->connection_status); return TLS_UNEXPECTED_MESSAGE; } int res = 0; int dh_res = 0; if (buf_len < 3) { return TLS_NEED_MORE_DATA; } int size = get24(buf); res += 3; if (buf_len - res < size) { return TLS_NEED_MORE_DATA; } if (!size) { return res; } dh_res = tls_parse_random(context, &buf[res], size); if (dh_res <= 0) { DEBUG_PRINT("broken key\n"); return TLS_BROKEN_PACKET; } DEBUG_PRINT("\n"); res += size; context->connection_status = 2; return res; } static int tls_parse_server_hello_done(const unsigned char *buf, int buf_len) { int res = 0; if (buf_len < 3) { return TLS_NEED_MORE_DATA; } int size = get24(buf); res += 3; if (buf_len - res < size) { return TLS_NEED_MORE_DATA; } MARK; res += size; return res; } int tls_parse_finished(struct TLSContext *context, const unsigned char *buf, int buf_len, unsigned int *write_packets) { if (context->connection_status < 2 || context->connection_status == TLS_CONNECTED) { DEBUG_PRINT("UNEXPECTED FINISHED MESSAGE\n"); return TLS_UNEXPECTED_MESSAGE; } int res = 0; *write_packets = 0; if (buf_len < 3) { return TLS_NEED_MORE_DATA; } int size = get24(buf); res += 3; if (size < TLS_MIN_FINISHED_OPAQUE_LEN) { DEBUG_PRINT("Invalid finished packet size: %i\n", size); return TLS_BROKEN_PACKET; } if (buf_len - res < size) { return TLS_NEED_MORE_DATA; } unsigned char hash[TLS_MAX_SHA_SIZE]; unsigned int hash_len = tls_get_hash(context, hash); if (context->tlsver == TLS_VERSION13) { unsigned char hash_out[TLS_MAX_SHA_SIZE]; unsigned long out_size = TLS_MAX_SHA_SIZE; if (!context->remote_finished_key || !hash_len) { DEBUG_PRINT ("NO FINISHED KEY COMPUTED OR NO HANDSHAKE HASH\n"); return TLS_NOT_VERIFIED; } DEBUG_DUMP_HEX_LABEL("HS HASH", hash, hash_len); DEBUG_DUMP_HEX_LABEL("HS FINISH", context->remote_finished_key, hash_len); out_size = hash_len; hmac_state hmac; hmac_init(&hmac, tls_get_hash_idx(context), context->remote_finished_key, hash_len); hmac_process(&hmac, hash, hash_len); hmac_done(&hmac, hash_out, &out_size); if (size != (int)out_size || memcmp(hash_out, &buf[res], size)) { DEBUG_PRINT ("Finished validation error (sequence number, local: %i, remote: %i)\n", (int) context->local_sequence_number, (int) context->remote_sequence_number); DEBUG_DUMP_HEX_LABEL("FINISHED OPAQUE", &buf[res], size); DEBUG_DUMP_HEX_LABEL("VERIFY", hash_out, out_size); return TLS_NOT_VERIFIED; } if (context->is_server) { context->connection_status = TLS_CONNECTED; res += size; _private_tls13_key(context, 0); context->local_sequence_number = 0; context->remote_sequence_number = 0; return res; } /* TODO client verify */ } else { /* verify */ unsigned char *out = malloc(size); if (!out) { DEBUG_PRINT("Error in malloc (%i bytes)\n", (int) size); return TLS_NO_MEMORY; } /* server verifies client's message */ if (context->is_server) { tls_prf(context, out, size, context->master_key, context->master_key_len, (unsigned char *) "client finished", 15, hash, hash_len, NULL, 0); } else { tls_prf(context, out, size, context->master_key, context->master_key_len, (unsigned char *) "server finished", 15, hash, hash_len, NULL, 0); } if (memcmp(out, &buf[res], size)) { free(out); DEBUG_PRINT ("Finished validation error (sequence number, local: %i, remote: %i)\n", (int) context->local_sequence_number, (int) context->remote_sequence_number); DEBUG_DUMP_HEX_LABEL("FINISHED OPAQUE", &buf[res], size); DEBUG_DUMP_HEX_LABEL("VERIFY", out, size); return TLS_NOT_VERIFIED; } free(out); } if (context->is_server) { *write_packets = 3; } else { context->connection_status = TLS_CONNECTED; } // fprintf(stderr, "set conn status = %d\n", context->connection_status); MARK; res += size; return res; } int tls_parse_verify_tls13(struct TLSContext *context, const unsigned char *buf, int buf_len) { if (buf_len < 7) { return TLS_NEED_MORE_DATA; } int size = get24(buf); if (size < 2) { return buf_len; } unsigned char signing_data[TLS_MAX_HASH_SIZE + 98]; int signing_data_len; /* first 64 bytes to 0x20 (32) */ memset(signing_data, 0x20, 64); /* context string 33 bytes */ if (context->is_server) { memcpy(signing_data + 64, "TLS 1.3, server CertificateVerify", 33); } else { memcpy(signing_data + 64, "TLS 1.3, client CertificateVerify", 33); } /* a single 0 byte separator */ signing_data[97] = 0; signing_data_len = 98; signing_data_len += tls_get_hash(context, signing_data + 98); DEBUG_DUMP_HEX_LABEL("signature data", signing_data, signing_data_len); uint16_t signature = get16(&buf[3]); uint16_t signature_size = get16(&buf[5]); int valid = 0; if (buf_len < size + 7) { return TLS_NEED_MORE_DATA; } switch (signature) { case 0x0403: /* secp256r1 + sha256 */ valid = tls_verify_ecdsa(context, sha256, buf + 7, signature_size, signing_data, signing_data_len, &secp256r1); break; case 0x0503: /* secp384r1 + sha384 */ valid = tls_verify_ecdsa(context, sha384, buf + 7, signature_size, signing_data, signing_data_len, &secp384r1); break; case 0x0603: /* secp521r1 + sha512 */ valid = tls_verify_ecdsa(context, sha512, buf + 7, signature_size, signing_data, signing_data_len, &secp521r1); break; case 0x0804: valid = verify_rsa(context, sha256, buf + 7, signature_size, signing_data, signing_data_len); break; default: DEBUG_PRINT("Unsupported signature: %x\n", (int) signature); return TLS_UNSUPPORTED_CERTIFICATE; } if (valid != 1) { DEBUG_PRINT("Signature FAILED!\n"); return TLS_DECRYPTION_FAILED; } return buf_len; } int tls_parse_verify(struct TLSContext *context, const unsigned char *buf, int buf_len) { if (context->tlsver == TLS_VERSION13) { return tls_parse_verify_tls13(context, buf, buf_len); } if (buf_len < 7) { return TLS_BAD_CERTIFICATE; } int bytes_to_follow = get24(buf); if (buf_len - 3 < bytes_to_follow) { return TLS_BAD_CERTIFICATE; } int res = -1; unsigned int hash = buf[3]; unsigned int algorithm = buf[4]; if (algorithm != rsa) { return TLS_UNSUPPORTED_CERTIFICATE; } uint16_t size = get16(&buf[5]); if (bytes_to_follow - 4 < size) { return TLS_BAD_CERTIFICATE; } DEBUG_PRINT("ALGORITHM %i/%i (%i)\n", hash, algorithm, (int) size); DEBUG_DUMP_HEX_LABEL("VERIFY", &buf[7], bytes_to_follow - 7); res = verify_rsa(context, hash, &buf[7], size, context->cached_handshake.buffer, context->cached_handshake.len); tls_buffer_free(&context->cached_handshake); if (res == 1) { DEBUG_PRINT("Signature OK\n"); context->client_verified = 1; } else { DEBUG_PRINT("Signature FAILED\n"); context->client_verified = 0; } return 1; } /* TODO This is actually a parse a handshake message */ int tls_parse_payload(struct TLSContext *context, const unsigned char *buf, int buf_len) { ENTER; int orig_len = buf_len; if (context->connection_status == TLS_CONNECTED) { if (context->version == TLS_V13) { tls_alert(context, 1, unexpected_message); } else { tls_alert(context, 0, no_renegotiation_RESERVED); } return 1; } while (buf_len >= 4 && !context->critical_error) { int payload_res = 0; //unsigned char update_hash = 1; unsigned char type = buf[0]; unsigned int write_packets = 0; int certificate_verify_alert = no_error; int payload_size = get24(buf+1) + 3; if (buf_len < payload_size + 1) { return TLS_NEED_MORE_DATA; } switch (type) { case 0x00: /* hello request */ CHECK_HANDSHAKE_STATE(context, 0, 1); DEBUG_PRINT (" => HELLO REQUEST (RENEGOTIATION?)\n"); if (context->is_server) { payload_res = TLS_UNEXPECTED_MESSAGE; } else { if (context->connection_status == TLS_CONNECTED) { /* renegotiation */ payload_res = TLS_NO_RENEGOTIATION; } else { payload_res = TLS_UNEXPECTED_MESSAGE; } } /* no payload */ break; case 0x01: /* client hello */ CHECK_HANDSHAKE_STATE(context, 1, 1); DEBUG_PRINT(" => CLIENT HELLO\n"); if (context->is_server) { payload_res = tls_parse_client_hello(context, buf + 1, payload_size, &write_packets); } else { payload_res = TLS_UNEXPECTED_MESSAGE; } break; case 0x02: /* server hello */ CHECK_HANDSHAKE_STATE(context, 2, 1); DEBUG_PRINT(" => SERVER HELLO\n"); if (context->is_server) { payload_res = TLS_UNEXPECTED_MESSAGE; } else { write_packets = 0; payload_res = tls_parse_server_hello(context, buf + 1, payload_size); } break; case 0x03: /* hello verify request */ DEBUG_PRINT(" => VERIFY REQUEST\n"); CHECK_HANDSHAKE_STATE(context, 3, 1); payload_res = TLS_UNEXPECTED_MESSAGE; break; case 0x0B: /* certificate */ CHECK_HANDSHAKE_STATE(context, 4, 1); DEBUG_PRINT(" => CERTIFICATE\n"); if (context->tlsver == TLS_VERSION13) { if (context->connection_status == 2) { payload_res = tls_parse_certificate(context, buf + 1, payload_size, context-> is_server); if (context->is_server) { if (context->certificate_verify && context->client_certificates_count) { certificate_verify_alert = context->certificate_verify(context, context->client_certificates, context->client_certificates_count); } /* empty certificates are permitted for client */ if (payload_res <= 0) { payload_res = 1; } } } else payload_res = TLS_UNEXPECTED_MESSAGE; } else if (context->connection_status == 1) { if (context->is_server) { /* client certificate */ payload_res = tls_parse_certificate(context, buf + 1, payload_size, 1); if (context->certificate_verify && context->client_certificates_count) { certificate_verify_alert = context->certificate_verify (context, context-> client_certificates, context-> client_certificates_count); } /* empty certificates are permitted for client */ if (payload_res <= 0) payload_res = 1; } else { payload_res = tls_parse_certificate(context, buf + 1, payload_size, 0); if (certificate_verify && context->certificates_count) { certificate_verify_alert = context->certificate_verify (context, context->certificates, context-> certificates_count); } } } else { payload_res = TLS_UNEXPECTED_MESSAGE; } break; case 0x0C: /* server key exchange */ CHECK_HANDSHAKE_STATE(context, 5, 1); DEBUG_PRINT(" => SERVER KEY EXCHANGE\n"); if (context->is_server) { payload_res = TLS_UNEXPECTED_MESSAGE; } else { payload_res = tls_parse_server_key_exchange(context, buf + 1, payload_size); } break; case 0x0D: /* certificate request */ CHECK_HANDSHAKE_STATE(context, 6, 1); /* server to client */ if (context->is_server) { payload_res = TLS_UNEXPECTED_MESSAGE; } else { context->client_verified = 2; } DEBUG_PRINT(" => CERTIFICATE REQUEST\n"); break; case 0x0E: /* server hello done */ CHECK_HANDSHAKE_STATE(context, 7, 1); DEBUG_PRINT(" => SERVER HELLO DONE\n"); if (context->is_server) { payload_res = TLS_UNEXPECTED_MESSAGE; } else { payload_res = tls_parse_server_hello_done( buf + 1, payload_size); if (payload_res > 0) { write_packets = 1; } } break; case 0x0F: /* certificate verify */ CHECK_HANDSHAKE_STATE(context, 8, 1); DEBUG_PRINT(" => CERTIFICATE VERIFY\n"); if (context->connection_status == 2) { payload_res = tls_parse_verify(context, buf + 1, payload_size); } else { payload_res = TLS_UNEXPECTED_MESSAGE; } break; case 0x10: /* client key exchange */ CHECK_HANDSHAKE_STATE(context, 9, 1); DEBUG_PRINT(" => CLIENT KEY EXCHANGE\n"); if (context->is_server) { payload_res = tls_parse_client_key_exchange(context, buf + 1, payload_size); } else { payload_res = TLS_UNEXPECTED_MESSAGE; } break; case 0x14: /* finished */ tls_buffer_free(&context->cached_handshake); CHECK_HANDSHAKE_STATE(context, 10, 1); DEBUG_PRINT(" => FINISHED\n"); payload_res = tls_parse_finished(context, buf + 1, payload_size, &write_packets); if (payload_res > 0) { memset(context->hs_messages, 0, sizeof(context->hs_messages)); } break; default: DEBUG_PRINT (" => NOT UNDERSTOOD PAYLOAD TYPE: %x\n", (int) type); return TLS_NOT_UNDERSTOOD; } //if (type != 0x00 && update_hash) { if (type != 0x00) { tls_update_hash(context, buf, payload_size + 1); } if (certificate_verify_alert != no_error) { MARK; tls_alert(context, 1, certificate_verify_alert); context->critical_error = 1; } if (payload_res < 0) { switch (payload_res) { case TLS_UNEXPECTED_MESSAGE: MARK; tls_alert(context, 1, unexpected_message); break; case TLS_COMPRESSION_NOT_SUPPORTED: MARK; tls_alert(context, 1, decompression_failure_RESERVED); break; case TLS_BROKEN_PACKET: MARK; tls_alert(context, 1, decode_error); break; case TLS_NO_MEMORY: MARK; tls_alert(context, 1, internal_error); break; case TLS_NOT_VERIFIED: MARK; tls_alert(context, 1, bad_record_mac); break; case TLS_BAD_CERTIFICATE: MARK; if (context->is_server) { /* bad client certificate, continue */ tls_alert(context, 0, bad_certificate); payload_res = 0; } else { tls_alert(context, 1, bad_certificate); } break; case TLS_UNSUPPORTED_CERTIFICATE: MARK; tls_alert(context, 1, unsupported_certificate); break; case TLS_NO_COMMON_CIPHER: MARK; tls_alert(context, 1, insufficient_security); break; case TLS_NOT_UNDERSTOOD: MARK; tls_alert(context, 1, internal_error); break; case TLS_NO_RENEGOTIATION: MARK; tls_alert(context, 0, no_renegotiation_RESERVED); payload_res = 0; break; case TLS_DECRYPTION_FAILED: MARK; tls_alert(context, 1, decryption_failed_RESERVED); break; } if (payload_res < 0) { return payload_res; } } if (certificate_verify_alert != no_error) { payload_res = TLS_BAD_CERTIFICATE; /* TODO this is set but not used */ } /* except renegotiation */ struct TLSPacket *pkt; switch (write_packets) { case 1: if (context->client_verified == 2) { tls_send_certificate(context); context->client_verified = 0; } /* client handshake */ tls_send_client_key_exchange(context); tls_send_change_cipher_spec(context); context->cipher_spec_set = 1; context->local_sequence_number = 0; tls_send_finished(context); context->cipher_spec_set = 0; break; case 2: /* server handshake */ DEBUG_PRINT("<= SENDING SERVER HELLO\n"); if (context->connection_status == 3) { context->connection_status = 2; tls_queue_packet (tls_build_hello(context, 0)); tls_send_change_cipher_spec(context); _private_tls13_key(context, 1); context->cipher_spec_set = 1; DEBUG_PRINT ("<= SENDING ENCRYPTED EXTENSIONS\n"); tls_send_encrypted_extensions(context); if (context->request_client_certificate) { DEBUG_PRINT ("<= SENDING CERTIFICATE REQUEST\n"); tls_queue_packet (tls_certificate_request (context)); } tls_send_certificate(context); tls_send_certificate_verify(context); tls_send_finished(context); /* new key */ free(context->server_finished_hash); context->server_finished_hash = malloc(tls_mac_length(context)); if (context->server_finished_hash) { tls_get_hash(context, context->server_finished_hash); } break; } tls_queue_packet(tls_build_hello(context, 0)); DEBUG_PRINT("<= SENDING CERTIFICATE\n"); tls_send_certificate(context); int ephemeral_cipher = tls_cipher_is_ephemeral(context); if (ephemeral_cipher) { DEBUG_PRINT ("<= SENDING EPHEMERAL DH KEY\n"); tls_send_server_key_exchange(context, ephemeral_cipher == 1 ? KEA_dhe_rsa : KEA_ec_diffie_hellman); } if (context->request_client_certificate) { DEBUG_PRINT ("<= SENDING CERTIFICATE REQUEST\n"); tls_queue_packet (tls_certificate_request (context)); } tls_send_done(context); break; case 3: /* finished */ tls_send_change_cipher_spec(context); tls_send_finished(context); context->connection_status = TLS_CONNECTED; break; case 4: /* dtls only */ /* TODO error */ break; case 5: /* hello retry request */ DEBUG_PRINT("<= SENDING HELLO RETRY REQUEST\n"); pkt = tls_build_hello(context, 0); tls_queue_packet(pkt); break; } payload_size++; buf += payload_size; buf_len -= payload_size; } LEAVE; return orig_len; } unsigned int asn1_get_len(const unsigned char *buffer, int buf_len, unsigned int *octets) { *octets = 0; if (buf_len < 1) { return 0; } unsigned char size = buffer[0]; int i; if (size & 0x80) { *octets = size & 0x7F; if ((int) *octets > buf_len - 1) { return 0; } /* max 32 bits */ unsigned int ref_octets = *octets; if (*octets > 4) { ref_octets = 4; } if ((int) *octets > buf_len - 1) { return 0; } unsigned int long_size = 0; unsigned int coef = 1; for (i = ref_octets; i > 0; i--) { long_size += buffer[i] * coef; coef *= 0x100; } ++*octets; return long_size; } ++*octets; return size; } void print_index(const unsigned int *fields) { int i = 0; while (fields[i]) { if (i) { DEBUG_PRINT("."); } DEBUG_PRINT("%i", fields[i]); i++; } while (i < 6) { DEBUG_PRINT(" "); i++; } } int _is_field(const unsigned int *fields, const unsigned int *prefix) { int i = 0; while (prefix[i]) { if (fields[i] != prefix[i]) { return 0; } i++; } return 1; } static int tls_hash_len(int algorithm) { switch (algorithm) { case TLS_RSA_SIGN_MD5: return 16; case TLS_RSA_SIGN_SHA1: return 20; case TLS_RSA_SIGN_SHA256: return 32; case TLS_RSA_SIGN_SHA384: return 48; case TLS_RSA_SIGN_SHA512: return 64; } return 0; } static unsigned char *tls_compute_hash(int algorithm, const unsigned char *message, unsigned int message_len) { unsigned char *hash = NULL; int err; int hash_index = -1; unsigned long hash_len = 0; if (!message || !message_len) { return hash; } switch (algorithm) { case TLS_RSA_SIGN_MD5: DEBUG_PRINT("SIGN MD5\n"); hash_index = find_hash("md5"); hash_len = 16; break; case TLS_RSA_SIGN_SHA1: DEBUG_PRINT("SIGN SHA1\n"); hash_index = find_hash("sha1"); hash_len = 20; break; case TLS_RSA_SIGN_SHA256: DEBUG_PRINT("SIGN SHA256\n"); hash_index = find_hash("sha256"); hash_len = 32; break; case TLS_RSA_SIGN_SHA384: DEBUG_PRINT("SIGN SHA384\n"); hash_index = find_hash("sha384"); hash_len = 48; break; case TLS_RSA_SIGN_SHA512: DEBUG_PRINT("SIGN SHA512\n"); hash_index = find_hash("sha512"); hash_len = 64; break; default: DEBUG_PRINT("UNKNOWN SIGNATURE ALGORITHM\n"); return NULL; break; } hash = malloc(hash_len); if (!hash) { return NULL; } err = hash_memory(hash_index, message, message_len, hash, &hash_len); if (err) { return NULL; } return hash; } int tls_certificate_verify_signature(struct TLSCertificate *cert, struct TLSCertificate *parent) { if (!cert || !parent || !cert->sign_key || !cert->fingerprint || !cert->sign_len || !parent->der_bytes || !parent->der_len) { DEBUG_PRINT("CANNOT VERIFY SIGNATURE "); if (!cert) { DEBUG_PRINT("!cert "); } else { if (!cert->sign_key) { DEBUG_PRINT("!cert->sign_key "); } if (!cert->fingerprint) { DEBUG_PRINT("!cert->fingerprint "); } if (!cert->sign_len) { DEBUG_PRINT("!cert->sign_len "); } } if (!parent) { DEBUG_PRINT("!parent "); } else { if (!parent->der_bytes) { DEBUG_PRINT("!parent->der_bytes "); } if (!parent->der_len) { DEBUG_PRINT("!parent->der_len "); } } DEBUG_PRINT("\n"); return 0; } DEBUG_PRINT("checking alg\n"); int hash_len = tls_hash_len(cert->algorithm); if (hash_len <= 0) { return 0; } int hash_index; switch (cert->algorithm) { case TLS_RSA_SIGN_MD5: hash_index = find_hash("md5"); break; case TLS_RSA_SIGN_SHA1: hash_index = find_hash("sha1"); break; case TLS_RSA_SIGN_SHA256: hash_index = find_hash("sha256"); break; case TLS_RSA_SIGN_SHA384: hash_index = find_hash("sha384"); break; case TLS_RSA_SIGN_SHA512: hash_index = find_hash("sha512"); break; default: DEBUG_PRINT("UNKNOWN SIGNATURE ALGORITHM\n"); return 0; } rsa_key key; DEBUG_PRINTLN("rsa_import(%p, %d, %p)\n", parent->der_bytes, parent->der_len, &key); int err = rsa_import(parent->der_bytes, parent->der_len, &key); if (err) { DEBUG_PRINTLN ("Error importing RSA certificate (code: %i)\n", err); DEBUG_PRINT("Message: %s\n", error_to_string(err)); DEBUG_DUMP_HEX_LABEL("CERTIFICATE", parent->der_bytes, parent->der_len); return 0; } int rsa_stat = 0; unsigned char *signature = cert->sign_key; int signature_len = cert->sign_len; if (!signature[0]) { signature++; signature_len--; } err = rsa_verify_hash_ex(signature, signature_len, cert->fingerprint, hash_len, LTC_PKCS_1_V1_5, hash_index, 0, &rsa_stat, &key); rsa_free(&key); if (err) { DEBUG_PRINT("HASH VERIFY ERROR %i\n", err); return 0; } DEBUG_PRINT("CERTIFICATE VALIDATION: %i\n", rsa_stat); return rsa_stat; } int tls_certificate_chain_is_valid(struct TLSCertificate **certificates, int len) { if (!certificates || !len) { return bad_certificate; } int i; DEBUG_PRINT("verifying %i length cert chain\n", len); len--; /* expired certificate or not yet valid ? */ if (tls_certificate_is_valid(certificates[0])) { return bad_certificate; } /* check */ for (i = 0; i < len; i++) { /* certificate in chain is expired ? */ if (tls_certificate_is_valid(certificates[i + 1])) { return bad_certificate; } if (!tls_certificate_verify_signature(certificates[i], certificates[i + 1])) { DEBUG_PRINT ("tls_certificate_verify_signature certs[%d], certs[%d+1] failed\n", i, i); return bad_certificate; } } return 0; } int tls_certificate_chain_is_valid_root(struct TLSContext *context, struct TLSCertificate **certificates, int len) { int i, j; if (!certificates || !len || !context->root_certificates || !context->root_count) { return bad_certificate; } for (i = 0; i < len; i++) { for (j = 0; j < context->root_count; j++) { /* check if root certificate expired */ if (tls_certificate_is_valid (context->root_certificates[j])) { continue; } /* if any root validates any certificate in the chain, * then is root validated */ if (tls_certificate_verify_signature(certificates[i], context->root_certificates[j])) { return 0; } } } return bad_certificate; } int _private_is_oid(struct OID_chain *ref_chain, const unsigned char *looked_oid, int looked_oid_len) { while (ref_chain) { if (ref_chain->oid) { if (_is_oid2 (ref_chain->oid, looked_oid, 16, looked_oid_len)) { return 1; } } ref_chain = (struct OID_chain *) ref_chain->top; } return 0; } int _private_asn1_parse(struct TLSContext *context, struct TLSCertificate *cert, const unsigned char *buffer, int size, int level, unsigned int *fields, unsigned char *has_key, int client_cert, unsigned char *top_oid, struct OID_chain *chain) { struct OID_chain local_chain; DEBUG_INDEX(fields); DEBUG_PRINT("\n"); local_chain.top = chain; int pos = 0; /* X.690 */ int idx = 0; unsigned char oid[16]; memset(oid, 0, 16); local_chain.oid = oid; if (has_key) { *has_key = 0; } unsigned char local_has_key = 0; const unsigned char *cert_data = NULL; unsigned int cert_len = 0; while (pos < size) { unsigned int start_pos = pos; if (size - pos < 2) { return TLS_NEED_MORE_DATA; } unsigned char first = buffer[pos++]; unsigned char type = first & 0x1F; unsigned char constructed = first & 0x20; unsigned char element_class = first >> 6; int octets = 0; unsigned int temp; idx++; if (level <= TLS_ASN1_MAXLEVEL) { fields[level - 1] = idx; } DEBUG_INDEX(fields); DEBUG_PRINT("\n"); int length = asn1_get_len((unsigned char *) &buffer[pos], size - pos, &octets); DEBUG_PRINT("asn1_get_len = %u\n", length); if ((octets > 4) || (octets > size - pos)) { DEBUG_PRINT ("CANNOT READ CERTIFICATE octets = %d, size = %d pos = %d, size - pos = %d\n", octets, size, pos, size - pos); return pos; } pos += octets; if (size - pos < length) { return TLS_NEED_MORE_DATA; } /*DEBUG_PRINT("FIRST: %x => %x (%i)\n", (int)first, (int)type, length); */ /* sequence */ /*DEBUG_PRINT("%2i: ", level); */ #ifdef DEBUG DEBUG_INDEX(fields); int i1; for (i1 = 1; i1 < level; i1++) { DEBUG_PRINT(" "); } #endif if (length && constructed) { switch (type) { case 0x03: DEBUG_PRINT("CONSTRUCTED BITSTREAM\n"); break; case 0x10: DEBUG_PRINT("SEQUENCE\n"); if ((level == 2) && (idx == 1)) { cert_len = length + (pos - start_pos); cert_data = &buffer[start_pos]; } /* private key on server or public key on client */ if (!cert->version && (_is_field(fields, priv_der_id))) { free(cert->der_bytes); temp = length + (pos - start_pos); cert->der_bytes = malloc(temp); if (cert->der_bytes) { memcpy(cert->der_bytes, &buffer[start_pos], temp); cert->der_len = temp; } else cert->der_len = 0; } break; case 0x11: DEBUG_PRINT("EMBEDDED PDV\n"); break; case 0x00: if (element_class == 0x02) { DEBUG_PRINT("CONTEXT-SPECIFIC\n"); break; } default: DEBUG_PRINT("CONSTRUCT TYPE %02X\n",(int)type); } local_has_key = 0; _private_asn1_parse(context, cert, &buffer[pos], length, level + 1, fields, &local_has_key, client_cert, top_oid, &local_chain); if (((local_has_key && context && (!context->is_server || client_cert)) || !context) && (_is_field(fields, pk_id))) { free(cert->der_bytes); temp = length + (pos - start_pos); cert->der_bytes = malloc(temp); if (cert->der_bytes) { memcpy(cert->der_bytes, &buffer[start_pos], temp); cert->der_len = temp; } else { cert->der_len = 0; } } } else { switch (type) { case 0x00: /* end of content */ DEBUG_PRINT("END OF CONTENT\n"); return pos; break; case 0x01: /* boolean */ temp = buffer[pos]; DEBUG_PRINT("BOOLEAN: %i\n", temp); break; case 0x02: /* integer */ if (_is_field(fields, pk_id)) { if (has_key) { *has_key = 1; } if (idx == 1) { tls_certificate_set_key (cert, &buffer[pos], length); } else if (idx == 2) { tls_certificate_set_exponent (cert, &buffer[pos], length); } } else if (_is_field(fields, serial_id)) { tls_certificate_set_serial(cert, &buffer [pos], length); } if (_is_field(fields, version_id)) { if (length == 1) { cert->version = buffer[pos]; } #ifdef TLS_X509_V1_SUPPORT else { cert->version = 0; } idx++; #endif } if (level >= 2) { unsigned int fields_temp[3]; fields_temp[0] = fields[level - 2]; fields_temp[1] = fields[level - 1]; fields_temp[2] = 0; if (_is_field (fields_temp, priv_id)) { tls_certificate_set_priv (cert, &buffer[pos], length); } } DEBUG_PRINT("INTEGER(%i): ", length); DEBUG_DUMP_HEX(&buffer[pos], length); if ((chain) && (length > 2)) { if (_private_is_oid (chain, san_oid, sizeof(san_oid) - 1)) { void *new; new = TLS_REALLOC(cert->san, sizeof (unsigned char *) * (cert-> san_length + 1)); if (new) { cert->san = new; cert->san[cert-> san_length] = NULL; tls_certificate_set_copy (&cert-> san[cert-> san_length], &buffer[pos], length); DEBUG_PRINT (" => SUBJECT ALTERNATIVE NAME: %s", cert-> san[cert-> san_length]); cert->san_length++; } else { free(cert-> san); cert->san = 0; cert->san_length = 0; } } } DEBUG_PRINT("\n"); break; case 0x03: if (_is_field(fields, pk_id)) { if (has_key) *has_key = 1; } /* bitstream */ DEBUG_PRINT("BITSTREAM(%i): ", length); DEBUG_DUMP_HEX(&buffer[pos], length); DEBUG_PRINT("\n"); if (_is_field(fields, sign_id) || _is_field(fields, sign_id2)) { DEBUG_PRINT("set sign key\n"); tls_certificate_set_sign_key(cert, &buffer [pos], length); } else if (cert->ec_algorithm && (_is_field(fields, pk_id))) { tls_certificate_set_key(cert, &buffer [pos], length); } else { if (buffer[pos] == 0x00 && length > 256) { _private_asn1_parse (context, cert, &buffer[pos] + 1, length - 1, level + 1, fields, &local_has_key, client_cert, top_oid, &local_chain); } else { _private_asn1_parse (context, cert, &buffer[pos], length, level + 1, fields, &local_has_key, client_cert, top_oid, &local_chain); } if (top_oid) { if (_is_oid2 (top_oid, TLS_EC_prime256v1_OID, sizeof(oid), sizeof (TLS_EC_prime256v1) - 1)) { cert-> ec_algorithm = secp256r1.iana; } else if (_is_oid2 (top_oid, TLS_EC_secp224r1_OID, sizeof(oid), sizeof (TLS_EC_secp224r1_OID) - 1)) { cert-> ec_algorithm = secp224r1.iana; } else if (_is_oid2 (top_oid, TLS_EC_secp384r1_OID, sizeof(oid), sizeof (TLS_EC_secp384r1_OID) - 1)) { cert-> ec_algorithm = secp384r1.iana; } else if (_is_oid2 (top_oid, TLS_EC_secp521r1_OID, sizeof(oid), sizeof (TLS_EC_secp521r1_OID) - 1)) { cert-> ec_algorithm = secp521r1.iana; } if ((cert->ec_algorithm) && (!cert->pk)) tls_certificate_set_key (cert, &buffer[pos], length); } } break; case 0x04: if (top_oid && _is_field(fields, ecc_priv_id) && !cert->priv) { DEBUG_PRINT("BINARY STRING(%i): ", length); DEBUG_DUMP_HEX(&buffer[pos], length); DEBUG_PRINT("\n"); tls_certificate_set_priv(cert, &buffer [pos], length); } else { _private_asn1_parse(context, cert, &buffer[pos], length, level + 1, fields, &local_has_key, client_cert, top_oid, &local_chain); } break; case 0x05: DEBUG_PRINT("NULL\n"); break; case 0x06: /* object identifier */ if (_is_field(fields, pk_id)) { if (length == 8 || length == 5) { tls_certificate_set_algorithm (&cert->ec_algorithm, &buffer[pos], length); } else { tls_certificate_set_algorithm (&cert->key_algorithm, &buffer[pos], length); } } if (_is_field(fields, algorithm_id)) tls_certificate_set_algorithm (&cert->algorithm, &buffer[pos], length); DEBUG_PRINT("OBJECT IDENTIFIER(%i): ", length); DEBUG_DUMP_HEX(&buffer[pos], length); DEBUG_PRINT("\n"); /* check previous oid */ if (_is_oid2 (oid, ocsp_oid, 16, sizeof(ocsp_oid) - 1)) tls_certificate_set_copy(&cert->ocsp, &buffer[pos], length); if (length < 16) { memcpy(oid, &buffer[pos], length); } else { memcpy(oid, &buffer[pos], 16); } if (top_oid) memcpy(top_oid, oid, 16); break; case 0x09: DEBUG_PRINT("REAL NUMBER(%i): ", length); DEBUG_DUMP_HEX(&buffer[pos], length); DEBUG_PRINT("\n"); break; case 0x17: /* utc time */ DEBUG_PRINT("UTC TIME: ["); DEBUG_DUMP(&buffer[pos], length); DEBUG_PRINT("]\n"); if (_is_field(fields, validity_id)) { if (idx == 1) { tls_certificate_set_copy_date (cert->not_before, &buffer[pos], length); } else { tls_certificate_set_copy_date (cert->not_after, &buffer[pos], length); } } break; case 0x18: /* generalized time */ DEBUG_PRINT("GENERALIZED TIME: ["); DEBUG_DUMP(&buffer[pos], length); DEBUG_PRINT("]\n"); break; case 0x13: /* printable string */ case 0x0C: case 0x14: case 0x15: case 0x16: case 0x19: case 0x1A: case 0x1B: case 0x1C: case 0x1D: case 0x1E: if (_is_field(fields, issurer_id)) { if (_is_oid(oid, country_oid, 3)) { tls_certificate_set_copy (&cert->issuer_country, &buffer[pos], length); } else if (_is_oid (oid, state_oid, 3)) { tls_certificate_set_copy (&cert->issuer_state, &buffer[pos], length); } else if (_is_oid (oid, location_oid, 3)) { tls_certificate_set_copy (&cert-> issuer_location, &buffer[pos], length); } else if (_is_oid (oid, entity_oid, 3)) { tls_certificate_set_copy (&cert->issuer_entity, &buffer[pos], length); } else if (_is_oid (oid, subject_oid, 3)) { tls_certificate_set_copy (&cert->issuer_subject, &buffer[pos], length); } } else if (_is_field(fields, owner_id)) { if (_is_oid(oid, country_oid, 3)) { tls_certificate_set_copy (&cert->country, &buffer[pos], length); } else if (_is_oid (oid, state_oid, 3)) { tls_certificate_set_copy (&cert->state, &buffer[pos], length); } else if (_is_oid (oid, location_oid, 3)) { tls_certificate_set_copy (&cert->location, &buffer[pos], length); } else if (_is_oid (oid, entity_oid, 3)) { tls_certificate_set_copy (&cert->entity, &buffer[pos], length); } else if (_is_oid (oid, subject_oid, 3)) { tls_certificate_set_copy (&cert->subject, &buffer[pos], length); } } DEBUG_PRINT("STR: ["); DEBUG_DUMP(&buffer[pos], length); DEBUG_PRINT("]\n"); break; case 0x10: DEBUG_PRINT("EMPTY SEQUENCE\n"); break; case 0xA: DEBUG_PRINT("ENUMERATED(%i): ", length); DEBUG_DUMP_HEX(&buffer[pos], length); DEBUG_PRINT("\n"); break; default: DEBUG_PRINT("========> NOT SUPPORTED %x\n", (int) type); /* not supported / needed */ break; } } pos += length; } if (cert_len && cert_data) { int h = find_hash("sha256"); size_t len = sizeof cert->fp; hash_memory(h, cert_data,cert_len, cert->fp, &len); } if (level == 2 && cert->sign_key && cert->sign_len && cert_len && cert_data) { free(cert->fingerprint); cert->fingerprint = tls_compute_hash(cert->algorithm, cert_data, cert_len); #ifdef DEBUG if (cert->fingerprint) { DEBUG_DUMP_HEX_LABEL("FINGERPRINT", cert->fingerprint, tls_hash_len(cert->algorithm)); } #endif } return pos; } struct TLSCertificate *asn1_parse(struct TLSContext *context, const unsigned char *buffer, int size, int client_cert) { unsigned int fields[TLS_ASN1_MAXLEVEL] = { 0 }; struct TLSCertificate *cert = tls_create_certificate(); if (cert) { if (client_cert < 0) { client_cert = 0; /* private key */ unsigned char top_oid[16]; memset(top_oid, 0, sizeof(top_oid)); _private_asn1_parse(context, cert, buffer, size, 1, fields, NULL, client_cert, top_oid, NULL); } else { _private_asn1_parse(context, cert, buffer, size, 1, fields, NULL, client_cert, NULL, NULL); } } return cert; } int tls_clear_certificates(struct TLSContext *tls) { int i; if (!tls || !tls->is_server || tls->is_child) { return TLS_GENERIC_ERROR; } if (tls->root_certificates) { for (i = 0; i < tls->root_count; i++) { tls_destroy_certificate(tls->root_certificates[i]); } } tls->root_certificates = NULL; tls->root_count = 0; if (tls->private_key) { tls_destroy_certificate(tls->private_key); } tls->private_key = NULL; if (tls->ec_private_key) { tls_destroy_certificate(tls->ec_private_key); } tls->ec_private_key = NULL; free(tls->certificates); tls->certificates = NULL; tls->certificates_count = 0; return 0; } /* This is just a wrapper around parse message so we don't * call read more often than necessary. IOW, if there's * more than one record in the input buffer, process them all */ int tls_consume_stream(struct TLSContext *context) { if (!context) { return TLS_GENERIC_ERROR; } if (context->critical_error) { return TLS_BROKEN_CONNECTION; } size_t tls_buffer_len = context->input_buffer.len; unsigned char *buffer = context->input_buffer.buffer; unsigned int index = 0; int err_flag = 0; int tls_header_size; int tls_size_offset; tls_size_offset = 3; tls_header_size = 5; while (tls_buffer_len >= 5) { uint16_t length; length = get16(buffer + index + tls_size_offset) + tls_header_size; if (length > tls_buffer_len) { /* record not complete */ break; } /* This is the only place tls_parse_message is called */ int consumed = tls_parse_message(context, buffer+index, length); if (consumed < 0) { fprintf(stderr, "parse message error: %d\n", consumed); err_flag = consumed; break; } index += length; tls_buffer_len -= length; if (context->critical_error) { err_flag = TLS_BROKEN_CONNECTION; break; } } if (err_flag || context->input_buffer.error) { if (!context->critical_error) { context->critical_error = 1; } DEBUG_PRINT("ERROR IN CONSUME: %i\n", err_flag); tls_buffer_free(&context->input_buffer); return err_flag; } tls_buffer_shift(&context->input_buffer, index); return index; } void tls_close_notify(struct TLSContext *context) { if (!context || context->critical_error) { return; } context->critical_error = 1; DEBUG_PRINT("CLOSE\n"); tls_alert(context, 0, close_notify); } void tls_alert(struct TLSContext *context, int critical, int code) { if (!context) { return; } struct TLSPacket *packet = tls_create_packet(context, TLS_ALERT, context->version, 0); tls_packet_uint8(packet, critical ? TLS_ALERT_CRITICAL : TLS_ALERT_WARNING); tls_packet_uint8(packet, code); tls_packet_update(packet); if (critical) { context->critical_error = 1; } tls_queue_packet(packet); } int tls_is_broken(struct TLSContext *context) { if (!context || context->critical_error) { return 1; } return 0; } /* TODO I don't see that this ever gets cleared */ int tls_request_client_certificate(struct TLSContext *context) { if (!context || !context->is_server) { return 0; } context->request_client_certificate = 1; return 1; } int tls_client_verified(struct TLSContext *context) { if (!context || context->critical_error) { return 0; } return context->client_verified == 1; } int tls_sni_set(struct TLSContext *context, const char *sni) { if (!context || context->is_server || context->critical_error || context->connection_status != 0) { return 0; } free(context->sni); errno = 0; context->sni = sni ? strdup(sni) : 0; return context->sni ? 1 : 0; } int tls_default_verify(struct TLSContext *context, struct TLSCertificate **certificate_chain, int len) { int i; int err; if (certificate_chain) { for (i = 0; i < len; i++) { struct TLSCertificate *certificate = certificate_chain[i]; /* check validity date */ err = tls_certificate_is_valid(certificate); if (err) { return err; } } } /* check if chain is valid */ err = tls_certificate_chain_is_valid(certificate_chain, len); if (err) { return err; } /* check certificate subject */ if (!context->is_server && context->sni && len > 0 && certificate_chain) { err = tls_certificate_valid_subject(certificate_chain[0], context->sni); if (err) { return err; } } err = tls_certificate_chain_is_valid_root(context, certificate_chain, len); if (err) { return err; } DEBUG_PRINT("Certificate OK\n"); return no_error; } ssize_t tls_fsync(struct TLSContext *context) { size_t buflen = 0; size_t offset = 0; ssize_t send_res = 0; int fd; unsigned char *buffer; tls_send_func write_cb = NULL; if (!context) { return 0; } fd = context->fd; if (fd < 0) { return -1; } buffer = context->output_buffer.buffer; buflen = context->output_buffer.len; if (context->send) { write_cb = context->send; } else { write_cb = send; } while (buflen > 0) { ssize_t res; errno = 0; res = write_cb(fd, buffer+offset, buflen, 0); if (res <= 0) { perror("send error"); send_res = res; break; } buflen -= res; offset += res; send_res += res; } DEBUG_PRINT("sent %zd bytes\n", send_res); context->output_buffer.len = 0; return send_res; } void tls_free(struct TLSContext *context) { if (context) { free(context->user_data); tls_destroy_context(context); } } int tls_set_fd(struct TLSContext *context, int socket) { if (!context) { return TLS_GENERIC_ERROR; } context->fd = socket; return 0; } int tls_load_root_file(struct TLSContext *context, const char *pem_filename) { int fd; struct stat st; void *addr; if (!context) { return -1; } int count = -1; fd = open(pem_filename, O_RDONLY); if (fd == -1) { return -1; } if (fstat(fd, &st) == -1) { close(fd); return -1; } addr = mmap(NULL, st.st_size, PROT_READ, MAP_PRIVATE, fd, 0); if (addr == MAP_FAILED) { close(fd); return -1; } count = tls_load_root_certificates(context, addr, st.st_size); munmap(addr, st.st_size); close(fd); return count; } void tls_set_verify(struct TLSContext *tls, tls_validation_function vfunc) { if (tls) { tls->certificate_verify = vfunc; } } static ssize_t tls_safe_read(struct TLSContext *tls) { tls_recv_func read_cb; char buffer[8192]; ssize_t bytes; if (!tls || tls->fd <= 0) { return TLS_GENERIC_ERROR; } if (tls->recv) { read_cb = tls->recv; } else { read_cb = recv; } errno = 0; bytes = read_cb(tls->fd, buffer, sizeof buffer, 0); if (bytes > 0) { tls_buffer_append(&tls->input_buffer, buffer, bytes); } return bytes; } /* I think this is the server handshake */ int SSL_accept(struct TLSContext *context) { ssize_t read_size = 0; if (!context || context->fd <= 0) { return TLS_GENERIC_ERROR; } if (tls_established(context)) { return 1; } /* accept */ while ((read_size = tls_safe_read(context)) > 0) { if (tls_consume_stream(context) >= 0) { ssize_t res = tls_fsync(context); if (res < 0) { return res; } } if (tls_established(context)) { return 1; } } if (read_size <= 0) { return TLS_BROKEN_CONNECTION; } return 0; } /* TODO this is really do the handshake */ int tls_connect(struct TLSContext *context) { int res; ssize_t read_size; MARK; if (!context || context->fd < 0 || context->critical_error) { if (!context) { MARK; } else if (context->fd < 0) { MARK; } else { MARK; } return TLS_GENERIC_ERROR; } MARK; if (context->is_server) { return TLS_UNEXPECTED_MESSAGE; } MARK; res = tls_queue_packet(tls_build_client_hello(context)); MARK; if (res < 0) { return res; } MARK; res = tls_fsync(context); MARK; if (res < 0) { return res; } while ((read_size = tls_safe_read(context)) > 0) { if ((res = tls_consume_stream(context)) >= 0) { res = tls_fsync(context); if (res < 0) { return res; } } MARK; if (tls_established(context)) { MARK; return 1; } MARK; if (context->critical_error) { fprintf(stderr, "critical error: %d\n", context->critical_error); return TLS_GENERIC_ERROR; } } MARK; return read_size; } int tls_shutdown(struct TLSContext *tls) { if (!tls || tls->fd <= 0) { return TLS_GENERIC_ERROR; } tls_close_notify(tls); return 0; } /* TODO configure for maximum packet data length * max is 2^14 - 5 byte header - 32 byte mac - padding which depends * on the cipher (up to 255 bytes I think). */ ssize_t tls_write(struct TLSContext *context, const void *buf, size_t count) { if (!context) { return TLS_GENERIC_ERROR; } if (context->connection_status != TLS_CONNECTED) { return TLS_UNEXPECTED_MESSAGE; } if (count > TLS_MAXTLS_APP_SIZE) { count = TLS_MAXTLS_APP_SIZE; } if (!buf || !count) { return 0; } struct TLSPacket *packet = tls_create_packet(context, TLS_APPLICATION_DATA, context->version, count); tls_packet_append(packet, buf, count); tls_packet_update(packet); tls_queue_packet(packet); /* TODO think about this. context->sync with O_NONBLOCK might be a * problem */ if (context->sync) { ssize_t res; res = tls_fsync(context); if (res == -1) { return res; } } return count; } static ssize_t tls_readbuf(struct TLSContext *tls, void *buf, size_t count) { if (count > tls->application_buffer.len) { count = tls->application_buffer.len; } if (count > 0) { /* TODO should have a buffer read and shift */ memcpy(buf, tls->application_buffer.buffer, count); tls_buffer_shift(&tls->application_buffer, count); } return count; } ssize_t tls_read(struct TLSContext *context, void *buf, size_t count) { if (!context) { return TLS_GENERIC_ERROR; } if (context->application_buffer.len) { return tls_readbuf(context, buf, count); } if (context->fd <= 0 || context->critical_error) { return TLS_GENERIC_ERROR; } if (!tls_established(context)) { return TLS_GENERIC_ERROR; } if (context->application_buffer.len == 0 && !context->critical_error) { /* attempt to fill buffer, unless we're already in an error * state */ ssize_t read_size; while ((read_size = tls_safe_read(context)) > 0) { if (tls_consume_stream(context) > 0) { tls_fsync(context); break; } if (context->critical_error && !context->application_buffer.len) { /* if there's a critical error, don't bail if * we managed to get some data */ return TLS_GENERIC_ERROR; } } if (read_size <= 0 && context->application_buffer.len == 0) { /* can return errors as for read(2) */ return read_size; } } return tls_readbuf(context, buf, count); }