]> pd.if.org Git - zpackage/blob - zpm-fetchurl.c
implement trust on first use
[zpackage] / zpm-fetchurl.c
1 #define _POSIX_C_SOURCE 200809L
2
3 #include <stdio.h>
4 #include <sys/types.h>
5 #include <sys/socket.h>
6 #include <netinet/in.h>
7 #include <netdb.h>
8 #include <string.h>
9 #include <strings.h>
10 #include <stdlib.h>
11 #include <unistd.h>
12 #include <signal.h>
13 #include <time.h>
14 #include <sys/stat.h>
15 #include <fcntl.h>
16
17 #include "tlse.h"
18
19 struct tls_uri {
20         char *scheme;
21         char *userinfo;
22         char *host;
23         char *port;
24         char *path;
25         char *query;
26         char *fragment;
27 };
28 int tls_parse_uri(char *, struct tls_uri *);
29 void tls_free_uri(struct tls_uri *);
30
31 int open_tcp_connection(char *host, int port);
32
33 /* if trustpolicy is 0, we just accept anything */
34 int verify_trust(struct TLSContext *context, struct TLSCertificate **chain, int
35                 len) {
36         /* suppress unused */
37         if (context || chain || len) {
38                 return 0;
39         }
40
41         return 0;
42 }
43
44 static char hexchars[] = "0123456789abcdefABCDEF";
45
46 static void hex(char *dst, uint8_t *src, size_t len) {
47         while (len--) {
48                 dst[0] = hexchars[(src[0]>>4)&0xf];
49                 dst[1] = hexchars[src[0]&0xf];
50                 dst+=2;
51                 src++;
52         }
53 }
54
55 #if 0
56 static void hexbin(uint8_t *dst, unsigned char *src, size_t len) {
57         size_t i;
58         int x;
59
60         for (i=0; i<len; i+=2) {
61                 sscanf((const char *)src+i, "%02x", &x);
62                 dst[i/2] = x;
63         }
64 }
65 #endif
66
67 char *my_getline(struct tls_buffer *b, int fd, size_t *size) {
68         char *loc = 0;
69         char buf[4096];
70         ssize_t rv = 0;
71
72         while (b->error == 0) {
73                 loc = memchr(b->buffer, '\n', b->len);
74                 if (loc) {
75                         *size = loc - b->buffer + 1;
76                         return b->buffer;
77                 } else {
78                         rv = read(fd, buf, sizeof buf);
79                         if (rv == -1) {
80                                 return 0;
81                         }
82                         if (rv == 0) {
83                                 break;
84                         }
85                         tls_buffer_append(b, buf, rv);
86                 }
87         }
88         if (rv == 0) {
89                 *size = b->len;
90                 return b->buffer;
91         }
92         return 0;
93 }
94
95 /*
96  * We use a trust on first use policy.  The trust DB is a simple
97  * file in /var/lib/zpm/known_hosts, or ~/.zpm/known_hosts, or ZPM_KNOWNHOSTS.
98  * if -k is given, no verification is done
99  */
100 int verify_first(struct TLSContext *context, struct TLSCertificate **chain, int
101                 certs) {
102         int err;
103         char *trustfile, *homedir = 0, *host, *fp;
104         unsigned char certhash[65];
105         int trustdb;
106         struct tls_buffer tbuf;
107
108         char *line = 0;
109         size_t len = 0;
110
111         if (certs == 0 || chain == 0) {
112                 return 1;
113         }
114
115         err = tls_certificate_is_valid(chain[0]);
116         if (err) {
117                 return err;
118         }
119
120         if (context->sni) {
121                 err = tls_certificate_valid_subject(chain[0], context->sni);
122                 if (err) {
123                         return err;
124                 }
125         }
126
127         hex(certhash, chain[0]->fp, 32);
128         certhash[64] = 0;
129
130         trustfile = getenv("ZPM_KNOWNHOSTS");
131         if (!trustfile) {
132                 if (geteuid() == 0) {
133                         trustfile = "/var/lib/zpm/known_hosts";
134                 } else {
135                         homedir = getenv("HOME");
136                         if (!homedir) {
137                                 return 1;
138                         }
139                         len = snprintf(homedir, 0, "%s/.zpm/known_hosts", homedir);
140                         homedir = malloc(len+1);
141                         if (!homedir) {
142                                 return 1;
143                         }
144                         len = snprintf(homedir, len+1, "%s/.zpm/known_hosts", homedir);
145                         trustfile = homedir;
146                 }
147         }
148         /* cert is valid on its face, so check against the trust db */
149         trustdb = open(trustfile, O_RDWR|O_CREAT, 0600);
150         if (homedir) {
151                 free(homedir);
152         }
153         if (trustdb == -1) {
154                 fprintf(stderr, "cannot open trustdb: %s\n", strerror(errno));
155                 return 1;
156         }
157
158         len = 0;
159         tls_buffer_init(&tbuf, 128);
160         do {
161                 char *off;
162
163                 tls_buffer_shift(&tbuf, len);
164                 line = my_getline(&tbuf, trustdb, &len);
165
166                 if (!line || !len) {
167                         break;
168                 }
169
170                 fp = line;
171                 while (isspace(*fp)) {
172                         fp++;
173                 }
174                 if (*fp == '#') {
175                         continue;
176                 }
177                 off = strchr(line, ':');
178                 if (!off) {
179                         continue;
180                 }
181                 host = off + 1;
182                 *off = 0;
183                 if (line[len-1] == '\n') {
184                         line[len-1] = 0;
185                 }
186
187                 if (strlen(fp) != 64) {
188                         continue;
189                 }
190
191                 if (len && line[len-1] == '\n') {
192                         line[len-1] = 0;
193                 }
194
195                 if (strcmp(context->sni, host) != 0) {
196                         continue;
197                 }
198
199                 int match = (memcmp(certhash, fp, 64) == 0); 
200
201                 close(trustdb);
202                 tls_buffer_free(&tbuf);
203                 return match ? no_error : bad_certificate;
204         } while (!tbuf.error);
205
206         /* got here, so we should be at EOF, so add this host to trust db */
207         lseek(trustdb, 0, SEEK_END);
208         write(trustdb, certhash, 64);
209         write(trustdb, ":", 1);
210         write(trustdb, context->sni, strlen(context->sni));
211         write(trustdb, "\n", 1);
212         close(trustdb);
213         tls_buffer_free(&tbuf);
214
215         return no_error;
216 }
217
218 int verify_roots(struct TLSContext *context, struct TLSCertificate **chain, int
219                 len) {
220         int i, err;
221
222         if (chain) {
223                 for (i = 0; i < len; i++) {
224                         struct TLSCertificate *certificate = chain[i];
225                         // check validity date
226                         err = tls_certificate_is_valid(certificate);
227                         if (err) {
228                                 return err;
229                         }
230                         // check certificate in certificate->bytes of length
231                         // certificate->len the certificate is in ASN.1 DER
232                         // format
233                 }
234         }
235         // check if chain is valid
236         err = tls_certificate_chain_is_valid(chain, len);
237         if (err) {
238                 return err;
239         }
240
241         if (len > 0 && context->sni) {
242                 err = tls_certificate_valid_subject(chain[0], context->sni);
243                 if (err) {
244                         return err;
245                 }
246         }
247
248         /* Perform certificate validation against ROOT CA */
249         err = tls_certificate_chain_is_valid_root(context, chain, len);
250         if (err) {
251                 return err;
252         }
253
254         return no_error;
255 }
256
257 struct io {
258         struct tls_buffer response;
259         struct TLSContext *tls;
260         int socket;
261         int status_code;
262         time_t last_modified;
263         time_t date;
264         size_t content_length;
265         char *redirect;
266 };
267
268 int month(char *m) {
269         char *months[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul",
270                 "Aug", "Sep", "Oct", "Nov", "Dec" };
271         int i;
272
273         for (i=0; i < 12; i++) {
274                 if (!strncasecmp(m, months[i], 3)) {
275                         return i+1;
276                 }
277         }
278         return 0;
279 }
280
281 /* Wed, 06 Feb 2019 10:06:05 GMT */
282 time_t parse_date(char *d) {
283         //static const char format[] = "%a, %d %b %Y %H:%M:%S %Z"; // rfc 1123
284         struct tm tm = { 0 };
285         int rv;
286
287         //char *data = "Tue, 13 Dec 2011 16:08:21 GMT";
288         int h, m, s, dom, Y;
289         char M[4];
290         rv = sscanf(d, "%*[a-zA-Z,] %d %s %d %d:%d:%d", &dom, M, &Y, &h, &m, &s);
291         if (rv == 6) {
292                 tm.tm_year = Y - 1900;
293                 tm.tm_hour = h;
294                 tm.tm_min = m;
295                 tm.tm_sec = s;
296                 tm.tm_mon = month(M)-1;
297                 tm.tm_mday = dom;
298
299                 return mktime(&tm);
300         }
301         return -1;
302 }
303
304 char *find_header(struct io *io, char *header, size_t *len) {
305         char *eoh, *soh;
306         size_t hlen;
307
308         *len = 0;
309
310         hlen = strlen(header);
311         eoh = strstr(io->response.buffer, "\r\n\r\n");
312         if (!eoh) {
313                 return 0;
314         }
315         soh = io->response.buffer;
316         do {
317                 soh = strstr(soh, "\r\n");
318                 if (soh == eoh) {
319                         break;
320                 }
321                 soh += 2;
322                 if (!memcmp(soh, header, hlen)) {
323                         break;
324                 }
325         } while (soh < eoh);
326
327         if (soh >= eoh) {
328                 return 0;
329         }
330         eoh = strstr(soh, "\r\n");
331         soh += hlen;
332         while (soh < eoh && isspace(*soh)) {
333                 soh++;
334         }
335         *len = (eoh - soh);
336         return soh;
337 }
338
339 void parse_header(struct io *io) {
340         char *s = io->response.buffer;
341         int code = 0;
342         char *hval;
343         size_t hlen;
344
345         while (!isspace(*s)) {
346                 s++;
347         }
348         while (isspace(*s)) {
349                 s++;
350         }
351         code = strtol(s, 0, 10);
352         io->status_code = code;
353
354         hval = find_header(io, "Date:", &hlen);
355         if (hval) {
356                 hval[hlen] = 0;
357                 io->date = parse_date(hval);
358                 hval[hlen] = '\r';
359         }
360         hval = find_header(io, "Last-Modified:", &hlen);
361         if (hval) {
362                 hval[hlen] = 0;
363                 io->last_modified = parse_date(hval);
364                 hval[hlen] = '\r';
365         }
366
367         hval = find_header(io, "Content-Length:", &hlen);
368         if (hval) {
369                 hval[hlen] = 0;
370                 io->content_length = strtoul(hval, 0, 10);
371                 hval[hlen] = '\r';
372         }
373
374         switch (code) {
375                 case 301:
376                 case 302:
377                 case 303:
378                 case 307:
379                         hval = find_header(io, "Location:", &hlen);
380                         if (hval) {
381                                 io->redirect = strndup(hval, hlen);
382                         }
383                         break;
384                 default:
385                         break;
386         }
387
388 }
389
390 ssize_t fill_buffer(struct io *io) {
391         unsigned char buffer[4096];
392         ssize_t ret;
393
394         if (io->tls) {
395                 ret = tls_read(io->tls, buffer, sizeof buffer);
396         } else {
397                 ret = read(io->socket, buffer, sizeof buffer);
398         }
399
400         if (ret > 0) {
401                 tls_buffer_append(&io->response, buffer, ret);
402         }
403
404         return ret;
405 }
406
407 #if 0
408 char *nextline(struct io *io) {
409         char *eol = 0;;
410
411         eol = memchr(io->response.buffer, '\n', io.response.size);
412         while (eol == 0) {
413                 fill_buffer(io);
414                 eol = memchr(io->response.buffer, '\n', io.response.size);
415         }
416         if (eol) {
417
418 }
419 #endif
420
421 void append_header(struct tls_buffer *buf, char *header, char *val) {
422         tls_buffer_append(buf, header, strlen(header));
423         tls_buffer_append(buf, ": ", 2);
424         tls_buffer_append(buf, val, strlen(val));
425         tls_buffer_append(buf, "\r\n", 2);
426 }
427
428 static void pdots(int len, int ch, unsigned long was, unsigned long now,
429                 unsigned long total) {
430         was = len * was / total;
431         if (now > total) {
432                 now = total;
433         }
434         now = len * now / total;
435         while (was++ < now) {
436                 putc(ch,stderr);
437         }
438 }
439
440 int main(int ac, char *av[]) {
441         int sockfd, port = -1, rv;
442         ssize_t ret;
443         int option;
444 #if 0
445         char msg[] = "GET %s HTTP/1.1\r\nHost: %s:%i\r\nConnection: close\r\n\r\n";
446         char msg2[] = "GET %s HTTP/1.1\r\nHost: %s:%i\r\nLast-Modified: %s\r\nConnection: close\r\n\r\n";
447         char msg_buffer[1024];
448 #endif
449         char *req_file = 0;
450         char *host = 0;
451         struct tls_uri uri;
452         char *outfile = 0;
453         int raw = 0, head = 0;
454         int out = 1;
455         int use_tls = 0;
456         struct io io = { {0}, 0, -1, 0, 0, 0, 0, 0 };
457         struct TLSContext *clientssl = 0;
458         int failsilent = 0;
459         char *lmfile = 0;
460         int progressbar = 0;
461         struct tls_buffer request;
462         char lmtime[80];
463         char *eoh = 0;
464         size_t total = 0;
465         size_t header_len;
466         char *url = 0;
467         int redirs = 0, redirlimit = 50, printstatus = 0;
468         int verifypolicy = 1;
469
470         ltc_mp = tfm_desc;
471
472         while ((option = getopt(ac, av, "o:rIfz:#R:SkK")) != -1) {
473                 switch (option) {
474                         case 'o': outfile = optarg; break;
475                         case 'S': printstatus = 1; head = 1; break;
476                         case 'k': verifypolicy = 0; break;
477                         case 'K': verifypolicy = 2; break;
478                         case 'I': head = 1;
479                         case 'r': raw = 1; break;
480                         case 'f': failsilent = 1; break;
481                         case 'z': lmfile = optarg; break;
482                         case 'R': redirlimit = strtol(optarg, 0, 10); break;
483                         case '#': progressbar = 1; break;
484                         default:
485                                   exit(EXIT_FAILURE);
486                                   break;
487                 }
488         }
489
490         if (ac < optind) {
491                 fprintf(stderr, "Usage: %s uri\n", av[0]);
492                 exit(EXIT_FAILURE);
493         }
494
495         if (lmfile) {
496                 struct stat st;
497                 int rv;
498                 struct tm *mtime;
499                 time_t ts;
500
501                 rv = stat(lmfile, &st);
502                 if (rv == -1) {
503                         perror("stat failed:");
504                         exit(EXIT_FAILURE);
505                 }
506                 ts = st.st_mtime;
507                 mtime = gmtime(&ts);
508                 strftime(lmtime, sizeof lmtime, "%a, %d %b %Y %H:%M:%S GMT", mtime);
509         }
510
511         url = strdup(av[optind]);
512         if (!url) {
513                 exit(EXIT_FAILURE);
514         }
515
516         if (outfile) {
517                 out = open(outfile, O_WRONLY|O_CREAT, 0600);
518                 if (out == -1) {
519                         perror("can't open output file:");
520                         exit(EXIT_FAILURE);
521                 }
522         }
523
524         signal(SIGPIPE, SIG_IGN);
525
526         tls_buffer_init(&io.response, 0);
527         tls_buffer_init(&request, 128);
528
529         while (redirs++ <= redirlimit) {
530                 tls_free_uri(&uri);
531                 io.response.len = 0;
532                 request.len = 0;
533
534                 tls_parse_uri(url, &uri);
535                 host = uri.host;
536                 port = atoi(uri.port);
537                 req_file = uri.path;
538
539                 /* construct request */
540                 if (head) {
541                         tls_buffer_append(&request, "HEAD ", 5);
542                 } else {
543                         tls_buffer_append(&request, "GET ", 4);
544                 }
545                 tls_buffer_append(&request, uri.path, strlen(uri.path));
546                 tls_buffer_append(&request, " HTTP/1.1", 9);
547                 tls_buffer_append(&request, "\r\n", 2);
548
549                 append_header(&request, "Host", host);
550                 append_header(&request, "Connection", "close");
551                 if (lmfile) {
552                         append_header(&request, "If-Modified-Since", lmtime);
553                 }
554                 tls_buffer_append(&request, "\r\n", 2);
555                 //fprintf(stderr, "msg =\n%.*s", (int)request.len, request.buffer);
556
557                 if (!strcmp(uri.scheme, "https")) {
558                         use_tls = 1;
559
560                         clientssl = tls_create_context(TLS_CLIENT, TLS_V12);
561
562                         /* optionally, we can set a certificate validation
563                          * callback function if set_verify is not called, and
564                          * root ca is set, `tls_default_verify` will be used
565                          * (does exactly what `verify` does in this example)
566                          */
567                         if (verifypolicy == 2) {
568                                 char *cert_path = 0;
569                                 cert_path = getenv("ZPM_CERTFILE");
570                                 if (!cert_path) {
571                                         cert_path = "/var/lib/zpm/roots.pem";
572                                 }
573                                 rv = tls_load_root_file(clientssl, cert_path);
574                                 if (rv == -1) {
575                                         fprintf(stderr, "Error loading root certs\n");
576                                         return 1;
577                                 }
578                                 tls_set_verify(clientssl, verify_roots);
579                         } else if (verifypolicy == 1) {
580                                 tls_set_verify(clientssl, verify_first);
581                         } else {
582                                 tls_set_verify(clientssl, verify_trust);
583                         }
584
585                         if (!clientssl) {
586                                 fprintf(stderr, "Error initializing client context\n");
587                                 return -1;
588                         }
589                         tls_sni_set(clientssl, uri.host);
590                         clientssl->sync = 1;
591                         io.tls = clientssl;
592                 }
593
594                 sockfd = open_tcp_connection(host, port);
595                 io.socket = sockfd;
596
597                 if (sockfd < 0) {
598                         exit(EXIT_FAILURE);
599                 }
600
601                 if (use_tls) {
602                         tls_set_fd(clientssl, sockfd);
603
604                         if ((rv = tls_connect(clientssl)) != 1) {
605                                 fprintf(stderr, "Handshake Error %i\n", rv);
606                                 return -5;
607                         }
608
609                         ret = tls_write(clientssl, request.buffer, request.len);
610                 } else {
611                         ret = write(io.socket, request.buffer, request.len);
612                 }
613
614                 if (ret < 0) {
615                         fprintf(stderr, "write error %zd\n", ret);
616                         return -6;
617                 }
618
619                 do {
620                         ret = fill_buffer(&io);
621                         if (ret < 0) {
622                                 break;
623                         }
624                         eoh = strstr(io.response.buffer, "\r\n\r\n");
625                         if (ret == 0) {
626                                 break;
627                         }
628                 } while (!eoh);
629
630                 if (!eoh) {
631                         /* never got (complet) header */
632                         fprintf(stderr, "incomplete response to %s\n", av[optind]);
633                         exit(EXIT_FAILURE);
634                 }
635
636                 header_len = (size_t)(eoh - io.response.buffer) + 4;
637                 parse_header(&io);
638
639                 switch (io.status_code) {
640                         case 301:
641                         case 302:
642                         case 303:
643                         case 307:
644                                 free(url);
645                                 url = strdup(io.redirect);
646                                 continue;
647                                 break;
648                 }
649
650                 if (printstatus) {
651                         printf("%d\n", io.status_code);
652                         break;
653                 }
654
655                 if (io.status_code != 200) {
656                         break;
657                 }
658
659                 if (!raw) {
660                         tls_buffer_shift(&io.response, header_len);
661                 }
662                 if (head) {
663                         io.response.len -= 2;
664                 }
665
666                 if (progressbar) {
667                         if (io.content_length) {
668                                 fprintf(stderr, "(%lu) ", io.content_length);
669                         }
670                 }
671
672                 do {
673                         write(out, io.response.buffer, io.response.len);
674                         ret = io.response.len;
675                         io.response.len = 0;
676
677                         if (progressbar) {
678                                 if (io.content_length) {
679                                         pdots(50, '.', total, total+ret,
680                                                         io.content_length);
681                                 } else {
682                                         int old = total / 1000000;
683                                         int new = (total+ret)/1000000;
684                                         while (old < new) {
685                                                 putc('.',stderr);
686                                         }
687                                 }
688                                 total += ret;
689                         }
690                         ret = fill_buffer(&io);
691                 } while (ret > 0);
692
693                 if (ret < 0) {
694                         fprintf(stderr, "%s read error %zd\n", uri.scheme, ret);
695                 }
696                 /* futimens(out, ...) */
697                 close(out);
698                 tls_buffer_free(&io.response);
699                 break;
700         }
701
702         if (use_tls) {
703                 tls_shutdown(clientssl);
704                 tls_free(clientssl);
705         }
706
707         close(sockfd);
708         if (progressbar && io.status_code == 200) {
709                 fprintf(stderr, "(%lu)", total);
710                 putc('\n',stderr);
711         }
712
713         return io.status_code == 200 ? 0 : EXIT_FAILURE;
714 }