]> pd.if.org Git - zpackage/blob - src/fetchurl.c
rework zpm-add
[zpackage] / src / fetchurl.c
1 #define _POSIX_C_SOURCE 200809L
2
3 #include <stdio.h>
4 #include <sys/types.h>
5 #include <sys/socket.h>
6 #include <netinet/in.h>
7 #include <netdb.h>
8 #include <string.h>
9 #include <strings.h>
10 #include <stdlib.h>
11 #include <unistd.h>
12 #include <signal.h>
13 #include <time.h>
14 #include <sys/stat.h>
15 #include <fcntl.h>
16
17 #include "tlse.h"
18 #define MARK fprintf(stderr, "%s %s:%d\n", __FILE__, __func__, __LINE__)
19
20 struct tls_uri {
21         char *scheme;
22         char *userinfo;
23         char *host;
24         char *port;
25         char *path;
26         char *query;
27         char *fragment;
28 };
29 int tls_parse_uri(char *, struct tls_uri *);
30 void tls_free_uri(struct tls_uri *);
31
32 int open_tcp_connection(char *host, int port);
33
34 /* if trustpolicy is 0, we just accept anything */
35 int verify_trust(struct TLSContext *context, struct TLSCertificate **chain, int
36                 len) {
37         /* suppress unused */
38         if (context || chain || len) {
39                 return 0;
40         }
41
42         return 0;
43 }
44
45 static char hexchars[] = "0123456789abcdefABCDEF";
46
47 static void hex(char *dst, uint8_t *src, size_t len) {
48         while (len--) {
49                 dst[0] = hexchars[(src[0]>>4)&0xf];
50                 dst[1] = hexchars[src[0]&0xf];
51                 dst+=2;
52                 src++;
53         }
54 }
55
56 #if 0
57 static void hexbin(uint8_t *dst, unsigned char *src, size_t len) {
58         size_t i;
59         int x;
60
61         for (i=0; i<len; i+=2) {
62                 sscanf((const char *)src+i, "%02x", &x);
63                 dst[i/2] = x;
64         }
65 }
66 #endif
67
68 char *my_getline(struct tls_buffer *b, int fd, size_t *size) {
69         char *loc = 0;
70         char buf[4096];
71         ssize_t rv = 0;
72
73         while (b->error == 0) {
74                 loc = memchr(b->buffer, '\n', b->len);
75                 if (loc) {
76                         *size = loc - b->buffer + 1;
77                         return b->buffer;
78                 } else {
79                         rv = read(fd, buf, sizeof buf);
80                         if (rv == -1) {
81                                 return 0;
82                         }
83                         if (rv == 0) {
84                                 break;
85                         }
86                         tls_buffer_append(b, buf, rv);
87                 }
88         }
89         if (rv == 0) {
90                 *size = b->len;
91                 return b->buffer;
92         }
93         return 0;
94 }
95
96 /*
97  * We use a trust on first use policy.  The trust DB is a simple
98  * file in /var/lib/zpm/known_hosts, or ~/.zpm/known_hosts, or ZPM_KNOWNHOSTS.
99  * if -k is given, no verification is done
100  */
101 int verify_first(struct TLSContext *context, struct TLSCertificate **chain, int
102                 certs) {
103         int err;
104         char *trustfile, *homedir = 0, *host, *fp;
105         unsigned char certhash[65];
106         int trustdb;
107         struct tls_buffer tbuf;
108
109         char *line = 0;
110         size_t len = 0;
111
112         if (certs == 0 || chain == 0) {
113                 return 1;
114         }
115
116         err = tls_certificate_is_valid(chain[0]);
117         if (err) {
118                 return err;
119         }
120
121         if (context->sni) {
122                 err = tls_certificate_valid_subject(chain[0], context->sni);
123                 if (err) {
124                         return err;
125                 }
126         }
127
128         hex(certhash, chain[0]->fp, 32);
129         certhash[64] = 0;
130
131         trustfile = getenv("ZPM_KNOWNHOSTS");
132         if (!trustfile) {
133                 if (geteuid() == 0) {
134                         trustfile = "/var/lib/zpm/known_hosts";
135                 } else {
136                         /* we could do this with a series of
137                          * openat() calls instead of building
138                          * up a string 
139                          */
140                         trustfile = getenv("HOME");
141                         if (!trustfile) {
142                                 fprintf(stderr, "home = %s\n", trustfile);
143                                 return 1;
144                         }
145                         len = snprintf(homedir, 0, "%s/.zpm/known_hosts", trustfile);
146                         homedir = malloc(len+1);
147                         if (!homedir) {
148                                 return 1;
149                         }
150                         len = snprintf(homedir, len+1, "%s/.zpm/known_hosts", trustfile);
151                         trustfile = homedir;
152                 }
153         }
154         /* cert is valid on its face, so check against the trust db */
155         trustdb = open(trustfile, O_RDWR|O_CREAT, 0600);
156         if (trustdb == -1) {
157                 fprintf(stderr, "cannot open trustdb %s: %s\n", trustfile, strerror(errno));
158                 if (homedir) {
159                         free(homedir);
160                 }
161                 return 1;
162         }
163
164         if (homedir) {
165                 free(homedir);
166         }
167
168         len = 0;
169         tls_buffer_init(&tbuf, 128);
170         do {
171                 char *off;
172
173                 tls_buffer_shift(&tbuf, len);
174                 line = my_getline(&tbuf, trustdb, &len);
175
176                 if (!line || !len) {
177                         break;
178                 }
179
180                 fp = line;
181                 while (isspace(*fp)) {
182                         fp++;
183                 }
184                 if (*fp == '#') {
185                         continue;
186                 }
187                 off = strchr(line, ':');
188                 if (!off) {
189                         continue;
190                 }
191                 host = off + 1;
192                 *off = 0;
193                 if (line[len-1] == '\n') {
194                         line[len-1] = 0;
195                 }
196
197                 if (strlen(fp) != 64) {
198                         continue;
199                 }
200
201                 if (len && line[len-1] == '\n') {
202                         line[len-1] = 0;
203                 }
204
205                 if (strcmp(context->sni, host) != 0) {
206                         continue;
207                 }
208
209                 int match = (memcmp(certhash, fp, 64) == 0); 
210
211                 close(trustdb);
212                 tls_buffer_free(&tbuf);
213                 return match ? no_error : bad_certificate;
214         } while (!tbuf.error);
215
216         /* got here, so we should be at EOF, so add this host to trust db */
217         lseek(trustdb, 0, SEEK_END);
218
219         /* re-use the buffer so we only do one write */
220         /* ignore errors, the cert is fine, we just can't update
221          * the trustdb if there's errors here
222          */
223         tbuf.len = 0;
224         tls_buffer_append(&tbuf, certhash, 64);
225         tls_buffer_append_byte(&tbuf, ':');
226         tls_buffer_append(&tbuf, context->sni, strlen(context->sni));
227         tls_buffer_append_byte(&tbuf, '\n');
228         write(trustdb, tbuf.buffer, tbuf.len);
229         close(trustdb);
230         tls_buffer_free(&tbuf);
231
232         return no_error;
233 }
234
235 int verify_roots(struct TLSContext *context, struct TLSCertificate **chain, int
236                 len) {
237         int i, err;
238
239         if (chain) {
240                 for (i = 0; i < len; i++) {
241                         struct TLSCertificate *certificate = chain[i];
242                         err = tls_certificate_is_valid(certificate);
243                         if (err) {
244                                 return err;
245                         }
246                 }
247         }
248         
249         err = tls_certificate_chain_is_valid(chain, len);
250         if (err) {
251                 return err;
252         }
253
254         if (len > 0 && context->sni) {
255                 err = tls_certificate_valid_subject(chain[0], context->sni);
256                 if (err) {
257                         return err;
258                 }
259         }
260
261         /* Perform certificate validation against ROOT CA */
262         err = tls_certificate_chain_is_valid_root(context, chain, len);
263         if (err) {
264                 return err;
265         }
266
267         return no_error;
268 }
269
270 struct io {
271         struct tls_buffer response;
272         struct TLSContext *tls;
273         int socket;
274         int status_code;
275         time_t last_modified;
276         time_t date;
277         size_t content_length;
278         char *redirect;
279 };
280
281 int month(char *m) {
282         char *months[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul",
283                 "Aug", "Sep", "Oct", "Nov", "Dec" };
284         int i;
285
286         for (i=0; i < 12; i++) {
287                 if (!strncasecmp(m, months[i], 3)) {
288                         return i+1;
289                 }
290         }
291         return 0;
292 }
293
294 /* Wed, 06 Feb 2019 10:06:05 GMT */
295 time_t parse_date(char *d) {
296         struct tm tm = { 0 };
297         int rv;
298
299         int h, m, s, dom, Y;
300         char M[4];
301         rv = sscanf(d, "%*[a-zA-Z,] %d %s %d %d:%d:%d", &dom, M, &Y, &h, &m, &s);
302         if (rv == 6) {
303                 tm.tm_year = Y - 1900;
304                 tm.tm_hour = h;
305                 tm.tm_min = m;
306                 tm.tm_sec = s;
307                 tm.tm_mon = month(M)-1;
308                 tm.tm_mday = dom;
309
310                 return mktime(&tm);
311         }
312         return -1;
313 }
314
315 char *find_header(struct io *io, char *header, size_t *len) {
316         char *eoh, *soh;
317         size_t hlen;
318
319         *len = 0;
320
321         hlen = strlen(header);
322         eoh = strstr(io->response.buffer, "\r\n\r\n");
323         if (!eoh) {
324                 return 0;
325         }
326         soh = io->response.buffer;
327         do {
328                 soh = strstr(soh, "\r\n");
329                 if (soh == eoh) {
330                         break;
331                 }
332                 soh += 2;
333                 if (!memcmp(soh, header, hlen)) {
334                         break;
335                 }
336         } while (soh < eoh);
337
338         if (soh >= eoh) {
339                 return 0;
340         }
341         eoh = strstr(soh, "\r\n");
342         soh += hlen;
343         while (soh < eoh && isspace(*soh)) {
344                 soh++;
345         }
346         *len = (eoh - soh);
347         return soh;
348 }
349
350 void parse_header(struct io *io) {
351         char *s = io->response.buffer;
352         int code = 0;
353         char *hval;
354         size_t hlen;
355
356         while (!isspace(*s)) {
357                 s++;
358         }
359         while (isspace(*s)) {
360                 s++;
361         }
362         code = strtol(s, 0, 10);
363         io->status_code = code;
364
365         hval = find_header(io, "Date:", &hlen);
366         if (hval) {
367                 hval[hlen] = 0;
368                 io->date = parse_date(hval);
369                 hval[hlen] = '\r';
370         }
371         hval = find_header(io, "Last-Modified:", &hlen);
372         if (hval) {
373                 hval[hlen] = 0;
374                 io->last_modified = parse_date(hval);
375                 hval[hlen] = '\r';
376         }
377
378         hval = find_header(io, "Content-Length:", &hlen);
379         if (hval) {
380                 hval[hlen] = 0;
381                 io->content_length = strtoul(hval, 0, 10);
382                 hval[hlen] = '\r';
383         }
384
385         switch (code) {
386                 case 301:
387                 case 302:
388                 case 303:
389                 case 307:
390                         hval = find_header(io, "Location:", &hlen);
391                         if (hval) {
392                                 io->redirect = strndup(hval, hlen);
393                         }
394                         break;
395                 default:
396                         break;
397         }
398
399 }
400
401 ssize_t fill_buffer(struct io *io) {
402         unsigned char buffer[4096];
403         ssize_t ret;
404
405         if (io->tls) {
406                 ret = tls_read(io->tls, buffer, sizeof buffer);
407         } else {
408                 ret = read(io->socket, buffer, sizeof buffer);
409         }
410
411         if (ret > 0) {
412                 tls_buffer_append(&io->response, buffer, ret);
413         }
414
415         return ret;
416 }
417
418 #if 0
419 char *nextline(struct io *io) {
420         char *eol = 0;;
421
422         eol = memchr(io->response.buffer, '\n', io.response.size);
423         while (eol == 0) {
424                 fill_buffer(io);
425                 eol = memchr(io->response.buffer, '\n', io.response.size);
426         }
427         if (eol) {
428
429 }
430 #endif
431
432 void append_header(struct tls_buffer *buf, char *header, char *val) {
433         tls_buffer_append(buf, header, strlen(header));
434         tls_buffer_append(buf, ": ", 2);
435         tls_buffer_append(buf, val, strlen(val));
436         tls_buffer_append(buf, "\r\n", 2);
437 }
438
439 void append_timeheader(struct tls_buffer *buf, char *header, time_t ts) {
440         char timestr[80];
441         struct tm *tm;
442
443         tm = gmtime(&ts);
444
445         strftime(timestr, sizeof timestr, "%a, %d %b %Y %H:%M:%S GMT", tm);
446         append_header(buf, header, timestr);
447 }
448
449 static void pdots(int len, int ch, unsigned long was, unsigned long now,
450                 unsigned long total) {
451         was = len * was / total;
452         if (now > total) {
453                 now = total;
454         }
455         now = len * now / total;
456         while (was++ < now) {
457                 putc(ch,stderr);
458         }
459 }
460
461 static void fake_header(struct io *io, int fd) {
462         struct stat st;
463         int code = 200, rv;
464         char *message, codestr[5], length[32];
465         struct tls_buffer *hdr = &io->response;
466
467         if (fd == -1) {
468                 switch (errno) {
469                         case EACCES: code = 403; break;
470                         case ENOENT: code = 404; break;
471                         default: code = 500; break;
472                 }
473         } else {
474                 rv = fstat(fd, &st);
475                 if (rv == -1) {
476                         code = 500;
477                 }
478         }
479
480         if (io->last_modified >= st.st_mtime) {
481                 code = 304;
482         }
483
484         switch (code) {
485                 case 200: message = "OK"; break;
486                 case 304: message = "Not Modified"; break;
487                 case 403: message = "Forbidden"; break;
488                 case 404: message = "Not Found"; break;
489                 case 500: message = "Internal Server Error"; break;
490                 default: break;
491         }
492         sprintf(codestr, "%0.3d ", code);
493         tls_buffer_append(hdr, "HTTP/1.1 ", 9);
494         tls_buffer_append(hdr, codestr, 4);
495         tls_buffer_append_str(hdr, message);
496         tls_buffer_append(hdr, "\r\n", 2);
497
498         append_timeheader(hdr, "Date", time(NULL));
499         append_header(hdr, "Server", "zpm-fetchurl/0.9");
500         if (code < 400) {
501                 append_timeheader(hdr, "Last-Modified", st.st_mtime);
502                 sprintf(length, "%zu", st.st_size);
503                 append_header(hdr, "Content-Length", length);
504         }
505
506         append_header(hdr, "Connection", "close");
507         append_header(hdr, "Content-Type", "application/octet-stream");
508         tls_buffer_append(hdr, "\r\n", 2);
509 }
510
511 int main(int ac, char *av[]) {
512         int sockfd, port = -1, rv;
513         ssize_t ret;
514         int option;
515 #if 0
516         char msg[] = "GET %s HTTP/1.1\r\nHost: %s:%i\r\nConnection: close\r\n\r\n";
517         char msg2[] = "GET %s HTTP/1.1\r\nHost: %s:%i\r\nLast-Modified: %s\r\nConnection: close\r\n\r\n";
518         char msg_buffer[1024];
519 #endif
520         char *req_file = 0;
521         char *host = 0;
522         struct tls_uri uri;
523         char *outfile = 0;
524         int raw = 0, head = 0;
525         int out = 1;
526         int use_tls = 0;
527         struct io io = { {0}, 0, -1, 0, 0, 0, 0, 0 };
528         struct TLSContext *clientssl = 0;
529         int failsilent = 0;
530         char *lmfile = 0;
531         int progressbar = 0;
532         struct tls_buffer request;
533         char lmtime[80];
534         char *eoh = 0;
535         size_t total = 0;
536         size_t header_len;
537         char *url = 0;
538         int redirs = 0, redirlimit = 50, printstatus = 0;
539         int verifypolicy = 1;
540
541         ltc_mp = tfm_desc;
542
543         while ((option = getopt(ac, av, "o:rIfz:#R:SkK")) != -1) {
544                 switch (option) {
545                         case 'o': outfile = optarg; break;
546                         case 'S': printstatus = 1; head = 1; break;
547                         case 'k': verifypolicy = 0; break;
548                         case 'K': verifypolicy = 2; break;
549                         case 'I': head = 1;
550                         case 'r': raw = 1; break;
551                         case 'f': failsilent = 1; break;
552                         case 'z': lmfile = optarg; break;
553                         case 'R': redirlimit = strtol(optarg, 0, 10); break;
554                         case '#': progressbar = 1; break;
555                         default:
556                                   exit(EXIT_FAILURE);
557                                   break;
558                 }
559         }
560
561         if (ac < optind) {
562                 fprintf(stderr, "Usage: %s uri\n", av[0]);
563                 exit(EXIT_FAILURE);
564         }
565
566         io.last_modified = 0;
567         if (lmfile) {
568                 struct stat st;
569                 int rv;
570                 struct tm *mtime;
571                 time_t ts;
572
573                 rv = stat(lmfile, &st);
574                 if (rv == -1) {
575                         perror("stat failed:");
576                         exit(EXIT_FAILURE);
577                 }
578                 ts = st.st_mtime;
579                 io.last_modified = ts;
580                 mtime = gmtime(&ts);
581                 strftime(lmtime, sizeof lmtime, "%a, %d %b %Y %H:%M:%S GMT", mtime);
582         }
583
584         url = strdup(av[optind]);
585         if (!url) {
586                 exit(EXIT_FAILURE);
587         }
588
589         if (outfile) {
590                 out = open(outfile, O_WRONLY|O_CREAT, 0600);
591                 if (out == -1) {
592                         perror("can't open output file:");
593                         exit(EXIT_FAILURE);
594                 }
595         }
596
597         signal(SIGPIPE, SIG_IGN);
598
599         tls_buffer_init(&io.response, 0);
600         tls_buffer_init(&request, 128);
601
602         while (redirs++ <= redirlimit) {
603                 tls_free_uri(&uri);
604                 io.response.len = 0;
605                 request.len = 0;
606
607                 tls_parse_uri(url, &uri);
608                 host = uri.host;
609                 port = atoi(uri.port);
610                 req_file = uri.path;
611
612                 /* construct request */
613                 if (head) {
614                         tls_buffer_append(&request, "HEAD ", 5);
615                 } else {
616                         tls_buffer_append(&request, "GET ", 4);
617                 }
618                 tls_buffer_append(&request, uri.path, strlen(uri.path));
619                 tls_buffer_append(&request, " HTTP/1.1\r\n", 11);
620
621                 append_header(&request, "Host", host);
622                 append_header(&request, "Connection", "close");
623                 if (lmfile) {
624                         append_header(&request, "If-Modified-Since", lmtime);
625                 }
626                 tls_buffer_append(&request, "\r\n", 2);
627
628                 if (!strcmp(uri.scheme, "https")) {
629                         use_tls = 1;
630
631                         clientssl = tls_create_context(TLS_CLIENT, TLS_V12);
632
633                         /* optionally, we can set a certificate validation
634                          * callback function if set_verify is not called, and
635                          * root ca is set, `tls_default_verify` will be used
636                          * (does exactly what `verify` does in this example)
637                          */
638                         if (verifypolicy == 2) {
639                                 char *cert_path = 0;
640                                 cert_path = getenv("ZPM_CERTFILE");
641                                 if (!cert_path) {
642                                         cert_path = "/var/lib/zpm/roots.pem";
643                                 }
644                                 rv = tls_load_root_file(clientssl, cert_path);
645                                 if (rv == -1) {
646                                         fprintf(stderr, "Error loading root certs\n");
647                                         return 1;
648                                 }
649                                 tls_set_verify(clientssl, verify_roots);
650                         } else if (verifypolicy == 1) {
651                                 tls_set_verify(clientssl, verify_first);
652                         } else {
653                                 tls_set_verify(clientssl, verify_trust);
654                         }
655
656                         if (!clientssl) {
657                                 fprintf(stderr, "Error initializing client context\n");
658                                 return -1;
659                         }
660                         tls_sni_set(clientssl, uri.host);
661                         clientssl->sync = 1;
662                         io.tls = clientssl;
663                         sockfd = open_tcp_connection(host, port);
664                         if (sockfd < 0) {
665                                 perror("can't open connection");
666                                 exit(EXIT_FAILURE);
667                         }
668                         tls_set_fd(clientssl, sockfd);
669                         if ((rv = tls_connect(clientssl)) != 1) {
670                                 fprintf(stderr, "Handshake Error %i\n", rv);
671                                 return 1;
672                         }
673                         ret = tls_write(clientssl, request.buffer, request.len);
674                 } else if (!strcmp(uri.scheme, "http")) {
675                         sockfd = open_tcp_connection(host, port);
676                         if (sockfd < 0) {
677                                 perror("can't open connection");
678                                 exit(EXIT_FAILURE);
679                         }
680                         ret = write(sockfd, request.buffer, request.len);
681                 } else if (!strcmp(uri.scheme, "file")) {
682                         sockfd = open(uri.path, O_RDONLY);
683                         fake_header(&io, sockfd);
684                         ret = 0;
685                 } else {
686                         fprintf(stderr, "scheme %s unknown\n", uri.scheme);
687                         exit(EXIT_FAILURE);
688                 }
689
690                 if (ret == -1) {
691                         fprintf(stderr, "unable to write http request: %s\n", strerror(errno));
692                         exit(EXIT_FAILURE);
693                 }
694
695                 io.socket = sockfd;
696
697                 do {
698                         if (io.response.len >= 4) {
699                                 eoh = strstr(io.response.buffer, "\r\n\r\n");
700                         }
701                         if (!eoh) {
702                                 ret = fill_buffer(&io);
703                                 if (ret <= 0) {
704                                         break;
705                                 }
706                         }
707                 } while (!eoh);
708
709                 if (!eoh) {
710                         /* never got (complet) header */
711                         fprintf(stderr, "incomplete response to %s\n", av[optind]);
712                         exit(EXIT_FAILURE);
713                 }
714
715                 header_len = (size_t)(eoh - io.response.buffer) + 4;
716                 parse_header(&io);
717
718                 switch (io.status_code) {
719                         case 301:
720                         case 302:
721                         case 303:
722                         case 307:
723                                 free(url);
724                                 url = strdup(io.redirect);
725                                 continue;
726                                 break;
727                 }
728
729                 if (printstatus) {
730                         printf("%d\n", io.status_code);
731                         break;
732                 }
733
734                 if (!raw) {
735                         tls_buffer_shift(&io.response, header_len);
736                 }
737                 if (head) {
738                         io.response.len -= 2;
739                 }
740
741                 if (progressbar) {
742                         if (io.content_length) {
743                                 fprintf(stderr, "(%lu) ", io.content_length);
744                         }
745                 }
746
747                 do {
748                         write(out, io.response.buffer, io.response.len);
749                         ret = io.response.len;
750                         io.response.len = 0;
751
752                         if (progressbar) {
753                                 if (io.content_length) {
754                                         pdots(50, '.', total, total+ret,
755                                                         io.content_length);
756                                 } else {
757                                         int old = total / 1000000;
758                                         int new = (total+ret)/1000000;
759                                         while (old < new) {
760                                                 putc('.',stderr);
761                                         }
762                                 }
763                                 total += ret;
764                         }
765                         if (head) {
766                                 break;
767                         }
768                         ret = fill_buffer(&io);
769                 } while (ret > 0);
770
771                 if (ret < 0) {
772                         fprintf(stderr, "%s read error %zd\n", uri.scheme, ret);
773                 }
774                 /* futimens(out, ...) */
775                 close(out);
776                 tls_buffer_free(&io.response);
777                 break;
778         }
779
780         if (use_tls) {
781                 tls_shutdown(clientssl);
782                 tls_free(clientssl);
783         }
784
785         close(sockfd);
786         if (progressbar && io.status_code == 200) {
787                 fprintf(stderr, "(%lu)", total);
788                 putc('\n',stderr);
789         }
790
791         return io.status_code == 200 ? 0 : EXIT_FAILURE;
792 }