]> pd.if.org Git - zpackage/blob - crypto/parse_message.c
e4e7b15ef95a61bd56c4a36169c229c05becba66
[zpackage] / crypto / parse_message.c
1 #define _POSIX_C_SOURCE 200809L
2
3 #include <fcntl.h>
4 #include <stdint.h>
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8 #include <strings.h>
9 #include <sys/mman.h>
10 #include <sys/stat.h>
11 #include <time.h>
12
13 #include <sys/socket.h>
14 #include <arpa/inet.h>
15 #include <unistd.h>
16
17 #include <errno.h>
18
19 #include <tomcrypt.h>
20
21 #include "buffer.h"
22 #include "tlse.h"
23
24 #ifndef htonll
25 #define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
26 #endif
27
28 #ifndef ntohll
29 #define ntohll(x) ((1==ntohl(1)) ? (x) : ((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32))
30 #endif
31
32 static uint16_t get16(const unsigned char *buf) {
33         uint16_t res;
34
35         res = ((*buf) << 8) + (*(buf+1));
36         return res;
37 }
38
39 static int tls_microsleep(unsigned int microseconds) {
40         struct timespec ts;
41         struct timespec rem;
42         int rv;
43
44         ts.tv_sec = (time_t) (microseconds / 1000000L);
45         ts.tv_nsec = (long) (((long) microseconds % 1000000L) * 1000L);
46
47         errno = 0;
48         while ((rv = nanosleep(&ts, &rem)) == -1) {
49                 if (errno != EINTR) {
50                         return rv;
51                 }
52                 ts = rem;
53         }
54         return rv;
55 }
56
57 /* TODO this is biased unless limit is a power of 2 */
58 static unsigned int random_int(int limit) {
59         unsigned int res = 0;
60         tls_random((unsigned char *) &res, sizeof(int));
61         if (limit) {
62                 res %= limit;
63         }
64         return res;
65 }
66
67 static int random_sleep(long max) {
68         return tls_microsleep(random_int(max));
69 }
70
71 static void U32TO8(unsigned char *p, unsigned long v) {
72         p[0] = (v) & 0xff;
73         p[1] = (v >> 8) & 0xff;
74         p[2] = (v >> 16) & 0xff;
75         p[3] = (v >> 24) & 0xff;
76 }
77
78 static int decrypt_aes_gcm(struct TLSContext *context, uint16_t length, int
79                 header_size, const unsigned char *ptr, const unsigned char
80                 *buf, int buf_len, struct tls_buffer *pt_buffer) {
81         int delta = 8;
82         int pt_length;
83         unsigned char iv[TLS_13_AES_GCM_IV_LENGTH];
84         unsigned char aad[16];
85         int aad_size = sizeof(aad);
86         unsigned char *sequence = aad;
87
88         gcm_state *remote_gcm;
89         remote_gcm = &context->crypto.ctx_remote.aes_gcm_remote;
90
91         gcm_reset(remote_gcm);
92
93         /* set up the aad */
94         if (context->tlsver == TLS_VERSION13) {
95                 aad[0] = TLS_APPLICATION_DATA;
96                 aad[1] = 0x03;
97                 aad[2] = 0x03;
98                 *((unsigned short *) &aad[3]) =
99                         htons(buf_len - header_size);
100                 aad_size = 5;
101                 sequence = aad + 5;
102                 *((uint64_t *) sequence) =
103                         htonll(context->remote_sequence_number);
104                 memcpy(iv, context->crypto.ctx_remote_mac.remote_iv,
105                                 TLS_13_AES_GCM_IV_LENGTH);
106                 int i;
107                 int offset = TLS_13_AES_GCM_IV_LENGTH - 8;
108                 for (i = 0; i < 8; i++) {
109                         iv[offset + i] =
110                                 context->crypto.ctx_remote_mac.remote_iv[offset +
111                                 i] ^ sequence[i];
112                 }
113                 pt_length = buf_len - header_size - TLS_GCM_TAG_LEN;
114                 delta = 0;
115         } else {
116                 aad_size = 13;
117                 pt_length = length - 8 - TLS_GCM_TAG_LEN;
118
119                 /* build aad and iv */
120                 *((uint64_t *) aad) = htonll(context->remote_sequence_number);
121                 aad[8] = buf[0];
122                 aad[9] = buf[1];
123                 aad[10] = buf[2];
124
125                 memcpy(iv, context->crypto.ctx_remote_mac.remote_aead_iv, 4);
126                 memcpy(iv + 4, buf + header_size, 8);
127                 *((unsigned short *) &aad[11]) = htons(pt_length);
128         }
129
130         if (pt_length < 0) {
131                 DEBUG_PRINT("Invalid packet length");
132                 return TLS_BROKEN_PACKET;
133         }
134         DEBUG_DUMP_HEX_LABEL("aad", aad, aad_size);
135         DEBUG_DUMP_HEX_LABEL("aad iv", iv, 12);
136
137         /* I think we do all of this even if they fail to avoid timing
138          * attacks
139          */
140         int res0 = gcm_add_iv(&context->crypto.ctx_remote.aes_gcm_remote, iv, 12);
141         int res1 = gcm_add_aad(&context->crypto.ctx_remote.aes_gcm_remote, aad, aad_size);
142
143         DEBUG_PRINT("PT SIZE: %i\n", pt_length);
144
145         /* TODO we might want to expand this buffer before we call this
146          * function */
147         tls_buffer_expand(pt_buffer, pt_length);
148
149         int res2 = gcm_process(&context->crypto.ctx_remote.
150                         aes_gcm_remote, pt_buffer->buffer,
151                         pt_length,
152                         (char *)buf + header_size + delta,
153                         GCM_DECRYPT);
154         pt_buffer->len = pt_length;
155
156         unsigned char tag[32];
157         unsigned long taglen = 32;
158         int res3 = gcm_done(&context->crypto.ctx_remote.aes_gcm_remote,
159                         tag, &taglen);
160
161         if (res0 || res1 || res2 || res3 || taglen != TLS_GCM_TAG_LEN) {
162                 DEBUG_PRINT
163                         ("ERROR: gcm_add_iv: %i, gcm_add_aad: %i, gcm_process: %i, gcm_done: %i\n",
164                          res0, res1, res2, res3);
165                 return TLS_BROKEN_PACKET;
166         }
167
168         DEBUG_DUMP_HEX_LABEL("decrypted1", pt_buffer->buffer, pt_length);
169         DEBUG_DUMP_HEX_LABEL("tag", tag, taglen);
170
171         /* check tag */
172         if (memcmp(buf + header_size + delta + pt_length, tag, taglen)) {
173                 DEBUG_PRINT("INTEGRITY CHECK FAILED (msg length %i)\n",
174                                 pt_length);
175                 DEBUG_DUMP_HEX_LABEL("TAG RECEIVED",
176                                 buf + header_size + delta + pt_length,
177                                 taglen);
178                 DEBUG_DUMP_HEX_LABEL("TAG COMPUTED", tag, taglen);
179                 tls_alert(context, 1, bad_record_mac);
180                 return TLS_INTEGRITY_FAILED;
181         }
182         ptr = pt_buffer->buffer;
183         length = pt_length;
184
185         return 0;
186 }
187
188 static int decrypt_chacha_poly1305(struct TLSContext *context, uint16_t length,
189                 int header_size, const unsigned char *ptr, const unsigned char
190                 *buf, int buf_len, struct tls_buffer *pt_buffer) {
191
192         unsigned char aad[16];
193         int aad_size = sizeof(aad);
194         unsigned char *sequence = aad;
195
196         int pt_length = length - POLY1305_TAGLEN;
197         unsigned int counter = 1;
198         unsigned char poly1305_key[POLY1305_KEYLEN];
199         unsigned char trail[16];
200         unsigned char mac_tag[POLY1305_TAGLEN];
201
202         aad_size = 16;
203         if (pt_length < 0) {
204                 DEBUG_PRINT("Invalid packet length");
205                 return TLS_BROKEN_PACKET;
206         }
207
208         /* set up add */
209         if (context->tlsver == TLS_VERSION13) {
210                 aad[0] = TLS_APPLICATION_DATA;
211                 aad[1] = 0x03;
212                 aad[2] = 0x03;
213                 *((unsigned short *) &aad[3]) =
214                         htons(buf_len - header_size);
215                 aad_size = 5;
216                 sequence = aad + 5;
217                 *((uint64_t *) sequence) =
218                         htonll(context->remote_sequence_number);
219         } else {
220                 *((uint64_t *) aad) =
221                         htonll(context->remote_sequence_number);
222                 aad[8] = buf[0];
223                 aad[9] = buf[1];
224                 aad[10] = buf[2];
225                 *((unsigned short *) &aad[11]) = htons(pt_length);
226                 aad[13] = 0;
227                 aad[14] = 0;
228                 aad[15] = 0;
229         }
230
231         tls_buffer_expand(pt_buffer, pt_length);
232         chacha_ivupdate(&context->crypto.ctx_remote.chacha_remote,
233                         context->crypto.ctx_remote_mac.remote_aead_iv,
234                         sequence, (unsigned char *) &counter);
235
236         chacha_encrypt_bytes(&context->crypto.ctx_remote.chacha_remote,
237                         buf + header_size, pt_buffer->buffer, pt_length);
238
239         DEBUG_DUMP_HEX_LABEL("decrypted2", pt_buffer->buffer, pt_length);
240         ptr = pt_buffer->buffer;
241         length = pt_length;
242
243         chacha20_poly1305_key(&context->crypto.ctx_remote.chacha_remote,
244                         poly1305_key);
245
246         struct poly1305_context ctx;
247         tls_poly1305_init(&ctx, poly1305_key);
248         tls_poly1305_update(&ctx, aad, aad_size);
249
250         static unsigned char zeropad[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0,
251                 0, 0, 0, 0, 0, 0 };
252
253         int rem = aad_size % 16;
254         if (rem) {
255                 tls_poly1305_update(&ctx, zeropad, 16 - rem);
256         }
257         tls_poly1305_update(&ctx, buf + header_size, pt_length);
258
259         rem = pt_length % 16;
260         if (rem) {
261                 tls_poly1305_update(&ctx, zeropad, 16 - rem);
262         }
263
264         U32TO8(&trail[0], aad_size == 5 ? 5 : 13);
265         *(int *) &trail[4] = 0;
266         U32TO8(&trail[8], pt_length);
267         *(int *) &trail[12] = 0;
268
269         tls_poly1305_update(&ctx, trail, 16);
270         tls_poly1305_finish(&ctx, mac_tag);
271         if (memcmp(mac_tag, buf + header_size + pt_length,
272                                 POLY1305_TAGLEN)) {
273                 DEBUG_PRINT
274                         ("INTEGRITY CHECK FAILED (msg length %i)\n",
275                          length);
276                 DEBUG_DUMP_HEX_LABEL("POLY1305 TAG RECEIVED",
277                                 buf + header_size + pt_length,
278                                 POLY1305_TAGLEN);
279                 DEBUG_DUMP_HEX_LABEL("POLY1305 TAG COMPUTED",
280                                 mac_tag, POLY1305_TAGLEN);
281
282                 tls_alert(context, 1, bad_record_mac);
283                 return TLS_INTEGRITY_FAILED;
284         }
285
286         pt_buffer->len = pt_length;
287
288         return 0;
289 }
290
291 static int decrypt_other(struct TLSContext *context, uint16_t length, int
292                 header_size, const unsigned char *ptr, const unsigned char
293                 *buf, struct tls_buffer *pt_buffer) {
294         int err;
295
296         err = cbc_decrypt(buf+header_size, pt_buffer->buffer, length,
297                         &context->crypto.ctx_remote.aes_remote);
298         if (err) {
299                 DEBUG_PRINT("Decryption error %i\n", (int) err);
300                 return TLS_BROKEN_PACKET;
301         }
302         unsigned char padding_byte = pt_buffer->buffer[length - 1];
303         unsigned char padding = padding_byte + 1;
304
305         /* poodle check */
306         int padding_index = length - padding;
307         if (padding_index > 0) {
308                 int i;
309                 int limit = length - 1;
310                 for (i = length - padding; i < limit; i++) {
311                         if (pt_buffer->buffer[i] != padding_byte) {
312                                 DEBUG_PRINT
313                                         ("BROKEN PACKET (POODLE ?)\n");
314                                 tls_alert(context, 1, decrypt_error);
315                                 return TLS_BROKEN_PACKET;
316                         }
317                 }
318         }
319
320         unsigned int decrypted_length = length;
321         if (padding < decrypted_length) {
322                 decrypted_length -= padding;
323                 pt_buffer->len -= padding;
324         }
325
326         DEBUG_DUMP_HEX_LABEL("decrypted3", pt_buffer->buffer, decrypted_length);
327         ptr = pt_buffer->buffer;
328
329         if (decrypted_length > TLS_AES_IV_LENGTH) {
330                 decrypted_length -= TLS_AES_IV_LENGTH;
331                 ptr += TLS_AES_IV_LENGTH;
332                 tls_buffer_shift(pt_buffer, TLS_AES_IV_LENGTH);
333         }
334         length = decrypted_length;
335
336         unsigned int mac_size = tls_mac_length(context);
337         if (length < mac_size || !mac_size) {
338                 DEBUG_PRINT("BROKEN PACKET\n");
339                 tls_alert(context, 1, decrypt_error);
340                 return TLS_BROKEN_PACKET;
341         }
342
343         length -= mac_size;
344         pt_buffer->len -= mac_size;
345
346         const unsigned char *message_hmac = &ptr[length];
347         unsigned char hmac_out[TLS_MAX_MAC_SIZE];
348         unsigned char temp_buf[5];
349         memcpy(temp_buf, buf, 3);
350         *(unsigned short *) &temp_buf[3] = htons(length);
351         unsigned int hmac_out_len = tls_hmac_message(0, context, temp_buf, 5,
352                         ptr, length, hmac_out, mac_size);
353         if (hmac_out_len != mac_size
354                         || memcmp(message_hmac, hmac_out, mac_size)) {
355                 DEBUG_PRINT("INTEGRITY CHECK FAILED (msg length %i)\n",
356                                 length);
357                 DEBUG_DUMP_HEX_LABEL("HMAC RECEIVED", message_hmac, mac_size);
358                 DEBUG_DUMP_HEX_LABEL("HMAC COMPUTED", hmac_out, hmac_out_len);
359
360                 tls_alert(context, 1, bad_record_mac);
361
362                 return TLS_INTEGRITY_FAILED;
363         }
364
365         return 0;
366 }
367
368 static int decrypt_payload(struct TLSContext *context,
369                 uint16_t length,
370                 int header_size,
371                 const unsigned char *ptr,
372                 const unsigned char *buf,
373                 int buf_len,
374                 struct tls_buffer *pt_buffer
375                 ) {
376
377         if (context->crypto.created == 2) {
378                 return decrypt_aes_gcm(context, length, header_size, ptr, buf,
379                                 buf_len, pt_buffer);
380         } else if (context->crypto.created == 3) {
381                 return decrypt_chacha_poly1305(context, length, header_size,
382                                 ptr, buf, buf_len, pt_buffer);
383         } else if (context->crypto.created == 1) {
384                 return decrypt_other(context, length, header_size, ptr, buf,
385                                 pt_buffer);
386         } else {
387                 tls_buffer_free(pt_buffer);
388                 return TLS_BROKEN_PACKET;
389         }
390
391         return 0;
392 }
393
394 int tls_parse_message(struct TLSContext *context, unsigned char *buf,
395                       int buf_len) {
396         uint8_t type = *buf;
397         uint16_t version; /* a struct of two uint8 per the rfc, but
398                              we encode it as a uint16_t */
399         uint16_t length;
400         int buf_pos = 0;
401         int res = 5;
402         int payload_res = 0;
403         const unsigned char *ptr = 0;
404         /* TODO probably make this buffer part of the context,
405          * and just zero it out instead of malloc and free
406          */
407         struct tls_buffer pt;
408
409         int header_size = res;
410
411         if (buf_len < res) {
412                 return TLS_NEED_MORE_DATA;
413         }
414
415         type = *buf;
416         buf_pos += 1;
417         version = get16(&buf[buf_pos]);
418         buf_pos += 2;
419
420         if (!tls_supported_version(version)) {
421                 return TLS_NOT_SAFE;
422         }
423
424         length = get16(&buf[buf_pos]);
425         buf_pos += 2;
426
427         ptr = buf + buf_pos;
428
429         DEBUG_PRINT("Message type: %0x, length: %i\n", (int)type, (int)length);
430
431         /* this buffer can go out of scope */
432         tls_buffer_init(&pt, 0);
433
434         if (context->cipher_spec_set && type != TLS_CHANGE_CIPHER) {
435                 /* Need to decrypt payload */
436                 DEBUG_DUMP_HEX_LABEL("encrypted", &buf[header_size], length);
437
438                 if (!context->crypto.created) {
439                         DEBUG_PRINT("Encryption context not created\n");
440                         random_sleep(TLS_MAX_ERROR_SLEEP_uS);
441                         return TLS_BROKEN_PACKET;
442                 }
443
444                 int rv = decrypt_payload(context, length, header_size, ptr,
445                                 buf, buf_len, &pt);
446
447                 if (rv != 0) {
448                         tls_buffer_free(&pt);
449                         random_sleep(TLS_MAX_ERROR_SLEEP_uS);
450                         return rv;
451                 }
452
453                 if (pt.error) {
454                         tls_buffer_free(&pt);
455                         random_sleep(TLS_MAX_ERROR_SLEEP_uS);
456                         return TLS_NO_MEMORY;
457                 }
458
459                 ptr = pt.buffer;
460                 length = pt.len;
461         }
462
463         context->remote_sequence_number++;
464
465         if (context->tlsver == TLS_VERSION13) {
466                 /*(context->connection_status == 2) && */
467                 if (type == TLS_APPLICATION_DATA && context->crypto.created) {
468                         do {
469                                 length--;
470                                 type = ptr[length];
471                         } while (!type);
472                 }
473         }
474
475         /* TODO for v1.3 encrypted handshake messages will show up
476          * as application data, so we need to re-compute the record
477          * type, that may be what the above is doing
478          */
479
480         switch (type) {
481                 /* application data */
482                 case TLS_APPLICATION_DATA:
483                         if (context->connection_status != TLS_CONNECTED) {
484                                 DEBUG_PRINT
485                                         ("UNEXPECTED APPLICATION DATA MESSAGE\n");
486                                 payload_res = TLS_UNEXPECTED_MESSAGE;
487                                 tls_alert(context, 1, unexpected_message);
488                                 break;
489                         }
490                         DEBUG_PRINT
491                                 ("APPLICATION DATA MESSAGE (TLS VERSION: %x):\n",
492                                  (int) context->version);
493                         DEBUG_DUMP(ptr, length);
494                         DEBUG_PRINT("\n");
495                         tls_buffer_append(&context->application_buffer, ptr, length);
496                         if (context->application_buffer.error) {
497                                 payload_res =  TLS_NO_MEMORY;
498                         }
499                         break;
500                         /* handshake */
501                 case TLS_HANDSHAKE:
502                         DEBUG_PRINT("HANDSHAKE MESSAGE\n");
503                         payload_res = tls_parse_payload(context, ptr, length);
504                         break;
505                         /* change cipher spec */
506                 case TLS_CHANGE_CIPHER:
507                         if (context->connection_status != 2) {
508                                 if (context->connection_status == 4) {
509                                         DEBUG_PRINT
510                                                 ("IGNORING CHANGE CIPHER SPEC MESSAGE (HELLO RETRY REQUEST)\n");
511                                         break;
512                                 }
513                                 DEBUG_PRINT
514                                         ("UNEXPECTED CHANGE CIPHER SPEC MESSAGE (%i)\n",
515                                          context->connection_status);
516                                 tls_alert(context, 1, unexpected_message);
517                                 payload_res = TLS_UNEXPECTED_MESSAGE;
518                         } else {
519                                 DEBUG_PRINT("CHANGE CIPHER SPEC MESSAGE\n");
520                                 context->cipher_spec_set = 1;
521                                 /* reset sequence numbers */
522                                 context->remote_sequence_number = 0;
523                         }
524                         break;
525                         /* alert */
526                 case TLS_ALERT:
527                         DEBUG_PRINT("ALERT MESSAGE\n");
528                         if (length >= 2) {
529                                 DEBUG_PRINT("ALERT MESSAGE ...\n");
530                                 DEBUG_DUMP_HEX(ptr, length);
531                                 int level = ptr[0];
532                                 int code = ptr[1];
533                                 DEBUG_PRINT("level = %d, code = %d\n",
534                                                 level, code);
535                                 if (level == TLS_ALERT_CRITICAL) {
536                                         context->critical_error = 1;
537                                         res = TLS_ERROR_ALERT;
538                                 }
539                                 context->error_code = code;
540                         } else {
541                                 DEBUG_PRINT("ALERT MESSAGE short\n");
542                         }
543
544                         break;
545                 default:
546                         DEBUG_PRINT("UNKNOWN MESSAGE TYPE: %x\n", (int)type);
547                         payload_res =  TLS_NOT_UNDERSTOOD;
548                         break;
549         }
550
551         tls_buffer_free(&pt);
552
553         if (payload_res < 0) {
554                 return payload_res;
555         }
556
557         if (res > 0) {
558                 return header_size + length;
559         }
560
561         return res;
562 }