xref: /trueos/usr.bin/csup/mux.c (revision bcd0e15cf642d6e5bf78ee585ad282b0e3061864)
1 /*-
2  * Copyright (c) 2003-2006, Maxime Henrion <mux@FreeBSD.org>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24  * SUCH DAMAGE.
25  *
26  * $FreeBSD$
27  */
28 
29 #include <sys/param.h>
30 #include <sys/socket.h>
31 #include <sys/uio.h>
32 
33 #include <netinet/in.h>
34 
35 #include <assert.h>
36 #include <errno.h>
37 #include <pthread.h>
38 #include <stdarg.h>
39 #include <stdio.h>
40 #include <stdlib.h>
41 #include <string.h>
42 #include <unistd.h>
43 
44 #include "misc.h"
45 #include "mux.h"
46 
47 /*
48  * Packet types.
49  */
50 #define	MUX_STARTUPREQ		0
51 #define	MUX_STARTUPREP		1
52 #define	MUX_CONNECT		2
53 #define	MUX_ACCEPT		3
54 #define	MUX_RESET		4
55 #define	MUX_DATA		5
56 #define	MUX_WINDOW		6
57 #define	MUX_CLOSE		7
58 
59 /*
60  * Header sizes.
61  */
62 #define	MUX_STARTUPHDRSZ	3
63 #define	MUX_CONNECTHDRSZ	8
64 #define	MUX_ACCEPTHDRSZ		8
65 #define	MUX_RESETHDRSZ		2
66 #define	MUX_DATAHDRSZ		4
67 #define	MUX_WINDOWHDRSZ		6
68 #define	MUX_CLOSEHDRSZ		2
69 
70 #define	MUX_PROTOVER		0		/* Protocol version. */
71 
72 struct mux_header {
73 	uint8_t type;
74 	union {
75 		struct {
76 			uint16_t version;
77 		} __packed mh_startup;
78 		struct {
79 			uint8_t id;
80 			uint16_t mss;
81 			uint32_t window;
82 		} __packed mh_connect;
83 		struct {
84 			uint8_t id;
85 			uint16_t mss;
86 			uint32_t window;
87 		} __packed mh_accept;
88 		struct {
89 			uint8_t id;
90 		} __packed mh_reset;
91 		struct {
92 			uint8_t id;
93 			uint16_t len;
94 		} __packed mh_data;
95 		struct {
96 			uint8_t id;
97 			uint32_t window;
98 		} __packed mh_window;
99 		struct {
100 			uint8_t id;
101 		} __packed mh_close;
102 	} mh_u;
103 } __packed;
104 
105 #define	mh_startup		mh_u.mh_startup
106 #define	mh_connect		mh_u.mh_connect
107 #define	mh_accept		mh_u.mh_accept
108 #define	mh_reset		mh_u.mh_reset
109 #define	mh_data			mh_u.mh_data
110 #define	mh_window		mh_u.mh_window
111 #define	mh_close		mh_u.mh_close
112 
113 #define	MUX_MAXCHAN		2
114 
115 /* Channel states. */
116 #define	CS_UNUSED		0
117 #define	CS_LISTENING		1
118 #define	CS_CONNECTING		2
119 #define	CS_ESTABLISHED		3
120 #define	CS_RDCLOSED		4
121 #define	CS_WRCLOSED		5
122 #define	CS_CLOSED		6
123 
124 /* Channel flags. */
125 #define	CF_CONNECT		0x01
126 #define	CF_ACCEPT		0x02
127 #define	CF_RESET		0x04
128 #define	CF_WINDOW		0x08
129 #define	CF_DATA			0x10
130 #define	CF_CLOSE		0x20
131 
132 #define	CHAN_SBSIZE		(16 * 1024)	/* Send buffer size. */
133 #define	CHAN_RBSIZE		(16 * 1024)	/* Receive buffer size. */
134 #define	CHAN_MAXSEGSIZE		1024		/* Maximum segment size. */
135 
136 /* Circular buffer. */
137 struct buf {
138 	uint8_t *data;
139 	size_t size;
140 	size_t in;
141 	size_t out;
142 };
143 
144 struct chan {
145 	int		flags;
146 	int		state;
147 	pthread_mutex_t	lock;
148 	struct mux	*mux;
149 
150 	/* Receiver state variables. */
151 	struct buf	*recvbuf;
152 	pthread_cond_t	rdready;
153 	uint32_t	recvseq;
154 	uint16_t	recvmss;
155 
156 	/* Sender state variables. */
157 	struct buf	*sendbuf;
158 	pthread_cond_t	wrready;
159 	uint32_t	sendseq;
160 	uint32_t	sendwin;
161 	uint16_t	sendmss;
162 };
163 
164 struct mux {
165 	int		closed;
166 	int		status;
167 	int		socket;
168 	pthread_mutex_t	lock;
169 	pthread_cond_t	done;
170 	struct chan	*channels[MUX_MAXCHAN];
171 	int		nchans;
172 
173 	/* Sender thread data. */
174 	pthread_t	sender;
175 	pthread_cond_t	sender_newwork;
176 	pthread_cond_t	sender_started;
177 	int		sender_waiting;
178 	int		sender_ready;
179 	int		sender_lastid;
180 
181 	/* Receiver thread data. */
182 	pthread_t	receiver;
183 };
184 
185 static int		 sock_writev(int, struct iovec *, int);
186 static int		 sock_write(int, void *, size_t);
187 static ssize_t		 sock_read(int, void *, size_t);
188 static int		 sock_readwait(int, void *, size_t);
189 
190 static int		 mux_init(struct mux *);
191 static void		 mux_lock(struct mux *);
192 static void		 mux_unlock(struct mux *);
193 
194 static struct chan	*chan_new(struct mux *);
195 static struct chan	*chan_get(struct mux *, int);
196 static struct chan	*chan_connect(struct mux *, int);
197 static void		 chan_lock(struct chan *);
198 static void		 chan_unlock(struct chan *);
199 static int		 chan_insert(struct mux *, struct chan *);
200 static void		 chan_free(struct chan *);
201 
202 static struct buf	*buf_new(size_t);
203 static size_t		 buf_count(struct buf *);
204 static size_t		 buf_avail(struct buf *);
205 static void		 buf_get(struct buf *, void *, size_t);
206 static void		 buf_put(struct buf *, const void *, size_t);
207 static void		 buf_free(struct buf *);
208 
209 static void		 sender_wakeup(struct mux *);
210 static void		*sender_loop(void *);
211 static int		 sender_waitforwork(struct mux *, int *);
212 static int		 sender_scan(struct mux *, int *);
213 static void		 sender_cleanup(void *);
214 
215 static void		*receiver_loop(void *);
216 
217 static int
sock_writev(int s,struct iovec * iov,int iovcnt)218 sock_writev(int s, struct iovec *iov, int iovcnt)
219 {
220 	ssize_t nbytes;
221 
222 again:
223 	nbytes = writev(s, iov, iovcnt);
224 	if (nbytes != -1) {
225 		while (nbytes > 0 && (size_t)nbytes >= iov->iov_len) {
226 			nbytes -= iov->iov_len;
227 			iov++;
228 			iovcnt--;
229 		}
230 		if (nbytes == 0)
231 			return (0);
232 		iov->iov_len -= nbytes;
233 		iov->iov_base = (char *)iov->iov_base + nbytes;
234 	} else if (errno != EINTR) {
235 		return (-1);
236 	}
237 	goto again;
238 }
239 
240 static int
sock_write(int s,void * buf,size_t size)241 sock_write(int s, void *buf, size_t size)
242 {
243 	struct iovec iov;
244 	int ret;
245 
246 	iov.iov_base = buf;
247 	iov.iov_len = size;
248 	ret = sock_writev(s, &iov, 1);
249 	return (ret);
250 }
251 
252 static ssize_t
sock_read(int s,void * buf,size_t size)253 sock_read(int s, void *buf, size_t size)
254 {
255 	ssize_t nbytes;
256 
257 again:
258 	nbytes = read(s, buf, size);
259 	if (nbytes == -1 && errno == EINTR)
260 		goto again;
261 	return (nbytes);
262 }
263 
264 static int
sock_readwait(int s,void * buf,size_t size)265 sock_readwait(int s, void *buf, size_t size)
266 {
267 	char *cp;
268 	ssize_t nbytes;
269 	size_t left;
270 
271 	cp = buf;
272 	left = size;
273 	while (left > 0) {
274 		nbytes = sock_read(s, cp, left);
275 		if (nbytes == 0) {
276 			errno = ECONNRESET;
277 			return (-1);
278 		}
279 		if (nbytes < 0)
280 			return (-1);
281 		left -= nbytes;
282 		cp += nbytes;
283 	}
284 	return (0);
285 }
286 
287 static void
mux_lock(struct mux * m)288 mux_lock(struct mux *m)
289 {
290 	int error;
291 
292 	error = pthread_mutex_lock(&m->lock);
293 	assert(!error);
294 }
295 
296 static void
mux_unlock(struct mux * m)297 mux_unlock(struct mux *m)
298 {
299 	int error;
300 
301 	error = pthread_mutex_unlock(&m->lock);
302 	assert(!error);
303 }
304 
305 /* Create a TCP multiplexer on the given socket. */
306 struct mux *
mux_open(int sock,struct chan ** chan)307 mux_open(int sock, struct chan **chan)
308 {
309 	struct mux *m;
310 	struct chan *chan0;
311 	int error;
312 
313 	m = xmalloc(sizeof(struct mux));
314 	memset(m->channels, 0, sizeof(m->channels));
315 	m->nchans = 0;
316 	m->closed = 0;
317 	m->status = -1;
318 	m->socket = sock;
319 
320 	m->sender_waiting = 0;
321 	m->sender_lastid = 0;
322 	m->sender_ready = 0;
323 	pthread_mutex_init(&m->lock, NULL);
324 	pthread_cond_init(&m->done, NULL);
325 	pthread_cond_init(&m->sender_newwork, NULL);
326 	pthread_cond_init(&m->sender_started, NULL);
327 
328 	error = mux_init(m);
329 	if (error)
330 		goto bad;
331 	chan0 = chan_connect(m, 0);
332 	if (chan0 == NULL)
333 		goto bad;
334 	*chan = chan0;
335 	return (m);
336 bad:
337 	mux_shutdown(m, NULL, STATUS_FAILURE);
338 	(void)mux_close(m);
339 	return (NULL);
340 }
341 
342 int
mux_close(struct mux * m)343 mux_close(struct mux *m)
344 {
345 	struct chan *chan;
346 	int i, status;
347 
348 	assert(m->closed);
349 	for (i = 0; i < m->nchans; i++) {
350 		chan = m->channels[i];
351 		if (chan != NULL)
352 			chan_free(chan);
353 	}
354 	pthread_cond_destroy(&m->sender_started);
355 	pthread_cond_destroy(&m->sender_newwork);
356 	pthread_cond_destroy(&m->done);
357 	pthread_mutex_destroy(&m->lock);
358 	status = m->status;
359 	free(m);
360 	return (status);
361 }
362 
363 /* Close a channel. */
364 int
chan_close(struct chan * chan)365 chan_close(struct chan *chan)
366 {
367 
368 	chan_lock(chan);
369 	if (chan->state == CS_ESTABLISHED) {
370 		chan->state = CS_WRCLOSED;
371 		chan->flags |= CF_CLOSE;
372 	} else if (chan->state == CS_RDCLOSED) {
373 		chan->state = CS_CLOSED;
374 		chan->flags |= CF_CLOSE;
375 	} else if (chan->state == CS_WRCLOSED || chan->state == CS_CLOSED) {
376 		chan_unlock(chan);
377 		return (0);
378 	} else {
379 		chan_unlock(chan);
380 		return (-1);
381 	}
382 	chan_unlock(chan);
383 	sender_wakeup(chan->mux);
384 	return (0);
385 }
386 
387 void
chan_wait(struct chan * chan)388 chan_wait(struct chan *chan)
389 {
390 
391 	chan_lock(chan);
392 	while (chan->state != CS_CLOSED)
393 		pthread_cond_wait(&chan->rdready, &chan->lock);
394 	chan_unlock(chan);
395 }
396 
397 /* Returns the ID of an available channel in the listening state. */
398 int
chan_listen(struct mux * m)399 chan_listen(struct mux *m)
400 {
401 	struct chan *chan;
402 	int i;
403 
404 	mux_lock(m);
405 	for (i = 0; i < m->nchans; i++) {
406 		chan = m->channels[i];
407 		chan_lock(chan);
408 		if (chan->state == CS_UNUSED) {
409 			mux_unlock(m);
410 			chan->state = CS_LISTENING;
411 			chan_unlock(chan);
412 			return (i);
413 		}
414 		chan_unlock(chan);
415 	}
416 	mux_unlock(m);
417 	chan = chan_new(m);
418 	chan->state = CS_LISTENING;
419 	i = chan_insert(m, chan);
420 	if (i == -1)
421 		chan_free(chan);
422 	return (i);
423 }
424 
425 struct chan *
chan_accept(struct mux * m,int id)426 chan_accept(struct mux *m, int id)
427 {
428 	struct chan *chan;
429 
430 	chan = chan_get(m, id);
431 	while (chan->state == CS_LISTENING)
432 		pthread_cond_wait(&chan->rdready, &chan->lock);
433 	if (chan->state != CS_ESTABLISHED) {
434 		errno = ECONNRESET;
435 		chan_unlock(chan);
436 		return (NULL);
437 	}
438 	chan_unlock(chan);
439 	return (chan);
440 }
441 
442 /* Read bytes from a channel. */
443 ssize_t
chan_read(struct chan * chan,void * buf,size_t size)444 chan_read(struct chan *chan, void *buf, size_t size)
445 {
446 	char *cp;
447 	size_t count, n;
448 
449 	cp = buf;
450 	chan_lock(chan);
451 	for (;;) {
452 		if (chan->state == CS_RDCLOSED || chan->state == CS_CLOSED) {
453 			chan_unlock(chan);
454 			return (0);
455 		}
456 		if (chan->state != CS_ESTABLISHED &&
457 		    chan->state != CS_WRCLOSED) {
458 			chan_unlock(chan);
459 			errno = EBADF;
460 			return (-1);
461 		}
462 		count = buf_count(chan->recvbuf);
463 		if (count > 0)
464 			break;
465 		pthread_cond_wait(&chan->rdready, &chan->lock);
466 	}
467 	n = min(count, size);
468 	buf_get(chan->recvbuf, cp, n);
469 	chan->recvseq += n;
470 	chan->flags |= CF_WINDOW;
471 	chan_unlock(chan);
472 	/* We need to wake up the sender so that it sends a window update. */
473 	sender_wakeup(chan->mux);
474 	return (n);
475 }
476 
477 /* Write bytes to a channel. */
478 ssize_t
chan_write(struct chan * chan,const void * buf,size_t size)479 chan_write(struct chan *chan, const void *buf, size_t size)
480 {
481 	const char *cp;
482 	size_t avail, n, pos;
483 
484 	pos = 0;
485 	cp = buf;
486 	chan_lock(chan);
487 	while (pos < size) {
488 		for (;;) {
489 			if (chan->state != CS_ESTABLISHED &&
490 			    chan->state != CS_RDCLOSED) {
491 				chan_unlock(chan);
492 				errno = EPIPE;
493 				return (-1);
494 			}
495 			avail = buf_avail(chan->sendbuf);
496 			if (avail > 0)
497 				break;
498 			pthread_cond_wait(&chan->wrready, &chan->lock);
499 		}
500 		n = min(avail, size - pos);
501 		buf_put(chan->sendbuf, cp + pos, n);
502 		pos += n;
503 	}
504 	chan_unlock(chan);
505 	sender_wakeup(chan->mux);
506 	return (size);
507 }
508 
509 /*
510  * Internal channel API.
511  */
512 
513 static struct chan *
chan_connect(struct mux * m,int id)514 chan_connect(struct mux *m, int id)
515 {
516 	struct chan *chan;
517 
518 	chan = chan_get(m, id);
519 	if (chan->state != CS_UNUSED) {
520 		chan_unlock(chan);
521 		return (NULL);
522 	}
523 	chan->state = CS_CONNECTING;
524 	chan->flags |= CF_CONNECT;
525 	chan_unlock(chan);
526 	sender_wakeup(m);
527 	chan_lock(chan);
528 	while (chan->state == CS_CONNECTING)
529 		pthread_cond_wait(&chan->wrready, &chan->lock);
530 	if (chan->state != CS_ESTABLISHED) {
531 		chan_unlock(chan);
532 		return (NULL);
533 	}
534 	chan_unlock(chan);
535 	return (chan);
536 }
537 
538 /*
539  * Get a channel from its ID, creating it if necessary.
540  * The channel is returned locked.
541  */
542 static struct chan *
chan_get(struct mux * m,int id)543 chan_get(struct mux *m, int id)
544 {
545 	struct chan *chan;
546 
547 	assert(id < MUX_MAXCHAN);
548 	mux_lock(m);
549 	chan = m->channels[id];
550 	if (chan == NULL) {
551 		chan = chan_new(m);
552 		m->channels[id] = chan;
553 		m->nchans++;
554 	}
555 	chan_lock(chan);
556 	mux_unlock(m);
557 	return (chan);
558 }
559 
560 /* Lock a channel. */
561 static void
chan_lock(struct chan * chan)562 chan_lock(struct chan *chan)
563 {
564 	int error;
565 
566 	error = pthread_mutex_lock(&chan->lock);
567 	assert(!error);
568 }
569 
570 /* Unlock a channel.  */
571 static void
chan_unlock(struct chan * chan)572 chan_unlock(struct chan *chan)
573 {
574 	int error;
575 
576 	error = pthread_mutex_unlock(&chan->lock);
577 	assert(!error);
578 }
579 
580 /*
581  * Create a new channel.
582  */
583 static struct chan *
chan_new(struct mux * m)584 chan_new(struct mux *m)
585 {
586 	struct chan *chan;
587 
588 	chan = xmalloc(sizeof(struct chan));
589 	chan->state = CS_UNUSED;
590 	chan->flags = 0;
591 	chan->mux = m;
592 	chan->sendbuf = buf_new(CHAN_SBSIZE);
593 	chan->sendseq = 0;
594 	chan->sendwin = 0;
595 	chan->sendmss = 0;
596 	chan->recvbuf = buf_new(CHAN_RBSIZE);
597 	chan->recvseq = 0;
598 	chan->recvmss = CHAN_MAXSEGSIZE;
599 	pthread_mutex_init(&chan->lock, NULL);
600 	pthread_cond_init(&chan->rdready, NULL);
601 	pthread_cond_init(&chan->wrready, NULL);
602 	return (chan);
603 }
604 
605 /* Free any resources associated with a channel. */
606 static void
chan_free(struct chan * chan)607 chan_free(struct chan *chan)
608 {
609 
610 	pthread_cond_destroy(&chan->rdready);
611 	pthread_cond_destroy(&chan->wrready);
612 	pthread_mutex_destroy(&chan->lock);
613 	buf_free(chan->recvbuf);
614 	buf_free(chan->sendbuf);
615 	free(chan);
616 }
617 
618 /* Insert the new channel in the channel list. */
619 static int
chan_insert(struct mux * m,struct chan * chan)620 chan_insert(struct mux *m, struct chan *chan)
621 {
622 	int i;
623 
624 	mux_lock(m);
625 	for (i = 0; i < MUX_MAXCHAN; i++) {
626 		if (m->channels[i] == NULL) {
627 			m->channels[i] = chan;
628 			m->nchans++;
629 			mux_unlock(m);
630 			return (i);
631 		}
632 	}
633 	errno = ENOBUFS;
634 	return (-1);
635 }
636 
637 /*
638  * Initialize the multiplexer protocol.
639  *
640  * This means negotiating protocol version and starting
641  * the receiver and sender threads.
642  */
643 static int
mux_init(struct mux * m)644 mux_init(struct mux *m)
645 {
646 	struct mux_header mh;
647 	int error;
648 
649 	mh.type = MUX_STARTUPREQ;
650 	mh.mh_startup.version = htons(MUX_PROTOVER);
651 	error = sock_write(m->socket, &mh, MUX_STARTUPHDRSZ);
652 	if (error)
653 		return (-1);
654 	error = sock_readwait(m->socket, &mh, MUX_STARTUPHDRSZ);
655 	if (error)
656 		return (-1);
657 	if (mh.type != MUX_STARTUPREP ||
658 	    ntohs(mh.mh_startup.version) != MUX_PROTOVER)
659 		return (-1);
660 	mux_lock(m);
661 	error = pthread_create(&m->sender, NULL, sender_loop, m);
662 	if (error) {
663 		mux_unlock(m);
664 		return (-1);
665 	}
666 	/*
667 	 * Make sure the sender thread has run and is waiting for new work
668 	 * before going on.  Otherwise, it might lose the race and a
669 	 * request, which will cause a deadlock.
670 	 */
671 	while (!m->sender_ready)
672 		pthread_cond_wait(&m->sender_started, &m->lock);
673 
674 	mux_unlock(m);
675 	error = pthread_create(&m->receiver, NULL, receiver_loop, m);
676 	if (error)
677 		return (-1);
678 	return (0);
679 }
680 
681 /*
682  * Close all the channels, terminate the sender and receiver thread.
683  * This is an important function because it is used every time we need
684  * to wake up all the worker threads to abort the program.
685  *
686  * This function accepts an error message that will be printed if the
687  * multiplexer wasn't already closed.  This is useful because it ensures
688  * that only the first error message will be printed, and that it will
689  * be printed before doing the actual shutdown work.  If this is a
690  * normal shutdown, NULL can be passed instead.
691  *
692  * The "status" parameter of the first mux_shutdown() call is retained
693  * and then returned by mux_close(),  so that the main thread can know
694  * what type of error happened in the end, if any.
695  */
696 void
mux_shutdown(struct mux * m,const char * errmsg,int status)697 mux_shutdown(struct mux *m, const char *errmsg, int status)
698 {
699 	pthread_t self, sender, receiver;
700 	struct chan *chan;
701 	const char *name;
702 	void *val;
703 	int i, ret;
704 
705 	mux_lock(m);
706 	if (m->closed) {
707 		mux_unlock(m);
708 		return;
709 	}
710 	m->closed = 1;
711 	m->status = status;
712 	self = pthread_self();
713 	sender = m->sender;
714 	receiver = m->receiver;
715 	if (errmsg != NULL) {
716 		if (pthread_equal(self, receiver))
717 			name = "Receiver";
718 		else if (pthread_equal(self, sender))
719 			name = "Sender";
720 		else
721 			name = NULL;
722 		if (name == NULL)
723 			lprintf(-1, "%s\n", errmsg);
724 		else
725 			lprintf(-1, "%s: %s\n", name, errmsg);
726 	}
727 
728 	for (i = 0; i < MUX_MAXCHAN; i++) {
729 		if (m->channels[i] != NULL) {
730 			chan = m->channels[i];
731 			chan_lock(chan);
732 			if (chan->state != CS_UNUSED) {
733 				chan->state = CS_CLOSED;
734 				chan->flags = 0;
735 				pthread_cond_broadcast(&chan->rdready);
736 				pthread_cond_broadcast(&chan->wrready);
737 			}
738 			chan_unlock(chan);
739 		}
740 	}
741 	mux_unlock(m);
742 
743 	if (!pthread_equal(self, receiver)) {
744 		ret = pthread_cancel(receiver);
745 		assert(!ret);
746 		pthread_join(receiver, &val);
747 		assert(val == PTHREAD_CANCELED);
748 	}
749 	if (!pthread_equal(self, sender)) {
750 		ret = pthread_cancel(sender);
751 		assert(!ret);
752 		pthread_join(sender, &val);
753 		assert(val == PTHREAD_CANCELED);
754 	}
755 }
756 
757 static void
sender_wakeup(struct mux * m)758 sender_wakeup(struct mux *m)
759 {
760 	int waiting;
761 
762 	mux_lock(m);
763 	waiting = m->sender_waiting;
764 	mux_unlock(m);
765 	/*
766 	 * We don't care about the race here: if the sender was
767 	 * waiting and is not anymore, we'll just send a useless
768 	 * signal; if he wasn't waiting then he won't go to sleep
769 	 * before having sent what we want him to.
770 	 */
771 	if (waiting)
772 		pthread_cond_signal(&m->sender_newwork);
773 }
774 
775 static void *
sender_loop(void * arg)776 sender_loop(void *arg)
777 {
778 	struct iovec iov[3];
779 	struct mux_header mh;
780 	struct mux *m;
781 	struct chan *chan;
782 	struct buf *buf;
783 	uint32_t winsize;
784 	uint16_t hdrsize, size, len;
785 	int error, id, iovcnt, what = 0;
786 
787 	m = (struct mux *)arg;
788 	what = 0;
789 again:
790 	id = sender_waitforwork(m, &what);
791 	chan = chan_get(m, id);
792 	hdrsize = size = 0;
793 	switch (what) {
794 	case CF_CONNECT:
795 		mh.type = MUX_CONNECT;
796 		mh.mh_connect.id = id;
797 		mh.mh_connect.mss = htons(chan->recvmss);
798 		mh.mh_connect.window = htonl(chan->recvseq +
799 		    chan->recvbuf->size);
800 		hdrsize = MUX_CONNECTHDRSZ;
801 		break;
802 	case CF_ACCEPT:
803 		mh.type = MUX_ACCEPT;
804 		mh.mh_accept.id = id;
805 		mh.mh_accept.mss = htons(chan->recvmss);
806 		mh.mh_accept.window = htonl(chan->recvseq +
807 		    chan->recvbuf->size);
808 		hdrsize = MUX_ACCEPTHDRSZ;
809 		break;
810 	case CF_RESET:
811 		mh.type = MUX_RESET;
812 		mh.mh_reset.id = id;
813 		hdrsize = MUX_RESETHDRSZ;
814 		break;
815 	case CF_WINDOW:
816 		mh.type = MUX_WINDOW;
817 		mh.mh_window.id = id;
818 		mh.mh_window.window = htonl(chan->recvseq +
819 		    chan->recvbuf->size);
820 		hdrsize = MUX_WINDOWHDRSZ;
821 		break;
822 	case CF_DATA:
823 		mh.type = MUX_DATA;
824 		mh.mh_data.id = id;
825 		size = min(buf_count(chan->sendbuf), chan->sendmss);
826 		winsize = chan->sendwin - chan->sendseq;
827 		if (winsize < size)
828 			size = winsize;
829 		mh.mh_data.len = htons(size);
830 		hdrsize = MUX_DATAHDRSZ;
831 		break;
832 	case CF_CLOSE:
833 		mh.type = MUX_CLOSE;
834 		mh.mh_close.id = id;
835 		hdrsize = MUX_CLOSEHDRSZ;
836 		break;
837 	}
838 	if (size > 0) {
839 		assert(mh.type == MUX_DATA);
840 		/*
841 		 * Older FreeBSD versions (and maybe other OSes) have the
842 		 * iov_base field defined as char *.  Cast to char * to
843 		 * silence a warning in this case.
844 		 */
845 		iov[0].iov_base = (char *)&mh;
846 		iov[0].iov_len = hdrsize;
847 		iovcnt = 1;
848 		/* We access the buffer directly to avoid some copying. */
849 		buf = chan->sendbuf;
850 		len = min(size, buf->size + 1 - buf->out);
851 		iov[iovcnt].iov_base = buf->data + buf->out;
852 		iov[iovcnt].iov_len = len;
853 		iovcnt++;
854 		if (size > len) {
855 			/* Wrapping around. */
856 			iov[iovcnt].iov_base = buf->data;
857 			iov[iovcnt].iov_len = size - len;
858 			iovcnt++;
859 		}
860 		/*
861 		 * Since we're the only thread sending bytes from the
862 		 * buffer and modifying buf->out, it's safe to unlock
863 		 * here during I/O.  It avoids keeping the channel lock
864 		 * too long, since write() might block.
865 		 */
866 		chan_unlock(chan);
867 		error = sock_writev(m->socket, iov, iovcnt);
868 		if (error)
869 			goto bad;
870 		chan_lock(chan);
871 		chan->sendseq += size;
872 		buf->out += size;
873 		if (buf->out > buf->size)
874 			buf->out -= buf->size + 1;
875 		pthread_cond_signal(&chan->wrready);
876 		chan_unlock(chan);
877 	} else {
878 		chan_unlock(chan);
879 		error = sock_write(m->socket, &mh, hdrsize);
880 		if (error)
881 			goto bad;
882 	}
883 	goto again;
884 bad:
885 	if (error == EPIPE)
886 		mux_shutdown(m, strerror(errno), STATUS_TRANSIENTFAILURE);
887 	else
888 		mux_shutdown(m, strerror(errno), STATUS_FAILURE);
889 	return (NULL);
890 }
891 
892 static void
sender_cleanup(void * arg)893 sender_cleanup(void *arg)
894 {
895 	struct mux *m;
896 
897 	m = (struct mux *)arg;
898 	mux_unlock(m);
899 }
900 
901 static int
sender_waitforwork(struct mux * m,int * what)902 sender_waitforwork(struct mux *m, int *what)
903 {
904 	int id;
905 
906 	mux_lock(m);
907 	pthread_cleanup_push(sender_cleanup, m);
908 	if (!m->sender_ready) {
909 		pthread_cond_signal(&m->sender_started);
910 		m->sender_ready = 1;
911 	}
912 	while ((id = sender_scan(m, what)) == -1) {
913 		m->sender_waiting = 1;
914 		pthread_cond_wait(&m->sender_newwork, &m->lock);
915 	}
916 	m->sender_waiting = 0;
917 	pthread_cleanup_pop(1);
918 	return (id);
919 }
920 
921 /*
922  * Scan for work to do for the sender.  Has to be called with
923  * the multiplexer lock held.
924  */
925 static int
sender_scan(struct mux * m,int * what)926 sender_scan(struct mux *m, int *what)
927 {
928 	struct chan *chan;
929 	int id;
930 
931 	if (m->nchans <= 0)
932 		return (-1);
933 	id = m->sender_lastid;
934 	do {
935 		id++;
936 		if (id >= m->nchans)
937 			id = 0;
938 		chan = m->channels[id];
939 		chan_lock(chan);
940 		if (chan->state != CS_UNUSED) {
941 			if (chan->sendseq != chan->sendwin &&
942 			    buf_count(chan->sendbuf) > 0)
943 				chan->flags |= CF_DATA;
944 			if (chan->flags) {
945 				/* By order of importance. */
946 				if (chan->flags & CF_CONNECT)
947 					*what = CF_CONNECT;
948 				else if (chan->flags & CF_ACCEPT)
949 					*what = CF_ACCEPT;
950 				else if (chan->flags & CF_RESET)
951 					*what = CF_RESET;
952 				else if (chan->flags & CF_WINDOW)
953 					*what = CF_WINDOW;
954 				else if (chan->flags & CF_DATA)
955 					*what = CF_DATA;
956 				else if (chan->flags & CF_CLOSE)
957 					*what = CF_CLOSE;
958 				chan->flags &= ~*what;
959 				chan_unlock(chan);
960 				m->sender_lastid = id;
961 				return (id);
962 			}
963 		}
964 		chan_unlock(chan);
965 	} while (id != m->sender_lastid);
966 	return (-1);
967 }
968 
969 /* Read the rest of a packet header depending on its type. */
970 #define	SOCK_READREST(s, mh, hsize)	\
971     sock_readwait(s, (char *)&mh + sizeof(mh.type), (hsize) - sizeof(mh.type))
972 
973 void *
receiver_loop(void * arg)974 receiver_loop(void *arg)
975 {
976 	struct mux_header mh;
977 	struct mux *m;
978 	struct chan *chan;
979 	struct buf *buf;
980 	uint16_t size, len;
981 	int error;
982 
983 	m = (struct mux *)arg;
984 	while ((error = sock_readwait(m->socket, &mh.type,
985 	    sizeof(mh.type))) == 0) {
986 		switch (mh.type) {
987 		case MUX_CONNECT:
988 			error = SOCK_READREST(m->socket, mh, MUX_CONNECTHDRSZ);
989 			if (error)
990 				goto bad;
991 			chan = chan_get(m, mh.mh_connect.id);
992 			if (chan->state == CS_LISTENING) {
993 				chan->state = CS_ESTABLISHED;
994 				chan->sendmss = ntohs(mh.mh_connect.mss);
995 				chan->sendwin = ntohl(mh.mh_connect.window);
996 				chan->flags |= CF_ACCEPT;
997 				pthread_cond_signal(&chan->rdready);
998 			} else
999 				chan->flags |= CF_RESET;
1000 			chan_unlock(chan);
1001 			sender_wakeup(m);
1002 			break;
1003 		case MUX_ACCEPT:
1004 			error = SOCK_READREST(m->socket, mh, MUX_ACCEPTHDRSZ);
1005 			if (error)
1006 				goto bad;
1007 			chan = chan_get(m, mh.mh_accept.id);
1008 			if (chan->state == CS_CONNECTING) {
1009 				chan->sendmss = ntohs(mh.mh_accept.mss);
1010 				chan->sendwin = ntohl(mh.mh_accept.window);
1011 				chan->state = CS_ESTABLISHED;
1012 				pthread_cond_signal(&chan->wrready);
1013 				chan_unlock(chan);
1014 			} else {
1015 				chan->flags |= CF_RESET;
1016 				chan_unlock(chan);
1017 				sender_wakeup(m);
1018 			}
1019 			break;
1020 		case MUX_RESET:
1021 			error = SOCK_READREST(m->socket, mh, MUX_RESETHDRSZ);
1022 			if (error)
1023 				goto bad;
1024 			goto badproto;
1025 		case MUX_WINDOW:
1026 			error = SOCK_READREST(m->socket, mh, MUX_WINDOWHDRSZ);
1027 			if (error)
1028 				goto bad;
1029 			chan = chan_get(m, mh.mh_window.id);
1030 			if (chan->state == CS_ESTABLISHED ||
1031 			    chan->state == CS_RDCLOSED) {
1032 				chan->sendwin = ntohl(mh.mh_window.window);
1033 				chan_unlock(chan);
1034 				sender_wakeup(m);
1035 			} else {
1036 				chan_unlock(chan);
1037 			}
1038 			break;
1039 		case MUX_DATA:
1040 			error = SOCK_READREST(m->socket, mh, MUX_DATAHDRSZ);
1041 			if (error)
1042 				goto bad;
1043 			chan = chan_get(m, mh.mh_data.id);
1044 			len = ntohs(mh.mh_data.len);
1045 			buf = chan->recvbuf;
1046 			if ((chan->state != CS_ESTABLISHED &&
1047 			     chan->state != CS_WRCLOSED) ||
1048 			    (len > buf_avail(buf) ||
1049 			     len > chan->recvmss)) {
1050 				chan_unlock(chan);
1051 				goto badproto;
1052 				return (NULL);
1053 			}
1054 			/*
1055 			 * Similarly to the sender code, it's safe to
1056 			 * unlock the channel here.
1057 			 */
1058 			chan_unlock(chan);
1059 			size = min(buf->size + 1 - buf->in, len);
1060 			error = sock_readwait(m->socket,
1061 			    buf->data + buf->in, size);
1062 			if (error)
1063 				goto bad;
1064 			if (len > size) {
1065 				/* Wrapping around. */
1066 				error = sock_readwait(m->socket,
1067 				    buf->data, len - size);
1068 				if (error)
1069 					goto bad;
1070 			}
1071 			chan_lock(chan);
1072 			buf->in += len;
1073 			if (buf->in > buf->size)
1074 				buf->in -= buf->size + 1;
1075 			pthread_cond_signal(&chan->rdready);
1076 			chan_unlock(chan);
1077 			break;
1078 		case MUX_CLOSE:
1079 			error = SOCK_READREST(m->socket, mh, MUX_CLOSEHDRSZ);
1080 			if (error)
1081 				goto bad;
1082 			chan = chan_get(m, mh.mh_close.id);
1083 			if (chan->state == CS_ESTABLISHED)
1084 				chan->state = CS_RDCLOSED;
1085 			else if (chan->state == CS_WRCLOSED)
1086 				chan->state = CS_CLOSED;
1087 			else
1088 				goto badproto;
1089 			pthread_cond_signal(&chan->rdready);
1090 			chan_unlock(chan);
1091 			break;
1092 		default:
1093 			goto badproto;
1094 		}
1095 	}
1096 bad:
1097 	if (errno == ECONNRESET || errno == ECONNABORTED)
1098 		mux_shutdown(m, strerror(errno), STATUS_TRANSIENTFAILURE);
1099 	else
1100 		mux_shutdown(m, strerror(errno), STATUS_FAILURE);
1101 	return (NULL);
1102 badproto:
1103 	mux_shutdown(m, "Protocol error", STATUS_FAILURE);
1104 	return (NULL);
1105 }
1106 
1107 /*
1108  * Circular buffers API.
1109  */
1110 
1111 static struct buf *
buf_new(size_t size)1112 buf_new(size_t size)
1113 {
1114 	struct buf *buf;
1115 
1116 	buf = xmalloc(sizeof(struct buf));
1117 	buf->data = xmalloc(size + 1);
1118 	buf->size = size;
1119 	buf->in = 0;
1120 	buf->out = 0;
1121 	return (buf);
1122 }
1123 
1124 static void
buf_free(struct buf * buf)1125 buf_free(struct buf *buf)
1126 {
1127 
1128 	free(buf->data);
1129 	free(buf);
1130 }
1131 
1132 /* Number of bytes stored in the buffer. */
1133 static size_t
buf_count(struct buf * buf)1134 buf_count(struct buf *buf)
1135 {
1136 	size_t count;
1137 
1138 	if (buf->in >= buf->out)
1139 		count = buf->in - buf->out;
1140 	else
1141 		count = buf->size + 1 + buf->in - buf->out;
1142 	return (count);
1143 }
1144 
1145 /* Number of bytes available in the buffer. */
1146 static size_t
buf_avail(struct buf * buf)1147 buf_avail(struct buf *buf)
1148 {
1149 	size_t avail;
1150 
1151 	if (buf->out > buf->in)
1152 		avail = buf->out - buf->in - 1;
1153 	else
1154 		avail = buf->size + buf->out - buf->in;
1155 	return (avail);
1156 }
1157 
1158 static void
buf_put(struct buf * buf,const void * data,size_t size)1159 buf_put(struct buf *buf, const void *data, size_t size)
1160 {
1161 	const char *cp;
1162 	size_t len;
1163 
1164 	assert(size > 0);
1165 	assert(buf_avail(buf) >= size);
1166 	cp = data;
1167 	len = buf->size + 1 - buf->in;
1168 	if (len < size) {
1169 		/* Wrapping around. */
1170 		memcpy(buf->data + buf->in, cp, len);
1171 		memcpy(buf->data, cp + len, size - len);
1172 	} else {
1173 		/* Not wrapping around. */
1174 		memcpy(buf->data + buf->in, cp, size);
1175 	}
1176 	buf->in += size;
1177 	if (buf->in > buf->size)
1178 		buf->in -= buf->size + 1;
1179 }
1180 
1181 static void
buf_get(struct buf * buf,void * data,size_t size)1182 buf_get(struct buf *buf, void *data, size_t size)
1183 {
1184 	char *cp;
1185 	size_t len;
1186 
1187 	assert(size > 0);
1188 	assert(buf_count(buf) >= size);
1189 	cp = data;
1190 	len = buf->size + 1 - buf->out;
1191 	if (len < size) {
1192 		/* Wrapping around. */
1193 		memcpy(cp, buf->data + buf->out, len);
1194 		memcpy(cp + len, buf->data, size - len);
1195 	} else {
1196 		/* Not wrapping around. */
1197 		memcpy(cp, buf->data + buf->out, size);
1198 	}
1199 	buf->out += size;
1200 	if (buf->out > buf->size)
1201 		buf->out -= buf->size + 1;
1202 }
1203