83a7a2c732eb15314408a5e708d20c117782be69
[debian/amanda] / common-src / ssh-security.c
1 /*
2  * Amanda, The Advanced Maryland Automatic Network Disk Archiver
3  * Copyright (c) 1999 University of Maryland
4  * All Rights Reserved.
5  *
6  * Permission to use, copy, modify, distribute, and sell this software and its
7  * documentation for any purpose is hereby granted without fee, provided that
8  * the above copyright notice appear in all copies and that both that
9  * copyright notice and this permission notice appear in supporting
10  * documentation, and that the name of U.M. not be used in advertising or
11  * publicity pertaining to distribution of the software without specific,
12  * written prior permission.  U.M. makes no representations about the
13  * suitability of this software for any purpose.  It is provided "as is"
14  * without express or implied warranty.
15  *
16  * U.M. DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
17  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT SHALL U.M.
18  * BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
19  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
20  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
21  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
22  *
23  * Authors: the Amanda Development Team.  Its members are listed in a
24  * file named AUTHORS, in the root directory of this distribution.
25  */
26
27 /*
28  * $Id: ssh-security.c,v 1.8.2.1 2006/04/11 11:11:16 martinea Exp $
29  *
30  * ssh-security.c - security and transport over ssh or a ssh-like command.
31  *
32  * XXX still need to check for initial keyword on connect so we can skip
33  * over shell garbage and other stuff that ssh might want to spew out.
34  */
35
36 #include "amanda.h"
37 #include "event.h"
38 #include "packet.h"
39 #include "queue.h"
40 #include "security.h"
41 #include "stream.h"
42 #include "version.h"
43
44 #ifdef SSH_SECURITY
45
46 /*#define       SSH_DEBUG*/
47
48 #ifdef SSH_DEBUG
49 #define sshprintf(x)    dbprintf(x)
50 #else
51 #define sshprintf(x)
52 #endif
53
54 /*
55  * Path to the ssh binary.  This should be configurable.
56  */
57 #define SSH_PATH        "/usr/bin/ssh"
58
59 /*
60  * Arguments to ssh.  This should also be configurable
61  */
62 #define SSH_ARGS        "-x", "-l", CLIENT_LOGIN
63
64 /*
65  * Number of seconds ssh has to start up
66  */
67 #define CONNECT_TIMEOUT 20
68
69 /*
70  * Magic values for ssh_conn->handle
71  */
72 #define H_TAKEN -1              /* ssh_conn->tok was already read */
73 #define H_EOF   -2              /* this connection has been shut down */
74
75 /*
76  * This is a ssh connection to a host.  We should only have
77  * one connection per host.
78  */
79 struct ssh_conn {
80     int read, write;                            /* pipes to ssh */
81     pid_t pid;                                  /* pid of ssh process */
82     char pkt[NETWORK_BLOCK_BYTES];              /* last pkt read */
83     unsigned long pktlen;                       /* len of above */
84     struct {                                    /* buffer read() calls */
85         char buf[STREAM_BUFSIZE];               /* buffer */
86         size_t left;                    /* unread data */
87         ssize_t size;                   /* size of last read */
88     } readbuf;
89     event_handle_t *ev_read;                    /* read (EV_READFD) handle */
90     int ev_read_refcnt;                         /* number of readers */
91     char hostname[MAX_HOSTNAME_LENGTH+1];       /* host we're talking to */
92     char *errmsg;                               /* error passed up */
93     int refcnt;                                 /* number of handles using */
94     int handle;                                 /* last proto handle read */
95     TAILQ_ENTRY(ssh_conn) tq;                   /* queue handle */
96 };
97
98
99 struct ssh_stream;
100
101 /*
102  * This is the private handle data.
103  */
104 struct ssh_handle {
105     security_handle_t sech;             /* MUST be first */
106     char *hostname;                     /* ptr to rc->hostname */
107     struct ssh_stream *rs;              /* virtual stream we xmit over */
108
109     union {
110         void (*recvpkt) P((void *, pkt_t *, security_status_t));
111                                         /* func to call when packet recvd */
112         void (*connect) P((void *, security_handle_t *, security_status_t));
113                                         /* func to call when connected */
114     } fn;
115     void *arg;                          /* argument to pass function */
116     event_handle_t *ev_timeout;         /* timeout handle for recv */
117 };
118
119 /*
120  * This is the internal security_stream data for ssh.
121  */
122 struct ssh_stream {
123     security_stream_t secstr;           /* MUST be first */
124     struct ssh_conn *rc;                /* physical connection */
125     int handle;                         /* protocol handle */
126     event_handle_t *ev_read;            /* read (EV_WAIT) event handle */
127     void (*fn) P((void *, void *, ssize_t));    /* read event fn */
128     void *arg;                          /* arg for previous */
129 };
130
131 /*
132  * Interface functions
133  */
134 static int ssh_sendpkt P((void *, pkt_t *));
135 static int ssh_stream_accept P((void *));
136 static int ssh_stream_auth P((void *));
137 static int ssh_stream_id P((void *));
138 static int ssh_stream_write P((void *, const void *, size_t));
139 static void *ssh_stream_client P((void *, int));
140 static void *ssh_stream_server P((void *));
141 static void ssh_accept P((int, int,
142     void (*)(security_handle_t *, pkt_t *)));
143 static void ssh_close P((void *));
144 static void ssh_connect P((const char *,
145     char *(*)(char *, void *), 
146     void (*)(void *, security_handle_t *, security_status_t), void *));
147 static void ssh_recvpkt P((void *,
148     void (*)(void *, pkt_t *, security_status_t), void *, int));
149 static void ssh_recvpkt_cancel P((void *));
150 static void ssh_stream_close P((void *));
151 static void ssh_stream_read P((void *, void (*)(void *, void *, ssize_t),
152     void *));
153 static void ssh_stream_read_cancel P((void *));
154
155 /*
156  * This is our interface to the outside world.
157  */
158 const security_driver_t ssh_security_driver = {
159     "SSH",
160     ssh_connect,
161     ssh_accept,
162     ssh_close,
163     ssh_sendpkt,
164     ssh_recvpkt,
165     ssh_recvpkt_cancel,
166     ssh_stream_server,
167     ssh_stream_accept,
168     ssh_stream_client,
169     ssh_stream_close,
170     ssh_stream_auth,
171     ssh_stream_id,
172     ssh_stream_write,
173     ssh_stream_read,
174     ssh_stream_read_cancel,
175 };
176
177 /*
178  * This is a queue of open connections
179  */
180 static struct {
181     TAILQ_HEAD(, ssh_conn) tailq;
182     int qlength;
183 } connq = {
184     TAILQ_HEAD_INITIALIZER(connq.tailq), 0
185 };
186 #define connq_first()           TAILQ_FIRST(&connq.tailq)
187 #define connq_next(rc)          TAILQ_NEXT(rc, tq)
188 #define connq_append(rc)        do {                                    \
189     TAILQ_INSERT_TAIL(&connq.tailq, rc, tq);                            \
190     connq.qlength++;                                                    \
191 } while (0)
192 #define connq_remove(rc)        do {                                    \
193     assert(connq.qlength > 0);                                          \
194     TAILQ_REMOVE(&connq.tailq, rc, tq);                                 \
195     connq.qlength--;                                                    \
196 } while (0)
197
198 static int newhandle = 1;
199
200 /*
201  * This is a function that should be called if a new security_handle_t is
202  * created.  If NULL, no new handles are created.
203  * It is passed the new handle and the received pkt
204  */
205 static void (*accept_fn) P((security_handle_t *, pkt_t *));
206
207 /*
208  * Local functions
209  */
210 static void connect_callback P((void *));
211 static void connect_timeout P((void *));
212 static int send_token P((struct ssh_conn *, int, const void *, size_t));
213 static int recv_token P((struct ssh_conn *, int));
214 static void recvpkt_callback P((void *, void *, ssize_t));
215 static void recvpkt_timeout P((void *));
216 static void stream_read_callback P((void *));
217
218 static int runssh P((struct ssh_conn *));
219 static struct ssh_conn *conn_get P((const char *));
220 static void conn_put P((struct ssh_conn *));
221 static void conn_read P((struct ssh_conn *));
222 static void conn_read_cancel P((struct ssh_conn *));
223 static void conn_read_callback P((void *));
224 static int net_writev P((int, struct iovec *, int));
225 static ssize_t net_read P((struct ssh_conn *, void *, size_t, int));
226 static int net_read_fillbuf P((struct ssh_conn *, int, int));
227 static void parse_pkt P((pkt_t *, const void *, size_t));
228
229
230 /*
231  * ssh version of a security handle allocator.  Logically sets
232  * up a network "connection".
233  */
234 static void
235 ssh_connect(hostname, conf_fn, fn, arg)
236     const char *hostname;
237     char *(*conf_fn) P((char *, void *));
238     void (*fn) P((void *, security_handle_t *, security_status_t));
239     void *arg;
240 {
241     struct ssh_handle *rh;
242     struct hostent *he;
243
244     assert(fn != NULL);
245     assert(hostname != NULL);
246
247     sshprintf(("%s: ssh: ssh_connect: %s\n", debug_prefix_time(NULL), hostname));
248
249     rh = alloc(sizeof(*rh));
250     security_handleinit(&rh->sech, &ssh_security_driver);
251     rh->hostname = NULL;
252     rh->rs = NULL;
253     rh->ev_timeout = NULL;
254
255     if ((he = gethostbyname(hostname)) == NULL) {
256         security_seterror(&rh->sech,
257             "%s: could not resolve hostname", hostname);
258         (*fn)(arg, &rh->sech, S_ERROR);
259         return;
260     }
261     rh->hostname = he->h_name;  /* will be replaced */
262     rh->rs = ssh_stream_client(rh, newhandle++);
263
264     if (rh->rs == NULL)
265         goto error;
266
267     rh->hostname = rh->rs->rc->hostname;
268
269     if (rh->rs->rc->pid < 0) {
270         /*
271          * We need to open a new connection.
272          *
273          * XXX need to eventually limit number of outgoing connections here.
274          */
275         if (runssh(rh->rs->rc) < 0) {
276             security_seterror(&rh->sech,
277                 "can't connect to %s: %s", hostname, rh->rs->rc->errmsg);
278             goto error;
279         }
280     }
281     /*
282      * The socket will be opened async so hosts that are down won't
283      * block everything.  We need to register a write event
284      * so we will know when the socket comes alive.
285      *
286      * Overload rh->rs->ev_read to provide a write event handle.
287      * We also register a timeout.
288      */
289     rh->fn.connect = fn;
290     rh->arg = arg;
291     rh->rs->ev_read = event_register(rh->rs->rc->write, EV_WRITEFD,
292         connect_callback, rh);
293     rh->ev_timeout = event_register(CONNECT_TIMEOUT, EV_TIME,
294         connect_timeout, rh);
295
296     return;
297
298 error:
299     (*fn)(arg, &rh->sech, S_ERROR);
300 }
301
302 /*
303  * Called when a ssh connection is finished connecting and is ready
304  * to be authenticated.
305  */
306 static void
307 connect_callback(cookie)
308     void *cookie;
309 {
310     struct ssh_handle *rh = cookie;
311
312     event_release(rh->rs->ev_read);
313     rh->rs->ev_read = NULL;
314     event_release(rh->ev_timeout);
315     rh->ev_timeout = NULL;
316
317     (*rh->fn.connect)(rh->arg, &rh->sech, S_OK);
318 }
319
320 /*
321  * Called if a connection times out before completion.
322  */
323 static void
324 connect_timeout(cookie)
325     void *cookie;
326 {
327     struct ssh_handle *rh = cookie;
328
329     event_release(rh->rs->ev_read);
330     rh->rs->ev_read = NULL;
331     event_release(rh->ev_timeout);
332     rh->ev_timeout = NULL;
333
334     (*rh->fn.connect)(rh->arg, &rh->sech, S_TIMEOUT);
335 }
336
337 /*
338  * Setup to handle new incoming connections
339  */
340 static void
341 ssh_accept(in, out, fn)
342     int in, out;
343     void (*fn) P((security_handle_t *, pkt_t *));
344 {
345     struct ssh_conn *rc;
346
347     rc = conn_get("unknown");
348     rc->read = in;
349     rc->write = out;
350     accept_fn = fn;
351     conn_read(rc);
352 }
353
354 /*
355  * Locate an existing connection to the given host, or create a new,
356  * unconnected entry if none exists.  The caller is expected to check
357  * for the lack of a connection (rc->read == -1) and set one up.
358  */
359 static struct ssh_conn *
360 conn_get(hostname)
361     const char *hostname;
362 {
363     struct ssh_conn *rc;
364
365     sshprintf(("%s: ssh: conn_get: %s\n", debug_prefix_time(NULL), hostname));
366
367     for (rc = connq_first(); rc != NULL; rc = connq_next(rc)) {
368         if (strcasecmp(hostname, rc->hostname) == 0)
369             break;
370     }
371
372     if (rc != NULL) {
373         rc->refcnt++;
374         sshprintf(("%s: ssh: conn_get: exists, refcnt to %s is now %d\n", debug_prefix_time(NULL),
375             rc->hostname, rc->refcnt));
376         return (rc);
377     }
378
379     sshprintf(("%s: ssh: conn_get: creating new handle\n", debug_prefix_time(NULL)));
380     /*
381      * We can't be creating a new handle if we are the client
382      */
383     assert(accept_fn == NULL);
384     rc = alloc(sizeof(*rc));
385     rc->read = rc->write = -1;
386     rc->pid = -1;
387     rc->readbuf.left = 0;
388     rc->readbuf.size = 0;
389     rc->ev_read = NULL;
390     strncpy(rc->hostname, hostname, sizeof(rc->hostname) - 1);
391     rc->hostname[sizeof(rc->hostname) - 1] = '\0';
392     rc->errmsg = NULL;
393     rc->refcnt = 1;
394     rc->handle = -1;
395     connq_append(rc);
396     return (rc);
397 }
398
399 /*
400  * Delete a reference to a connection, and close it if it is the last
401  * reference.
402  */
403 static void
404 conn_put(rc)
405     struct ssh_conn *rc;
406 {
407     amwait_t status;
408
409     assert(rc->refcnt > 0);
410     --rc->refcnt;
411     sshprintf(("%s: ssh: conn_put: decrementing refcnt for %s to %d\n", debug_prefix_time(NULL),
412         rc->hostname, rc->refcnt));
413     if (rc->refcnt > 0) {
414         return;
415     }
416     sshprintf(("%s: ssh: conn_put: closing connection to %s\n", debug_prefix_time(NULL), rc->hostname));
417     if (rc->read != -1)
418         aclose(rc->read);
419     if (rc->write != -1)
420         aclose(rc->write);
421     if (rc->pid != -1) {
422         waitpid(rc->pid, &status, WNOHANG);
423     }
424     if (rc->ev_read != NULL)
425         event_release(rc->ev_read);
426     if (rc->errmsg != NULL)
427         amfree(rc->errmsg);
428     connq_remove(rc);
429     amfree(rc);
430 }
431
432 /*
433  * Turn on read events for a conn.  Or, increase a ev_read_refcnt if we are
434  * already receiving read events.
435  */
436 static void
437 conn_read(rc)
438     struct ssh_conn *rc;
439 {
440
441     if (rc->ev_read != NULL) {
442         rc->ev_read_refcnt++;
443         sshprintf(("%s: ssh: conn_read: incremented ev_read_refcnt to %d for %s\n", debug_prefix_time(NULL),
444             rc->ev_read_refcnt, rc->hostname));
445         return;
446     }
447     sshprintf(("%s: ssh: conn_read registering event handler for %s\n", debug_prefix_time(NULL),
448         rc->hostname));
449     rc->ev_read = event_register(rc->read, EV_READFD, conn_read_callback, rc);
450     rc->ev_read_refcnt = 1;
451 }
452
453 static void
454 conn_read_cancel(rc)
455     struct ssh_conn *rc;
456 {
457
458     --rc->ev_read_refcnt;
459     sshprintf(("%s: ssh: conn_read_cancel: decremented ev_read_refcnt to %d for %s\n", debug_prefix_time(NULL),
460         rc->ev_read_refcnt, rc->hostname));
461     if(rc->ev_read_refcnt > 0) {
462         return;
463     }
464     sshprintf(("%s: ssh: conn_read_cancel: releasing event handler for %s\n", debug_prefix_time(NULL),
465         rc->hostname));
466     event_release(rc->ev_read);
467     rc->ev_read = NULL;
468 }
469
470 /*
471  * frees a handle allocated by the above
472  */
473 static void
474 ssh_close(inst)
475     void *inst;
476 {
477     struct ssh_handle *rh = inst;
478
479     assert(rh != NULL);
480
481     sshprintf(("%s: ssh: closing handle to %s\n", debug_prefix_time(NULL), rh->hostname));
482
483     if (rh->rs != NULL) {
484         /* This may be null if we get here on an error */
485         ssh_recvpkt_cancel(rh);
486         security_stream_close(&rh->rs->secstr);
487     }
488     /* keep us from getting here again */
489     rh->sech.driver = NULL;
490     amfree(rh);
491 }
492
493 /*
494  * Forks a ssh to the host listed in rc->hostname
495  * Returns negative on error, with an errmsg in rc->errmsg.
496  */
497 static int
498 runssh(rc)
499     struct ssh_conn *rc;
500 {
501     int rpipe[2], wpipe[2];
502     char *amandad_path;
503
504     if (pipe(rpipe) < 0 || pipe(wpipe) < 0) {
505         rc->errmsg = newvstralloc("pipe: ", strerror(errno), NULL);
506         return (-1);
507     }
508     switch (rc->pid = fork()) {
509     case -1:
510         rc->errmsg = newvstralloc("fork: ", strerror(errno), NULL);
511         aclose(rpipe[0]);
512         aclose(rpipe[1]);
513         aclose(wpipe[0]);
514         aclose(wpipe[1]);
515         return (-1);
516     case 0:
517         dup2(wpipe[0], 0);
518         dup2(rpipe[1], 1);
519         dup2(rpipe[1], 2);
520         break;
521     default:
522         rc->read = rpipe[0];
523         aclose(rpipe[1]);
524         rc->write = wpipe[1];
525         aclose(wpipe[0]);
526         return (0);
527     }
528
529     safe_fd(-1, 0);
530
531     amandad_path = vstralloc(libexecdir, "/", "amandad", versionsuffix(),
532         NULL);
533     execlp(SSH_PATH, SSH_PATH, SSH_ARGS, rc->hostname, amandad_path,
534         "-auth=ssh", NULL);
535     error("error: couldn't exec %s: %s", SSH_PATH, strerror(errno));
536
537     /* should nerver go here, shut up compiler warning */
538     return(-1);
539 }
540
541 /*
542  * Transmit a packet.
543  */
544 static int
545 ssh_sendpkt(cookie, pkt)
546     void *cookie;
547     pkt_t *pkt;
548 {
549     char buf[sizeof(pkt_t)];
550     struct ssh_handle *rh = cookie;
551     size_t len;
552
553     assert(rh != NULL);
554     assert(pkt != NULL);
555
556     sshprintf(("%s: ssh: sendpkt: enter\n", debug_prefix_time(NULL)));
557
558     len = strlen(pkt->body) + 2;
559     buf[0] = (char)pkt->type;
560     strcpy(&buf[1], pkt->body);
561
562     sshprintf(("%s: ssh: sendpkt: %s (%d) pkt_t (len %d) contains:\n\n\"%s\"\n\n", debug_prefix_time(NULL),
563         pkt_type2str(pkt->type), pkt->type, strlen(pkt->body), pkt->body));
564
565     if (ssh_stream_write(rh->rs, buf, len) < 0) {
566         security_seterror(&rh->sech, security_stream_geterror(&rh->rs->secstr));
567         return (-1);
568     }
569     return (0);
570 }
571
572 /*
573  * Set up to receive a packet asyncronously, and call back when
574  * it has been read.
575  */
576 static void
577 ssh_recvpkt(cookie, fn, arg, timeout)
578     void *cookie, *arg;
579     void (*fn) P((void *, pkt_t *, security_status_t));
580     int timeout;
581 {
582     struct ssh_handle *rh = cookie;
583
584     assert(rh != NULL);
585
586     sshprintf(("%s: ssh: recvpkt registered for %s\n", debug_prefix_time(NULL), rh->hostname));
587
588     /*
589      * Reset any pending timeout on this handle
590      */
591     if (rh->ev_timeout != NULL)
592         event_release(rh->ev_timeout);
593
594     /*
595      * Negative timeouts mean no timeout
596      */
597     if (timeout < 0)
598         rh->ev_timeout = NULL;
599     else
600         rh->ev_timeout = event_register(timeout, EV_TIME, recvpkt_timeout, rh);
601
602     rh->fn.recvpkt = fn;
603     rh->arg = arg;
604     ssh_stream_read(rh->rs, recvpkt_callback, rh);
605 }
606
607 /*
608  * Remove a async receive request from the queue
609  */
610 static void
611 ssh_recvpkt_cancel(cookie)
612     void *cookie;
613 {
614     struct ssh_handle *rh = cookie;
615
616     sshprintf(("%s: ssh: cancelling recvpkt for %s\n", debug_prefix_time(NULL), rh->hostname));
617
618     assert(rh != NULL);
619
620     ssh_stream_read_cancel(rh->rs);
621     if (rh->ev_timeout != NULL) {
622         event_release(rh->ev_timeout);
623         rh->ev_timeout = NULL;
624     }
625 }
626
627 /*
628  * This is called when a handle is woken up because data read off of the
629  * net is for it.
630  */
631 static void
632 recvpkt_callback(cookie, buf, bufsize)
633     void *cookie, *buf;
634     ssize_t bufsize;
635 {
636     pkt_t pkt;
637     struct ssh_handle *rh = cookie;
638
639     assert(rh != NULL);
640
641     /*
642      * We need to cancel the recvpkt request before calling
643      * the callback because the callback may reschedule us.
644      */
645     ssh_recvpkt_cancel(rh);
646
647     switch (bufsize) {
648     case 0:
649         security_seterror(&rh->sech,
650             "EOF on read from %s", rh->hostname);
651         (*rh->fn.recvpkt)(rh->arg, NULL, S_ERROR);
652         return;
653     case -1:
654         security_seterror(&rh->sech, security_stream_geterror(&rh->rs->secstr));
655         (*rh->fn.recvpkt)(rh->arg, NULL, S_ERROR);
656         return;
657     default:
658         break;
659     }
660
661     parse_pkt(&pkt, buf, bufsize);
662     sshprintf(("%s: ssh: received %s packet (%d) from %s, contains:\n\n\"%s\"\n\n", debug_prefix_time(NULL),
663         pkt_type2str(pkt.type), pkt.type, rh->hostname, pkt.body));
664     (*rh->fn.recvpkt)(rh->arg, &pkt, S_OK);
665 }
666
667 /*
668  * This is called when a handle times out before receiving a packet.
669  */
670 static void
671 recvpkt_timeout(cookie)
672     void *cookie;
673 {
674     struct ssh_handle *rh = cookie;
675
676     assert(rh != NULL);
677
678     sshprintf(("%s: ssh: recvpkt timeout for %s\n", debug_prefix_time(NULL), rh->hostname));
679
680     ssh_recvpkt_cancel(rh);
681     (*rh->fn.recvpkt)(rh->arg, NULL, S_TIMEOUT);
682 }
683
684 /*
685  * Create the server end of a stream.  For ssh, this means setup a stream
686  * object and allocate a new handle for it.
687  */
688 static void *
689 ssh_stream_server(h)
690     void *h;
691 {
692     struct ssh_handle *rh = h;
693     struct ssh_stream *rs;
694
695     assert(rh != NULL);
696
697     rs = alloc(sizeof(*rs));
698     security_streaminit(&rs->secstr, &ssh_security_driver);
699     rs->rc = conn_get(rh->hostname);
700     /*
701      * Stream should already be setup!
702      */
703     if (rs->rc->read < 0) {
704         conn_put(rs->rc);
705         amfree(rs);
706         security_seterror(&rh->sech, "lost connection to %s", rh->hostname);
707         return (NULL);
708     }
709     rh->hostname = rs->rc->hostname;
710     /*
711      * so as not to conflict with the amanda server's handle numbers,
712      * we start at 5000 and work down
713      */
714     rs->handle = 5000 - newhandle++;
715     rs->ev_read = NULL;
716     sshprintf(("%s: ssh: stream_server: created stream %d\n", debug_prefix_time(NULL), rs->handle));
717     return (rs);
718 }
719
720 /*
721  * Accept an incoming connection on a stream_server socket
722  * Nothing needed for ssh.
723  */
724 static int
725 ssh_stream_accept(s)
726     void *s;
727 {
728
729     return (0);
730 }
731
732 /*
733  * Return a connected stream.  For ssh, this means setup a stream
734  * with the supplied handle.
735  */
736 static void *
737 ssh_stream_client(h, id)
738     void *h;
739     int id;
740 {
741     struct ssh_handle *rh = h;
742     struct ssh_stream *rs;
743
744     assert(rh != NULL);
745
746     if (id <= 0) {
747         security_seterror(&rh->sech,
748             "%d: invalid security stream id", id);
749         return (NULL);
750     }
751
752     rs = alloc(sizeof(*rs));
753     security_streaminit(&rs->secstr, &ssh_security_driver);
754     rs->handle = id;
755     rs->ev_read = NULL;
756     rs->rc = conn_get(rh->hostname);
757
758     sshprintf(("%s: ssh: stream_client: connected to stream %d\n", debug_prefix_time(NULL), id));
759
760     return (rs);
761 }
762
763 /*
764  * Close and unallocate resources for a stream.
765  */
766 static void
767 ssh_stream_close(s)
768     void *s;
769 {
770     struct ssh_stream *rs = s;
771
772     assert(rs != NULL);
773
774     sshprintf(("%s: ssh: stream_close: closing stream %d\n", debug_prefix_time(NULL), rs->handle));
775
776     ssh_stream_read_cancel(rs);
777     conn_put(rs->rc);
778     amfree(rs);
779 }
780
781 /*
782  * Authenticate a stream
783  * Nothing needed for ssh.  The connection is authenticated by sshd
784  * on startup.
785  */
786 static int
787 ssh_stream_auth(s)
788     void *s;
789 {
790
791     return (0);
792 }
793
794 /*
795  * Returns the stream id for this stream.  This is just the local
796  * port.
797  */
798 static int
799 ssh_stream_id(s)
800     void *s;
801 {
802     struct ssh_stream *rs = s;
803
804     assert(rs != NULL);
805
806     return (rs->handle);
807 }
808
809 /*
810  * Write a chunk of data to a stream.  Blocks until completion.
811  */
812 static int
813 ssh_stream_write(s, buf, size)
814     void *s;
815     const void *buf;
816     size_t size;
817 {
818     struct ssh_stream *rs = s;
819
820     assert(rs != NULL);
821
822     sshprintf(("%s: ssh: stream_write: writing %d bytes to %s:%d\n", debug_prefix_time(NULL), size,
823         rs->rc->hostname, rs->handle));
824
825     if (send_token(rs->rc, rs->handle, buf, size) < 0) {
826         security_stream_seterror(&rs->secstr, rs->rc->errmsg);
827         return (-1);
828     }
829     return (0);
830 }
831
832 /*
833  * Submit a request to read some data.  Calls back with the given
834  * function and arg when completed.
835  */
836 static void
837 ssh_stream_read(s, fn, arg)
838     void *s, *arg;
839     void (*fn) P((void *, void *, ssize_t));
840 {
841     struct ssh_stream *rs = s;
842
843     assert(rs != NULL);
844
845     /*
846      * Only one read request can be active per stream.
847      */
848     if (rs->ev_read == NULL) {
849         rs->ev_read = event_register((event_id_t)rs->rc, EV_WAIT,
850             stream_read_callback, rs);
851         conn_read(rs->rc);
852     }
853     rs->fn = fn;
854     rs->arg = arg;
855 }
856
857 /*
858  * Cancel a previous stream read request.  It's ok if we didn't have a read
859  * scheduled.
860  */
861 static void
862 ssh_stream_read_cancel(s)
863     void *s;
864 {
865     struct ssh_stream *rs = s;
866
867     assert(rs != NULL);
868
869     if (rs->ev_read != NULL) {
870         event_release(rs->ev_read);
871         rs->ev_read = NULL;
872         conn_read_cancel(rs->rc);
873     }
874 }
875
876 /*
877  * Callback for ssh_stream_read
878  */
879 static void
880 stream_read_callback(arg)
881     void *arg;
882 {
883     struct ssh_stream *rs = arg;
884     assert(rs != NULL);
885
886     sshprintf(("%s: ssh: stream_read_callback: handle %d\n", debug_prefix_time(NULL), rs->handle));
887
888     /*
889      * Make sure this was for us.  If it was, then blow away the handle
890      * so it doesn't get claimed twice.  Otherwise, leave it alone.
891      *
892      * If the handle is EOF, pass that up to our callback.
893      */
894     if (rs->rc->handle == rs->handle) {
895         sshprintf(("%s: ssh: stream_read_callback: it was for us\n", debug_prefix_time(NULL)));
896         rs->rc->handle = H_TAKEN;
897     } else if (rs->rc->handle != H_EOF) {
898         sshprintf(("%s: ssh: stream_read_callback: not for us\n", debug_prefix_time(NULL)));
899         return;
900     }
901
902     /*
903      * Remove the event first, and then call the callback.
904      * We remove it first because we don't want to get in their
905      * way if they reschedule it.
906      */
907     ssh_stream_read_cancel(rs);
908
909     if (rs->rc->pktlen == 0) {
910         sshprintf(("%s: ssh: stream_read_callback: EOF\n", debug_prefix_time(NULL)));
911         (*rs->fn)(rs->arg, NULL, 0);
912         return;
913     }
914     sshprintf(("%s: ssh: stream_read_callback: read %ld bytes from %s:%d\n", debug_prefix_time(NULL),
915         rs->rc->pktlen, rs->rc->hostname, rs->handle));
916     (*rs->fn)(rs->arg, rs->rc->pkt, rs->rc->pktlen);
917 }
918
919 /*
920  * The callback for the netfd for the event handler
921  * Determines if this packet is for this security handle,
922  * and does the real callback if so.
923  */
924 static void
925 conn_read_callback(cookie)
926     void *cookie;
927 {
928     struct ssh_conn *rc = cookie;
929     struct ssh_handle *rh;
930     pkt_t pkt;
931     int rval;
932
933     assert(cookie != NULL);
934
935     sshprintf(("%s: ssh: conn_read_callback\n",debug_prefix_time(NULL)));
936
937     /* Read the data off the wire.  If we get errors, shut down. */
938     rval = recv_token(rc, 60);
939     sshprintf(("%s: ssh: conn_read_callback: recv_token returned %d\n", debug_prefix_time(NULL), rval));
940     if (rval <= 0) {
941         rc->pktlen = 0;
942         rc->handle = H_EOF;
943         rval = event_wakeup((event_id_t)rc);
944         sshprintf(("%s: ssh: conn_read_callback: event_wakeup return %d\n", debug_prefix_time(NULL), rval));
945         /* delete our 'accept' reference */
946         if (accept_fn != NULL)
947             conn_put(rc);
948         accept_fn = NULL;
949         return;
950     }
951
952     /* If there are events waiting on this handle, we're done */
953     rval = event_wakeup((event_id_t)rc);
954     sshprintf(("%s: ssh: conn_read_callback: event_wakeup return %d\n", debug_prefix_time(NULL), rval));
955     if (rval > 0)
956         return;
957
958     /* If there is no accept fn registered, then drop the packet */
959     if (accept_fn == NULL)
960         return;
961
962     rh = alloc(sizeof(*rh));
963     security_handleinit(&rh->sech, &ssh_security_driver);
964     rh->hostname = rc->hostname;
965     rh->rs = ssh_stream_client(rh, rc->handle);
966     rh->ev_timeout = NULL;
967
968     sshprintf(("%s: ssh: new connection\n", debug_prefix_time(NULL)));
969     parse_pkt(&pkt, rc->pkt, rc->pktlen);
970     sshprintf(("%s: ssh: calling accept_fn\n", debug_prefix_time(NULL)));
971     (*accept_fn)(&rh->sech, &pkt);
972 }
973
974 static void
975 parse_pkt(pkt, buf, bufsize)
976     pkt_t *pkt;
977     const void *buf;
978     size_t bufsize;
979 {
980     const unsigned char *bufp = buf;
981
982     sshprintf(("%s: ssh: parse_pkt: parsing buffer of %d bytes\n", debug_prefix_time(NULL), bufsize));
983
984     pkt->type = (pktype_t)*bufp++;
985     bufsize--;
986
987     if (bufsize == 0) {
988         pkt->body[0] = '\0';
989     } else {
990         if (bufsize > sizeof(pkt->body) - 1)
991             bufsize = sizeof(pkt->body) - 1;
992         memcpy(pkt->body, bufp, bufsize);
993         pkt->body[sizeof(pkt->body) - 1] = '\0';
994     }
995
996     sshprintf(("%s: ssh: parse_pkt: %s (%d): \"%s\"\n", debug_prefix_time(NULL), pkt_type2str(pkt->type),
997         pkt->type, pkt->body));
998 }
999
1000
1001 /*
1002  * Transmits a chunk of data over a ssh_handle, adding
1003  * the necessary headers to allow the remote end to decode it.
1004  */
1005 static int
1006 send_token(rc, handle, buf, len)
1007     struct ssh_conn *rc;
1008     int handle;
1009     const void *buf;
1010     size_t len;
1011 {
1012     unsigned int netlength, nethandle;
1013     struct iovec iov[3];
1014
1015     sshprintf(("%s: ssh: send_token: handle %d writing %d bytes to %s\n", debug_prefix_time(NULL), handle, len,
1016         rc->hostname));
1017
1018     assert(sizeof(netlength) == 4);
1019
1020     /*
1021      * Format is:
1022      *   32 bit length (network byte order)
1023      *   32 bit handle (network byte order)
1024      *   data
1025      */
1026     netlength = htonl(len);
1027     iov[0].iov_base = (void *)&netlength;
1028     iov[0].iov_len = sizeof(netlength);
1029
1030     nethandle = htonl(handle);
1031     iov[1].iov_base = (void *)&nethandle;
1032     iov[1].iov_len = sizeof(nethandle);
1033
1034     iov[2].iov_base = (void *)buf;
1035     iov[2].iov_len = len;
1036
1037     if (net_writev(rc->write, iov, 3) < 0) {
1038         rc->errmsg = newvstralloc(rc->errmsg, "ssh write error to ",
1039             rc->hostname, ": ", strerror(errno), NULL);
1040         return (-1);
1041     }
1042     return (0);
1043 }
1044
1045 static int
1046 recv_token(rc, timeout)
1047     struct ssh_conn *rc;
1048     int timeout;
1049 {
1050     unsigned int netint;
1051
1052     assert(sizeof(netint) == 4);
1053
1054     assert(rc->read >= 0);
1055
1056     sshprintf(("%s: ssh: recv_token: reading from %s\n", debug_prefix_time(NULL), rc->hostname));
1057
1058     switch (net_read(rc, &netint, sizeof(netint), timeout)) {
1059     case -1:
1060         rc->errmsg = newvstralloc(rc->errmsg, "recv error: ", strerror(errno),
1061             NULL);
1062         sshprintf(("%s: ssh: recv_token: A return(-1)\n", debug_prefix_time(NULL)));
1063         return (-1);
1064     case 0:
1065         rc->pktlen = 0;
1066         sshprintf(("%s: ssh: recv_token: A return(0)\n", debug_prefix_time(NULL)));
1067         return (0);
1068     default:
1069         break;
1070     }
1071     rc->pktlen = ntohl(netint);
1072     if (rc->pktlen > sizeof(rc->pkt)) {
1073         rc->errmsg = newstralloc(rc->errmsg, "recv error: huge packet");
1074         sshprintf(("%s: ssh: recv_token: B return(-1)\n", debug_prefix_time(NULL)));
1075         return (-1);
1076     }
1077
1078     switch (net_read(rc, &netint, sizeof(netint), timeout)) {
1079     case -1:
1080         rc->errmsg = newvstralloc(rc->errmsg, "recv error: ", strerror(errno),
1081             NULL);
1082         sshprintf(("%s: ssh: recv_token: C return(-1)\n", debug_prefix_time(NULL)));
1083         return (-1);
1084     case 0:
1085         rc->pktlen = 0;
1086         sshprintf(("%s: ssh: recv_token: D return(0)\n", debug_prefix_time(NULL)));
1087         return (0);
1088     default:
1089         break;
1090     }
1091     rc->handle = ntohl(netint);
1092
1093     switch (net_read(rc, rc->pkt, rc->pktlen, timeout)) {
1094     case -1:
1095         rc->errmsg = newvstralloc(rc->errmsg, "recv error: ", strerror(errno),
1096             NULL);
1097         sshprintf(("%s: ssh: recv_token: E return(-1)\n", debug_prefix_time(NULL)));
1098         return (-1);
1099     case 0:
1100         rc->pktlen = 0;
1101         break;
1102     default:
1103         break;
1104     }
1105
1106     sshprintf(("%s: ssh: recv_token: read %ld bytes from %s\n", debug_prefix_time(NULL), rc->pktlen,
1107         rc->hostname));
1108     sshprintf(("%s: ssh: recv_token: end %d\n", debug_prefix_time(NULL),rc->pktlen));
1109     return (rc->pktlen);
1110 }
1111
1112 /*
1113  * Writes out the entire iovec
1114  */
1115 static int
1116 net_writev(fd, iov, iovcnt)
1117     int fd, iovcnt;
1118     struct iovec *iov;
1119 {
1120     int delta, n, total;
1121
1122     assert(iov != NULL);
1123
1124     total = 0;
1125     while (iovcnt > 0) {
1126         /*
1127          * Write the iovec
1128          */
1129         total += n = writev(fd, iov, iovcnt);
1130         if (n < 0)
1131             return (-1);
1132         if (n == 0) {
1133             errno = EIO;
1134             return (-1);
1135         }
1136         /*
1137          * Iterate through each iov.  Figure out what we still need
1138          * to write out.
1139          */
1140         for (; n > 0; iovcnt--, iov++) {
1141             /* 'delta' is the bytes written from this iovec */
1142             delta = n < iov->iov_len ? n : iov->iov_len;
1143             /* subtract from the total num bytes written */
1144             n -= delta;
1145             assert(n >= 0);
1146             /* subtract from this iovec */
1147             iov->iov_len -= delta;
1148             iov->iov_base = (char *)iov->iov_base + delta;
1149             /* if this iovec isn't empty, run the writev again */
1150             if (iov->iov_len > 0)
1151                 break;
1152         }
1153     }
1154     return (total);
1155 }
1156
1157 /*
1158  * Like read(), but waits until the entire buffer has been filled.
1159  */
1160 static ssize_t
1161 net_read(rc, vbuf, origsize, timeout)
1162     struct ssh_conn *rc;
1163     void *vbuf;
1164     size_t origsize;
1165     int timeout;
1166 {
1167     char *buf = vbuf, *off;     /* ptr arith */
1168     int nread;
1169     size_t size = origsize;
1170
1171     sshprintf(("%s: ssh: net_read: begin %d\n", debug_prefix_time(NULL), origsize));
1172     while (size > 0) {
1173         sshprintf(("%s: ssh: net_read: while %d\n", debug_prefix_time(NULL), size));
1174         if (rc->readbuf.left == 0) {
1175             if (net_read_fillbuf(rc, timeout, size) < 0) {
1176                 sshprintf(("%s: ssh: net_read: end retrun(-1)\n", debug_prefix_time(NULL)));
1177                 return (-1);
1178             }
1179             if (rc->readbuf.size == 0) {
1180                 sshprintf(("%s: ssh: net_read: end retrun(0)\n", debug_prefix_time(NULL)));
1181                 return (0);
1182             }
1183         }
1184         nread = min(rc->readbuf.left, size);
1185         off = rc->readbuf.buf + rc->readbuf.size - rc->readbuf.left;
1186         memcpy(buf, off, nread);
1187
1188         buf += nread;
1189         size -= nread;
1190         rc->readbuf.left -= nread;
1191     }
1192     sshprintf(("%s: ssh: net_read: end %d\n", debug_prefix_time(NULL), origsize));
1193     return ((ssize_t)origsize);
1194 }
1195
1196 /*
1197  * net_read likes to do a lot of little reads.  Buffer it.
1198  */
1199 static int
1200 net_read_fillbuf(rc, timeout, size)
1201     struct ssh_conn *rc;
1202     int timeout;
1203     int size;
1204 {
1205     fd_set readfds;
1206     struct timeval tv;
1207     if(size > sizeof(rc->readbuf.buf)) size = sizeof(rc->readbuf.buf);
1208
1209     sshprintf(("%s: ssh: net_read_fillbuf: begin\n", debug_prefix_time(NULL)));
1210     FD_ZERO(&readfds);
1211     FD_SET(rc->read, &readfds);
1212     tv.tv_sec = timeout;
1213     tv.tv_usec = 0;
1214     switch (select(rc->read + 1, &readfds, NULL, NULL, &tv)) {
1215     case 0:
1216         errno = ETIMEDOUT;
1217         /* FALLTHROUGH */
1218     case -1:
1219         sshprintf(("%s: ssh: net_read_fillbuf: case -1\n", debug_prefix_time(NULL)));
1220         return (-1);
1221     case 1:
1222         sshprintf(("%s: ssh: net_read_fillbuf: case 1\n", debug_prefix_time(NULL)));
1223         assert(FD_ISSET(rc->read, &readfds));
1224         break;
1225     default:
1226         sshprintf(("%s: ssh: net_read_fillbuf: case default\n", debug_prefix_time(NULL)));
1227         assert(0);
1228         break;
1229     }
1230     rc->readbuf.left = 0;
1231     rc->readbuf.size = read(rc->read, rc->readbuf.buf, size);
1232     if (rc->readbuf.size < 0)
1233         return (-1);
1234     rc->readbuf.left = rc->readbuf.size;
1235     sshprintf(("%s: ssh: net_read_fillbuf: end %d\n", debug_prefix_time(NULL),rc->readbuf.size));
1236     return (0);
1237 }
1238
1239 #endif  /* SSH_SECURITY */