Fix compiler warning.
[tinc] / src / sptps_test.c
1 /*
2     sptps_test.c -- Simple Peer-to-Peer Security test program
3     Copyright (C) 2011-2022 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #ifdef HAVE_LINUX
23 #include <linux/if_tun.h>
24 #endif
25
26 #include <getopt.h>
27
28 #ifdef HAVE_MINGW
29 #include <pthread.h>
30 #endif
31
32 #include "crypto.h"
33 #include "ecdsa.h"
34 #include "meta.h"
35 #include "protocol.h"
36 #include "sptps.h"
37 #include "utils.h"
38 #include "names.h"
39
40 #ifndef HAVE_MINGW
41 #define closesocket(s) close(s)
42 #endif
43
44 // Symbols necessary to link with logger.o
45 bool send_request(struct connection_t *c, const char *msg, ...) {
46         (void)c;
47         (void)msg;
48         return false;
49 }
50
51 list_t connection_list;
52
53 bool send_meta(struct connection_t *c, const void *msg, size_t len) {
54         (void)c;
55         (void)msg;
56         (void)len;
57         return false;
58 }
59
60 bool do_detach = false;
61 struct timeval now;
62
63 static bool special;
64 static bool verbose;
65 static bool readonly;
66 static bool writeonly;
67 static int in = 0;
68 static int out = 1;
69 int addressfamily = AF_UNSPEC;
70
71 static bool send_data(void *handle, uint8_t type, const void *data, size_t len) {
72         (void)type;
73         char hex[len * 2 + 1];
74         bin2hex(data, hex, len);
75
76         if(verbose) {
77                 fprintf(stderr, "Sending %lu bytes of data:\n%s\n", (unsigned long)len, hex);
78         }
79
80         const int *sock = handle;
81         const char *p = data;
82
83         while(len) {
84                 ssize_t sent = send(*sock, p, len, 0);
85
86                 if(sent <= 0) {
87                         fprintf(stderr, "Error sending data: %s\n", strerror(errno));
88                         return false;
89                 }
90
91                 p += sent;
92                 len -= sent;
93         }
94
95         return true;
96 }
97
98 static bool receive_record(void *handle, uint8_t type, const void *data, uint16_t len) {
99         (void)handle;
100
101         if(verbose) {
102                 fprintf(stderr, "Received type %d record of %u bytes:\n", type, len);
103         }
104
105         if(writeonly) {
106                 return true;
107         }
108
109         const char *p = data;
110
111         while(len) {
112                 ssize_t written = write(out, p, len);
113
114                 if(written <= 0) {
115                         fprintf(stderr, "Error writing received data: %s\n", strerror(errno));
116                         return false;
117                 }
118
119                 p += written;
120                 len -= written;
121         }
122
123         return true;
124 }
125
126 static struct option const long_options[] = {
127         {"datagram", no_argument, NULL, 'd'},
128         {"quit", no_argument, NULL, 'q'},
129         {"readonly", no_argument, NULL, 'r'},
130         {"writeonly", no_argument, NULL, 'w'},
131         {"packet-loss", required_argument, NULL, 'L'},
132         {"replay-window", required_argument, NULL, 'W'},
133         {"special", no_argument, NULL, 's'},
134         {"verbose", required_argument, NULL, 'v'},
135         {"help", no_argument, NULL, 1},
136         {NULL, 0, NULL, 0}
137 };
138
139 static void usage(void) {
140         static const char *message =
141                 "Usage: %s [options] my_ed25519_key_file his_ed25519_key_file [host] port\n"
142                 "\n"
143                 "Valid options are:\n"
144                 "  -d, --datagram          Enable datagram mode.\n"
145                 "  -q, --quit              Quit when EOF occurs on stdin.\n"
146                 "  -r, --readonly          Only send data from the socket to stdout.\n"
147 #ifdef HAVE_LINUX
148                 "  -t, --tun               Use a tun device instead of stdio.\n"
149 #endif
150                 "  -w, --writeonly         Only send data from stdin to the socket.\n"
151                 "  -L, --packet-loss RATE  Fake packet loss of RATE percent.\n"
152                 "  -R, --replay-window N   Set replay window to N bytes.\n"
153                 "  -s, --special           Enable special handling of lines starting with #, ^ and $.\n"
154                 "  -v, --verbose           Display debug messages.\n"
155                 "  -4                      Use IPv4.\n"
156                 "  -6                      Use IPv6.\n"
157                 "\n"
158                 "Report bugs to tinc@tinc-vpn.org.\n";
159
160         fprintf(stderr, message, program_name);
161 }
162
163 #ifdef HAVE_MINGW
164
165 int stdin_sock_fd = -1;
166
167 // Windows does not allow calling select() on anything but sockets. Therefore,
168 // to keep the same code as on other operating systems, we have to put a
169 // separate thread between the stdin and the sptps loop way below. This thread
170 // reads stdin and sends its content to the main thread through a TCP socket,
171 // which can be properly select()'ed.
172 static void *stdin_reader_thread(void *arg) {
173         struct sockaddr_in sa;
174         socklen_t sa_size = sizeof(sa);
175
176         while(true) {
177                 int peer_fd = accept(stdin_sock_fd, (struct sockaddr *) &sa, &sa_size);
178
179                 if(peer_fd < 0) {
180                         fprintf(stderr, "accept() failed: %s\n", strerror(errno));
181                         continue;
182                 }
183
184                 if(verbose) {
185                         fprintf(stderr, "New connection received from :%d\n", ntohs(sa.sin_port));
186                 }
187
188                 char buf[1024];
189                 ssize_t nread;
190
191                 while((nread = read(STDIN_FILENO, buf, sizeof(buf))) > 0) {
192                         if(verbose) {
193                                 fprintf(stderr, "Read %lld bytes from input\n", nread);
194                         }
195
196                         char *start = buf;
197                         ssize_t nleft = nread;
198
199                         while(nleft) {
200                                 ssize_t nsend = send(peer_fd, start, nleft, 0);
201
202                                 if(nsend < 0) {
203                                         if(sockwouldblock(sockerrno)) {
204                                                 continue;
205                                         }
206
207                                         break;
208                                 }
209
210                                 start += nsend;
211                                 nleft -= nsend;
212                         }
213
214                         if(nleft) {
215                                 fprintf(stderr, "Could not send data: %s\n", strerror(errno));
216                                 break;
217                         }
218
219                         if(verbose) {
220                                 fprintf(stderr, "Sent %lld bytes to peer\n", nread);
221                         }
222                 }
223
224                 closesocket(peer_fd);
225         }
226
227         closesocket(stdin_sock_fd);
228         stdin_sock_fd = -1;
229         return NULL;
230 }
231
232 static int start_input_reader(void) {
233         if(stdin_sock_fd != -1) {
234                 fprintf(stderr, "stdin thread can only be started once.\n");
235                 return -1;
236         }
237
238         stdin_sock_fd = socket(AF_INET, SOCK_STREAM, 0);
239
240         if(stdin_sock_fd < 0) {
241                 fprintf(stderr, "Could not create server socket: %s\n", strerror(errno));
242                 return -1;
243         }
244
245         struct sockaddr_in serv_sa;
246
247         memset(&serv_sa, 0, sizeof(serv_sa));
248
249         serv_sa.sin_family = AF_INET;
250
251         serv_sa.sin_addr.s_addr = htonl(0x7f000001); // 127.0.0.1
252
253         int res = bind(stdin_sock_fd, (struct sockaddr *)&serv_sa, sizeof(serv_sa));
254
255         if(res < 0) {
256                 fprintf(stderr, "Could not bind socket: %s\n", strerror(errno));
257                 goto server_err;
258         }
259
260         if(listen(stdin_sock_fd, 1) < 0) {
261                 fprintf(stderr, "Could not listen: %s\n", strerror(errno));
262                 goto server_err;
263         }
264
265         struct sockaddr_in connect_sa;
266
267         socklen_t addr_len = sizeof(connect_sa);
268
269         if(getsockname(stdin_sock_fd, (struct sockaddr *)&connect_sa, &addr_len) < 0) {
270                 fprintf(stderr, "Could not determine the address of the stdin thread socket\n");
271                 goto server_err;
272         }
273
274         if(verbose) {
275                 fprintf(stderr, "stdin thread is listening on :%d\n", ntohs(connect_sa.sin_port));
276         }
277
278         pthread_t th;
279         int err = pthread_create(&th, NULL, stdin_reader_thread, NULL);
280
281         if(err) {
282                 fprintf(stderr, "Could not start reader thread: %s\n", strerror(err));
283                 goto server_err;
284         }
285
286         int client_fd = socket(AF_INET, SOCK_STREAM, 0);
287
288         if(client_fd < 0) {
289                 fprintf(stderr, "Could not create client socket: %s\n", strerror(errno));
290                 return -1;
291         }
292
293         if(connect(client_fd, (struct sockaddr *)&connect_sa, sizeof(connect_sa)) < 0) {
294                 fprintf(stderr, "Could not connect: %s\n", strerror(errno));
295                 closesocket(client_fd);
296                 return -1;
297         }
298
299         return client_fd;
300
301 server_err:
302
303         if(stdin_sock_fd != -1) {
304                 closesocket(stdin_sock_fd);
305                 stdin_sock_fd = -1;
306         }
307
308         return -1;
309 }
310
311 #endif // HAVE_MINGW
312
313 int main(int argc, char *argv[]) {
314         program_name = argv[0];
315         bool initiator = false;
316         bool datagram = false;
317 #ifdef HAVE_LINUX
318         bool tun = false;
319 #endif
320         int packetloss = 0;
321         int r;
322         int option_index = 0;
323         bool quit = false;
324
325         while((r = getopt_long(argc, argv, "dqrstwL:W:v46", long_options, &option_index)) != EOF) {
326                 switch(r) {
327                 case 0:   /* long option */
328                         break;
329
330                 case 'd': /* datagram mode */
331                         datagram = true;
332                         break;
333
334                 case 'q': /* close connection on EOF from stdin */
335                         quit = true;
336                         break;
337
338                 case 'r': /* read only */
339                         readonly = true;
340                         break;
341
342                 case 't': /* read only */
343 #ifdef HAVE_LINUX
344                         tun = true;
345 #else
346                         fprintf(stderr, "--tun is only supported on Linux.\n");
347                         usage();
348                         return 1;
349 #endif
350                         break;
351
352                 case 'w': /* write only */
353                         writeonly = true;
354                         break;
355
356                 case 'L': /* packet loss rate */
357                         packetloss = atoi(optarg);
358                         break;
359
360                 case 'W': /* replay window size */
361                         sptps_replaywin = atoi(optarg);
362                         break;
363
364                 case 'v': /* be verbose */
365                         verbose = true;
366                         break;
367
368                 case 's': /* special character handling */
369                         special = true;
370                         break;
371
372                 case '?': /* wrong options */
373                         usage();
374                         return 1;
375
376                 case '4': /* IPv4 */
377                         addressfamily = AF_INET;
378                         break;
379
380                 case '6': /* IPv6 */
381                         addressfamily = AF_INET6;
382                         break;
383
384                 case 1: /* help */
385                         usage();
386                         return 0;
387
388                 default:
389                         break;
390                 }
391         }
392
393         argc -= optind - 1;
394         argv += optind - 1;
395
396         if(argc < 4 || argc > 5) {
397                 fprintf(stderr, "Wrong number of arguments.\n");
398                 usage();
399                 return 1;
400         }
401
402         if(argc > 4) {
403                 initiator = true;
404         }
405
406 #ifdef HAVE_LINUX
407
408         if(tun) {
409                 in = out = open("/dev/net/tun", O_RDWR | O_NONBLOCK);
410
411                 if(in < 0) {
412                         fprintf(stderr, "Could not open tun device: %s\n", strerror(errno));
413                         return 1;
414                 }
415
416                 struct ifreq ifr = {
417                         .ifr_flags = IFF_TUN
418                 };
419
420                 if(ioctl(in, TUNSETIFF, &ifr)) {
421                         fprintf(stderr, "Could not configure tun interface: %s\n", strerror(errno));
422                         return 1;
423                 }
424
425                 ifr.ifr_name[IFNAMSIZ - 1] = 0;
426                 fprintf(stderr, "Using tun interface %s\n", ifr.ifr_name);
427         }
428
429 #endif
430
431 #ifdef HAVE_MINGW
432         static struct WSAData wsa_state;
433
434         if(WSAStartup(MAKEWORD(2, 2), &wsa_state)) {
435                 return 1;
436         }
437
438 #endif
439
440         struct addrinfo *ai, hint;
441         memset(&hint, 0, sizeof(hint));
442
443         hint.ai_family = addressfamily;
444         hint.ai_socktype = datagram ? SOCK_DGRAM : SOCK_STREAM;
445         hint.ai_protocol = datagram ? IPPROTO_UDP : IPPROTO_TCP;
446         hint.ai_flags = initiator ? 0 : AI_PASSIVE;
447
448         if(getaddrinfo(initiator ? argv[3] : NULL, initiator ? argv[4] : argv[3], &hint, &ai) || !ai) {
449                 fprintf(stderr, "getaddrinfo() failed: %s\n", sockstrerror(sockerrno));
450                 return 1;
451         }
452
453         int sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
454
455         if(sock < 0) {
456                 fprintf(stderr, "Could not create socket: %s\n", sockstrerror(sockerrno));
457                 freeaddrinfo(ai);
458                 return 1;
459         }
460
461         int one = 1;
462         setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (void *)&one, sizeof(one));
463
464         if(initiator) {
465                 int res = connect(sock, ai->ai_addr, ai->ai_addrlen);
466
467                 freeaddrinfo(ai);
468                 ai = NULL;
469
470                 if(res) {
471                         fprintf(stderr, "Could not connect to peer: %s\n", sockstrerror(sockerrno));
472                         return 1;
473                 }
474
475                 fprintf(stderr, "Connected\n");
476         } else {
477                 int res = bind(sock, ai->ai_addr, ai->ai_addrlen);
478
479                 freeaddrinfo(ai);
480                 ai = NULL;
481
482                 if(res) {
483                         fprintf(stderr, "Could not bind socket: %s\n", sockstrerror(sockerrno));
484                         return 1;
485                 }
486
487                 if(!datagram) {
488                         if(listen(sock, 1)) {
489                                 fprintf(stderr, "Could not listen on socket: %s\n", sockstrerror(sockerrno));
490                                 return 1;
491                         }
492
493                         fprintf(stderr, "Listening...\n");
494
495                         sock = accept(sock, NULL, NULL);
496
497                         if(sock < 0) {
498                                 fprintf(stderr, "Could not accept connection: %s\n", sockstrerror(sockerrno));
499                                 return 1;
500                         }
501                 } else {
502                         fprintf(stderr, "Listening...\n");
503
504                         char buf[65536];
505                         struct sockaddr addr;
506                         socklen_t addrlen = sizeof(addr);
507
508                         if(recvfrom(sock, buf, sizeof(buf), MSG_PEEK, &addr, &addrlen) <= 0) {
509                                 fprintf(stderr, "Could not read from socket: %s\n", sockstrerror(sockerrno));
510                                 return 1;
511                         }
512
513                         if(connect(sock, &addr, addrlen)) {
514                                 fprintf(stderr, "Could not accept connection: %s\n", sockstrerror(sockerrno));
515                                 return 1;
516                         }
517                 }
518
519                 fprintf(stderr, "Connected\n");
520         }
521
522         crypto_init();
523         prng_init();
524
525         FILE *fp = fopen(argv[1], "r");
526
527         if(!fp) {
528                 fprintf(stderr, "Could not open %s: %s\n", argv[1], strerror(errno));
529                 return 1;
530         }
531
532         ecdsa_t *mykey = NULL;
533
534         if(!(mykey = ecdsa_read_pem_private_key(fp))) {
535                 return 1;
536         }
537
538         fclose(fp);
539
540         fp = fopen(argv[2], "r");
541
542         if(!fp) {
543                 fprintf(stderr, "Could not open %s: %s\n", argv[2], strerror(errno));
544                 free(mykey);
545                 return 1;
546         }
547
548         ecdsa_t *hiskey = NULL;
549
550         if(!(hiskey = ecdsa_read_pem_public_key(fp))) {
551                 free(mykey);
552                 return 1;
553         }
554
555         fclose(fp);
556
557         if(verbose) {
558                 fprintf(stderr, "Keys loaded\n");
559         }
560
561         sptps_t s;
562
563         if(!sptps_start(&s, &sock, initiator, datagram, mykey, hiskey, "sptps_test", 10, send_data, receive_record)) {
564                 free(mykey);
565                 free(hiskey);
566                 return 1;
567         }
568
569 #ifdef HAVE_MINGW
570
571         if(!readonly) {
572                 in = start_input_reader();
573
574                 if(in < 0) {
575                         fprintf(stderr, "Could not init stdin reader thread\n");
576                         free(mykey);
577                         free(hiskey);
578                         return 1;
579                 }
580         }
581
582 #endif
583
584         int max_fd = MAX(sock, in);
585
586         while(true) {
587                 if(writeonly && readonly) {
588                         break;
589                 }
590
591                 char buf[65535] = "";
592                 size_t readsize = datagram ? 1460u : sizeof(buf);
593
594                 fd_set fds;
595                 FD_ZERO(&fds);
596
597                 if(!readonly && s.instate) {
598                         FD_SET(in, &fds);
599                 }
600
601                 FD_SET(sock, &fds);
602
603                 if(select(max_fd + 1, &fds, NULL, NULL, NULL) <= 0) {
604                         free(mykey);
605                         free(hiskey);
606                         return 1;
607                 }
608
609                 if(FD_ISSET(in, &fds)) {
610 #ifdef HAVE_MINGW
611                         ssize_t len = recv(in, buf, readsize, 0);
612 #else
613                         ssize_t len = read(in, buf, readsize);
614 #endif
615
616                         if(len < 0) {
617                                 fprintf(stderr, "Could not read from stdin: %s\n", strerror(errno));
618                                 free(mykey);
619                                 free(hiskey);
620                                 return 1;
621                         }
622
623                         if(len == 0) {
624 #ifdef HAVE_MINGW
625                                 shutdown(in, SD_SEND);
626                                 closesocket(in);
627 #endif
628
629                                 if(quit) {
630                                         break;
631                                 }
632
633                                 readonly = true;
634                                 continue;
635                         }
636
637                         if(special && buf[0] == '#') {
638                                 s.outseqno = atoi(buf + 1);
639                         }
640
641                         if(special && buf[0] == '^') {
642                                 sptps_send_record(&s, SPTPS_HANDSHAKE, NULL, 0);
643                         } else if(special && buf[0] == '$') {
644                                 sptps_force_kex(&s);
645
646                                 if(len > 1) {
647                                         sptps_send_record(&s, 0, buf, len);
648                                 }
649                         } else if(!sptps_send_record(&s, buf[0] == '!' ? 1 : 0, buf, (len == 1 && buf[0] == '\n') ? 0 : buf[0] == '*' ? sizeof(buf) : (size_t)len)) {
650                                 free(mykey);
651                                 free(hiskey);
652                                 return 1;
653                         }
654                 }
655
656                 if(FD_ISSET(sock, &fds)) {
657                         ssize_t len = recv(sock, buf, sizeof(buf), 0);
658
659                         if(len < 0) {
660                                 fprintf(stderr, "Could not read from socket: %s\n", sockstrerror(sockerrno));
661                                 free(mykey);
662                                 free(hiskey);
663                                 return 1;
664                         }
665
666                         if(len == 0) {
667                                 fprintf(stderr, "Connection terminated by peer.\n");
668                                 break;
669                         }
670
671                         if(verbose) {
672                                 char hex[len * 2 + 1];
673                                 bin2hex(buf, hex, len);
674                                 fprintf(stderr, "Received %ld bytes of data:\n%s\n", (long)len, hex);
675                         }
676
677                         if(packetloss && (int)prng(100) < packetloss) {
678                                 if(verbose) {
679                                         fprintf(stderr, "Dropped.\n");
680                                 }
681
682                                 continue;
683                         }
684
685                         char *bufp = buf;
686
687                         while(len) {
688                                 size_t done = sptps_receive_data(&s, bufp, len);
689
690                                 if(!done) {
691                                         if(!datagram) {
692                                                 free(mykey);
693                                                 free(hiskey);
694                                                 return 1;
695                                         }
696                                 }
697
698                                 bufp += done;
699                                 len -= (ssize_t) done;
700                         }
701                 }
702         }
703
704         bool stopped = sptps_stop(&s);
705
706         free(mykey);
707         free(hiskey);
708
709         if(!stopped) {
710                 return 1;
711         }
712
713         closesocket(sock);
714
715         return 0;
716 }