#define _POSIX_C_SOURCE 200809L #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "buffer.h" #include "tlse.h" #ifndef htonll #define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32)) #endif #ifndef ntohll #define ntohll(x) ((1==ntohl(1)) ? (x) : ((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32)) #endif #ifdef DEBUG static char *tls_alert_msg_name(int code) { switch (code) { case 0: return "close_notify"; break; case 10: return "unexpected_message"; break; case 20: return "bad_record_mac"; break; default: break; } return "unknown alert"; } #endif static uint16_t get16(const unsigned char *buf) { uint16_t res; res = ((*buf) << 8) + (*(buf+1)); return res; } static int tls_microsleep(unsigned int microseconds) { struct timespec ts; struct timespec rem; int rv; ts.tv_sec = (time_t) (microseconds / 1000000L); ts.tv_nsec = (long) (((long) microseconds % 1000000L) * 1000L); errno = 0; while ((rv = nanosleep(&ts, &rem)) == -1) { if (errno != EINTR) { return rv; } ts = rem; } return rv; } /* TODO this is biased unless limit is a power of 2 */ static unsigned int random_int(int limit) { unsigned int res = 0; tls_random((unsigned char *) &res, sizeof(int)); if (limit) { res %= limit; } return res; } static int random_sleep(long max) { return tls_microsleep(random_int(max)); } static void U32TO8(unsigned char *p, unsigned long v) { p[0] = (v) & 0xff; p[1] = (v >> 8) & 0xff; p[2] = (v >> 16) & 0xff; p[3] = (v >> 24) & 0xff; } static int decrypt_aes_gcm(struct TLSContext *context, uint16_t length, int header_size, const unsigned char *ptr, const unsigned char *buf, int buf_len, struct tls_buffer *pt_buffer) { int delta = 8; int pt_length; unsigned char iv[TLS_13_AES_GCM_IV_LENGTH]; unsigned char aad[16]; int aad_size = sizeof(aad); unsigned char *sequence = aad; gcm_state *remote_gcm; remote_gcm = &context->crypto.ctx_remote.aes_gcm_remote; gcm_reset(remote_gcm); /* set up the aad */ if (context->tlsver == TLS_VERSION13) { aad[0] = TLS_APPLICATION_DATA; aad[1] = 0x03; aad[2] = 0x03; *((unsigned short *) &aad[3]) = htons(buf_len - header_size); aad_size = 5; sequence = aad + 5; *((uint64_t *) sequence) = htonll(context->remote_sequence_number); memcpy(iv, context->crypto.ctx_remote_mac.remote_iv, TLS_13_AES_GCM_IV_LENGTH); int i; int offset = TLS_13_AES_GCM_IV_LENGTH - 8; for (i = 0; i < 8; i++) { iv[offset + i] = context->crypto.ctx_remote_mac.remote_iv[offset + i] ^ sequence[i]; } pt_length = buf_len - header_size - TLS_GCM_TAG_LEN; delta = 0; } else { aad_size = 13; pt_length = length - 8 - TLS_GCM_TAG_LEN; /* build aad and iv */ *((uint64_t *) aad) = htonll(context->remote_sequence_number); aad[8] = buf[0]; aad[9] = buf[1]; aad[10] = buf[2]; memcpy(iv, context->crypto.ctx_remote_mac.remote_aead_iv, 4); memcpy(iv + 4, buf + header_size, 8); *((unsigned short *) &aad[11]) = htons(pt_length); } if (pt_length < 0) { DEBUG_PRINT("Invalid packet length"); return TLS_BROKEN_PACKET; } DEBUG_DUMP_HEX_LABEL("aad", aad, aad_size); DEBUG_DUMP_HEX_LABEL("aad iv", iv, 12); /* I think we do all of this even if they fail to avoid timing * attacks */ int res0 = gcm_add_iv(&context->crypto.ctx_remote.aes_gcm_remote, iv, 12); int res1 = gcm_add_aad(&context->crypto.ctx_remote.aes_gcm_remote, aad, aad_size); DEBUG_PRINT("PT SIZE: %i\n", pt_length); /* TODO we might want to expand this buffer before we call this * function */ tls_buffer_expand(pt_buffer, pt_length); int res2 = gcm_process(&context->crypto.ctx_remote. aes_gcm_remote, pt_buffer->buffer, pt_length, (char *)buf + header_size + delta, GCM_DECRYPT); pt_buffer->len = pt_length; unsigned char tag[32]; unsigned long taglen = 32; int res3 = gcm_done(&context->crypto.ctx_remote.aes_gcm_remote, tag, &taglen); if (res0 || res1 || res2 || res3 || taglen != TLS_GCM_TAG_LEN) { DEBUG_PRINT ("ERROR: gcm_add_iv: %i, gcm_add_aad: %i, gcm_process: %i, gcm_done: %i\n", res0, res1, res2, res3); return TLS_BROKEN_PACKET; } DEBUG_DUMP_HEX_LABEL("decrypted1", pt_buffer->buffer, pt_length); DEBUG_DUMP_HEX_LABEL("tag", tag, taglen); /* check tag */ if (memcmp(buf + header_size + delta + pt_length, tag, taglen)) { DEBUG_PRINT("INTEGRITY CHECK FAILED (msg length %i)\n", pt_length); DEBUG_DUMP_HEX_LABEL("TAG RECEIVED", buf + header_size + delta + pt_length, taglen); DEBUG_DUMP_HEX_LABEL("TAG COMPUTED", tag, taglen); tls_alert(context, 1, bad_record_mac); return TLS_INTEGRITY_FAILED; } ptr = pt_buffer->buffer; length = pt_length; return 0; } static int decrypt_chacha_poly1305(struct TLSContext *context, uint16_t length, int header_size, const unsigned char *ptr, const unsigned char *buf, int buf_len, struct tls_buffer *pt_buffer) { unsigned char aad[16]; int aad_size = sizeof(aad); unsigned char *sequence = aad; int pt_length = length - POLY1305_TAGLEN; unsigned int counter = 1; unsigned char poly1305_key[POLY1305_KEYLEN]; unsigned char trail[16]; unsigned char mac_tag[POLY1305_TAGLEN]; aad_size = 16; if (pt_length < 0) { DEBUG_PRINT("Invalid packet length"); return TLS_BROKEN_PACKET; } /* set up add */ if (context->tlsver == TLS_VERSION13) { aad[0] = TLS_APPLICATION_DATA; aad[1] = 0x03; aad[2] = 0x03; *((unsigned short *) &aad[3]) = htons(buf_len - header_size); aad_size = 5; sequence = aad + 5; *((uint64_t *) sequence) = htonll(context->remote_sequence_number); } else { *((uint64_t *) aad) = htonll(context->remote_sequence_number); aad[8] = buf[0]; aad[9] = buf[1]; aad[10] = buf[2]; *((unsigned short *) &aad[11]) = htons(pt_length); aad[13] = 0; aad[14] = 0; aad[15] = 0; } tls_buffer_expand(pt_buffer, pt_length); chacha_ivupdate(&context->crypto.ctx_remote.chacha_remote, context->crypto.ctx_remote_mac.remote_aead_iv, sequence, (unsigned char *) &counter); chacha_encrypt_bytes(&context->crypto.ctx_remote.chacha_remote, buf + header_size, pt_buffer->buffer, pt_length); DEBUG_DUMP_HEX_LABEL("decrypted2", pt_buffer->buffer, pt_length); ptr = pt_buffer->buffer; length = pt_length; chacha20_poly1305_key(&context->crypto.ctx_remote.chacha_remote, poly1305_key); struct poly1305_context ctx; tls_poly1305_init(&ctx, poly1305_key); tls_poly1305_update(&ctx, aad, aad_size); static unsigned char zeropad[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; int rem = aad_size % 16; if (rem) { tls_poly1305_update(&ctx, zeropad, 16 - rem); } tls_poly1305_update(&ctx, buf + header_size, pt_length); rem = pt_length % 16; if (rem) { tls_poly1305_update(&ctx, zeropad, 16 - rem); } U32TO8(&trail[0], aad_size == 5 ? 5 : 13); *(int *) &trail[4] = 0; U32TO8(&trail[8], pt_length); *(int *) &trail[12] = 0; tls_poly1305_update(&ctx, trail, 16); tls_poly1305_finish(&ctx, mac_tag); if (memcmp(mac_tag, buf + header_size + pt_length, POLY1305_TAGLEN)) { DEBUG_PRINT ("INTEGRITY CHECK FAILED (msg length %i)\n", length); DEBUG_DUMP_HEX_LABEL("POLY1305 TAG RECEIVED", buf + header_size + pt_length, POLY1305_TAGLEN); DEBUG_DUMP_HEX_LABEL("POLY1305 TAG COMPUTED", mac_tag, POLY1305_TAGLEN); tls_alert(context, 1, bad_record_mac); return TLS_INTEGRITY_FAILED; } pt_buffer->len = pt_length; return 0; } static int decrypt_other(struct TLSContext *context, uint16_t length, int header_size, const unsigned char *ptr, const unsigned char *buf, struct tls_buffer *pt_buffer) { int err; ENTER; tls_buffer_expand(pt_buffer, length); DEBUG_PRINTLN("cbc_decrypt(%p, %p, %lu, %p)\n", buf+header_size, pt_buffer->buffer, (unsigned long)length, &context->crypto.ctx_remote.aes_remote); err = cbc_decrypt(buf+header_size, pt_buffer->buffer, length, &context->crypto.ctx_remote.aes_remote); if (err) { DEBUG_PRINTLN("Decryption error %d %s\n", err, error_to_string(err)); LEAVE; return TLS_BROKEN_PACKET; } pt_buffer->len = length; unsigned char padding_byte = pt_buffer->buffer[length - 1]; unsigned char padding = padding_byte + 1; DEBUG_PRINTLN("cbc padding byte = %d\n", (int)padding_byte); /* poodle check */ int padding_index = length - padding; if (padding_index > 0) { int i; int limit = length - 1; for (i = length - padding; i < limit; i++) { if (pt_buffer->buffer[i] != padding_byte) { DEBUG_PRINTLN("BROKEN PACKET (POODLE ?)\n"); tls_alert(context, 1, decrypt_error); LEAVE; return TLS_BROKEN_PACKET; } } } unsigned int decrypted_length = length; if (padding < decrypted_length) { decrypted_length -= padding; pt_buffer->len -= padding; } DEBUG_DUMP_HEX_LABEL("decrypted", pt_buffer->buffer, decrypted_length); if (decrypted_length > TLS_AES_IV_LENGTH) { decrypted_length -= TLS_AES_IV_LENGTH; tls_buffer_shift(pt_buffer, TLS_AES_IV_LENGTH); } ptr = pt_buffer->buffer; length = decrypted_length; unsigned int mac_size = tls_mac_length(context); if (length < mac_size || !mac_size) { DEBUG_PRINTLN("BROKEN PACKET\n"); tls_alert(context, 1, decrypt_error); LEAVE; return TLS_BROKEN_PACKET; } DEBUG_PRINTLN("mac size %u\n", mac_size); length -= mac_size; pt_buffer->len -= mac_size; const unsigned char *message_hmac = &ptr[length]; unsigned char hmac_out[TLS_MAX_MAC_SIZE]; unsigned char temp_buf[5]; memcpy(temp_buf, buf, 3); *(unsigned short *) &temp_buf[3] = htons(length); unsigned int hmac_out_len = tls_hmac_message(0, context, temp_buf, 5, ptr, length, hmac_out, mac_size); if (hmac_out_len != mac_size || memcmp(message_hmac, hmac_out, mac_size)) { DEBUG_PRINTLN("INTEGRITY CHECK FAILED (msg length %i)\n", length); DEBUG_DUMP_HEX_LABEL("HMAC RECEIVED", message_hmac, mac_size); DEBUG_DUMP_HEX_LABEL("HMAC COMPUTED", hmac_out, hmac_out_len); tls_alert(context, 1, bad_record_mac); LEAVE; return TLS_INTEGRITY_FAILED; } LEAVE; return 0; } static int decrypt_payload(struct TLSContext *context, uint16_t length, int header_size, const unsigned char *ptr, const unsigned char *buf, int buf_len, struct tls_buffer *pt_buffer ) { if (context->crypto.created == 2) { return decrypt_aes_gcm(context, length, header_size, ptr, buf, buf_len, pt_buffer); } else if (context->crypto.created == 3) { return decrypt_chacha_poly1305(context, length, header_size, ptr, buf, buf_len, pt_buffer); } else if (context->crypto.created == 1) { return decrypt_other(context, length, header_size, ptr, buf, pt_buffer); } else { tls_buffer_free(pt_buffer); return TLS_BROKEN_PACKET; } return 0; } int tls_parse_message(struct TLSContext *context, unsigned char *buf, int buf_len) { uint8_t type = *buf; uint16_t version; /* a struct of two uint8 per the rfc, but we encode it as a uint16_t */ uint16_t length; int buf_pos = 0; int res = 5; int payload_res = 0; const unsigned char *ptr = 0; /* TODO probably make this buffer part of the context, * and just zero it out instead of malloc and free */ struct tls_buffer pt; int header_size = res; ENTER; if (buf_len < res) { LEAVE; return TLS_NEED_MORE_DATA; } type = *buf; buf_pos += 1; version = get16(&buf[buf_pos]); buf_pos += 2; if (!tls_supported_version(version)) { LEAVE; return TLS_NOT_SAFE; } length = get16(&buf[buf_pos]); buf_pos += 2; ptr = buf + buf_pos; DEBUG_PRINTLN("Message type: %d, length: %d\n", (int)type, (int)length); /* this buffer can go out of scope */ tls_buffer_init(&pt, 0); if (context->cipher_spec_set && type != TLS_CHANGE_CIPHER) { /* Need to decrypt payload */ DEBUG_DUMP_HEX_LABEL("encrypted", &buf[header_size], length); if (!context->crypto.created) { DEBUG_PRINT("Encryption context not created\n"); random_sleep(TLS_MAX_ERROR_SLEEP_uS); LEAVE; return TLS_BROKEN_PACKET; } int rv = decrypt_payload(context, length, header_size, ptr, buf, buf_len, &pt); if (rv != 0) { tls_buffer_free(&pt); random_sleep(TLS_MAX_ERROR_SLEEP_uS); LEAVE; return rv; } if (pt.error) { tls_buffer_free(&pt); random_sleep(TLS_MAX_ERROR_SLEEP_uS); LEAVE; return TLS_NO_MEMORY; } ptr = pt.buffer; length = pt.len; } context->remote_sequence_number++; if (context->tlsver == TLS_VERSION13) { /*(context->connection_status == 2) && */ if (type == TLS_APPLICATION_DATA && context->crypto.created) { do { length--; type = ptr[length]; } while (!type); } } /* TODO for v1.3 encrypted handshake messages will show up * as application data, so we need to re-compute the record * type, that may be what the above is doing */ switch (type) { /* application data */ case TLS_APPLICATION_DATA: if (context->connection_status != TLS_CONNECTED) { DEBUG_PRINT ("UNEXPECTED APPLICATION DATA MESSAGE\n"); payload_res = TLS_UNEXPECTED_MESSAGE; tls_alert(context, 1, unexpected_message); break; } DEBUG_PRINT ("APPLICATION DATA MESSAGE (TLS VERSION: %x):\n", (int) context->version); DEBUG_DUMP(ptr, length); DEBUG_PRINT("\n"); tls_buffer_append(&context->application_buffer, ptr, length); if (context->application_buffer.error) { payload_res = TLS_NO_MEMORY; } break; /* handshake */ case TLS_HANDSHAKE: DEBUG_PRINT("HANDSHAKE MESSAGE\n"); payload_res = tls_parse_payload(context, ptr, length); break; /* change cipher spec */ case TLS_CHANGE_CIPHER: if (context->connection_status != 2) { if (context->connection_status == 4) { DEBUG_PRINT ("IGNORING CHANGE CIPHER SPEC MESSAGE (HELLO RETRY REQUEST)\n"); break; } DEBUG_PRINT ("UNEXPECTED CHANGE CIPHER SPEC MESSAGE (%i)\n", context->connection_status); tls_alert(context, 1, unexpected_message); payload_res = TLS_UNEXPECTED_MESSAGE; } else { DEBUG_PRINT("CHANGE CIPHER SPEC MESSAGE\n"); context->cipher_spec_set = 1; /* reset sequence numbers */ context->remote_sequence_number = 0; } break; /* alert */ case TLS_ALERT: DEBUG_PRINTLN("ALERT MESSAGE\n"); if (length >= 2) { int level = ptr[0]; int code = ptr[1]; DEBUG_PRINTLN("level = %d, code = %d\n", level, code); if (level == TLS_ALERT_CRITICAL) { DEBUG_PRINTLN("critical error: %s\n", tls_alert_msg_name(code)); context->critical_error = 1; res = TLS_ERROR_ALERT; } context->error_code = code; } else { DEBUG_PRINT("ALERT MESSAGE short\n"); } break; default: DEBUG_PRINT("UNKNOWN MESSAGE TYPE: %x\n", (int)type); payload_res = TLS_NOT_UNDERSTOOD; break; } tls_buffer_free(&pt); if (payload_res < 0) { LEAVE; return payload_res; } if (res > 0) { LEAVE; return header_size + length; } LEAVE; return res; }