Resolve some todos (#1274)

This commit is contained in:
Nate Brown 2024-11-15 10:11:34 -06:00 committed by GitHub
parent 5380fef7b0
commit 9d310e72c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 86 additions and 101 deletions

View File

@ -426,17 +426,17 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
//TODO: current.vpnIp should become an array of vpnIps
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
// vpn addr is static across all tunnels for this host pair so lets
// use that to determine if we should consider swapping.
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip.
// Their primary vpn addr is less than mine. Do not swap.
return false
}
//TODO: we should favor v2 over v1 certificates if configured to send them
crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version())
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
}
@ -495,13 +495,14 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version())
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) {
cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return
}
//TODO: we should favor v2 over v1 certificates if configured to send them
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")

View File

@ -133,9 +133,9 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
if found {
//TODO: we might have 2 certs....
//TODO: this should return our latest version cert
return c.f.pki.getDefaultCertificate().Copy()
// Only returning the default certificate since its impossible
// for any other host but ourselves to have more than 1
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
}
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil {
@ -228,13 +228,9 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
// the int returned is a count of tunnels closed
func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
lighthouses := c.f.lightHouse.GetLighthouses()
shutdown := func(h *HostInfo) {
if excludeLighthouses {
if _, ok := lighthouses[h.vpnAddrs[0]]; ok {
return
}
if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
return
}
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h)

View File

@ -23,7 +23,6 @@ import (
)
type FirewallInterface interface {
//TODO: name these better addr, localAddr. Are they vpnAddrs?
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
}

View File

@ -419,7 +419,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
certExpirationGauge.Update(int64(f.pki.getCertState().GetDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
//TODO: we should also report the default certificate version
}
}

View File

@ -239,11 +239,12 @@ func (t *winTun) Close() error {
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
/* We don't support IPV6 yet
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
*/
_ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6)
return t.tun.Close()
}

11
pki.go
View File

@ -70,16 +70,6 @@ func (p *PKI) getCertState() *CertState {
return p.cs.Load()
}
// TODO: We should remove this
func (p *PKI) getDefaultCertificate() cert.Certificate {
return p.cs.Load().GetDefaultCertificate()
}
// TODO: We should remove this
func (p *PKI) getCertificate(v cert.Version) cert.Certificate {
return p.cs.Load().getCertificate(v)
}
func (p *PKI) reload(c *config.C, initial bool) error {
err := p.reloadCerts(c, initial)
if err != nil {
@ -300,7 +290,6 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
// Load the certificate
crt, rawCert, err = loadCertificate(rawCert)
if err != nil {
//TODO: check error
return nil, err
}

131
ssh.go
View File

@ -320,7 +320,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "print-cert",
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip",
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintCertFlags{}
@ -336,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "print-tunnel",
ShortDescription: "Prints json details about a tunnel for the provided vpn ip",
ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintTunnelFlags{}
@ -364,7 +364,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "change-remote",
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip",
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshChangeRemoteFlags{}
@ -378,7 +378,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "close-tunnel",
ShortDescription: "Closes a tunnel for the provided vpn ip",
ShortDescription: "Closes a tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshCloseTunnelFlags{}
@ -392,7 +392,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "create-tunnel",
ShortDescription: "Creates a tunnel for the provided vpn ip and address",
ShortDescription: "Creates a tunnel for the provided vpn address",
Help: "The lighthouses will be queried for real addresses but you can provide one as well.",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
@ -407,8 +407,8 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "query-lighthouse",
ShortDescription: "Query the lighthouses for the provided vpn ip",
Help: "This command is asynchronous. Only currently known udp ips will be printed.",
ShortDescription: "Query the lighthouses for the provided vpn address",
Help: "This command is asynchronous. Only currently known udp addresses will be printed.",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshQueryLighthouse(f, fs, a, w)
},
@ -465,8 +465,8 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
}
type lighthouseInfo struct {
VpnIp string `json:"vpnIp"`
Addrs *CacheMap `json:"addrs"`
VpnAddr string `json:"vpnAddr"`
Addrs *CacheMap `json:"addrs"`
}
lightHouse.RLock()
@ -474,15 +474,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
x := 0
for k, v := range lightHouse.addrMap {
addrMap[x] = lighthouseInfo{
VpnIp: k.String(),
Addrs: v.CopyCache(),
VpnAddr: k.String(),
Addrs: v.CopyCache(),
}
x++
}
lightHouse.RUnlock()
sort.Slice(addrMap, func(i, j int) bool {
return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0
})
if fs.Json || fs.Pretty {
@ -503,7 +503,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
if err != nil {
return err
}
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b)))
if err != nil {
return err
}
@ -541,20 +541,20 @@ func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter
func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
return w.WriteLine("No vpn address was provided")
}
vpnIp, err := netip.ParseAddr(a[0])
vpnAddr, err := netip.ParseAddr(a[0])
if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
}
if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
}
var cm *CacheMap
rl := ifce.lightHouse.Query(vpnIp)
rl := ifce.lightHouse.Query(vpnAddr)
if rl != nil {
cm = rl.CopyCache()
}
@ -569,21 +569,21 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
}
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
return w.WriteLine("No vpn address was provided")
}
vpnIp, err := netip.ParseAddr(a[0])
vpnAddr, err := netip.ParseAddr(a[0])
if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
}
if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0]))
}
if !flags.LocalOnly {
@ -610,24 +610,24 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
}
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
return w.WriteLine("No vpn address was provided")
}
vpnIp, err := netip.ParseAddr(a[0])
vpnAddr, err := netip.ParseAddr(a[0])
if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
}
if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
}
hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnIp)
hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
}
@ -640,7 +640,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
}
}
hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil)
if addr.IsValid() {
hostInfo.SetRemote(addr)
}
@ -656,7 +656,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
}
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
return w.WriteLine("No vpn address was provided")
}
if flags.Address == "" {
@ -668,18 +668,18 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("Address could not be parsed")
}
vpnIp, err := netip.ParseAddr(a[0])
vpnAddr, err := netip.ParseAddr(a[0])
if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
}
if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0]))
}
hostInfo.SetRemote(addr)
@ -785,21 +785,20 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return nil
}
//TODO: This should return both certs
cert := ifce.pki.getDefaultCertificate()
cert := ifce.pki.getCertState().GetDefaultCertificate()
if len(a) > 0 {
vpnIp, err := netip.ParseAddr(a[0])
vpnAddr, err := netip.ParseAddr(a[0])
if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
}
if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0]))
}
cert = hostInfo.GetCert().Certificate
@ -857,15 +856,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
Error error
Type string
State string
PeerIp netip.Addr
PeerAddr netip.Addr
LocalIndex uint32
RemoteIndex uint32
RelayedThrough []netip.Addr
}
type RelayOutput struct {
NebulaIp netip.Addr
RelayForIps []RelayFor
NebulaAddr netip.Addr
RelayForAddrs []RelayFor
}
type CmdOutput struct {
@ -881,16 +880,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
}
for k, v := range relays {
ro := RelayOutput{NebulaIp: v.vpnAddrs[0]}
ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]}
co.Relays = append(co.Relays, &ro)
relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0])
if relayHI == nil {
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")})
continue
}
for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() {
rf := RelayFor{Error: nil}
r, ok := relayHI.relayState.GetRelayForByAddr(vpnIp)
r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr)
if ok {
t := ""
switch r.Type {
@ -914,19 +913,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
rf.LocalIndex = r.LocalIndex
rf.RemoteIndex = r.RemoteIndex
rf.PeerIp = r.PeerAddr
rf.PeerAddr = r.PeerAddr
rf.Type = t
rf.State = s
if rf.LocalIndex != k {
rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
}
}
relayedHI := ifce.hostMap.QueryVpnAddr(vpnIp)
relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr)
if relayedHI != nil {
rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
}
ro.RelayForIps = append(ro.RelayForIps, rf)
ro.RelayForAddrs = append(ro.RelayForAddrs, rf)
}
}
err := enc.Encode(co)
@ -944,21 +943,21 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
}
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
return w.WriteLine("No vpn address was provided")
}
vpnIp, err := netip.ParseAddr(a[0])
vpnAddr, err := netip.ParseAddr(a[0])
if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
}
if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0]))
}
enc := json.NewEncoder(w.GetWriter())