xref: /dragonfly/crypto/openssh/kex.c (revision ba1276acd1c8c22d225b1bcf370a14c878644f44)
1 /* $OpenBSD: kex.c,v 1.186 2024/05/17 00:30:23 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39 
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44 
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61 
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66 
67 /* prototype */
68 static int kex_choose_conf(struct ssh *, uint32_t seq);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70 
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72           "KEX algorithms",
73           "host key algorithms",
74           "ciphers ctos",
75           "ciphers stoc",
76           "MACs ctos",
77           "MACs stoc",
78           "compression ctos",
79           "compression stoc",
80           "languages ctos",
81           "languages stoc",
82 };
83 
84 /*
85  * Fill out a proposal array with dynamically allocated values, which may
86  * be modified as required for compatibility reasons.
87  * Any of the options may be NULL, in which case the default is used.
88  * Array contents must be freed by calling kex_proposal_free_entries.
89  */
90 void
kex_proposal_populate_entries(struct ssh * ssh,char * prop[PROPOSAL_MAX],const char * kexalgos,const char * ciphers,const char * macs,const char * comp,const char * hkalgs)91 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
92     const char *kexalgos, const char *ciphers, const char *macs,
93     const char *comp, const char *hkalgs)
94 {
95           const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
96           const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
97           const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
98           u_int i;
99           char *cp;
100 
101           if (prop == NULL)
102                     fatal_f("proposal missing");
103 
104           /* Append EXT_INFO signalling to KexAlgorithms */
105           if (kexalgos == NULL)
106                     kexalgos = defprop[PROPOSAL_KEX_ALGS];
107           if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
108               "ext-info-s,kex-strict-s-v00@openssh.com" :
109               "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
110                     fatal_f("kex_names_cat");
111 
112           for (i = 0; i < PROPOSAL_MAX; i++) {
113                     switch(i) {
114                     case PROPOSAL_KEX_ALGS:
115                               prop[i] = compat_kex_proposal(ssh, cp);
116                               break;
117                     case PROPOSAL_ENC_ALGS_CTOS:
118                     case PROPOSAL_ENC_ALGS_STOC:
119                               prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
120                               break;
121                     case PROPOSAL_MAC_ALGS_CTOS:
122                     case PROPOSAL_MAC_ALGS_STOC:
123                               prop[i]  = xstrdup(macs ? macs : defprop[i]);
124                               break;
125                     case PROPOSAL_COMP_ALGS_CTOS:
126                     case PROPOSAL_COMP_ALGS_STOC:
127                               prop[i] = xstrdup(comp ? comp : defprop[i]);
128                               break;
129                     case PROPOSAL_SERVER_HOST_KEY_ALGS:
130                               prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
131                               break;
132                     default:
133                               prop[i] = xstrdup(defprop[i]);
134                     }
135           }
136           free(cp);
137 }
138 
139 void
kex_proposal_free_entries(char * prop[PROPOSAL_MAX])140 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
141 {
142           u_int i;
143 
144           for (i = 0; i < PROPOSAL_MAX; i++)
145                     free(prop[i]);
146 }
147 
148 /* put algorithm proposal into buffer */
149 int
kex_prop2buf(struct sshbuf * b,char * proposal[PROPOSAL_MAX])150 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
151 {
152           u_int i;
153           int r;
154 
155           sshbuf_reset(b);
156 
157           /*
158            * add a dummy cookie, the cookie will be overwritten by
159            * kex_send_kexinit(), each time a kexinit is set
160            */
161           for (i = 0; i < KEX_COOKIE_LEN; i++) {
162                     if ((r = sshbuf_put_u8(b, 0)) != 0)
163                               return r;
164           }
165           for (i = 0; i < PROPOSAL_MAX; i++) {
166                     if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
167                               return r;
168           }
169           if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
170               (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
171                     return r;
172           return 0;
173 }
174 
175 /* parse buffer and return algorithm proposal */
176 int
kex_buf2prop(struct sshbuf * raw,int * first_kex_follows,char *** propp)177 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
178 {
179           struct sshbuf *b = NULL;
180           u_char v;
181           u_int i;
182           char **proposal = NULL;
183           int r;
184 
185           *propp = NULL;
186           if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
187                     return SSH_ERR_ALLOC_FAIL;
188           if ((b = sshbuf_fromb(raw)) == NULL) {
189                     r = SSH_ERR_ALLOC_FAIL;
190                     goto out;
191           }
192           if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
193                     error_fr(r, "consume cookie");
194                     goto out;
195           }
196           /* extract kex init proposal strings */
197           for (i = 0; i < PROPOSAL_MAX; i++) {
198                     if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
199                               error_fr(r, "parse proposal %u", i);
200                               goto out;
201                     }
202                     debug2("%s: %s", proposal_names[i], proposal[i]);
203           }
204           /* first kex follows / reserved */
205           if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
206               (r = sshbuf_get_u32(b, &i)) != 0) { /* reserved */
207                     error_fr(r, "parse");
208                     goto out;
209           }
210           if (first_kex_follows != NULL)
211                     *first_kex_follows = v;
212           debug2("first_kex_follows %d ", v);
213           debug2("reserved %u ", i);
214           r = 0;
215           *propp = proposal;
216  out:
217           if (r != 0 && proposal != NULL)
218                     kex_prop_free(proposal);
219           sshbuf_free(b);
220           return r;
221 }
222 
223 void
kex_prop_free(char ** proposal)224 kex_prop_free(char **proposal)
225 {
226           u_int i;
227 
228           if (proposal == NULL)
229                     return;
230           for (i = 0; i < PROPOSAL_MAX; i++)
231                     free(proposal[i]);
232           free(proposal);
233 }
234 
235 int
kex_protocol_error(int type,u_int32_t seq,struct ssh * ssh)236 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
237 {
238           int r;
239 
240           /* If in strict mode, any unexpected message is an error */
241           if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
242                     ssh_packet_disconnect(ssh, "strict KEX violation: "
243                         "unexpected packet type %u (seqnr %u)", type, seq);
244           }
245           error_f("type %u seq %u", type, seq);
246           if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
247               (r = sshpkt_put_u32(ssh, seq)) != 0 ||
248               (r = sshpkt_send(ssh)) != 0)
249                     return r;
250           return 0;
251 }
252 
253 static void
kex_reset_dispatch(struct ssh * ssh)254 kex_reset_dispatch(struct ssh *ssh)
255 {
256           ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
257               SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
258 }
259 
260 void
kex_set_server_sig_algs(struct ssh * ssh,const char * allowed_algs)261 kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
262 {
263           char *alg, *oalgs, *algs, *sigalgs;
264           const char *sigalg;
265 
266           /*
267            * NB. allowed algorithms may contain certificate algorithms that
268            * map to a specific plain signature type, e.g.
269            * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
270            * We need to be careful here to match these, retain the mapping
271            * and only add each signature algorithm once.
272            */
273           if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
274                     fatal_f("sshkey_alg_list failed");
275           oalgs = algs = xstrdup(allowed_algs);
276           free(ssh->kex->server_sig_algs);
277           ssh->kex->server_sig_algs = NULL;
278           for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
279               (alg = strsep(&algs, ","))) {
280                     if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
281                               continue;
282                     if (!kex_has_any_alg(sigalg, sigalgs))
283                               continue;
284                     /* Don't add an algorithm twice. */
285                     if (ssh->kex->server_sig_algs != NULL &&
286                         kex_has_any_alg(sigalg, ssh->kex->server_sig_algs))
287                               continue;
288                     xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
289           }
290           free(oalgs);
291           free(sigalgs);
292           if (ssh->kex->server_sig_algs == NULL)
293                     ssh->kex->server_sig_algs = xstrdup("");
294 }
295 
296 static int
kex_compose_ext_info_server(struct ssh * ssh,struct sshbuf * m)297 kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
298 {
299           int r;
300 
301           if (ssh->kex->server_sig_algs == NULL &&
302               (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
303                     return SSH_ERR_ALLOC_FAIL;
304           if ((r = sshbuf_put_u32(m, 3)) != 0 ||
305               (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
306               (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
307               (r = sshbuf_put_cstring(m,
308               "publickey-hostbound@openssh.com")) != 0 ||
309               (r = sshbuf_put_cstring(m, "0")) != 0 ||
310               (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
311               (r = sshbuf_put_cstring(m, "0")) != 0) {
312                     error_fr(r, "compose");
313                     return r;
314           }
315           return 0;
316 }
317 
318 static int
kex_compose_ext_info_client(struct ssh * ssh,struct sshbuf * m)319 kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
320 {
321           int r;
322 
323           if ((r = sshbuf_put_u32(m, 1)) != 0 ||
324               (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
325               (r = sshbuf_put_cstring(m, "0")) != 0) {
326                     error_fr(r, "compose");
327                     goto out;
328           }
329           /* success */
330           r = 0;
331  out:
332           return r;
333 }
334 
335 static int
kex_maybe_send_ext_info(struct ssh * ssh)336 kex_maybe_send_ext_info(struct ssh *ssh)
337 {
338           int r;
339           struct sshbuf *m = NULL;
340 
341           if ((ssh->kex->flags & KEX_INITIAL) == 0)
342                     return 0;
343           if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
344                     return 0;
345 
346           /* Compose EXT_INFO packet. */
347           if ((m = sshbuf_new()) == NULL)
348                     fatal_f("sshbuf_new failed");
349           if (ssh->kex->ext_info_c &&
350               (r = kex_compose_ext_info_server(ssh, m)) != 0)
351                     goto fail;
352           if (ssh->kex->ext_info_s &&
353               (r = kex_compose_ext_info_client(ssh, m)) != 0)
354                     goto fail;
355 
356           /* Send the actual KEX_INFO packet */
357           debug("Sending SSH2_MSG_EXT_INFO");
358           if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
359               (r = sshpkt_putb(ssh, m)) != 0 ||
360               (r = sshpkt_send(ssh)) != 0) {
361                     error_f("send EXT_INFO");
362                     goto fail;
363           }
364 
365           r = 0;
366 
367  fail:
368           sshbuf_free(m);
369           return r;
370 }
371 
372 int
kex_server_update_ext_info(struct ssh * ssh)373 kex_server_update_ext_info(struct ssh *ssh)
374 {
375           int r;
376 
377           if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
378                     return 0;
379 
380           debug_f("Sending SSH2_MSG_EXT_INFO");
381           if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
382               (r = sshpkt_put_u32(ssh, 1)) != 0 ||
383               (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
384               (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
385               (r = sshpkt_send(ssh)) != 0) {
386                     error_f("send EXT_INFO");
387                     return r;
388           }
389           return 0;
390 }
391 
392 int
kex_send_newkeys(struct ssh * ssh)393 kex_send_newkeys(struct ssh *ssh)
394 {
395           int r;
396 
397           kex_reset_dispatch(ssh);
398           if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
399               (r = sshpkt_send(ssh)) != 0)
400                     return r;
401           debug("SSH2_MSG_NEWKEYS sent");
402           ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
403           if ((r = kex_maybe_send_ext_info(ssh)) != 0)
404                     return r;
405           debug("expecting SSH2_MSG_NEWKEYS");
406           return 0;
407 }
408 
409 /* Check whether an ext_info value contains the expected version string */
410 static int
kex_ext_info_check_ver(struct kex * kex,const char * name,const u_char * val,size_t len,const char * want_ver,u_int flag)411 kex_ext_info_check_ver(struct kex *kex, const char *name,
412     const u_char *val, size_t len, const char *want_ver, u_int flag)
413 {
414           if (memchr(val, '\0', len) != NULL) {
415                     error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
416                     return SSH_ERR_INVALID_FORMAT;
417           }
418           debug_f("%s=<%s>", name, val);
419           if (strcmp(val, want_ver) == 0)
420                     kex->flags |= flag;
421           else
422                     debug_f("unsupported version of %s extension", name);
423           return 0;
424 }
425 
426 static int
kex_ext_info_client_parse(struct ssh * ssh,const char * name,const u_char * value,size_t vlen)427 kex_ext_info_client_parse(struct ssh *ssh, const char *name,
428     const u_char *value, size_t vlen)
429 {
430           int r;
431 
432           /* NB. some messages are only accepted in the initial EXT_INFO */
433           if (strcmp(name, "server-sig-algs") == 0) {
434                     /* Ensure no \0 lurking in value */
435                     if (memchr(value, '\0', vlen) != NULL) {
436                               error_f("nul byte in %s", name);
437                               return SSH_ERR_INVALID_FORMAT;
438                     }
439                     debug_f("%s=<%s>", name, value);
440                     free(ssh->kex->server_sig_algs);
441                     ssh->kex->server_sig_algs = xstrdup((const char *)value);
442           } else if (ssh->kex->ext_info_received == 1 &&
443               strcmp(name, "publickey-hostbound@openssh.com") == 0) {
444                     if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
445                         "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
446                               return r;
447                     }
448           } else if (ssh->kex->ext_info_received == 1 &&
449               strcmp(name, "ping@openssh.com") == 0) {
450                     if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
451                         "0", KEX_HAS_PING)) != 0) {
452                               return r;
453                     }
454           } else
455                     debug_f("%s (unrecognised)", name);
456 
457           return 0;
458 }
459 
460 static int
kex_ext_info_server_parse(struct ssh * ssh,const char * name,const u_char * value,size_t vlen)461 kex_ext_info_server_parse(struct ssh *ssh, const char *name,
462     const u_char *value, size_t vlen)
463 {
464           int r;
465 
466           if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
467                     if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
468                         "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
469                               return r;
470                     }
471           } else
472                     debug_f("%s (unrecognised)", name);
473           return 0;
474 }
475 
476 int
kex_input_ext_info(int type,u_int32_t seq,struct ssh * ssh)477 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
478 {
479           struct kex *kex = ssh->kex;
480           const int max_ext_info = kex->server ? 1 : 2;
481           u_int32_t i, ninfo;
482           char *name;
483           u_char *val;
484           size_t vlen;
485           int r;
486 
487           debug("SSH2_MSG_EXT_INFO received");
488           if (++kex->ext_info_received > max_ext_info) {
489                     error("too many SSH2_MSG_EXT_INFO messages sent by peer");
490                     return dispatch_protocol_error(type, seq, ssh);
491           }
492           ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
493           if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
494                     return r;
495           if (ninfo >= 1024) {
496                     error("SSH2_MSG_EXT_INFO with too many entries, expected "
497                         "<=1024, received %u", ninfo);
498                     return dispatch_protocol_error(type, seq, ssh);
499           }
500           for (i = 0; i < ninfo; i++) {
501                     if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
502                               return r;
503                     if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
504                               free(name);
505                               return r;
506                     }
507                     debug3_f("extension %s", name);
508                     if (kex->server) {
509                               if ((r = kex_ext_info_server_parse(ssh, name,
510                                   val, vlen)) != 0)
511                                         return r;
512                     } else {
513                               if ((r = kex_ext_info_client_parse(ssh, name,
514                                   val, vlen)) != 0)
515                                         return r;
516                     }
517                     free(name);
518                     free(val);
519           }
520           return sshpkt_get_end(ssh);
521 }
522 
523 static int
kex_input_newkeys(int type,u_int32_t seq,struct ssh * ssh)524 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
525 {
526           struct kex *kex = ssh->kex;
527           int r, initial = (kex->flags & KEX_INITIAL) != 0;
528           char *cp, **prop;
529 
530           debug("SSH2_MSG_NEWKEYS received");
531           if (kex->ext_info_c && initial)
532                     ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
533           ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
534           ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
535           if ((r = sshpkt_get_end(ssh)) != 0)
536                     return r;
537           if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
538                     return r;
539           if (initial) {
540                     /* Remove initial KEX signalling from proposal for rekeying */
541                     if ((r = kex_buf2prop(kex->my, NULL, &prop)) != 0)
542                               return r;
543                     if ((cp = match_filter_denylist(prop[PROPOSAL_KEX_ALGS],
544                         kex->server ?
545                         "ext-info-s,kex-strict-s-v00@openssh.com" :
546                         "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL) {
547                               error_f("match_filter_denylist failed");
548                               goto fail;
549                     }
550                     free(prop[PROPOSAL_KEX_ALGS]);
551                     prop[PROPOSAL_KEX_ALGS] = cp;
552                     if ((r = kex_prop2buf(ssh->kex->my, prop)) != 0) {
553                               error_f("kex_prop2buf failed");
554  fail:
555                               kex_proposal_free_entries(prop);
556                               free(prop);
557                               return SSH_ERR_INTERNAL_ERROR;
558                     }
559                     kex_proposal_free_entries(prop);
560                     free(prop);
561           }
562           kex->done = 1;
563           kex->flags &= ~KEX_INITIAL;
564           sshbuf_reset(kex->peer);
565           kex->flags &= ~KEX_INIT_SENT;
566           free(kex->name);
567           kex->name = NULL;
568           return 0;
569 }
570 
571 int
kex_send_kexinit(struct ssh * ssh)572 kex_send_kexinit(struct ssh *ssh)
573 {
574           u_char *cookie;
575           struct kex *kex = ssh->kex;
576           int r;
577 
578           if (kex == NULL) {
579                     error_f("no kex");
580                     return SSH_ERR_INTERNAL_ERROR;
581           }
582           if (kex->flags & KEX_INIT_SENT)
583                     return 0;
584           kex->done = 0;
585 
586           /* generate a random cookie */
587           if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
588                     error_f("bad kex length: %zu < %d",
589                         sshbuf_len(kex->my), KEX_COOKIE_LEN);
590                     return SSH_ERR_INVALID_FORMAT;
591           }
592           if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
593                     error_f("buffer error");
594                     return SSH_ERR_INTERNAL_ERROR;
595           }
596           arc4random_buf(cookie, KEX_COOKIE_LEN);
597 
598           if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
599               (r = sshpkt_putb(ssh, kex->my)) != 0 ||
600               (r = sshpkt_send(ssh)) != 0) {
601                     error_fr(r, "compose reply");
602                     return r;
603           }
604           debug("SSH2_MSG_KEXINIT sent");
605           kex->flags |= KEX_INIT_SENT;
606           return 0;
607 }
608 
609 int
kex_input_kexinit(int type,u_int32_t seq,struct ssh * ssh)610 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
611 {
612           struct kex *kex = ssh->kex;
613           const u_char *ptr;
614           u_int i;
615           size_t dlen;
616           int r;
617 
618           debug("SSH2_MSG_KEXINIT received");
619           if (kex == NULL) {
620                     error_f("no kex");
621                     return SSH_ERR_INTERNAL_ERROR;
622           }
623           ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
624           ptr = sshpkt_ptr(ssh, &dlen);
625           if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
626                     return r;
627 
628           /* discard packet */
629           for (i = 0; i < KEX_COOKIE_LEN; i++) {
630                     if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
631                               error_fr(r, "discard cookie");
632                               return r;
633                     }
634           }
635           for (i = 0; i < PROPOSAL_MAX; i++) {
636                     if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
637                               error_fr(r, "discard proposal");
638                               return r;
639                     }
640           }
641           /*
642            * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
643            * KEX method has the server move first, but a server might be using
644            * a custom method or one that we otherwise don't support. We should
645            * be prepared to remember first_kex_follows here so we can eat a
646            * packet later.
647            * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
648            * for cases where the server *doesn't* go first. I guess we should
649            * ignore it when it is set for these cases, which is what we do now.
650            */
651           if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||        /* first_kex_follows */
652               (r = sshpkt_get_u32(ssh, NULL)) != 0 ||       /* reserved */
653               (r = sshpkt_get_end(ssh)) != 0)
654                               return r;
655 
656           if (!(kex->flags & KEX_INIT_SENT))
657                     if ((r = kex_send_kexinit(ssh)) != 0)
658                               return r;
659           if ((r = kex_choose_conf(ssh, seq)) != 0)
660                     return r;
661 
662           if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
663                     return (kex->kex[kex->kex_type])(ssh);
664 
665           error_f("unknown kex type %u", kex->kex_type);
666           return SSH_ERR_INTERNAL_ERROR;
667 }
668 
669 struct kex *
kex_new(void)670 kex_new(void)
671 {
672           struct kex *kex;
673 
674           if ((kex = calloc(1, sizeof(*kex))) == NULL ||
675               (kex->peer = sshbuf_new()) == NULL ||
676               (kex->my = sshbuf_new()) == NULL ||
677               (kex->client_version = sshbuf_new()) == NULL ||
678               (kex->server_version = sshbuf_new()) == NULL ||
679               (kex->session_id = sshbuf_new()) == NULL) {
680                     kex_free(kex);
681                     return NULL;
682           }
683           return kex;
684 }
685 
686 void
kex_free_newkeys(struct newkeys * newkeys)687 kex_free_newkeys(struct newkeys *newkeys)
688 {
689           if (newkeys == NULL)
690                     return;
691           if (newkeys->enc.key) {
692                     explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
693                     free(newkeys->enc.key);
694                     newkeys->enc.key = NULL;
695           }
696           if (newkeys->enc.iv) {
697                     explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
698                     free(newkeys->enc.iv);
699                     newkeys->enc.iv = NULL;
700           }
701           free(newkeys->enc.name);
702           explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
703           free(newkeys->comp.name);
704           explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
705           mac_clear(&newkeys->mac);
706           if (newkeys->mac.key) {
707                     explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
708                     free(newkeys->mac.key);
709                     newkeys->mac.key = NULL;
710           }
711           free(newkeys->mac.name);
712           explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
713           freezero(newkeys, sizeof(*newkeys));
714 }
715 
716 void
kex_free(struct kex * kex)717 kex_free(struct kex *kex)
718 {
719           u_int mode;
720 
721           if (kex == NULL)
722                     return;
723 
724 #ifdef WITH_OPENSSL
725           DH_free(kex->dh);
726 #ifdef OPENSSL_HAS_ECC
727           EC_KEY_free(kex->ec_client_key);
728 #endif /* OPENSSL_HAS_ECC */
729 #endif /* WITH_OPENSSL */
730           for (mode = 0; mode < MODE_MAX; mode++) {
731                     kex_free_newkeys(kex->newkeys[mode]);
732                     kex->newkeys[mode] = NULL;
733           }
734           sshbuf_free(kex->peer);
735           sshbuf_free(kex->my);
736           sshbuf_free(kex->client_version);
737           sshbuf_free(kex->server_version);
738           sshbuf_free(kex->client_pub);
739           sshbuf_free(kex->session_id);
740           sshbuf_free(kex->initial_sig);
741           sshkey_free(kex->initial_hostkey);
742           free(kex->failed_choice);
743           free(kex->hostkey_alg);
744           free(kex->name);
745           free(kex);
746 }
747 
748 int
kex_ready(struct ssh * ssh,char * proposal[PROPOSAL_MAX])749 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
750 {
751           int r;
752 
753           if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
754                     return r;
755           ssh->kex->flags = KEX_INITIAL;
756           kex_reset_dispatch(ssh);
757           ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
758           return 0;
759 }
760 
761 int
kex_setup(struct ssh * ssh,char * proposal[PROPOSAL_MAX])762 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
763 {
764           int r;
765 
766           if ((r = kex_ready(ssh, proposal)) != 0)
767                     return r;
768           if ((r = kex_send_kexinit(ssh)) != 0) {           /* we start */
769                     kex_free(ssh->kex);
770                     ssh->kex = NULL;
771                     return r;
772           }
773           return 0;
774 }
775 
776 /*
777  * Request key re-exchange, returns 0 on success or a ssherr.h error
778  * code otherwise. Must not be called if KEX is incomplete or in-progress.
779  */
780 int
kex_start_rekex(struct ssh * ssh)781 kex_start_rekex(struct ssh *ssh)
782 {
783           if (ssh->kex == NULL) {
784                     error_f("no kex");
785                     return SSH_ERR_INTERNAL_ERROR;
786           }
787           if (ssh->kex->done == 0) {
788                     error_f("requested twice");
789                     return SSH_ERR_INTERNAL_ERROR;
790           }
791           ssh->kex->done = 0;
792           return kex_send_kexinit(ssh);
793 }
794 
795 static int
choose_enc(struct sshenc * enc,char * client,char * server)796 choose_enc(struct sshenc *enc, char *client, char *server)
797 {
798           char *name = match_list(client, server, NULL);
799 
800           if (name == NULL)
801                     return SSH_ERR_NO_CIPHER_ALG_MATCH;
802           if ((enc->cipher = cipher_by_name(name)) == NULL) {
803                     error_f("unsupported cipher %s", name);
804                     free(name);
805                     return SSH_ERR_INTERNAL_ERROR;
806           }
807           enc->name = name;
808           enc->enabled = 0;
809           enc->iv = NULL;
810           enc->iv_len = cipher_ivlen(enc->cipher);
811           enc->key = NULL;
812           enc->key_len = cipher_keylen(enc->cipher);
813           enc->block_size = cipher_blocksize(enc->cipher);
814           return 0;
815 }
816 
817 static int
choose_mac(struct ssh * ssh,struct sshmac * mac,char * client,char * server)818 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
819 {
820           char *name = match_list(client, server, NULL);
821 
822           if (name == NULL)
823                     return SSH_ERR_NO_MAC_ALG_MATCH;
824           if (mac_setup(mac, name) < 0) {
825                     error_f("unsupported MAC %s", name);
826                     free(name);
827                     return SSH_ERR_INTERNAL_ERROR;
828           }
829           mac->name = name;
830           mac->key = NULL;
831           mac->enabled = 0;
832           return 0;
833 }
834 
835 static int
choose_comp(struct sshcomp * comp,char * client,char * server)836 choose_comp(struct sshcomp *comp, char *client, char *server)
837 {
838           char *name = match_list(client, server, NULL);
839 
840           if (name == NULL)
841                     return SSH_ERR_NO_COMPRESS_ALG_MATCH;
842 #ifdef WITH_ZLIB
843           if (strcmp(name, "zlib@openssh.com") == 0) {
844                     comp->type = COMP_DELAYED;
845           } else if (strcmp(name, "zlib") == 0) {
846                     comp->type = COMP_ZLIB;
847           } else
848 #endif    /* WITH_ZLIB */
849           if (strcmp(name, "none") == 0) {
850                     comp->type = COMP_NONE;
851           } else {
852                     error_f("unsupported compression scheme %s", name);
853                     free(name);
854                     return SSH_ERR_INTERNAL_ERROR;
855           }
856           comp->name = name;
857           return 0;
858 }
859 
860 static int
choose_kex(struct kex * k,char * client,char * server)861 choose_kex(struct kex *k, char *client, char *server)
862 {
863           k->name = match_list(client, server, NULL);
864 
865           debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
866           if (k->name == NULL)
867                     return SSH_ERR_NO_KEX_ALG_MATCH;
868           if (!kex_name_valid(k->name)) {
869                     error_f("unsupported KEX method %s", k->name);
870                     return SSH_ERR_INTERNAL_ERROR;
871           }
872           k->kex_type = kex_type_from_name(k->name);
873           k->hash_alg = kex_hash_from_name(k->name);
874           k->ec_nid = kex_nid_from_name(k->name);
875           return 0;
876 }
877 
878 static int
choose_hostkeyalg(struct kex * k,char * client,char * server)879 choose_hostkeyalg(struct kex *k, char *client, char *server)
880 {
881           free(k->hostkey_alg);
882           k->hostkey_alg = match_list(client, server, NULL);
883 
884           debug("kex: host key algorithm: %s",
885               k->hostkey_alg ? k->hostkey_alg : "(no match)");
886           if (k->hostkey_alg == NULL)
887                     return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
888           k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
889           if (k->hostkey_type == KEY_UNSPEC) {
890                     error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
891                     return SSH_ERR_INTERNAL_ERROR;
892           }
893           k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
894           return 0;
895 }
896 
897 static int
proposals_match(char * my[PROPOSAL_MAX],char * peer[PROPOSAL_MAX])898 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
899 {
900           static int check[] = {
901                     PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
902           };
903           int *idx;
904           char *p;
905 
906           for (idx = &check[0]; *idx != -1; idx++) {
907                     if ((p = strchr(my[*idx], ',')) != NULL)
908                               *p = '\0';
909                     if ((p = strchr(peer[*idx], ',')) != NULL)
910                               *p = '\0';
911                     if (strcmp(my[*idx], peer[*idx]) != 0) {
912                               debug2("proposal mismatch: my %s peer %s",
913                                   my[*idx], peer[*idx]);
914                               return (0);
915                     }
916           }
917           debug2("proposals match");
918           return (1);
919 }
920 
921 static int
kexalgs_contains(char ** peer,const char * ext)922 kexalgs_contains(char **peer, const char *ext)
923 {
924           return kex_has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
925 }
926 
927 static int
kex_choose_conf(struct ssh * ssh,uint32_t seq)928 kex_choose_conf(struct ssh *ssh, uint32_t seq)
929 {
930           struct kex *kex = ssh->kex;
931           struct newkeys *newkeys;
932           char **my = NULL, **peer = NULL;
933           char **cprop, **sprop;
934           int nenc, nmac, ncomp;
935           u_int mode, ctos, need, dh_need, authlen;
936           int r, first_kex_follows;
937 
938           debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
939           if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
940                     goto out;
941           debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
942           if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
943                     goto out;
944 
945           if (kex->server) {
946                     cprop=peer;
947                     sprop=my;
948           } else {
949                     cprop=my;
950                     sprop=peer;
951           }
952 
953           /* Check whether peer supports ext_info/kex_strict */
954           if ((kex->flags & KEX_INITIAL) != 0) {
955                     if (kex->server) {
956                               kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
957                               kex->kex_strict = kexalgs_contains(peer,
958                                   "kex-strict-c-v00@openssh.com");
959                     } else {
960                               kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
961                               kex->kex_strict = kexalgs_contains(peer,
962                                   "kex-strict-s-v00@openssh.com");
963                     }
964                     if (kex->kex_strict) {
965                               debug3_f("will use strict KEX ordering");
966                               if (seq != 0)
967                                         ssh_packet_disconnect(ssh,
968                                             "strict KEX violation: "
969                                             "KEXINIT was not the first packet");
970                     }
971           }
972 
973           /* Check whether client supports rsa-sha2 algorithms */
974           if (kex->server && (kex->flags & KEX_INITIAL)) {
975                     if (kex_has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
976                         "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
977                               kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
978                     if (kex_has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
979                         "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
980                               kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
981           }
982 
983           /* Algorithm Negotiation */
984           if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
985               sprop[PROPOSAL_KEX_ALGS])) != 0) {
986                     kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
987                     peer[PROPOSAL_KEX_ALGS] = NULL;
988                     goto out;
989           }
990           if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
991               sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
992                     kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
993                     peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
994                     goto out;
995           }
996           for (mode = 0; mode < MODE_MAX; mode++) {
997                     if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
998                               r = SSH_ERR_ALLOC_FAIL;
999                               goto out;
1000                     }
1001                     kex->newkeys[mode] = newkeys;
1002                     ctos = (!kex->server && mode == MODE_OUT) ||
1003                         (kex->server && mode == MODE_IN);
1004                     nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1005                     nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1006                     ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1007                     if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1008                         sprop[nenc])) != 0) {
1009                               kex->failed_choice = peer[nenc];
1010                               peer[nenc] = NULL;
1011                               goto out;
1012                     }
1013                     authlen = cipher_authlen(newkeys->enc.cipher);
1014                     /* ignore mac for authenticated encryption */
1015                     if (authlen == 0 &&
1016                         (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1017                         sprop[nmac])) != 0) {
1018                               kex->failed_choice = peer[nmac];
1019                               peer[nmac] = NULL;
1020                               goto out;
1021                     }
1022                     if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1023                         sprop[ncomp])) != 0) {
1024                               kex->failed_choice = peer[ncomp];
1025                               peer[ncomp] = NULL;
1026                               goto out;
1027                     }
1028                     debug("kex: %s cipher: %s MAC: %s compression: %s",
1029                         ctos ? "client->server" : "server->client",
1030                         newkeys->enc.name,
1031                         authlen == 0 ? newkeys->mac.name : "<implicit>",
1032                         newkeys->comp.name);
1033           }
1034           need = dh_need = 0;
1035           for (mode = 0; mode < MODE_MAX; mode++) {
1036                     newkeys = kex->newkeys[mode];
1037                     need = MAXIMUM(need, newkeys->enc.key_len);
1038                     need = MAXIMUM(need, newkeys->enc.block_size);
1039                     need = MAXIMUM(need, newkeys->enc.iv_len);
1040                     need = MAXIMUM(need, newkeys->mac.key_len);
1041                     dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1042                     dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1043                     dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1044                     dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1045           }
1046           /* XXX need runden? */
1047           kex->we_need = need;
1048           kex->dh_need = dh_need;
1049 
1050           /* ignore the next message if the proposals do not match */
1051           if (first_kex_follows && !proposals_match(my, peer))
1052                     ssh->dispatch_skip_packets = 1;
1053           r = 0;
1054  out:
1055           kex_prop_free(my);
1056           kex_prop_free(peer);
1057           return r;
1058 }
1059 
1060 static int
derive_key(struct ssh * ssh,int id,u_int need,u_char * hash,u_int hashlen,const struct sshbuf * shared_secret,u_char ** keyp)1061 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1062     const struct sshbuf *shared_secret, u_char **keyp)
1063 {
1064           struct kex *kex = ssh->kex;
1065           struct ssh_digest_ctx *hashctx = NULL;
1066           char c = id;
1067           u_int have;
1068           size_t mdsz;
1069           u_char *digest;
1070           int r;
1071 
1072           if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1073                     return SSH_ERR_INVALID_ARGUMENT;
1074           if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1075                     r = SSH_ERR_ALLOC_FAIL;
1076                     goto out;
1077           }
1078 
1079           /* K1 = HASH(K || H || "A" || session_id) */
1080           if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1081               ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1082               ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1083               ssh_digest_update(hashctx, &c, 1) != 0 ||
1084               ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1085               ssh_digest_final(hashctx, digest, mdsz) != 0) {
1086                     r = SSH_ERR_LIBCRYPTO_ERROR;
1087                     error_f("KEX hash failed");
1088                     goto out;
1089           }
1090           ssh_digest_free(hashctx);
1091           hashctx = NULL;
1092 
1093           /*
1094            * expand key:
1095            * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1096            * Key = K1 || K2 || ... || Kn
1097            */
1098           for (have = mdsz; need > have; have += mdsz) {
1099                     if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1100                         ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1101                         ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1102                         ssh_digest_update(hashctx, digest, have) != 0 ||
1103                         ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1104                               error_f("KDF failed");
1105                               r = SSH_ERR_LIBCRYPTO_ERROR;
1106                               goto out;
1107                     }
1108                     ssh_digest_free(hashctx);
1109                     hashctx = NULL;
1110           }
1111 #ifdef DEBUG_KEX
1112           fprintf(stderr, "key '%c'== ", c);
1113           dump_digest("key", digest, need);
1114 #endif
1115           *keyp = digest;
1116           digest = NULL;
1117           r = 0;
1118  out:
1119           free(digest);
1120           ssh_digest_free(hashctx);
1121           return r;
1122 }
1123 
1124 #define NKEYS       6
1125 int
kex_derive_keys(struct ssh * ssh,u_char * hash,u_int hashlen,const struct sshbuf * shared_secret)1126 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1127     const struct sshbuf *shared_secret)
1128 {
1129           struct kex *kex = ssh->kex;
1130           u_char *keys[NKEYS];
1131           u_int i, j, mode, ctos;
1132           int r;
1133 
1134           /* save initial hash as session id */
1135           if ((kex->flags & KEX_INITIAL) != 0) {
1136                     if (sshbuf_len(kex->session_id) != 0) {
1137                               error_f("already have session ID at kex");
1138                               return SSH_ERR_INTERNAL_ERROR;
1139                     }
1140                     if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1141                               return r;
1142           } else if (sshbuf_len(kex->session_id) == 0) {
1143                     error_f("no session ID in rekex");
1144                     return SSH_ERR_INTERNAL_ERROR;
1145           }
1146           for (i = 0; i < NKEYS; i++) {
1147                     if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1148                         shared_secret, &keys[i])) != 0) {
1149                               for (j = 0; j < i; j++)
1150                                         free(keys[j]);
1151                               return r;
1152                     }
1153           }
1154           for (mode = 0; mode < MODE_MAX; mode++) {
1155                     ctos = (!kex->server && mode == MODE_OUT) ||
1156                         (kex->server && mode == MODE_IN);
1157                     kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1158                     kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1159                     kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1160           }
1161           return 0;
1162 }
1163 
1164 int
kex_load_hostkey(struct ssh * ssh,struct sshkey ** prvp,struct sshkey ** pubp)1165 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1166 {
1167           struct kex *kex = ssh->kex;
1168 
1169           *pubp = NULL;
1170           *prvp = NULL;
1171           if (kex->load_host_public_key == NULL ||
1172               kex->load_host_private_key == NULL) {
1173                     error_f("missing hostkey loader");
1174                     return SSH_ERR_INVALID_ARGUMENT;
1175           }
1176           *pubp = kex->load_host_public_key(kex->hostkey_type,
1177               kex->hostkey_nid, ssh);
1178           *prvp = kex->load_host_private_key(kex->hostkey_type,
1179               kex->hostkey_nid, ssh);
1180           if (*pubp == NULL)
1181                     return SSH_ERR_NO_HOSTKEY_LOADED;
1182           return 0;
1183 }
1184 
1185 int
kex_verify_host_key(struct ssh * ssh,struct sshkey * server_host_key)1186 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1187 {
1188           struct kex *kex = ssh->kex;
1189 
1190           if (kex->verify_host_key == NULL) {
1191                     error_f("missing hostkey verifier");
1192                     return SSH_ERR_INVALID_ARGUMENT;
1193           }
1194           if (server_host_key->type != kex->hostkey_type ||
1195               (kex->hostkey_type == KEY_ECDSA &&
1196               server_host_key->ecdsa_nid != kex->hostkey_nid))
1197                     return SSH_ERR_KEY_TYPE_MISMATCH;
1198           if (kex->verify_host_key(server_host_key, ssh) == -1)
1199                     return  SSH_ERR_SIGNATURE_INVALID;
1200           return 0;
1201 }
1202 
1203 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1204 void
dump_digest(const char * msg,const u_char * digest,int len)1205 dump_digest(const char *msg, const u_char *digest, int len)
1206 {
1207           fprintf(stderr, "%s\n", msg);
1208           sshbuf_dump_data(digest, len, stderr);
1209 }
1210 #endif
1211 
1212 /*
1213  * Send a plaintext error message to the peer, suffixed by \r\n.
1214  * Only used during banner exchange, and there only for the server.
1215  */
1216 static void
send_error(struct ssh * ssh,char * msg)1217 send_error(struct ssh *ssh, char *msg)
1218 {
1219           char *crnl = "\r\n";
1220 
1221           if (!ssh->kex->server)
1222                     return;
1223 
1224           if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1225               msg, strlen(msg)) != strlen(msg) ||
1226               atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1227               crnl, strlen(crnl)) != strlen(crnl))
1228                     error_f("write: %.100s", strerror(errno));
1229 }
1230 
1231 /*
1232  * Sends our identification string and waits for the peer's. Will block for
1233  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1234  * Returns on 0 success or a ssherr.h code on failure.
1235  */
1236 int
kex_exchange_identification(struct ssh * ssh,int timeout_ms,const char * version_addendum)1237 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1238     const char *version_addendum)
1239 {
1240           int remote_major, remote_minor, mismatch, oerrno = 0;
1241           size_t len, n;
1242           int r, expect_nl;
1243           u_char c;
1244           struct sshbuf *our_version = ssh->kex->server ?
1245               ssh->kex->server_version : ssh->kex->client_version;
1246           struct sshbuf *peer_version = ssh->kex->server ?
1247               ssh->kex->client_version : ssh->kex->server_version;
1248           char *our_version_string = NULL, *peer_version_string = NULL;
1249           char *cp, *remote_version = NULL;
1250 
1251           /* Prepare and send our banner */
1252           sshbuf_reset(our_version);
1253           if (version_addendum != NULL && *version_addendum == '\0')
1254                     version_addendum = NULL;
1255           if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1256               PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1257               version_addendum == NULL ? "" : " ",
1258               version_addendum == NULL ? "" : version_addendum)) != 0) {
1259                     oerrno = errno;
1260                     error_fr(r, "sshbuf_putf");
1261                     goto out;
1262           }
1263 
1264           if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1265               sshbuf_mutable_ptr(our_version),
1266               sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1267                     oerrno = errno;
1268                     debug_f("write: %.100s", strerror(errno));
1269                     r = SSH_ERR_SYSTEM_ERROR;
1270                     goto out;
1271           }
1272           if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1273                     oerrno = errno;
1274                     error_fr(r, "sshbuf_consume_end");
1275                     goto out;
1276           }
1277           our_version_string = sshbuf_dup_string(our_version);
1278           if (our_version_string == NULL) {
1279                     error_f("sshbuf_dup_string failed");
1280                     r = SSH_ERR_ALLOC_FAIL;
1281                     goto out;
1282           }
1283           debug("Local version string %.100s", our_version_string);
1284 
1285           /* Read other side's version identification. */
1286           for (n = 0; ; n++) {
1287                     if (n >= SSH_MAX_PRE_BANNER_LINES) {
1288                               send_error(ssh, "No SSH identification string "
1289                                   "received.");
1290                               error_f("No SSH version received in first %u lines "
1291                                   "from server", SSH_MAX_PRE_BANNER_LINES);
1292                               r = SSH_ERR_INVALID_FORMAT;
1293                               goto out;
1294                     }
1295                     sshbuf_reset(peer_version);
1296                     expect_nl = 0;
1297                     for (;;) {
1298                               if (timeout_ms > 0) {
1299                                         r = waitrfd(ssh_packet_get_connection_in(ssh),
1300                                             &timeout_ms, NULL);
1301                                         if (r == -1 && errno == ETIMEDOUT) {
1302                                                   send_error(ssh, "Timed out waiting "
1303                                                       "for SSH identification string.");
1304                                                   error("Connection timed out during "
1305                                                       "banner exchange");
1306                                                   r = SSH_ERR_CONN_TIMEOUT;
1307                                                   goto out;
1308                                         } else if (r == -1) {
1309                                                   oerrno = errno;
1310                                                   error_f("%s", strerror(errno));
1311                                                   r = SSH_ERR_SYSTEM_ERROR;
1312                                                   goto out;
1313                                         }
1314                               }
1315 
1316                               len = atomicio(read, ssh_packet_get_connection_in(ssh),
1317                                   &c, 1);
1318                               if (len != 1 && errno == EPIPE) {
1319                                         verbose_f("Connection closed by remote host");
1320                                         r = SSH_ERR_CONN_CLOSED;
1321                                         goto out;
1322                               } else if (len != 1) {
1323                                         oerrno = errno;
1324                                         error_f("read: %.100s", strerror(errno));
1325                                         r = SSH_ERR_SYSTEM_ERROR;
1326                                         goto out;
1327                               }
1328                               if (c == '\r') {
1329                                         expect_nl = 1;
1330                                         continue;
1331                               }
1332                               if (c == '\n')
1333                                         break;
1334                               if (c == '\0' || expect_nl) {
1335                                         verbose_f("banner line contains invalid "
1336                                             "characters");
1337                                         goto invalid;
1338                               }
1339                               if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1340                                         oerrno = errno;
1341                                         error_fr(r, "sshbuf_put");
1342                                         goto out;
1343                               }
1344                               if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1345                                         verbose_f("banner line too long");
1346                                         goto invalid;
1347                               }
1348                     }
1349                     /* Is this an actual protocol banner? */
1350                     if (sshbuf_len(peer_version) > 4 &&
1351                         memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1352                               break;
1353                     /* If not, then just log the line and continue */
1354                     if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1355                               error_f("sshbuf_dup_string failed");
1356                               r = SSH_ERR_ALLOC_FAIL;
1357                               goto out;
1358                     }
1359                     /* Do not accept lines before the SSH ident from a client */
1360                     if (ssh->kex->server) {
1361                               verbose_f("client sent invalid protocol identifier "
1362                                   "\"%.256s\"", cp);
1363                               free(cp);
1364                               goto invalid;
1365                     }
1366                     debug_f("banner line %zu: %s", n, cp);
1367                     free(cp);
1368           }
1369           peer_version_string = sshbuf_dup_string(peer_version);
1370           if (peer_version_string == NULL)
1371                     fatal_f("sshbuf_dup_string failed");
1372           /* XXX must be same size for sscanf */
1373           if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1374                     error_f("calloc failed");
1375                     r = SSH_ERR_ALLOC_FAIL;
1376                     goto out;
1377           }
1378 
1379           /*
1380            * Check that the versions match.  In future this might accept
1381            * several versions and set appropriate flags to handle them.
1382            */
1383           if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1384               &remote_major, &remote_minor, remote_version) != 3) {
1385                     error("Bad remote protocol version identification: '%.100s'",
1386                         peer_version_string);
1387  invalid:
1388                     send_error(ssh, "Invalid SSH identification string.");
1389                     r = SSH_ERR_INVALID_FORMAT;
1390                     goto out;
1391           }
1392           debug("Remote protocol version %d.%d, remote software version %.100s",
1393               remote_major, remote_minor, remote_version);
1394           compat_banner(ssh, remote_version);
1395 
1396           mismatch = 0;
1397           switch (remote_major) {
1398           case 2:
1399                     break;
1400           case 1:
1401                     if (remote_minor != 99)
1402                               mismatch = 1;
1403                     break;
1404           default:
1405                     mismatch = 1;
1406                     break;
1407           }
1408           if (mismatch) {
1409                     error("Protocol major versions differ: %d vs. %d",
1410                         PROTOCOL_MAJOR_2, remote_major);
1411                     send_error(ssh, "Protocol major versions differ.");
1412                     r = SSH_ERR_NO_PROTOCOL_VERSION;
1413                     goto out;
1414           }
1415 
1416           if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1417                     logit("probed from %s port %d with %s.  Don't panic.",
1418                         ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1419                         peer_version_string);
1420                     r = SSH_ERR_CONN_CLOSED; /* XXX */
1421                     goto out;
1422           }
1423           if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1424                     logit("scanned from %s port %d with %s.  Don't panic.",
1425                         ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1426                         peer_version_string);
1427                     r = SSH_ERR_CONN_CLOSED; /* XXX */
1428                     goto out;
1429           }
1430           /* success */
1431           r = 0;
1432  out:
1433           free(our_version_string);
1434           free(peer_version_string);
1435           free(remote_version);
1436           if (r == SSH_ERR_SYSTEM_ERROR)
1437                     errno = oerrno;
1438           return r;
1439 }
1440 
1441