From 9d310e72c21e81930f3b67c54df5d75d0d7f21b6 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 Nov 2024 10:11:34 -0600 Subject: [PATCH] Resolve some todos (#1274) --- connection_manager.go | 23 ++++---- control.go | 14 ++--- firewall.go | 1 - interface.go | 2 +- overlay/tun_windows.go | 5 +- pki.go | 11 ---- ssh.go | 131 ++++++++++++++++++++--------------------- 7 files changed, 86 insertions(+), 101 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index 9c38b5d87..9d8d0715f 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -426,17 +426,17 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - //TODO: current.vpnIp should become an array of vpnIps + // Only one side should swap because if both swap then we may never resolve to a single tunnel. + // vpn addr is static across all tunnels for this host pair so lets + // use that to determine if we should consider swapping. if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { - // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. - // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. - // The remotes vpn ip is lower than mine. I will not flip. + // Their primary vpn addr is less than mine. Do not swap. return false } - //TODO: we should favor v2 over v1 certificates if configured to send them - - crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version()) + crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things + // settle down. return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } @@ -495,13 +495,14 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version()) - if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) { + cs := n.intf.pki.getCertState() + curCrt := hostinfo.ConnectionState.myCert + myCrt := cs.getCertificate(curCrt.Version()) + if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + // The current tunnel is using the latest certificate and version, no need to rehandshake. return } - //TODO: we should favor v2 over v1 certificates if configured to send them - n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") diff --git a/control.go b/control.go index 866e9db25..5302e564a 100644 --- a/control.go +++ b/control.go @@ -133,9 +133,9 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) if found { - //TODO: we might have 2 certs.... - //TODO: this should return our latest version cert - return c.f.pki.getDefaultCertificate().Copy() + // Only returning the default certificate since its impossible + // for any other host but ourselves to have more than 1 + return c.f.pki.getCertState().GetDefaultCertificate().Copy() } hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { @@ -228,13 +228,9 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { // the int returned is a count of tunnels closed func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { //TODO: this is probably better as a function in ConnectionManager or HostMap directly - lighthouses := c.f.lightHouse.GetLighthouses() - shutdown := func(h *HostInfo) { - if excludeLighthouses { - if _, ok := lighthouses[h.vpnAddrs[0]]; ok { - return - } + if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { + return } c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) diff --git a/firewall.go b/firewall.go index d1e306348..0aae7d6cf 100644 --- a/firewall.go +++ b/firewall.go @@ -23,7 +23,6 @@ import ( ) type FirewallInterface interface { - //TODO: name these better addr, localAddr. Are they vpnAddrs? AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error } diff --git a/interface.go b/interface.go index f8f2b0960..9b9315697 100644 --- a/interface.go +++ b/interface.go @@ -419,7 +419,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(f.pki.getCertState().GetDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second)) //TODO: we should also report the default certificate version } } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index db5b5fbd5..289999d8a 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -239,11 +239,12 @@ func (t *winTun) Close() error { luid := winipcfg.LUID(t.tun.LUID()) _ = luid.FlushRoutes(windows.AF_INET) _ = luid.FlushIPAddresses(windows.AF_INET) - /* We don't support IPV6 yet + _ = luid.FlushRoutes(windows.AF_INET6) _ = luid.FlushIPAddresses(windows.AF_INET6) - */ + _ = luid.FlushDNS(windows.AF_INET) + _ = luid.FlushDNS(windows.AF_INET6) return t.tun.Close() } diff --git a/pki.go b/pki.go index 69c60863b..c7c93ac5b 100644 --- a/pki.go +++ b/pki.go @@ -70,16 +70,6 @@ func (p *PKI) getCertState() *CertState { return p.cs.Load() } -// TODO: We should remove this -func (p *PKI) getDefaultCertificate() cert.Certificate { - return p.cs.Load().GetDefaultCertificate() -} - -// TODO: We should remove this -func (p *PKI) getCertificate(v cert.Version) cert.Certificate { - return p.cs.Load().getCertificate(v) -} - func (p *PKI) reload(c *config.C, initial bool) error { err := p.reloadCerts(c, initial) if err != nil { @@ -300,7 +290,6 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { // Load the certificate crt, rawCert, err = loadCertificate(rawCert) if err != nil { - //TODO: check error return nil, err } diff --git a/ssh.go b/ssh.go index 5244e4a4a..1062a77f9 100644 --- a/ssh.go +++ b/ssh.go @@ -320,7 +320,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", - ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip", + ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintCertFlags{} @@ -336,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-tunnel", - ShortDescription: "Prints json details about a tunnel for the provided vpn ip", + ShortDescription: "Prints json details about a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} @@ -364,7 +364,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "change-remote", - ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip", + ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshChangeRemoteFlags{} @@ -378,7 +378,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "close-tunnel", - ShortDescription: "Closes a tunnel for the provided vpn ip", + ShortDescription: "Closes a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCloseTunnelFlags{} @@ -392,7 +392,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "create-tunnel", - ShortDescription: "Creates a tunnel for the provided vpn ip and address", + ShortDescription: "Creates a tunnel for the provided vpn address", Help: "The lighthouses will be queried for real addresses but you can provide one as well.", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) @@ -407,8 +407,8 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "query-lighthouse", - ShortDescription: "Query the lighthouses for the provided vpn ip", - Help: "This command is asynchronous. Only currently known udp ips will be printed.", + ShortDescription: "Query the lighthouses for the provided vpn address", + Help: "This command is asynchronous. Only currently known udp addresses will be printed.", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { return sshQueryLighthouse(f, fs, a, w) }, @@ -465,8 +465,8 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr } type lighthouseInfo struct { - VpnIp string `json:"vpnIp"` - Addrs *CacheMap `json:"addrs"` + VpnAddr string `json:"vpnAddr"` + Addrs *CacheMap `json:"addrs"` } lightHouse.RLock() @@ -474,15 +474,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr x := 0 for k, v := range lightHouse.addrMap { addrMap[x] = lighthouseInfo{ - VpnIp: k.String(), - Addrs: v.CopyCache(), + VpnAddr: k.String(), + Addrs: v.CopyCache(), } x++ } lightHouse.RUnlock() sort.Slice(addrMap, func(i, j int) bool { - return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0 + return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0 }) if fs.Json || fs.Pretty { @@ -503,7 +503,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr if err != nil { return err } - err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b))) + err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b))) if err != nil { return err } @@ -541,20 +541,20 @@ func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } var cm *CacheMap - rl := ifce.lightHouse.Query(vpnIp) + rl := ifce.lightHouse.Query(vpnAddr) if rl != nil { cm = rl.CopyCache() } @@ -569,21 +569,21 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } if !flags.LocalOnly { @@ -610,24 +610,24 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } @@ -640,7 +640,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) + hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) if addr.IsValid() { hostInfo.SetRemote(addr) } @@ -656,7 +656,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } if flags.Address == "" { @@ -668,18 +668,18 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("Address could not be parsed") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } hostInfo.SetRemote(addr) @@ -785,21 +785,20 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } - //TODO: This should return both certs - cert := ifce.pki.getDefaultCertificate() + cert := ifce.pki.getCertState().GetDefaultCertificate() if len(a) > 0 { - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } cert = hostInfo.GetCert().Certificate @@ -857,15 +856,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp netip.Addr + PeerAddr netip.Addr LocalIndex uint32 RemoteIndex uint32 RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp netip.Addr - RelayForIps []RelayFor + NebulaAddr netip.Addr + RelayForAddrs []RelayFor } type CmdOutput struct { @@ -881,16 +880,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } for k, v := range relays { - ro := RelayOutput{NebulaIp: v.vpnAddrs[0]} + ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]} co.Relays = append(co.Relays, &ro) relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0]) if relayHI == nil { - ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) + ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")}) continue } - for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { + for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() { rf := RelayFor{Error: nil} - r, ok := relayHI.relayState.GetRelayForByAddr(vpnIp) + r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr) if ok { t := "" switch r.Type { @@ -914,19 +913,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.LocalIndex = r.LocalIndex rf.RemoteIndex = r.RemoteIndex - rf.PeerIp = r.PeerAddr + rf.PeerAddr = r.PeerAddr rf.Type = t rf.State = s if rf.LocalIndex != k { rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI := ifce.hostMap.QueryVpnAddr(vpnIp) + relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr) if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } - ro.RelayForIps = append(ro.RelayForIps, rf) + ro.RelayForAddrs = append(ro.RelayForAddrs, rf) } } err := enc.Encode(co) @@ -944,21 +943,21 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } enc := json.NewEncoder(w.GetWriter())