--- /dev/null
+#define _POSIX_C_SOURCE 200809L
+
+#include <fcntl.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <strings.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <time.h>
+
+#include <sys/socket.h>
+#include <arpa/inet.h>
+#include <unistd.h>
+
+#include <errno.h>
+
+#include <tomcrypt.h>
+
+#include "buffer.h"
+#include "tlse.h"
+
+#ifndef htonll
+#define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
+#endif
+
+#ifndef ntohll
+#define ntohll(x) ((1==ntohl(1)) ? (x) : ((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32))
+#endif
+
+static uint16_t get16(const unsigned char *buf) {
+ uint16_t res;
+
+ res = ((*buf) << 8) + (*(buf+1));
+ return res;
+}
+
+static int tls_microsleep(unsigned int microseconds) {
+ struct timespec ts;
+ struct timespec rem;
+ int rv;
+
+ ts.tv_sec = (time_t) (microseconds / 1000000L);
+ ts.tv_nsec = (long) (((long) microseconds % 1000000L) * 1000L);
+
+ errno = 0;
+ while ((rv = nanosleep(&ts, &rem)) == -1) {
+ if (errno != EINTR) {
+ return rv;
+ }
+ ts = rem;
+ }
+ return rv;
+}
+
+/* TODO this is biased unless limit is a power of 2 */
+static unsigned int random_int(int limit) {
+ unsigned int res = 0;
+ tls_random((unsigned char *) &res, sizeof(int));
+ if (limit) {
+ res %= limit;
+ }
+ return res;
+}
+
+static int random_sleep(long max) {
+ return tls_microsleep(random_int(max));
+}
+
+static void U32TO8(unsigned char *p, unsigned long v) {
+ p[0] = (v) & 0xff;
+ p[1] = (v >> 8) & 0xff;
+ p[2] = (v >> 16) & 0xff;
+ p[3] = (v >> 24) & 0xff;
+}
+
+static int decrypt_aes_gcm(struct TLSContext *context, uint16_t length, int
+ header_size, const unsigned char *ptr, const unsigned char
+ *buf, int buf_len, struct tls_buffer *pt_buffer) {
+ int delta = 8;
+ int pt_length;
+ unsigned char iv[TLS_13_AES_GCM_IV_LENGTH];
+ unsigned char aad[16];
+ int aad_size = sizeof(aad);
+ unsigned char *sequence = aad;
+
+ gcm_state *remote_gcm;
+ remote_gcm = &context->crypto.ctx_remote.aes_gcm_remote;
+
+ gcm_reset(remote_gcm);
+
+ /* set up the aad */
+ if (context->tlsver == TLS_VERSION13) {
+ aad[0] = TLS_APPLICATION_DATA;
+ aad[1] = 0x03;
+ aad[2] = 0x03;
+ *((unsigned short *) &aad[3]) =
+ htons(buf_len - header_size);
+ aad_size = 5;
+ sequence = aad + 5;
+ *((uint64_t *) sequence) =
+ htonll(context->remote_sequence_number);
+ memcpy(iv, context->crypto.ctx_remote_mac.remote_iv,
+ TLS_13_AES_GCM_IV_LENGTH);
+ int i;
+ int offset = TLS_13_AES_GCM_IV_LENGTH - 8;
+ for (i = 0; i < 8; i++) {
+ iv[offset + i] =
+ context->crypto.ctx_remote_mac.remote_iv[offset +
+ i] ^ sequence[i];
+ }
+ pt_length = buf_len - header_size - TLS_GCM_TAG_LEN;
+ delta = 0;
+ } else {
+ aad_size = 13;
+ pt_length = length - 8 - TLS_GCM_TAG_LEN;
+
+ /* build aad and iv */
+ *((uint64_t *) aad) = htonll(context->remote_sequence_number);
+ aad[8] = buf[0];
+ aad[9] = buf[1];
+ aad[10] = buf[2];
+
+ memcpy(iv, context->crypto.ctx_remote_mac.remote_aead_iv, 4);
+ memcpy(iv + 4, buf + header_size, 8);
+ *((unsigned short *) &aad[11]) = htons(pt_length);
+ }
+
+ if (pt_length < 0) {
+ DEBUG_PRINT("Invalid packet length");
+ return TLS_BROKEN_PACKET;
+ }
+ DEBUG_DUMP_HEX_LABEL("aad", aad, aad_size);
+ DEBUG_DUMP_HEX_LABEL("aad iv", iv, 12);
+
+ /* I think we do all of this even if they fail to avoid timing
+ * attacks
+ */
+ int res0 = gcm_add_iv(&context->crypto.ctx_remote.aes_gcm_remote, iv, 12);
+ int res1 = gcm_add_aad(&context->crypto.ctx_remote.aes_gcm_remote, aad, aad_size);
+
+ DEBUG_PRINT("PT SIZE: %i\n", pt_length);
+
+ /* TODO we might want to expand this buffer before we call this
+ * function */
+ tls_buffer_expand(pt_buffer, pt_length);
+
+ int res2 = gcm_process(&context->crypto.ctx_remote.
+ aes_gcm_remote, pt_buffer->buffer,
+ pt_length,
+ (char *)buf + header_size + delta,
+ GCM_DECRYPT);
+ pt_buffer->len = pt_length;
+
+ unsigned char tag[32];
+ unsigned long taglen = 32;
+ int res3 = gcm_done(&context->crypto.ctx_remote.aes_gcm_remote,
+ tag, &taglen);
+
+ if (res0 || res1 || res2 || res3 || taglen != TLS_GCM_TAG_LEN) {
+ DEBUG_PRINT
+ ("ERROR: gcm_add_iv: %i, gcm_add_aad: %i, gcm_process: %i, gcm_done: %i\n",
+ res0, res1, res2, res3);
+ return TLS_BROKEN_PACKET;
+ }
+
+ DEBUG_DUMP_HEX_LABEL("decrypted1", pt_buffer->buffer, pt_length);
+ DEBUG_DUMP_HEX_LABEL("tag", tag, taglen);
+
+ /* check tag */
+ if (memcmp(buf + header_size + delta + pt_length, tag, taglen)) {
+ DEBUG_PRINT("INTEGRITY CHECK FAILED (msg length %i)\n",
+ pt_length);
+ DEBUG_DUMP_HEX_LABEL("TAG RECEIVED",
+ buf + header_size + delta + pt_length,
+ taglen);
+ DEBUG_DUMP_HEX_LABEL("TAG COMPUTED", tag, taglen);
+ tls_alert(context, 1, bad_record_mac);
+ return TLS_INTEGRITY_FAILED;
+ }
+ ptr = pt_buffer->buffer;
+ length = pt_length;
+
+ return 0;
+}
+
+static int decrypt_chacha_poly1305(struct TLSContext *context, uint16_t length,
+ int header_size, const unsigned char *ptr, const unsigned char
+ *buf, int buf_len, struct tls_buffer *pt_buffer) {
+
+ unsigned char aad[16];
+ int aad_size = sizeof(aad);
+ unsigned char *sequence = aad;
+
+ int pt_length = length - POLY1305_TAGLEN;
+ unsigned int counter = 1;
+ unsigned char poly1305_key[POLY1305_KEYLEN];
+ unsigned char trail[16];
+ unsigned char mac_tag[POLY1305_TAGLEN];
+
+ aad_size = 16;
+ if (pt_length < 0) {
+ DEBUG_PRINT("Invalid packet length");
+ return TLS_BROKEN_PACKET;
+ }
+
+ /* set up add */
+ if (context->tlsver == TLS_VERSION13) {
+ aad[0] = TLS_APPLICATION_DATA;
+ aad[1] = 0x03;
+ aad[2] = 0x03;
+ *((unsigned short *) &aad[3]) =
+ htons(buf_len - header_size);
+ aad_size = 5;
+ sequence = aad + 5;
+ *((uint64_t *) sequence) =
+ htonll(context->remote_sequence_number);
+ } else {
+ *((uint64_t *) aad) =
+ htonll(context->remote_sequence_number);
+ aad[8] = buf[0];
+ aad[9] = buf[1];
+ aad[10] = buf[2];
+ *((unsigned short *) &aad[11]) = htons(pt_length);
+ aad[13] = 0;
+ aad[14] = 0;
+ aad[15] = 0;
+ }
+
+ tls_buffer_expand(pt_buffer, pt_length);
+ chacha_ivupdate(&context->crypto.ctx_remote.chacha_remote,
+ context->crypto.ctx_remote_mac.remote_aead_iv,
+ sequence, (unsigned char *) &counter);
+
+ chacha_encrypt_bytes(&context->crypto.ctx_remote.chacha_remote,
+ buf + header_size, pt_buffer->buffer, pt_length);
+
+ DEBUG_DUMP_HEX_LABEL("decrypted2", pt_buffer->buffer, pt_length);
+ ptr = pt_buffer->buffer;
+ length = pt_length;
+
+ chacha20_poly1305_key(&context->crypto.ctx_remote.chacha_remote,
+ poly1305_key);
+
+ struct poly1305_context ctx;
+ tls_poly1305_init(&ctx, poly1305_key);
+ tls_poly1305_update(&ctx, aad, aad_size);
+
+ static unsigned char zeropad[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0 };
+
+ int rem = aad_size % 16;
+ if (rem) {
+ tls_poly1305_update(&ctx, zeropad, 16 - rem);
+ }
+ tls_poly1305_update(&ctx, buf + header_size, pt_length);
+
+ rem = pt_length % 16;
+ if (rem) {
+ tls_poly1305_update(&ctx, zeropad, 16 - rem);
+ }
+
+ U32TO8(&trail[0], aad_size == 5 ? 5 : 13);
+ *(int *) &trail[4] = 0;
+ U32TO8(&trail[8], pt_length);
+ *(int *) &trail[12] = 0;
+
+ tls_poly1305_update(&ctx, trail, 16);
+ tls_poly1305_finish(&ctx, mac_tag);
+ if (memcmp(mac_tag, buf + header_size + pt_length,
+ POLY1305_TAGLEN)) {
+ DEBUG_PRINT
+ ("INTEGRITY CHECK FAILED (msg length %i)\n",
+ length);
+ DEBUG_DUMP_HEX_LABEL("POLY1305 TAG RECEIVED",
+ buf + header_size + pt_length,
+ POLY1305_TAGLEN);
+ DEBUG_DUMP_HEX_LABEL("POLY1305 TAG COMPUTED",
+ mac_tag, POLY1305_TAGLEN);
+
+ tls_alert(context, 1, bad_record_mac);
+ return TLS_INTEGRITY_FAILED;
+ }
+
+ pt_buffer->len = pt_length;
+
+ return 0;
+}
+
+static int decrypt_other(struct TLSContext *context, uint16_t length, int
+ header_size, const unsigned char *ptr, const unsigned char
+ *buf, struct tls_buffer *pt_buffer) {
+ int err;
+
+ err = cbc_decrypt(buf+header_size, pt_buffer->buffer, length,
+ &context->crypto.ctx_remote.aes_remote);
+ if (err) {
+ DEBUG_PRINT("Decryption error %i\n", (int) err);
+ return TLS_BROKEN_PACKET;
+ }
+ unsigned char padding_byte = pt_buffer->buffer[length - 1];
+ unsigned char padding = padding_byte + 1;
+
+ /* poodle check */
+ int padding_index = length - padding;
+ if (padding_index > 0) {
+ int i;
+ int limit = length - 1;
+ for (i = length - padding; i < limit; i++) {
+ if (pt_buffer->buffer[i] != padding_byte) {
+ DEBUG_PRINT
+ ("BROKEN PACKET (POODLE ?)\n");
+ tls_alert(context, 1, decrypt_error);
+ return TLS_BROKEN_PACKET;
+ }
+ }
+ }
+
+ unsigned int decrypted_length = length;
+ if (padding < decrypted_length) {
+ decrypted_length -= padding;
+ pt_buffer->len -= padding;
+ }
+
+ DEBUG_DUMP_HEX_LABEL("decrypted3", pt_buffer->buffer, decrypted_length);
+ ptr = pt_buffer->buffer;
+
+ if (decrypted_length > TLS_AES_IV_LENGTH) {
+ decrypted_length -= TLS_AES_IV_LENGTH;
+ ptr += TLS_AES_IV_LENGTH;
+ tls_buffer_shift(pt_buffer, TLS_AES_IV_LENGTH);
+ }
+ length = decrypted_length;
+
+ unsigned int mac_size = tls_mac_length(context);
+ if (length < mac_size || !mac_size) {
+ DEBUG_PRINT("BROKEN PACKET\n");
+ tls_alert(context, 1, decrypt_error);
+ return TLS_BROKEN_PACKET;
+ }
+
+ length -= mac_size;
+ pt_buffer->len -= mac_size;
+
+ const unsigned char *message_hmac = &ptr[length];
+ unsigned char hmac_out[TLS_MAX_MAC_SIZE];
+ unsigned char temp_buf[5];
+ memcpy(temp_buf, buf, 3);
+ *(unsigned short *) &temp_buf[3] = htons(length);
+ unsigned int hmac_out_len = tls_hmac_message(0, context, temp_buf, 5,
+ ptr, length, hmac_out, mac_size);
+ if (hmac_out_len != mac_size
+ || memcmp(message_hmac, hmac_out, mac_size)) {
+ DEBUG_PRINT("INTEGRITY CHECK FAILED (msg length %i)\n",
+ length);
+ DEBUG_DUMP_HEX_LABEL("HMAC RECEIVED", message_hmac, mac_size);
+ DEBUG_DUMP_HEX_LABEL("HMAC COMPUTED", hmac_out, hmac_out_len);
+
+ tls_alert(context, 1, bad_record_mac);
+
+ return TLS_INTEGRITY_FAILED;
+ }
+
+ return 0;
+}
+
+static int decrypt_payload(struct TLSContext *context,
+ uint16_t length,
+ int header_size,
+ const unsigned char *ptr,
+ const unsigned char *buf,
+ int buf_len,
+ struct tls_buffer *pt_buffer
+ ) {
+
+ if (context->crypto.created == 2) {
+ return decrypt_aes_gcm(context, length, header_size, ptr, buf,
+ buf_len, pt_buffer);
+ } else if (context->crypto.created == 3) {
+ return decrypt_chacha_poly1305(context, length, header_size,
+ ptr, buf, buf_len, pt_buffer);
+ } else if (context->crypto.created == 1) {
+ return decrypt_other(context, length, header_size, ptr, buf,
+ pt_buffer);
+ } else {
+ tls_buffer_free(pt_buffer);
+ return TLS_BROKEN_PACKET;
+ }
+
+ return 0;
+}
+
+int tls_parse_message(struct TLSContext *context, unsigned char *buf,
+ int buf_len) {
+ uint8_t type = *buf;
+ uint16_t version; /* a struct of two uint8 per the rfc, but
+ we encode it as a uint16_t */
+ uint16_t length;
+ int buf_pos = 0;
+ int res = 5;
+ int payload_res = 0;
+ const unsigned char *ptr = 0;
+ /* TODO probably make this buffer part of the context,
+ * and just zero it out instead of malloc and free
+ */
+ struct tls_buffer pt;
+
+ int header_size = res;
+
+ if (buf_len < res) {
+ return TLS_NEED_MORE_DATA;
+ }
+
+ type = *buf;
+ buf_pos += 1;
+ version = get16(&buf[buf_pos]);
+ buf_pos += 2;
+
+ if (!tls_supported_version(version)) {
+ return TLS_NOT_SAFE;
+ }
+
+ length = get16(&buf[buf_pos]);
+ buf_pos += 2;
+
+ ptr = buf + buf_pos;
+
+ DEBUG_PRINT("Message type: %0x, length: %i\n", (int)type, (int)length);
+
+ /* this buffer can go out of scope */
+ tls_buffer_init(&pt, 0);
+
+ if (context->cipher_spec_set && type != TLS_CHANGE_CIPHER) {
+ /* Need to decrypt payload */
+ DEBUG_DUMP_HEX_LABEL("encrypted", &buf[header_size], length);
+
+ if (!context->crypto.created) {
+ DEBUG_PRINT("Encryption context not created\n");
+ random_sleep(TLS_MAX_ERROR_SLEEP_uS);
+ return TLS_BROKEN_PACKET;
+ }
+
+ int rv = decrypt_payload(context, length, header_size, ptr,
+ buf, buf_len, &pt);
+
+ if (rv != 0) {
+ tls_buffer_free(&pt);
+ random_sleep(TLS_MAX_ERROR_SLEEP_uS);
+ return rv;
+ }
+
+ if (pt.error) {
+ tls_buffer_free(&pt);
+ random_sleep(TLS_MAX_ERROR_SLEEP_uS);
+ return TLS_NO_MEMORY;
+ }
+
+ ptr = pt.buffer;
+ length = pt.len;
+ }
+
+ context->remote_sequence_number++;
+
+ if (context->tlsver == TLS_VERSION13) {
+ /*(context->connection_status == 2) && */
+ if (type == TLS_APPLICATION_DATA && context->crypto.created) {
+ do {
+ length--;
+ type = ptr[length];
+ } while (!type);
+ }
+ }
+
+ /* TODO for v1.3 encrypted handshake messages will show up
+ * as application data, so we need to re-compute the record
+ * type, that may be what the above is doing
+ */
+
+ switch (type) {
+ /* application data */
+ case TLS_APPLICATION_DATA:
+ if (context->connection_status != TLS_CONNECTED) {
+ DEBUG_PRINT
+ ("UNEXPECTED APPLICATION DATA MESSAGE\n");
+ payload_res = TLS_UNEXPECTED_MESSAGE;
+ tls_alert(context, 1, unexpected_message);
+ break;
+ }
+ DEBUG_PRINT
+ ("APPLICATION DATA MESSAGE (TLS VERSION: %x):\n",
+ (int) context->version);
+ DEBUG_DUMP(ptr, length);
+ DEBUG_PRINT("\n");
+ tls_buffer_append(&context->application_buffer, ptr, length);
+ if (context->application_buffer.error) {
+ payload_res = TLS_NO_MEMORY;
+ }
+ break;
+ /* handshake */
+ case TLS_HANDSHAKE:
+ DEBUG_PRINT("HANDSHAKE MESSAGE\n");
+ payload_res = tls_parse_payload(context, ptr, length);
+ break;
+ /* change cipher spec */
+ case TLS_CHANGE_CIPHER:
+ if (context->connection_status != 2) {
+ if (context->connection_status == 4) {
+ DEBUG_PRINT
+ ("IGNORING CHANGE CIPHER SPEC MESSAGE (HELLO RETRY REQUEST)\n");
+ break;
+ }
+ DEBUG_PRINT
+ ("UNEXPECTED CHANGE CIPHER SPEC MESSAGE (%i)\n",
+ context->connection_status);
+ tls_alert(context, 1, unexpected_message);
+ payload_res = TLS_UNEXPECTED_MESSAGE;
+ } else {
+ DEBUG_PRINT("CHANGE CIPHER SPEC MESSAGE\n");
+ context->cipher_spec_set = 1;
+ /* reset sequence numbers */
+ context->remote_sequence_number = 0;
+ }
+ break;
+ /* alert */
+ case TLS_ALERT:
+ DEBUG_PRINT("ALERT MESSAGE\n");
+ if (length >= 2) {
+ DEBUG_PRINT("ALERT MESSAGE ...\n");
+ DEBUG_DUMP_HEX(ptr, length);
+ int level = ptr[0];
+ int code = ptr[1];
+ DEBUG_PRINT("level = %d, code = %d\n",
+ level, code);
+ if (level == TLS_ALERT_CRITICAL) {
+ context->critical_error = 1;
+ res = TLS_ERROR_ALERT;
+ }
+ context->error_code = code;
+ } else {
+ DEBUG_PRINT("ALERT MESSAGE short\n");
+ }
+
+ break;
+ default:
+ DEBUG_PRINT("UNKNOWN MESSAGE TYPE: %x\n", (int)type);
+ payload_res = TLS_NOT_UNDERSTOOD;
+ break;
+ }
+
+ tls_buffer_free(&pt);
+
+ if (payload_res < 0) {
+ return payload_res;
+ }
+
+ if (res > 0) {
+ return header_size + length;
+ }
+
+ return res;
+}