]> pd.if.org Git - zpackage/blobdiff - crypto/parse_message.c
commit files needed for zpm-fetchurl
[zpackage] / crypto / parse_message.c
diff --git a/crypto/parse_message.c b/crypto/parse_message.c
new file mode 100644 (file)
index 0000000..e4e7b15
--- /dev/null
@@ -0,0 +1,562 @@
+#define _POSIX_C_SOURCE 200809L
+
+#include <fcntl.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <strings.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <time.h>
+
+#include <sys/socket.h>
+#include <arpa/inet.h>
+#include <unistd.h>
+
+#include <errno.h>
+
+#include <tomcrypt.h>
+
+#include "buffer.h"
+#include "tlse.h"
+
+#ifndef htonll
+#define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
+#endif
+
+#ifndef ntohll
+#define ntohll(x) ((1==ntohl(1)) ? (x) : ((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32))
+#endif
+
+static uint16_t get16(const unsigned char *buf) {
+       uint16_t res;
+
+       res = ((*buf) << 8) + (*(buf+1));
+       return res;
+}
+
+static int tls_microsleep(unsigned int microseconds) {
+       struct timespec ts;
+       struct timespec rem;
+       int rv;
+
+       ts.tv_sec = (time_t) (microseconds / 1000000L);
+       ts.tv_nsec = (long) (((long) microseconds % 1000000L) * 1000L);
+
+       errno = 0;
+       while ((rv = nanosleep(&ts, &rem)) == -1) {
+               if (errno != EINTR) {
+                       return rv;
+               }
+               ts = rem;
+       }
+       return rv;
+}
+
+/* TODO this is biased unless limit is a power of 2 */
+static unsigned int random_int(int limit) {
+       unsigned int res = 0;
+       tls_random((unsigned char *) &res, sizeof(int));
+       if (limit) {
+               res %= limit;
+       }
+       return res;
+}
+
+static int random_sleep(long max) {
+       return tls_microsleep(random_int(max));
+}
+
+static void U32TO8(unsigned char *p, unsigned long v) {
+       p[0] = (v) & 0xff;
+       p[1] = (v >> 8) & 0xff;
+       p[2] = (v >> 16) & 0xff;
+       p[3] = (v >> 24) & 0xff;
+}
+
+static int decrypt_aes_gcm(struct TLSContext *context, uint16_t length, int
+               header_size, const unsigned char *ptr, const unsigned char
+               *buf, int buf_len, struct tls_buffer *pt_buffer) {
+       int delta = 8;
+       int pt_length;
+       unsigned char iv[TLS_13_AES_GCM_IV_LENGTH];
+       unsigned char aad[16];
+       int aad_size = sizeof(aad);
+       unsigned char *sequence = aad;
+
+       gcm_state *remote_gcm;
+       remote_gcm = &context->crypto.ctx_remote.aes_gcm_remote;
+
+       gcm_reset(remote_gcm);
+
+       /* set up the aad */
+       if (context->tlsver == TLS_VERSION13) {
+               aad[0] = TLS_APPLICATION_DATA;
+               aad[1] = 0x03;
+               aad[2] = 0x03;
+               *((unsigned short *) &aad[3]) =
+                       htons(buf_len - header_size);
+               aad_size = 5;
+               sequence = aad + 5;
+               *((uint64_t *) sequence) =
+                       htonll(context->remote_sequence_number);
+               memcpy(iv, context->crypto.ctx_remote_mac.remote_iv,
+                               TLS_13_AES_GCM_IV_LENGTH);
+               int i;
+               int offset = TLS_13_AES_GCM_IV_LENGTH - 8;
+               for (i = 0; i < 8; i++) {
+                       iv[offset + i] =
+                               context->crypto.ctx_remote_mac.remote_iv[offset +
+                               i] ^ sequence[i];
+               }
+               pt_length = buf_len - header_size - TLS_GCM_TAG_LEN;
+               delta = 0;
+       } else {
+               aad_size = 13;
+               pt_length = length - 8 - TLS_GCM_TAG_LEN;
+
+               /* build aad and iv */
+               *((uint64_t *) aad) = htonll(context->remote_sequence_number);
+               aad[8] = buf[0];
+               aad[9] = buf[1];
+               aad[10] = buf[2];
+
+               memcpy(iv, context->crypto.ctx_remote_mac.remote_aead_iv, 4);
+               memcpy(iv + 4, buf + header_size, 8);
+               *((unsigned short *) &aad[11]) = htons(pt_length);
+       }
+
+       if (pt_length < 0) {
+               DEBUG_PRINT("Invalid packet length");
+               return TLS_BROKEN_PACKET;
+       }
+       DEBUG_DUMP_HEX_LABEL("aad", aad, aad_size);
+       DEBUG_DUMP_HEX_LABEL("aad iv", iv, 12);
+
+       /* I think we do all of this even if they fail to avoid timing
+        * attacks
+        */
+       int res0 = gcm_add_iv(&context->crypto.ctx_remote.aes_gcm_remote, iv, 12);
+       int res1 = gcm_add_aad(&context->crypto.ctx_remote.aes_gcm_remote, aad, aad_size);
+
+       DEBUG_PRINT("PT SIZE: %i\n", pt_length);
+
+       /* TODO we might want to expand this buffer before we call this
+        * function */
+       tls_buffer_expand(pt_buffer, pt_length);
+
+       int res2 = gcm_process(&context->crypto.ctx_remote.
+                       aes_gcm_remote, pt_buffer->buffer,
+                       pt_length,
+                       (char *)buf + header_size + delta,
+                       GCM_DECRYPT);
+       pt_buffer->len = pt_length;
+
+       unsigned char tag[32];
+       unsigned long taglen = 32;
+       int res3 = gcm_done(&context->crypto.ctx_remote.aes_gcm_remote,
+                       tag, &taglen);
+
+       if (res0 || res1 || res2 || res3 || taglen != TLS_GCM_TAG_LEN) {
+               DEBUG_PRINT
+                       ("ERROR: gcm_add_iv: %i, gcm_add_aad: %i, gcm_process: %i, gcm_done: %i\n",
+                        res0, res1, res2, res3);
+               return TLS_BROKEN_PACKET;
+       }
+
+       DEBUG_DUMP_HEX_LABEL("decrypted1", pt_buffer->buffer, pt_length);
+       DEBUG_DUMP_HEX_LABEL("tag", tag, taglen);
+
+       /* check tag */
+       if (memcmp(buf + header_size + delta + pt_length, tag, taglen)) {
+               DEBUG_PRINT("INTEGRITY CHECK FAILED (msg length %i)\n",
+                               pt_length);
+               DEBUG_DUMP_HEX_LABEL("TAG RECEIVED",
+                               buf + header_size + delta + pt_length,
+                               taglen);
+               DEBUG_DUMP_HEX_LABEL("TAG COMPUTED", tag, taglen);
+               tls_alert(context, 1, bad_record_mac);
+               return TLS_INTEGRITY_FAILED;
+       }
+       ptr = pt_buffer->buffer;
+       length = pt_length;
+
+       return 0;
+}
+
+static int decrypt_chacha_poly1305(struct TLSContext *context, uint16_t length,
+               int header_size, const unsigned char *ptr, const unsigned char
+               *buf, int buf_len, struct tls_buffer *pt_buffer) {
+
+       unsigned char aad[16];
+       int aad_size = sizeof(aad);
+       unsigned char *sequence = aad;
+
+       int pt_length = length - POLY1305_TAGLEN;
+       unsigned int counter = 1;
+       unsigned char poly1305_key[POLY1305_KEYLEN];
+       unsigned char trail[16];
+       unsigned char mac_tag[POLY1305_TAGLEN];
+
+       aad_size = 16;
+       if (pt_length < 0) {
+               DEBUG_PRINT("Invalid packet length");
+               return TLS_BROKEN_PACKET;
+       }
+
+       /* set up add */
+       if (context->tlsver == TLS_VERSION13) {
+               aad[0] = TLS_APPLICATION_DATA;
+               aad[1] = 0x03;
+               aad[2] = 0x03;
+               *((unsigned short *) &aad[3]) =
+                       htons(buf_len - header_size);
+               aad_size = 5;
+               sequence = aad + 5;
+               *((uint64_t *) sequence) =
+                       htonll(context->remote_sequence_number);
+       } else {
+               *((uint64_t *) aad) =
+                       htonll(context->remote_sequence_number);
+               aad[8] = buf[0];
+               aad[9] = buf[1];
+               aad[10] = buf[2];
+               *((unsigned short *) &aad[11]) = htons(pt_length);
+               aad[13] = 0;
+               aad[14] = 0;
+               aad[15] = 0;
+       }
+
+       tls_buffer_expand(pt_buffer, pt_length);
+       chacha_ivupdate(&context->crypto.ctx_remote.chacha_remote,
+                       context->crypto.ctx_remote_mac.remote_aead_iv,
+                       sequence, (unsigned char *) &counter);
+
+       chacha_encrypt_bytes(&context->crypto.ctx_remote.chacha_remote,
+                       buf + header_size, pt_buffer->buffer, pt_length);
+
+       DEBUG_DUMP_HEX_LABEL("decrypted2", pt_buffer->buffer, pt_length);
+       ptr = pt_buffer->buffer;
+       length = pt_length;
+
+       chacha20_poly1305_key(&context->crypto.ctx_remote.chacha_remote,
+                       poly1305_key);
+
+       struct poly1305_context ctx;
+       tls_poly1305_init(&ctx, poly1305_key);
+       tls_poly1305_update(&ctx, aad, aad_size);
+
+       static unsigned char zeropad[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0,
+               0, 0, 0, 0, 0, 0 };
+
+       int rem = aad_size % 16;
+       if (rem) {
+               tls_poly1305_update(&ctx, zeropad, 16 - rem);
+       }
+       tls_poly1305_update(&ctx, buf + header_size, pt_length);
+
+       rem = pt_length % 16;
+       if (rem) {
+               tls_poly1305_update(&ctx, zeropad, 16 - rem);
+       }
+
+       U32TO8(&trail[0], aad_size == 5 ? 5 : 13);
+       *(int *) &trail[4] = 0;
+       U32TO8(&trail[8], pt_length);
+       *(int *) &trail[12] = 0;
+
+       tls_poly1305_update(&ctx, trail, 16);
+       tls_poly1305_finish(&ctx, mac_tag);
+       if (memcmp(mac_tag, buf + header_size + pt_length,
+                               POLY1305_TAGLEN)) {
+               DEBUG_PRINT
+                       ("INTEGRITY CHECK FAILED (msg length %i)\n",
+                        length);
+               DEBUG_DUMP_HEX_LABEL("POLY1305 TAG RECEIVED",
+                               buf + header_size + pt_length,
+                               POLY1305_TAGLEN);
+               DEBUG_DUMP_HEX_LABEL("POLY1305 TAG COMPUTED",
+                               mac_tag, POLY1305_TAGLEN);
+
+               tls_alert(context, 1, bad_record_mac);
+               return TLS_INTEGRITY_FAILED;
+       }
+
+       pt_buffer->len = pt_length;
+
+       return 0;
+}
+
+static int decrypt_other(struct TLSContext *context, uint16_t length, int
+               header_size, const unsigned char *ptr, const unsigned char
+               *buf, struct tls_buffer *pt_buffer) {
+       int err;
+
+       err = cbc_decrypt(buf+header_size, pt_buffer->buffer, length,
+                       &context->crypto.ctx_remote.aes_remote);
+       if (err) {
+               DEBUG_PRINT("Decryption error %i\n", (int) err);
+               return TLS_BROKEN_PACKET;
+       }
+       unsigned char padding_byte = pt_buffer->buffer[length - 1];
+       unsigned char padding = padding_byte + 1;
+
+       /* poodle check */
+       int padding_index = length - padding;
+       if (padding_index > 0) {
+               int i;
+               int limit = length - 1;
+               for (i = length - padding; i < limit; i++) {
+                       if (pt_buffer->buffer[i] != padding_byte) {
+                               DEBUG_PRINT
+                                       ("BROKEN PACKET (POODLE ?)\n");
+                               tls_alert(context, 1, decrypt_error);
+                               return TLS_BROKEN_PACKET;
+                       }
+               }
+       }
+
+       unsigned int decrypted_length = length;
+       if (padding < decrypted_length) {
+               decrypted_length -= padding;
+               pt_buffer->len -= padding;
+       }
+
+       DEBUG_DUMP_HEX_LABEL("decrypted3", pt_buffer->buffer, decrypted_length);
+       ptr = pt_buffer->buffer;
+
+       if (decrypted_length > TLS_AES_IV_LENGTH) {
+               decrypted_length -= TLS_AES_IV_LENGTH;
+               ptr += TLS_AES_IV_LENGTH;
+               tls_buffer_shift(pt_buffer, TLS_AES_IV_LENGTH);
+       }
+       length = decrypted_length;
+
+       unsigned int mac_size = tls_mac_length(context);
+       if (length < mac_size || !mac_size) {
+               DEBUG_PRINT("BROKEN PACKET\n");
+               tls_alert(context, 1, decrypt_error);
+               return TLS_BROKEN_PACKET;
+       }
+
+       length -= mac_size;
+       pt_buffer->len -= mac_size;
+
+       const unsigned char *message_hmac = &ptr[length];
+       unsigned char hmac_out[TLS_MAX_MAC_SIZE];
+       unsigned char temp_buf[5];
+       memcpy(temp_buf, buf, 3);
+       *(unsigned short *) &temp_buf[3] = htons(length);
+       unsigned int hmac_out_len = tls_hmac_message(0, context, temp_buf, 5,
+                       ptr, length, hmac_out, mac_size);
+       if (hmac_out_len != mac_size
+                       || memcmp(message_hmac, hmac_out, mac_size)) {
+               DEBUG_PRINT("INTEGRITY CHECK FAILED (msg length %i)\n",
+                               length);
+               DEBUG_DUMP_HEX_LABEL("HMAC RECEIVED", message_hmac, mac_size);
+               DEBUG_DUMP_HEX_LABEL("HMAC COMPUTED", hmac_out, hmac_out_len);
+
+               tls_alert(context, 1, bad_record_mac);
+
+               return TLS_INTEGRITY_FAILED;
+       }
+
+       return 0;
+}
+
+static int decrypt_payload(struct TLSContext *context,
+               uint16_t length,
+               int header_size,
+               const unsigned char *ptr,
+               const unsigned char *buf,
+               int buf_len,
+               struct tls_buffer *pt_buffer
+               ) {
+
+       if (context->crypto.created == 2) {
+               return decrypt_aes_gcm(context, length, header_size, ptr, buf,
+                               buf_len, pt_buffer);
+       } else if (context->crypto.created == 3) {
+               return decrypt_chacha_poly1305(context, length, header_size,
+                               ptr, buf, buf_len, pt_buffer);
+       } else if (context->crypto.created == 1) {
+               return decrypt_other(context, length, header_size, ptr, buf,
+                               pt_buffer);
+       } else {
+               tls_buffer_free(pt_buffer);
+               return TLS_BROKEN_PACKET;
+       }
+
+       return 0;
+}
+
+int tls_parse_message(struct TLSContext *context, unsigned char *buf,
+                     int buf_len) {
+       uint8_t type = *buf;
+       uint16_t version; /* a struct of two uint8 per the rfc, but
+                            we encode it as a uint16_t */
+       uint16_t length;
+       int buf_pos = 0;
+       int res = 5;
+       int payload_res = 0;
+       const unsigned char *ptr = 0;
+       /* TODO probably make this buffer part of the context,
+        * and just zero it out instead of malloc and free
+        */
+       struct tls_buffer pt;
+
+       int header_size = res;
+
+       if (buf_len < res) {
+               return TLS_NEED_MORE_DATA;
+       }
+
+       type = *buf;
+       buf_pos += 1;
+       version = get16(&buf[buf_pos]);
+       buf_pos += 2;
+
+       if (!tls_supported_version(version)) {
+               return TLS_NOT_SAFE;
+       }
+
+       length = get16(&buf[buf_pos]);
+       buf_pos += 2;
+
+       ptr = buf + buf_pos;
+
+       DEBUG_PRINT("Message type: %0x, length: %i\n", (int)type, (int)length);
+
+       /* this buffer can go out of scope */
+       tls_buffer_init(&pt, 0);
+
+       if (context->cipher_spec_set && type != TLS_CHANGE_CIPHER) {
+               /* Need to decrypt payload */
+               DEBUG_DUMP_HEX_LABEL("encrypted", &buf[header_size], length);
+
+               if (!context->crypto.created) {
+                       DEBUG_PRINT("Encryption context not created\n");
+                       random_sleep(TLS_MAX_ERROR_SLEEP_uS);
+                       return TLS_BROKEN_PACKET;
+               }
+
+               int rv = decrypt_payload(context, length, header_size, ptr,
+                               buf, buf_len, &pt);
+
+               if (rv != 0) {
+                       tls_buffer_free(&pt);
+                       random_sleep(TLS_MAX_ERROR_SLEEP_uS);
+                       return rv;
+               }
+
+               if (pt.error) {
+                       tls_buffer_free(&pt);
+                       random_sleep(TLS_MAX_ERROR_SLEEP_uS);
+                       return TLS_NO_MEMORY;
+               }
+
+               ptr = pt.buffer;
+               length = pt.len;
+       }
+
+       context->remote_sequence_number++;
+
+       if (context->tlsver == TLS_VERSION13) {
+               /*(context->connection_status == 2) && */
+               if (type == TLS_APPLICATION_DATA && context->crypto.created) {
+                       do {
+                               length--;
+                               type = ptr[length];
+                       } while (!type);
+               }
+       }
+
+       /* TODO for v1.3 encrypted handshake messages will show up
+        * as application data, so we need to re-compute the record
+        * type, that may be what the above is doing
+        */
+
+       switch (type) {
+               /* application data */
+               case TLS_APPLICATION_DATA:
+                       if (context->connection_status != TLS_CONNECTED) {
+                               DEBUG_PRINT
+                                       ("UNEXPECTED APPLICATION DATA MESSAGE\n");
+                               payload_res = TLS_UNEXPECTED_MESSAGE;
+                               tls_alert(context, 1, unexpected_message);
+                               break;
+                       }
+                       DEBUG_PRINT
+                               ("APPLICATION DATA MESSAGE (TLS VERSION: %x):\n",
+                                (int) context->version);
+                       DEBUG_DUMP(ptr, length);
+                       DEBUG_PRINT("\n");
+                       tls_buffer_append(&context->application_buffer, ptr, length);
+                       if (context->application_buffer.error) {
+                               payload_res =  TLS_NO_MEMORY;
+                       }
+                       break;
+                       /* handshake */
+               case TLS_HANDSHAKE:
+                       DEBUG_PRINT("HANDSHAKE MESSAGE\n");
+                       payload_res = tls_parse_payload(context, ptr, length);
+                       break;
+                       /* change cipher spec */
+               case TLS_CHANGE_CIPHER:
+                       if (context->connection_status != 2) {
+                               if (context->connection_status == 4) {
+                                       DEBUG_PRINT
+                                               ("IGNORING CHANGE CIPHER SPEC MESSAGE (HELLO RETRY REQUEST)\n");
+                                       break;
+                               }
+                               DEBUG_PRINT
+                                       ("UNEXPECTED CHANGE CIPHER SPEC MESSAGE (%i)\n",
+                                        context->connection_status);
+                               tls_alert(context, 1, unexpected_message);
+                               payload_res = TLS_UNEXPECTED_MESSAGE;
+                       } else {
+                               DEBUG_PRINT("CHANGE CIPHER SPEC MESSAGE\n");
+                               context->cipher_spec_set = 1;
+                               /* reset sequence numbers */
+                               context->remote_sequence_number = 0;
+                       }
+                       break;
+                       /* alert */
+               case TLS_ALERT:
+                       DEBUG_PRINT("ALERT MESSAGE\n");
+                       if (length >= 2) {
+                               DEBUG_PRINT("ALERT MESSAGE ...\n");
+                               DEBUG_DUMP_HEX(ptr, length);
+                               int level = ptr[0];
+                               int code = ptr[1];
+                               DEBUG_PRINT("level = %d, code = %d\n",
+                                               level, code);
+                               if (level == TLS_ALERT_CRITICAL) {
+                                       context->critical_error = 1;
+                                       res = TLS_ERROR_ALERT;
+                               }
+                               context->error_code = code;
+                       } else {
+                               DEBUG_PRINT("ALERT MESSAGE short\n");
+                       }
+
+                       break;
+               default:
+                       DEBUG_PRINT("UNKNOWN MESSAGE TYPE: %x\n", (int)type);
+                       payload_res =  TLS_NOT_UNDERSTOOD;
+                       break;
+       }
+
+       tls_buffer_free(&pt);
+
+       if (payload_res < 0) {
+               return payload_res;
+       }
+
+       if (res > 0) {
+               return header_size + length;
+       }
+
+       return res;
+}