]> pd.if.org Git - zpackage/blobdiff - src/fetchurl.c
remove stray debug fprintf
[zpackage] / src / fetchurl.c
index 21e2e661e7a85fa169e81f792a9154acdb3ad8eb..5ff67ddb2ddf257b69122159251cdcb02690c1eb 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "tlse.h"
 
+#define DEBUG(lvl, ...) if (debuglevel >= lvl ) { fprintf(stderr, __VA_ARGS__); }
+
 struct tls_uri {
        char *scheme;
        char *userinfo;
@@ -55,6 +57,8 @@ static void hex(char *dst, uint8_t *src, size_t len) {
        }
 }
 
+static int debuglevel = 0;
+
 #if 0
 static void hexbin(uint8_t *dst, unsigned char *src, size_t len) {
        size_t i;
@@ -276,14 +280,23 @@ int verify_roots(struct TLSContext *context, struct TLSCertificate **chain, int
 
 struct io {
        struct tls_buffer response;
+       struct tls_buffer chunkbuf;
        struct TLSContext *tls;
        int socket;
+       int chunked;
+       int chunknum;
+       size_t chunksize;
+       size_t chunkleft;
+       size_t chunktotal;
+       size_t chunkbytesread;
        int status_code;
        time_t last_modified;
        time_t date;
        size_t content_length;
+       size_t received;
        char *redirect;
 };
+ssize_t unchunk(struct io *io);
 
 int month(char *m) {
        char *months[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul",
@@ -326,18 +339,26 @@ char *find_header(struct io *io, char *header, size_t *len) {
        *len = 0;
 
        hlen = strlen(header);
+
+       /* TODO can't do this, buffer may not be zero terminated */
        eoh = strstr(io->response.buffer, "\r\n\r\n");
        if (!eoh) {
                return 0;
        }
+
        soh = io->response.buffer;
        do {
+               /* skip the first line for some reason */
                soh = strstr(soh, "\r\n");
                if (soh == eoh) {
                        break;
                }
                soh += 2;
-               if (!memcmp(soh, header, hlen)) {
+               /* done if not enough room */
+               if (hlen > (size_t)(eoh - soh)) {
+                       break;
+               }
+               if (!strncasecmp(soh, header, hlen)) {
                        break;
                }
        } while (soh < eoh);
@@ -389,11 +410,22 @@ void parse_header(struct io *io) {
                hval[hlen] = '\r';
        }
 
+       hval = find_header(io, "Transfer-Encoding:", &hlen);
+       if (hval) {
+               hval[hlen] = 0;
+               io->content_length = strtoul(hval, 0, 10);
+               if (!strcmp(hval, "chunked")) {
+                       io->chunked = 1;
+               }
+               hval[hlen] = '\r';
+       }
+
        switch (code) {
                case 301:
                case 302:
                case 303:
                case 307:
+                       DEBUG(1, "looking for Location header\n");
                        hval = find_header(io, "Location:", &hlen);
                        if (hval) {
                                io->redirect = strndup(hval, hlen);
@@ -405,23 +437,136 @@ void parse_header(struct io *io) {
 
 }
 
+/* fill buffer needs to put bytes into the response buffer
+ * if the transfer encoding is chunked, it will need to
+ * put the bytes into the chunkbuf first, then call
+ * unchunk.  if unchunk return 0, then it needs more data,
+ * otherwise unchunk returns the number of bytes transferred
+ */
+
 ssize_t fill_buffer(struct io *io) {
        unsigned char buffer[4096];
-       ssize_t ret;
+       ssize_t ret = 0;
 
-       if (io->tls) {
-               ret = tls_read(io->tls, buffer, sizeof buffer);
-       } else {
-               ret = read(io->socket, buffer, sizeof buffer);
-       }
+       ret = unchunk(io);
 
-       if (ret > 0) {
-               tls_buffer_append(&io->response, buffer, ret);
+       while (ret == 0) {
+               if (io->tls) {
+                       ret = tls_read(io->tls, buffer, sizeof buffer);
+               } else {
+                       ret = read(io->socket, buffer, sizeof buffer);
+               }
+
+               if (ret <= 0) {
+                       break;
+               }
+
+               if (io->chunked) {
+                       tls_buffer_append(&io->chunkbuf, buffer, ret);
+                       //fwrite(buffer, ret, 1, stderr);
+                       ret = unchunk(io);
+                       if (ret != 0 || io->chunksize == 0) {
+                               break;
+                       }
+               } else {
+                       tls_buffer_append(&io->response, buffer, ret);
+                       break;
+               }
        }
 
        return ret;
 }
 
+/* essentially memmem */
+void *lookfor(const void *buf, size_t buflen, const void *pattern, size_t len) {
+      const char *bf = buf;
+      const char *pt = pattern;
+      const char *p = bf;
+
+      while (len <= (buflen - (p - bf))) {
+            if ((p = memchr(p, *pt, buflen - (p - bf))) != 0) {
+                  if (memcmp(p, pattern, len) == 0) {
+                        return (void *)p;
+                 } else {
+                         p++;
+                 }
+            } else {
+                   break;
+           }
+      }
+      return NULL;
+}
+
+/* returns read chunksize, unshifts the line */
+ssize_t read_chunksize(struct io *io) {
+       char *cr;
+       ssize_t cs;
+
+       //fwrite(io->chunkbuf.buffer, io->chunkbuf.len, 1, stderr);
+       
+       /* there could be up to two leading bytes */
+       if (io->chunkbuf.len >= 2 && io->chunkbuf.buffer[0] == '\r' && io->chunkbuf.buffer[1] == '\n') {
+               tls_buffer_shift(&io->chunkbuf, 2);
+       }
+
+       cr = lookfor(io->chunkbuf.buffer, io->chunkbuf.len, "\r\n", 2);
+
+       if (cr == 0) {
+               return -1;
+       }
+
+       cs = strtol(io->chunkbuf.buffer, 0, 16);
+       tls_buffer_shift(&io->chunkbuf, cr - io->chunkbuf.buffer + 2);
+
+       return cs;
+}
+
+/* unchunk's job is to move bytes from the chunk buf to the response buf */
+/* return bytes from chunk, 0 if unable.  once last chunk, changed chunked
+ * to 0?
+ */
+ssize_t unchunk(struct io *io) {
+       ssize_t bytes_to_move = 0;
+       ssize_t chunksize;
+
+       if (!io || !io->chunked) {
+               return 0;
+       }
+
+       if (io->chunkleft == 0) {
+               chunksize = read_chunksize(io);
+               if (chunksize == -1) {
+                       return 0;
+               }
+               io->chunksize = chunksize;
+               if (io->chunksize == 0) {
+                       /* end of chunked data */
+                       io->chunked = 0;
+                       return 0;
+               }
+               io->chunknum++;
+               io->chunkleft = io->chunksize;
+               io->chunktotal += io->chunksize;
+       }
+
+       if (io->chunkbuf.len == 0) {
+               /* need more bytes */
+               return 0;
+       }
+
+       bytes_to_move = io->chunkbuf.len < io->chunkleft ? io->chunkbuf.len : io->chunkleft;
+
+       tls_buffer_append(&io->response, io->chunkbuf.buffer, bytes_to_move);
+       io->chunkleft -= bytes_to_move;
+       io->chunkbytesread += bytes_to_move;
+
+       /* chunk is terminated with a crlf */
+       //tls_buffer_shift(&io->chunkbuf, bytes_to_move + io->chunkleft ? 0 : 2);
+       tls_buffer_shift(&io->chunkbuf, bytes_to_move);
+
+       return bytes_to_move;
+}
+
 #if 0
 char *nextline(struct io *io) {
        char *eol = 0;;
@@ -578,7 +723,7 @@ int main(int ac, char *av[]) {
        int raw = 0, head = 0;
        int out = 1; /* output file descriptor */
        int use_tls = 0;
-       struct io io = { {0}, 0, -1, 0, 0, 0, 0, 0 };
+       struct io io = { {0}, {0}, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
        struct TLSContext *clientssl = 0;
        int failsilent = 0;
        char *lmfile = 0;
@@ -590,13 +735,14 @@ int main(int ac, char *av[]) {
        size_t total = 0;
        size_t header_len;
        char *url = 0;
-       int redirs = 0, redirlimit = 50, printstatus = 0;
+       int redirs = 0, redirlimit = 50, printstatus = 0, showreq = 0;
        int verifypolicy = 1, calcoutfile = 0, ifnewer = 0;
 
        ltc_mp = tfm_desc;
 
-       while ((option = getopt(ac, av, "o:OrIfz:np#R:SkKU:")) != -1) {
+       while ((option = getopt(ac, av, "do:OrIfz:np#RL:SkKU:")) != -1) {
                switch (option) {
+                       case 'd': debuglevel++; break;
                        case 'o': outfile = optarg; break;
                        case 'O': calcoutfile = 1; break;
                        case 'S': printstatus = 1; head = 1; break;
@@ -605,10 +751,11 @@ int main(int ac, char *av[]) {
                        case 'U': user_agent = optarg; break;
                        case 'I': head = 1;
                        case 'r': raw = 1; break;
+                       case 'R': showreq = 1; break;
                        case 'f': failsilent = 1; break;
                        case 'z': lmfile = optarg; break;
                        case 'n': ifnewer = 1; break;
-                       case 'R': redirlimit = strtol(optarg, 0, 10); break;
+                       case 'L': redirlimit = strtol(optarg, 0, 10); break;
                        case 'p':
                        case '#': progressbar = 1; break;
                        default:
@@ -671,6 +818,7 @@ int main(int ac, char *av[]) {
        while (redirs++ <= redirlimit) {
                tls_free_uri(&uri);
                io.response.len = 0;
+               io.chunked = 0;
                request.len = 0;
                eoh = 0;
 
@@ -697,6 +845,7 @@ int main(int ac, char *av[]) {
                        append_header(&request, "User-Agent", user_agent);
                }
                append_header(&request, "Accept", "*/*");
+               //append_header(&request, "Accept-Encoding", "chunked, identity;q=0.5");
                append_header(&request, "Connection", "close");
                if (lmfile) {
                        append_header(&request, "If-Modified-Since", lmtime);
@@ -706,6 +855,7 @@ int main(int ac, char *av[]) {
                if (!strcmp(uri.scheme, "https")) {
                        use_tls = 1;
 
+                       DEBUG(1, "creating tls context\n");
                        clientssl = tls_create_context(TLS_CLIENT, TLS_V12);
 
                        /* optionally, we can set a certificate validation
@@ -724,10 +874,14 @@ int main(int ac, char *av[]) {
                                        fprintf(stderr, "Error loading root certs\n");
                                        return 1;
                                }
+                               DEBUG(1, "verifying ssl cert via roots\n");
                                tls_set_verify(clientssl, verify_roots);
                        } else if (verifypolicy == 1) {
+                               DEBUG(1, "verifying ssl cert via first use\n");
                                tls_set_verify(clientssl, verify_first);
+                               DEBUG(1, "verified ssl cert via first use\n");
                        } else {
+                               DEBUG(1, "verifying ssl cert via trust\n");
                                tls_set_verify(clientssl, verify_trust);
                        }
 
@@ -736,9 +890,11 @@ int main(int ac, char *av[]) {
                                return -1;
                        }
                        tls_sni_set(clientssl, uri.host);
+                       DEBUG(1, "set sni to %s\n", uri.host);
                        clientssl->sync = 1;
                        io.tls = clientssl;
                        sockfd = open_tcp_connection(host, port);
+                       DEBUG(1, "opened tcp socket fd %d\n", sockfd);
                        if (sockfd < 0) {
                                perror("can't open connection");
                                exit(EXIT_FAILURE);
@@ -765,6 +921,7 @@ int main(int ac, char *av[]) {
                        exit(EXIT_FAILURE);
                }
 
+               DEBUG(1, "wrote http request\n");
                if (ret == -1) {
                        fprintf(stderr, "unable to write http request: %s\n", strerror(errno));
                        exit(EXIT_FAILURE);
@@ -772,26 +929,33 @@ int main(int ac, char *av[]) {
 
                io.socket = sockfd;
 
+               eoh = 0;
                do {
                        if (io.response.len >= 4) {
                                eoh = strstr(io.response.buffer, "\r\n\r\n");
                        }
                        if (!eoh) {
+                               DEBUG(1, "filling buffer\n");
                                ret = fill_buffer(&io);
                                if (ret <= 0) {
                                        break;
                                }
                        }
                } while (!eoh);
+               DEBUG(1, "got response\n");
 
                if (!eoh) {
-                       /* never got (complet) header */
-                       fprintf(stderr, "incomplete response to %s\n", av[optind]);
+                       /* never got (complete) header */
+                       fprintf(stderr, "incomplete response (ret = %zd) to %s\n", ret, url);
+                       fprintf(stderr, "have:\n");
+                       fwrite(io.response.buffer, io.response.len, 1, stderr);
                        exit(EXIT_FAILURE);
                }
 
                header_len = (size_t)(eoh - io.response.buffer) + 4;
+
                parse_header(&io);
+               DEBUG(1, "parsed response header, code %d\n", io.status_code);
 
                switch (io.status_code) {
                        case 304:
@@ -801,8 +965,10 @@ int main(int ac, char *av[]) {
                        case 302:
                        case 303:
                        case 307:
+                               DEBUG(1, "redirecting to %s\n", io.redirect);
                                free(url);
                                url = strdup(io.redirect);
+                               DEBUG(1, "redirecting to %s\n", url);
                                close(io.socket);
                                continue;
                                break;
@@ -812,58 +978,96 @@ int main(int ac, char *av[]) {
                        printf("%d\n", io.status_code);
                        break;
                }
-
-               if (!raw) {
-                       tls_buffer_shift(&io.response, header_len);
+               if (showreq) {
+                       fwrite(request.buffer, request.len, 1, stderr);
                }
+
                if (head) {
                        io.response.len -= 2;
+                       write(out, io.response.buffer, io.response.len);
+                       break;
                }
 
-               if (progressbar) {
-                       if (io.content_length) {
-                               fprintf(stderr, "(%lu) ", io.content_length);
-                       }
+               if (io.status_code == 304) {
+                       break;
                }
 
                if (outfile) {
-                       out = open(outfile, O_WRONLY|O_CREAT, 0600);
+                       out = open(outfile, O_WRONLY|O_CREAT|O_TRUNC, 0600);
                        if (out == -1) {
                                perror("can't open output file:");
                                exit(EXIT_FAILURE);
                        }
                }
 
+               if (progressbar) {
+                       if (io.content_length) {
+                               fprintf(stderr, "(%lu) ", io.content_length);
+                       }
+               }
+
+               if (raw) {
+                       write(out, io.response.buffer, header_len);
+               }
+               tls_buffer_shift(&io.response, header_len);
+
+               if (io.chunked) {
+                       /* we've written out the head if needed, so
+                        * what's in the response buffer is the
+                        * chunked encoding, so just reassign that
+                        * to the chunkbuf and reinit */
+                       io.chunkbuf = io.response;
+                       tls_buffer_init(&io.response, 0);
+                       /* and put whatever we've got into the response
+                        * buffer, may not be needed, fill buffer
+                        * can handle it.
+                        */ 
+                       //unchunk(&io);
+               }
+
                do {
-                       write(out, io.response.buffer, io.response.len);
-                       ret = io.response.len;
-                       io.response.len = 0;
+                       size_t before = io.received;
+                       if (io.response.len) {
+                               if (io.content_length && io.response.len + io.received > io.content_length) {
+                                       io.response.len = io.content_length - io.received;
+                                       /* we just ignore trailing garbage */
+                               }
+                               write(out, io.response.buffer, io.response.len);
+                               io.received += io.response.len;
+                               ret = io.response.len;
+                               io.response.len = 0;
+                       }
 
                        if (progressbar) {
                                if (io.content_length) {
-                                       pdots(50, '.', total, total+ret,
+                                       pdots(50, '.', before, io.received,
                                                        io.content_length);
                                } else {
                                        putc('\r', stderr);
-                                       fprintf(stderr, "%zu", total+ret);
+                                       fprintf(stderr, "%zu", io.received);
                                }
                                total += ret;
                        }
                        if (head) {
                                break;
                        }
+                       if (io.content_length && io.received >= io.content_length) {
+                               break;
+                       }
                        ret = fill_buffer(&io);
                } while (ret > 0);
 
+               //fprintf(stderr, "total received: %zu/%zu\n", io.received, io.content_length);
                if (ret < 0) {
                        fprintf(stderr, "%s read error %zd\n", uri.scheme, ret);
                }
-               struct timespec ts[2];
-               ts[0].tv_sec = 0; ts[0].tv_nsec = UTIME_OMIT;
-               ts[1].tv_sec = io.last_modified;
-               ts[1].tv_nsec = 0;
-
-               futimens(out, ts);
+               if (io.last_modified != 0) {
+                       struct timespec ts[2];
+                       ts[0].tv_sec = 0; ts[0].tv_nsec = UTIME_OMIT;
+                       ts[1].tv_sec = io.last_modified;
+                       ts[1].tv_nsec = 0;
+                       futimens(out, ts);
+               }
                close(out);
                tls_buffer_free(&io.response);
                break;
@@ -876,8 +1080,12 @@ int main(int ac, char *av[]) {
 
        close(sockfd);
        if (progressbar && io.status_code == 200) {
-               fprintf(stderr, "(%lu)", total);
-               putc('\n',stderr);
+               if (io.received == io.content_length || io.content_length == 0) {
+                       fprintf(stderr, " done\n");
+               } else if (io.content_length != io.received) {
+                       fprintf(stderr, "failed (%zu bytes read)\n", total);
+                       io.status_code = 531; /* non official code */
+               }
        }
 
        return io.status_code < 400 ? 0 : EXIT_FAILURE;