]> pd.if.org Git - zpackage/blob - crypto/packet_update.c
ocb mode fixup
[zpackage] / crypto / packet_update.c
1 #define _POSIX_C_SOURCE 200809L
2
3 #include <arpa/inet.h>
4
5 #include "tlse.h"
6
7 #ifndef htonll
8 #define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
9 #endif
10
11 #ifndef ntohll
12 #define ntohll(x) ((1==ntohl(1)) ? (x) : ((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32))
13 #endif
14
15 static int crypto_encrypt(struct TLSContext *context, unsigned char *buf,
16                 unsigned char *ct, unsigned int len) {
17         if (context->crypto.created == 1) {
18                 return cbc_encrypt(buf, ct, len,
19                                    &context->crypto.ctx_local.aes_local);
20         }
21
22         memset(ct, 0, len);
23         return TLS_GENERIC_ERROR;
24 }
25
26 static int packet_encrypt(struct TLSPacket *packet) {
27         int header_size = 5;
28         int block_size = TLS_AES_BLOCK_SIZE;
29         int mac_size = 0;
30         unsigned int length = 0;
31         unsigned char padding = 0;
32         //unsigned int pt_length = packet->len - header_size;
33         struct TLSContext *context = packet->context;
34
35         mac_size = tls_mac_length(packet->context);
36         length = packet->len - header_size + TLS_AES_IV_LENGTH + mac_size;
37
38         padding = block_size - length % block_size;
39         length += padding;
40         unsigned char *buf = malloc(length);
41         if (buf) {
42                 unsigned char *ct = malloc(length + header_size);
43                 if (ct) {
44                         unsigned int buf_pos = 0;
45                         memcpy(ct, packet->buf, header_size - 2);
46                         *(unsigned short *)&ct[header_size - 2] = htons(length);
47                         tls_random(buf, TLS_AES_IV_LENGTH);
48                         buf_pos += TLS_AES_IV_LENGTH;
49                         /* copy payload */
50                         memcpy(buf + buf_pos, packet-> buf + header_size, packet->len - header_size);
51                         buf_pos += packet->len - header_size;
52
53                         tls_hmac_message(1, context,
54                                         packet->buf,
55                                         packet->len, NULL, 0,
56                                         buf + buf_pos,
57                                         mac_size);
58                         buf_pos += mac_size;
59
60                         memset(buf + buf_pos, padding - 1, padding);
61                         buf_pos += padding;
62
63                         crypto_encrypt(context, buf, ct + header_size, length);
64                         free(packet->buf);
65                         packet->buf = ct;
66                         packet->len = length + header_size;
67                         packet->size = packet->len;
68                 } else {
69                         /* invalidate packet */
70                         memset (packet->buf, 0, packet-> len);
71                 }
72                 free(buf);
73         } else {
74                 /* invalidate packet */
75                 memset(packet->buf, 0, packet->len);
76         }
77
78         return 1;
79 }
80
81 static void put16(unsigned char *at, uint16_t val) {
82         at[0] = val >> 8;
83         at[1] = val & 0xff;
84 }
85
86 void tls_packet_update(struct TLSPacket *packet) {
87         struct TLSContext *context;
88         unsigned int header_size = 5;
89         int footer_size = 0;
90         uint16_t real_packet_len;
91         uint64_t lsn;
92         int mac_size = 0;
93         unsigned int length = 0;
94         int context_is_v13 = 0;
95         struct tls_buffer ciphertext;
96
97         if (!packet || packet->broken) {
98                 return;
99         }
100
101         if (packet->context && packet->context->tlsver == TLS_VERSION13
102                         && packet->context->cipher_spec_set
103                         && packet->context->crypto.created) {
104                 /* type */
105                 tls_packet_uint8(packet, packet->buf[0]);
106                 /* no padding
107                  * tls_packet_uint8(packet, 0);
108                  */
109                 footer_size = 1;
110         }
111
112         real_packet_len = packet->len - header_size;
113
114         put16(packet->buf + 3, real_packet_len);
115
116         if (!packet->context) {
117                 return;
118         }
119
120         context = packet->context;
121         lsn = context->local_sequence_number;
122
123         if (context->tlsver == TLS_VERSION13) {
124                 context_is_v13 = 1;
125         }
126
127         if (packet->buf[0] == TLS_CHANGE_CIPHER) {
128                 context->local_sequence_number++;
129                 return;
130         }
131
132         /* If this is a handshake message, update the handshake hash */
133         if (packet->buf[0] == TLS_HANDSHAKE && packet->len > header_size) {
134                 unsigned char handshake_type = packet->buf[header_size];
135                 if (handshake_type != 0x00 && handshake_type != 0x03) {
136                         tls_update_hash(context, packet->buf + header_size,
137                                         real_packet_len -
138                                         footer_size);
139                 }
140         }
141
142         if (!context->cipher_spec_set || !context->crypto.created) {
143                 context->local_sequence_number++;
144                 return;
145         }
146
147         unsigned int pt_length = real_packet_len;
148
149         if (context->crypto.created == 1) {
150         } else if (context->crypto.created == 3) {
151                 mac_size = POLY1305_TAGLEN;
152                 length = real_packet_len + mac_size;
153         } else {
154                 mac_size = TLS_GCM_TAG_LEN;
155                 length = real_packet_len + 8 + mac_size;
156         }
157
158         if (context->crypto.created == 1) {
159                 packet_encrypt(packet); 
160                 context->local_sequence_number++;
161                 return;
162         }
163
164         if (context->crypto.created < 1) {
165                 /* invalidate packet */
166                 memset(packet->buf, 0, packet->len);
167                 context->local_sequence_number++;
168                 return;
169         }
170
171         /* + 1 = type */
172         int ct_size = length + header_size + 12 + TLS_MAX_TAG_LEN + 1;
173         tls_buffer_init(&ciphertext, ct_size);
174
175         if (ciphertext.error) {
176                 /* invalidate packet */
177                 memset(packet->buf, 0, packet->len);
178                 context->local_sequence_number++;
179                 return;
180         }
181
182         /* AEAD */
183         /* sequence number (8 bytes) */
184         /* content type (1 byte) */
185         /* version (2 bytes) */
186         /* length (2 bytes) */
187         unsigned char aad[13];
188         int aad_size = sizeof(aad);
189         unsigned char *sequence = aad;
190         if (context_is_v13) {
191                 aad[0] = TLS_APPLICATION_DATA;
192                 aad[1] = packet->buf[1];
193                 aad[2] = packet->buf[2];
194                 if (packet->context->crypto.created == 3) {
195                         put16(aad+3, real_packet_len + POLY1305_TAGLEN);
196                 } else {
197                         put16(aad+3, real_packet_len + TLS_GCM_TAG_LEN);
198                 }
199                 aad_size = 5;
200                 sequence = aad + 5;
201
202                 *((uint64_t *) (aad+5)) = htonll(lsn);
203
204         } else {
205                 *((uint64_t *) aad) = htonll(lsn);
206                 aad[8] = packet->buf[0];
207                 aad[9] = packet->buf[1];
208                 aad[10] = packet->buf[2];
209                 put16(aad+11, packet->len - header_size);
210         }
211
212         ciphertext.len = header_size;
213
214         if (context->crypto.created == 3) {
215                 int size;
216                 unsigned int counter = 1;
217                 unsigned char poly1305_key[POLY1305_KEYLEN];
218                 chacha_ivupdate(&context->crypto.ctx_local.chacha_local,
219                                 context->crypto.ctx_local_mac.local_aead_iv,
220                                 sequence, (uint8_t *) &counter);
221                 chacha20_poly1305_key(&context->crypto.ctx_local.chacha_local,
222                                 poly1305_key);
223                 size =
224                         chacha20_poly1305_aead(&context->crypto.ctx_local.chacha_local,
225                                         packet->buf + header_size, pt_length,
226                                         aad, aad_size, poly1305_key,
227                                         ciphertext.buffer + ciphertext.len);
228                 ciphertext.len += size;
229         } else {
230                 unsigned char iv [TLS_13_AES_GCM_IV_LENGTH];
231                 if (context_is_v13) {
232                         memcpy(iv, context->crypto.ctx_local_mac.local_iv,
233                                         TLS_13_AES_GCM_IV_LENGTH);
234                         int i;
235                         int offset = TLS_13_AES_GCM_IV_LENGTH - 8;
236                         for (i = 0; i < 8; i++)
237                                 iv[offset + i] = context->crypto.ctx_local_mac.local_iv[offset + i] ^ sequence[i];
238                 } else {
239                         memcpy(iv, context->crypto.ctx_local_mac.local_aead_iv,
240                                         TLS_AES_GCM_IV_LENGTH);
241                         tls_random(iv + TLS_AES_GCM_IV_LENGTH, 8);
242                         tls_buffer_append(&ciphertext,
243                                         iv+TLS_AES_GCM_IV_LENGTH, 8);
244                 }
245
246                 gcm_state *localgcm;
247                 localgcm = &context->crypto.ctx_local.aes_gcm_local;
248
249                 gcm_reset(localgcm);
250                 gcm_add_iv(localgcm, iv, 12);
251                 gcm_add_aad(localgcm, aad, aad_size);
252                 gcm_process(localgcm, packet->buf + header_size, pt_length,
253                                 ciphertext.buffer+ciphertext.len, GCM_ENCRYPT);
254                 ciphertext.len += pt_length;
255
256                 unsigned long taglen = TLS_GCM_TAG_LEN;
257                 gcm_done(localgcm, ciphertext.buffer+ciphertext.len, &taglen);
258                 ciphertext.len += taglen;
259         }
260
261         if (context_is_v13) {
262                 ciphertext.buffer[0] = TLS_APPLICATION_DATA;
263                 tls_buffer_write16(&ciphertext, TLS_V12, 1);
264         } else {
265                 memcpy(ciphertext.buffer, packet->buf, header_size - 2);
266         }
267
268         tls_buffer_write16(&ciphertext, ciphertext.len - header_size, header_size - 2);
269
270         free(packet->buf);
271         packet->buf = ciphertext.buffer;
272         packet->len = ciphertext.len;
273         packet->size = ciphertext.size;
274
275         context->local_sequence_number++;
276 }