From: Kirill Isakov Date: Sat, 17 Jul 2021 12:17:11 +0000 (+0600) Subject: Allow running sptps_test on Windows X-Git-Url: https://git.tinc-vpn.org/git/browse?a=commitdiff_plain;h=ada609f3ab838fdcb522de54510c414452be5950;p=tinc Allow running sptps_test on Windows On Windows, you're not supposed to call select() on anything except proper BSD sockets, so we can't reuse the same select() loop that's been working fine on every other operating system. This is a hack which reads stdin in a separate thread and pushes data to the main through a TCP socket, which can then be used with select() instead of reading stdin directly. --- diff --git a/src/Makefile.am b/src/Makefile.am index 12261fff..4b266a73 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -272,6 +272,10 @@ tincd_LDADD = $(MINIUPNPC_LIBS) tincd_LDFLAGS = -pthread endif +if MINGW +sptps_test_LDFLAGS = -pthread +endif + tinc_LDADD = $(READLINE_LIBS) $(CURSES_LIBS) sptps_speed_LDADD = -lrt diff --git a/src/sptps_test.c b/src/sptps_test.c index ea81a28d..87f9b51b 100644 --- a/src/sptps_test.c +++ b/src/sptps_test.c @@ -25,11 +25,23 @@ #include +#ifdef HAVE_MINGW +#include +#endif + #include "crypto.h" #include "ecdsa.h" #include "sptps.h" #include "utils.h" +#ifndef MAX +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +#ifndef HAVE_MINGW +#define closesocket(s) close(s) +#endif + // Symbols necessary to link with logger.o bool send_request(void *c, const char *msg, ...) { (void)c; @@ -125,6 +137,155 @@ static void usage() { fprintf(stderr, "Report bugs to tinc@tinc-vpn.org.\n"); } +#ifdef HAVE_MINGW + +int stdin_sock_fd = -1; + +// Windows does not allow calling select() on anything but sockets. Therefore, +// to keep the same code as on other operating systems, we have to put a +// separate thread between the stdin and the sptps loop way below. This thread +// reads stdin and sends its content to the main thread through a TCP socket, +// which can be properly select()'ed. +void *stdin_reader_thread(void *arg) { + struct sockaddr_in sa; + socklen_t sa_size = sizeof(sa); + + while(true) { + int peer_fd = accept(stdin_sock_fd, (struct sockaddr *) &sa, &sa_size); + + if(peer_fd < 0) { + fprintf(stderr, "accept() failed: %s\n", strerror(errno)); + continue; + } + + if(verbose) { + fprintf(stderr, "New connection received from :%d\n", ntohs(sa.sin_port)); + } + + char buf[1024]; + ssize_t nread; + + while((nread = read(STDIN_FILENO, buf, sizeof(buf))) > 0) { + if(verbose) { + fprintf(stderr, "Read %lld bytes from input\n", nread); + } + + char *start = buf; + ssize_t nleft = nread; + + while(nleft) { + ssize_t nsend = send(peer_fd, start, nleft, 0); + + if(nsend < 0) { + if(sockwouldblock(sockerrno)) { + continue; + } + + break; + } + + start += nsend; + nleft -= nsend; + } + + if(nleft) { + fprintf(stderr, "Could not send data: %s\n", strerror(errno)); + break; + } + + if(verbose) { + fprintf(stderr, "Sent %lld bytes to peer\n", nread); + } + } + + closesocket(peer_fd); + } + + closesocket(stdin_sock_fd); + stdin_sock_fd = -1; +} + +int start_input_reader() { + if(stdin_sock_fd != -1) { + fprintf(stderr, "stdin thread can only be started once.\n"); + return -1; + } + + stdin_sock_fd = socket(AF_INET, SOCK_STREAM, 0); + + if(stdin_sock_fd < 0) { + fprintf(stderr, "Could not create server socket: %s\n", strerror(errno)); + return -1; + } + + struct sockaddr_in serv_sa; + + memset(&serv_sa, 0, sizeof(serv_sa)); + + serv_sa.sin_family = AF_INET; + + serv_sa.sin_addr.s_addr = htonl(0x7f000001); // 127.0.0.1 + + int res = bind(stdin_sock_fd, (struct sockaddr *)&serv_sa, sizeof(serv_sa)); + + if(res < 0) { + fprintf(stderr, "Could not bind socket: %s\n", strerror(errno)); + goto server_err; + } + + if(listen(stdin_sock_fd, 1) < 0) { + fprintf(stderr, "Could not listen: %s\n", strerror(errno)); + goto server_err; + } + + struct sockaddr_in connect_sa; + + socklen_t addr_len = sizeof(connect_sa); + + if(getsockname(stdin_sock_fd, (struct sockaddr *)&connect_sa, &addr_len) < 0) { + fprintf(stderr, "Could not determine the address of the stdin thread socket\n"); + goto server_err; + } + + if(verbose) { + fprintf(stderr, "stdin thread is listening on :%d\n", ntohs(connect_sa.sin_port)); + } + + pthread_t th; + int err = pthread_create(&th, NULL, stdin_reader_thread, NULL); + + if(err) { + fprintf(stderr, "Could not start reader thread: %s\n", strerror(err)); + goto server_err; + } + + int client_fd = socket(AF_INET, SOCK_STREAM, 0); + + if(client_fd < 0) { + fprintf(stderr, "Could not create client socket: %s\n", strerror(errno)); + return -1; + } + + if(connect(client_fd, (struct sockaddr *)&connect_sa, sizeof(connect_sa)) < 0) { + fprintf(stderr, "Could not connect: %s\n", strerror(errno)); + closesocket(client_fd); + return -1; + } + + return client_fd; + +server_err: + + if(stdin_sock_fd != -1) { + closesocket(stdin_sock_fd); + stdin_sock_fd = -1; + } + + return -1; +} + +#endif // HAVE_MINGW + int main(int argc, char *argv[]) { program_name = argv[0]; bool initiator = false; @@ -219,7 +380,7 @@ int main(int argc, char *argv[]) { initiator = true; } - srand(time(NULL)); + srand(getpid()); #ifdef HAVE_LINUX @@ -364,6 +525,21 @@ int main(int argc, char *argv[]) { return 1; } +#ifdef HAVE_MINGW + + if(!readonly) { + in = start_input_reader(); + + if(in < 0) { + fprintf(stderr, "Could not init stdin reader thread\n"); + return 1; + } + } + +#endif + + int max_fd = MAX(sock, in); + while(true) { if(writeonly && readonly) { break; @@ -374,21 +550,23 @@ int main(int argc, char *argv[]) { fd_set fds; FD_ZERO(&fds); -#ifndef HAVE_MINGW if(!readonly && s.instate) { FD_SET(in, &fds); } -#endif FD_SET(sock, &fds); - if(select(sock + 1, &fds, NULL, NULL, NULL) <= 0) { + if(select(max_fd + 1, &fds, NULL, NULL, NULL) <= 0) { return 1; } if(FD_ISSET(in, &fds)) { +#ifdef HAVE_MINGW + ssize_t len = recv(in, buf, readsize, 0); +#else ssize_t len = read(in, buf, readsize); +#endif if(len < 0) { fprintf(stderr, "Could not read from stdin: %s\n", strerror(errno)); @@ -396,6 +574,11 @@ int main(int argc, char *argv[]) { } if(len == 0) { +#ifdef HAVE_MINGW + shutdown(in, SD_SEND); + closesocket(in); +#endif + if(quit) { break; } @@ -469,5 +652,7 @@ int main(int argc, char *argv[]) { return 1; } + closesocket(sock); + return 0; }