]> pd.if.org Git - zpackage/blob - crypto/parse_message.c
remove stray debug fprintf
[zpackage] / crypto / parse_message.c
1 #define _POSIX_C_SOURCE 200809L
2
3 #include <fcntl.h>
4 #include <stdint.h>
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8 #include <strings.h>
9 #include <sys/mman.h>
10 #include <sys/stat.h>
11 #include <time.h>
12
13 #include <sys/socket.h>
14 #include <arpa/inet.h>
15 #include <unistd.h>
16
17 #include <errno.h>
18
19 #include <tomcrypt.h>
20
21 #include "buffer.h"
22 #include "tlse.h"
23
24 #ifndef htonll
25 #define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
26 #endif
27
28 #ifndef ntohll
29 #define ntohll(x) ((1==ntohl(1)) ? (x) : ((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32))
30 #endif
31
32 #ifdef DEBUG
33 static char *tls_alert_msg_name(int code) {
34         switch (code) {
35                 case 0: return "close_notify"; break;
36                 case 10: return "unexpected_message"; break;
37                 case 20: return "bad_record_mac"; break;
38                 default: break;
39         }
40         return "unknown alert";
41 }
42 #endif
43
44 static uint16_t get16(const unsigned char *buf) {
45         uint16_t res;
46
47         res = ((*buf) << 8) + (*(buf+1));
48         return res;
49 }
50
51 static int tls_microsleep(unsigned int microseconds) {
52         struct timespec ts;
53         struct timespec rem;
54         int rv;
55
56         ts.tv_sec = (time_t) (microseconds / 1000000L);
57         ts.tv_nsec = (long) (((long) microseconds % 1000000L) * 1000L);
58
59         errno = 0;
60         while ((rv = nanosleep(&ts, &rem)) == -1) {
61                 if (errno != EINTR) {
62                         return rv;
63                 }
64                 ts = rem;
65         }
66         return rv;
67 }
68
69 /* TODO this is biased unless limit is a power of 2 */
70 static unsigned int random_int(int limit) {
71         unsigned int res = 0;
72         tls_random((unsigned char *) &res, sizeof(int));
73         if (limit) {
74                 res %= limit;
75         }
76         return res;
77 }
78
79 static int random_sleep(long max) {
80         return tls_microsleep(random_int(max));
81 }
82
83 static void U32TO8(unsigned char *p, unsigned long v) {
84         p[0] = (v) & 0xff;
85         p[1] = (v >> 8) & 0xff;
86         p[2] = (v >> 16) & 0xff;
87         p[3] = (v >> 24) & 0xff;
88 }
89
90 static int decrypt_aes_gcm(struct TLSContext *context, uint16_t length, int
91                 header_size, const unsigned char *ptr, const unsigned char
92                 *buf, int buf_len, struct tls_buffer *pt_buffer) {
93         int delta = 8;
94         int pt_length;
95         unsigned char iv[TLS_13_AES_GCM_IV_LENGTH];
96         unsigned char aad[16];
97         int aad_size = sizeof(aad);
98         unsigned char *sequence = aad;
99
100         gcm_state *remote_gcm;
101         remote_gcm = &context->crypto.ctx_remote.aes_gcm_remote;
102
103         gcm_reset(remote_gcm);
104
105         /* set up the aad */
106         if (context->tlsver == TLS_VERSION13) {
107                 aad[0] = TLS_APPLICATION_DATA;
108                 aad[1] = 0x03;
109                 aad[2] = 0x03;
110                 *((unsigned short *) &aad[3]) =
111                         htons(buf_len - header_size);
112                 aad_size = 5;
113                 sequence = aad + 5;
114                 *((uint64_t *) sequence) =
115                         htonll(context->remote_sequence_number);
116                 memcpy(iv, context->crypto.ctx_remote_mac.remote_iv,
117                                 TLS_13_AES_GCM_IV_LENGTH);
118                 int i;
119                 int offset = TLS_13_AES_GCM_IV_LENGTH - 8;
120                 for (i = 0; i < 8; i++) {
121                         iv[offset + i] =
122                                 context->crypto.ctx_remote_mac.remote_iv[offset +
123                                 i] ^ sequence[i];
124                 }
125                 pt_length = buf_len - header_size - TLS_GCM_TAG_LEN;
126                 delta = 0;
127         } else {
128                 aad_size = 13;
129                 pt_length = length - 8 - TLS_GCM_TAG_LEN;
130
131                 /* build aad and iv */
132                 *((uint64_t *) aad) = htonll(context->remote_sequence_number);
133                 aad[8] = buf[0];
134                 aad[9] = buf[1];
135                 aad[10] = buf[2];
136
137                 memcpy(iv, context->crypto.ctx_remote_mac.remote_aead_iv, 4);
138                 memcpy(iv + 4, buf + header_size, 8);
139                 *((unsigned short *) &aad[11]) = htons(pt_length);
140         }
141
142         if (pt_length < 0) {
143                 DEBUG_PRINT("Invalid packet length");
144                 return TLS_BROKEN_PACKET;
145         }
146         DEBUG_DUMP_HEX_LABEL("aad", aad, aad_size);
147         DEBUG_DUMP_HEX_LABEL("aad iv", iv, 12);
148
149         /* I think we do all of this even if they fail to avoid timing
150          * attacks
151          */
152         int res0 = gcm_add_iv(&context->crypto.ctx_remote.aes_gcm_remote, iv, 12);
153         int res1 = gcm_add_aad(&context->crypto.ctx_remote.aes_gcm_remote, aad, aad_size);
154
155         DEBUG_PRINT("PT SIZE: %i\n", pt_length);
156
157         /* TODO we might want to expand this buffer before we call this
158          * function */
159         tls_buffer_expand(pt_buffer, pt_length);
160
161         int res2 = gcm_process(&context->crypto.ctx_remote.
162                         aes_gcm_remote, pt_buffer->buffer,
163                         pt_length,
164                         (char *)buf + header_size + delta,
165                         GCM_DECRYPT);
166         pt_buffer->len = pt_length;
167
168         unsigned char tag[32];
169         unsigned long taglen = 32;
170         int res3 = gcm_done(&context->crypto.ctx_remote.aes_gcm_remote,
171                         tag, &taglen);
172
173         if (res0 || res1 || res2 || res3 || taglen != TLS_GCM_TAG_LEN) {
174                 DEBUG_PRINT
175                         ("ERROR: gcm_add_iv: %i, gcm_add_aad: %i, gcm_process: %i, gcm_done: %i\n",
176                          res0, res1, res2, res3);
177                 return TLS_BROKEN_PACKET;
178         }
179
180         DEBUG_DUMP_HEX_LABEL("decrypted1", pt_buffer->buffer, pt_length);
181         DEBUG_DUMP_HEX_LABEL("tag", tag, taglen);
182
183         /* check tag */
184         if (memcmp(buf + header_size + delta + pt_length, tag, taglen)) {
185                 DEBUG_PRINT("INTEGRITY CHECK FAILED (msg length %i)\n",
186                                 pt_length);
187                 DEBUG_DUMP_HEX_LABEL("TAG RECEIVED",
188                                 buf + header_size + delta + pt_length,
189                                 taglen);
190                 DEBUG_DUMP_HEX_LABEL("TAG COMPUTED", tag, taglen);
191                 tls_alert(context, 1, bad_record_mac);
192                 return TLS_INTEGRITY_FAILED;
193         }
194         ptr = pt_buffer->buffer;
195         length = pt_length;
196
197         return 0;
198 }
199
200 static int decrypt_chacha_poly1305(struct TLSContext *context, uint16_t length,
201                 int header_size, const unsigned char *ptr, const unsigned char
202                 *buf, int buf_len, struct tls_buffer *pt_buffer) {
203
204         unsigned char aad[16];
205         int aad_size = sizeof(aad);
206         unsigned char *sequence = aad;
207
208         int pt_length = length - POLY1305_TAGLEN;
209         unsigned int counter = 1;
210         unsigned char poly1305_key[POLY1305_KEYLEN];
211         unsigned char trail[16];
212         unsigned char mac_tag[POLY1305_TAGLEN];
213
214         aad_size = 16;
215         if (pt_length < 0) {
216                 DEBUG_PRINT("Invalid packet length");
217                 return TLS_BROKEN_PACKET;
218         }
219
220         /* set up add */
221         if (context->tlsver == TLS_VERSION13) {
222                 aad[0] = TLS_APPLICATION_DATA;
223                 aad[1] = 0x03;
224                 aad[2] = 0x03;
225                 *((unsigned short *) &aad[3]) =
226                         htons(buf_len - header_size);
227                 aad_size = 5;
228                 sequence = aad + 5;
229                 *((uint64_t *) sequence) =
230                         htonll(context->remote_sequence_number);
231         } else {
232                 *((uint64_t *) aad) =
233                         htonll(context->remote_sequence_number);
234                 aad[8] = buf[0];
235                 aad[9] = buf[1];
236                 aad[10] = buf[2];
237                 *((unsigned short *) &aad[11]) = htons(pt_length);
238                 aad[13] = 0;
239                 aad[14] = 0;
240                 aad[15] = 0;
241         }
242
243         tls_buffer_expand(pt_buffer, pt_length);
244         chacha_ivupdate(&context->crypto.ctx_remote.chacha_remote,
245                         context->crypto.ctx_remote_mac.remote_aead_iv,
246                         sequence, (unsigned char *) &counter);
247
248         chacha_encrypt_bytes(&context->crypto.ctx_remote.chacha_remote,
249                         buf + header_size, pt_buffer->buffer, pt_length);
250
251         DEBUG_DUMP_HEX_LABEL("decrypted2", pt_buffer->buffer, pt_length);
252         ptr = pt_buffer->buffer;
253         length = pt_length;
254
255         chacha20_poly1305_key(&context->crypto.ctx_remote.chacha_remote,
256                         poly1305_key);
257
258         struct poly1305_context ctx;
259         tls_poly1305_init(&ctx, poly1305_key);
260         tls_poly1305_update(&ctx, aad, aad_size);
261
262         static unsigned char zeropad[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0,
263                 0, 0, 0, 0, 0, 0 };
264
265         int rem = aad_size % 16;
266         if (rem) {
267                 tls_poly1305_update(&ctx, zeropad, 16 - rem);
268         }
269         tls_poly1305_update(&ctx, buf + header_size, pt_length);
270
271         rem = pt_length % 16;
272         if (rem) {
273                 tls_poly1305_update(&ctx, zeropad, 16 - rem);
274         }
275
276         U32TO8(&trail[0], aad_size == 5 ? 5 : 13);
277         *(int *) &trail[4] = 0;
278         U32TO8(&trail[8], pt_length);
279         *(int *) &trail[12] = 0;
280
281         tls_poly1305_update(&ctx, trail, 16);
282         tls_poly1305_finish(&ctx, mac_tag);
283         if (memcmp(mac_tag, buf + header_size + pt_length,
284                                 POLY1305_TAGLEN)) {
285                 DEBUG_PRINT
286                         ("INTEGRITY CHECK FAILED (msg length %i)\n",
287                          length);
288                 DEBUG_DUMP_HEX_LABEL("POLY1305 TAG RECEIVED",
289                                 buf + header_size + pt_length,
290                                 POLY1305_TAGLEN);
291                 DEBUG_DUMP_HEX_LABEL("POLY1305 TAG COMPUTED",
292                                 mac_tag, POLY1305_TAGLEN);
293
294                 tls_alert(context, 1, bad_record_mac);
295                 return TLS_INTEGRITY_FAILED;
296         }
297
298         pt_buffer->len = pt_length;
299
300         return 0;
301 }
302
303 static int decrypt_other(struct TLSContext *context, uint16_t length, int
304                 header_size, const unsigned char *ptr, const unsigned char
305                 *buf, struct tls_buffer *pt_buffer) {
306         int err;
307
308         ENTER;
309         tls_buffer_expand(pt_buffer, length);
310         DEBUG_PRINTLN("cbc_decrypt(%p, %p, %lu, %p)\n",
311                         buf+header_size, pt_buffer->buffer,
312                         (unsigned long)length, &context->crypto.ctx_remote.aes_remote);
313         err = cbc_decrypt(buf+header_size, pt_buffer->buffer, length,
314                         &context->crypto.ctx_remote.aes_remote);
315         if (err) {
316                 DEBUG_PRINTLN("Decryption error %d %s\n", err,
317                                 error_to_string(err));
318                 LEAVE;
319                 return TLS_BROKEN_PACKET;
320         }
321         pt_buffer->len = length;
322         unsigned char padding_byte = pt_buffer->buffer[length - 1];
323         unsigned char padding = padding_byte + 1;
324         DEBUG_PRINTLN("cbc padding byte = %d\n", (int)padding_byte);
325
326         /* poodle check */
327         int padding_index = length - padding;
328         if (padding_index > 0) {
329                 int i;
330                 int limit = length - 1;
331                 for (i = length - padding; i < limit; i++) {
332                         if (pt_buffer->buffer[i] != padding_byte) {
333                                 DEBUG_PRINTLN("BROKEN PACKET (POODLE ?)\n");
334                                 tls_alert(context, 1, decrypt_error);
335                                 LEAVE;
336                                 return TLS_BROKEN_PACKET;
337                         }
338                 }
339         }
340
341         unsigned int decrypted_length = length;
342         if (padding < decrypted_length) {
343                 decrypted_length -= padding;
344                 pt_buffer->len -= padding;
345         }
346
347         DEBUG_DUMP_HEX_LABEL("decrypted", pt_buffer->buffer, decrypted_length);
348
349         if (decrypted_length > TLS_AES_IV_LENGTH) {
350                 decrypted_length -= TLS_AES_IV_LENGTH;
351                 tls_buffer_shift(pt_buffer, TLS_AES_IV_LENGTH);
352         }
353         ptr = pt_buffer->buffer;
354         length = decrypted_length;
355
356         unsigned int mac_size = tls_mac_length(context);
357         if (length < mac_size || !mac_size) {
358                 DEBUG_PRINTLN("BROKEN PACKET\n");
359                 tls_alert(context, 1, decrypt_error);
360                                 LEAVE;
361                 return TLS_BROKEN_PACKET;
362         }
363
364         DEBUG_PRINTLN("mac size %u\n", mac_size);
365         length -= mac_size;
366         pt_buffer->len -= mac_size;
367
368         const unsigned char *message_hmac = &ptr[length];
369         unsigned char hmac_out[TLS_MAX_MAC_SIZE];
370         unsigned char temp_buf[5];
371         memcpy(temp_buf, buf, 3);
372         *(unsigned short *) &temp_buf[3] = htons(length);
373         unsigned int hmac_out_len = tls_hmac_message(0, context, temp_buf, 5,
374                         ptr, length, hmac_out, mac_size);
375         if (hmac_out_len != mac_size
376                         || memcmp(message_hmac, hmac_out, mac_size)) {
377                 DEBUG_PRINTLN("INTEGRITY CHECK FAILED (msg length %i)\n",
378                                 length);
379                 DEBUG_DUMP_HEX_LABEL("HMAC RECEIVED", message_hmac, mac_size);
380                 DEBUG_DUMP_HEX_LABEL("HMAC COMPUTED", hmac_out, hmac_out_len);
381
382                 tls_alert(context, 1, bad_record_mac);
383
384                 LEAVE;
385                 return TLS_INTEGRITY_FAILED;
386         }
387
388         LEAVE;
389         return 0;
390 }
391
392 static int decrypt_payload(struct TLSContext *context,
393                 uint16_t length,
394                 int header_size,
395                 const unsigned char *ptr,
396                 const unsigned char *buf,
397                 int buf_len,
398                 struct tls_buffer *pt_buffer
399                 ) {
400
401         if (context->crypto.created == 2) {
402                 return decrypt_aes_gcm(context, length, header_size, ptr, buf,
403                                 buf_len, pt_buffer);
404         } else if (context->crypto.created == 3) {
405                 return decrypt_chacha_poly1305(context, length, header_size,
406                                 ptr, buf, buf_len, pt_buffer);
407         } else if (context->crypto.created == 1) {
408                 return decrypt_other(context, length, header_size, ptr, buf,
409                                 pt_buffer);
410         } else {
411                 tls_buffer_free(pt_buffer);
412                 return TLS_BROKEN_PACKET;
413         }
414
415         return 0;
416 }
417
418 int tls_parse_message(struct TLSContext *context, unsigned char *buf,
419                       int buf_len) {
420         uint8_t type = *buf;
421         uint16_t version; /* a struct of two uint8 per the rfc, but
422                              we encode it as a uint16_t */
423         uint16_t length;
424         int buf_pos = 0;
425         int res = 5;
426         int payload_res = 0;
427         const unsigned char *ptr = 0;
428         /* TODO probably make this buffer part of the context,
429          * and just zero it out instead of malloc and free
430          */
431         struct tls_buffer pt;
432
433         int header_size = res;
434
435         ENTER;
436
437         if (buf_len < res) {
438                 LEAVE;
439                 return TLS_NEED_MORE_DATA;
440         }
441
442         type = *buf;
443         buf_pos += 1;
444         version = get16(&buf[buf_pos]);
445         buf_pos += 2;
446
447         if (!tls_supported_version(version)) {
448                 LEAVE;
449                 return TLS_NOT_SAFE;
450         }
451
452         length = get16(&buf[buf_pos]);
453         buf_pos += 2;
454
455         ptr = buf + buf_pos;
456
457         DEBUG_PRINTLN("Message type: %d, length: %d\n", (int)type, (int)length);
458
459         /* this buffer can go out of scope */
460         tls_buffer_init(&pt, 0);
461
462         if (context->cipher_spec_set && type != TLS_CHANGE_CIPHER) {
463                 /* Need to decrypt payload */
464                 DEBUG_DUMP_HEX_LABEL("encrypted", &buf[header_size], length);
465
466                 if (!context->crypto.created) {
467                         DEBUG_PRINT("Encryption context not created\n");
468                         random_sleep(TLS_MAX_ERROR_SLEEP_uS);
469                         LEAVE;
470                         return TLS_BROKEN_PACKET;
471                 }
472
473                 int rv = decrypt_payload(context, length, header_size, ptr,
474                                 buf, buf_len, &pt);
475
476                 if (rv != 0) {
477                         tls_buffer_free(&pt);
478                         random_sleep(TLS_MAX_ERROR_SLEEP_uS);
479                         LEAVE;
480                         return rv;
481                 }
482
483                 if (pt.error) {
484                         tls_buffer_free(&pt);
485                         random_sleep(TLS_MAX_ERROR_SLEEP_uS);
486                         LEAVE;
487                         return TLS_NO_MEMORY;
488                 }
489
490                 ptr = pt.buffer;
491                 length = pt.len;
492         }
493
494         context->remote_sequence_number++;
495
496         if (context->tlsver == TLS_VERSION13) {
497                 /*(context->connection_status == 2) && */
498                 if (type == TLS_APPLICATION_DATA && context->crypto.created) {
499                         do {
500                                 length--;
501                                 type = ptr[length];
502                         } while (!type);
503                 }
504         }
505
506         /* TODO for v1.3 encrypted handshake messages will show up
507          * as application data, so we need to re-compute the record
508          * type, that may be what the above is doing
509          */
510
511         switch (type) {
512                 /* application data */
513                 case TLS_APPLICATION_DATA:
514                         if (context->connection_status != TLS_CONNECTED) {
515                                 DEBUG_PRINT
516                                         ("UNEXPECTED APPLICATION DATA MESSAGE\n");
517                                 payload_res = TLS_UNEXPECTED_MESSAGE;
518                                 tls_alert(context, 1, unexpected_message);
519                                 break;
520                         }
521                         DEBUG_PRINT
522                                 ("APPLICATION DATA MESSAGE (TLS VERSION: %x):\n",
523                                  (int) context->version);
524                         DEBUG_DUMP(ptr, length);
525                         DEBUG_PRINT("\n");
526                         tls_buffer_append(&context->application_buffer, ptr, length);
527                         if (context->application_buffer.error) {
528                                 payload_res =  TLS_NO_MEMORY;
529                         }
530                         break;
531                         /* handshake */
532                 case TLS_HANDSHAKE:
533                         DEBUG_PRINT("HANDSHAKE MESSAGE\n");
534                         payload_res = tls_parse_payload(context, ptr, length);
535                         break;
536                         /* change cipher spec */
537                 case TLS_CHANGE_CIPHER:
538                         if (context->connection_status != 2) {
539                                 if (context->connection_status == 4) {
540                                         DEBUG_PRINT
541                                                 ("IGNORING CHANGE CIPHER SPEC MESSAGE (HELLO RETRY REQUEST)\n");
542                                         break;
543                                 }
544                                 DEBUG_PRINT
545                                         ("UNEXPECTED CHANGE CIPHER SPEC MESSAGE (%i)\n",
546                                          context->connection_status);
547                                 tls_alert(context, 1, unexpected_message);
548                                 payload_res = TLS_UNEXPECTED_MESSAGE;
549                         } else {
550                                 DEBUG_PRINT("CHANGE CIPHER SPEC MESSAGE\n");
551                                 context->cipher_spec_set = 1;
552                                 /* reset sequence numbers */
553                                 context->remote_sequence_number = 0;
554                         }
555                         break;
556                         /* alert */
557                 case TLS_ALERT:
558                         DEBUG_PRINTLN("ALERT MESSAGE\n");
559                         if (length >= 2) {
560                                 int level = ptr[0];
561                                 int code = ptr[1];
562                                 DEBUG_PRINTLN("level = %d, code = %d\n", level, code);
563                                 if (level == TLS_ALERT_CRITICAL) {
564                                         DEBUG_PRINTLN("critical error: %s\n",
565                                                         tls_alert_msg_name(code));
566                                         context->critical_error = 1;
567                                         res = TLS_ERROR_ALERT;
568                                 }
569                                 context->error_code = code;
570                         } else {
571                                 DEBUG_PRINT("ALERT MESSAGE short\n");
572                         }
573
574                         break;
575                 default:
576                         DEBUG_PRINT("UNKNOWN MESSAGE TYPE: %x\n", (int)type);
577                         payload_res = TLS_NOT_UNDERSTOOD;
578                         break;
579         }
580
581         tls_buffer_free(&pt);
582
583         if (payload_res < 0) {
584                 LEAVE;
585                 return payload_res;
586         }
587
588         if (res > 0) {
589                 LEAVE;
590                 return header_size + length;
591         }
592
593         LEAVE;
594         return res;
595 }