diff --git a/header/header.go b/header/header.go index f22509b8..b973141f 100644 --- a/header/header.go +++ b/header/header.go @@ -174,6 +174,10 @@ func (h *H) SubTypeName() string { return SubTypeName(h.Type, h.Subtype) } +func (h *H) IsValidSubType() bool { + return IsValidSubType(h.Type, h.Subtype) +} + // SubTypeName will transform a nebula message sub type into a human string func SubTypeName(t MessageType, s MessageSubType) string { if n, ok := subTypeMap[t]; ok { @@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string { return "unknown" } +func IsValidSubType(t MessageType, s MessageSubType) bool { + if n, ok := subTypeMap[t]; ok { + if _, ok := (*n)[s]; ok { + return true + } + } + + return false +} + // NewHeader turns bytes into a header func NewHeader(b []byte) (*H, error) { h := new(H) diff --git a/message_metrics.go b/message_metrics.go index 10e8472c..45de9a5c 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -13,6 +13,8 @@ type MessageMetrics struct { rxUnknown metrics.Counter txUnknown metrics.Counter + + rxInvalid metrics.Counter } func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { @@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int } } } +func (m *MessageMetrics) RxInvalid(i int64) { + if m != nil && m.rxInvalid != nil { + m.rxInvalid.Inc(i) + } +} func newMessageMetrics() *MessageMetrics { gen := func(t string) [][]metrics.Counter { @@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics { rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil), txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil), + rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil), } } diff --git a/outside.go b/outside.go index 1e00a0a9..17013ed3 100644 --- a/outside.go +++ b/outside.go @@ -20,23 +20,46 @@ const ( minFwPacketLen = 4 ) +var ErrOutOfWindow = errors.New("out of window packet") + func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + // TODO: record metrics for rx holepunch/punchy packets? if len(packet) > 1 { - f.l.Info("Error while parsing inbound packet", - "from", via, - "error", err, - "packet", packet, - ) + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while parsing inbound packet", + "from", via, + "error", err, + "packet", packet, + ) + } + } + return + } + + if h.Version != header.Version { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected header version received", "from", via) + } + return + } + + // Check before processing to see if this is a expected type/subtype + if !h.IsValidSubType() { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected packet received", "from", via) } return } - //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { + f.messageMetrics.RxInvalid(1) if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("Refusing to process double encrypted packet", "from", via) } @@ -44,215 +67,192 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } } + // don't keep Rx metrics for message type, since you can see those in the tun metrics + if h.Type != header.Message { + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + } + + // Unencrypted packets + switch h.Type { + case header.Handshake: + f.handshakeManager.HandleIncoming(via, packet, h) + return + + case header.RecvError: + f.handleRecvError(via.UdpAddr, h) + return + } + + // Relay packets are special + isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay) + var hostinfo *HostInfo - // verify if we've seen this index before, otherwise respond to the handshake initiation - if h.Type == header.Message && h.Subtype == header.MessageRelay { + if isMessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } - var ci *ConnectionState - if hostinfo != nil { - ci = hostinfo.ConnectionState + // At this point we should have a valid existing tunnel, verify and send + // recvError if necessary + if hostinfo == nil || hostinfo.ConnectionState == nil { + if !via.IsRelayed { + f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) + } + return } + // All remaining packets are encrypted + ci := hostinfo.ConnectionState + if !ci.window.Check(f.l, h.MessageCounter) { + return + } + + // Relay packets are special + if isMessageRelay { + f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache) + + return + } + + out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + if err != nil { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Failed to decrypt packet", + "error", err, + "from", via, + "header", h, + ) + } + return + } + + // Roam before we respond + f.handleHostRoaming(hostinfo, via) + f.connectionManager.In(hostinfo) + switch h.Type { case header.Message: - if !f.handleEncrypted(ci, via, h) { - return - } - switch h.Subtype { case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { - return - } - case header.MessageRelay: - // The entire body is sent as AD, not encrypted. - // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. - // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's - // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice - // which will gracefully fail in the DecryptDanger call. - signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] - signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) - if err != nil { - return - } - // Successfully validated the thing. Get rid of the Relay header. - signedPayload = signedPayload[header.Len:] - // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, via) - // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo) - f.connectionManager.RelayUsed(h.RemoteIndex) - - relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) - if !ok { - // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing - // its internal mapping. This should never happen. - hostinfo.logger(f.l).Error("HostInfo missing remote relay index", - "vpnAddrs", hostinfo.vpnAddrs, - "remoteIndex", h.RemoteIndex, - ) - return - } - - switch relay.Type { - case TerminalType: - // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - via = ViaSender{ - UdpAddr: via.UdpAddr, - relayHI: hostinfo, - remoteIdx: relay.RemoteIndex, - relay: relay, - IsRelayed: true, - } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) - return - case ForwardingType: - // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) - if err != nil { - hostinfo.logger(f.l).Info("Failed to find target host info by ip", - "relayTo", relay.PeerAddr, - "error", err, - "hostinfo.vpnAddrs", hostinfo.vpnAddrs, - ) - return - } - - // If that relay is Established, forward the payload through it - if targetRelay.State == Established { - switch targetRelay.Type { - case ForwardingType: - // Forward this packet through the relay tunnel - // Find the target HostInfo - f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) - return - case TerminalType: - hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") - } - } else { - hostinfo.logger(f.l).Info("Unexpected target relay state", - "relayTo", relay.PeerAddr, - "relayFrom", hostinfo.vpnAddrs[0], - "targetRelayState", targetRelay.State, - ) - return - } - } + f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) + return } case header.LightHouse: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - //TODO: assert via is not relayed - lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) - - // Fallthrough to the bottom to record incoming traffic + lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f) case header.Test: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { + switch h.Subtype { + case header.TestReply: + // No-op, useful for the Roaming and connectionManager side-effects above + case header.TestRequest: + f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h) return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt test packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - - if h.Subtype == header.TestRequest { - // This testRequest might be from TryPromoteBest, so we should roam - // to the new IP address before responding - f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) - } - - // Fallthrough to the bottom to record incoming traffic - - // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they - // are unauthenticated - - case header.Handshake: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(via, packet, h) - return - - case header.RecvError: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(via.UdpAddr, h) - return - case header.CloseTunnel: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via) - f.closeTunnel(hostinfo) - return case header.Control: - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt Control packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - - f.relayManager.HandleControlMsg(hostinfo, d, f) + f.relayManager.HandleControlMsg(hostinfo, out, f) default: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via) - } + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h) + } +} + +func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] + var err error + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) + if err != nil { + return + } + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, via) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet + f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).Error("HostInfo missing remote relay index", + "vpnAddrs", hostinfo.vpnAddrs, + "remoteIndex", h.RemoteIndex, + ) return } - f.handleHostRoaming(hostinfo, via) + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. + via = ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, + } + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).Info("Failed to find target host info by ip", + "relayTo", relay.PeerAddr, + "error", err, + "hostinfo.vpnAddrs", hostinfo.vpnAddrs, + ) + return + } - f.connectionManager.In(hostinfo) + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel + // Find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + return + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type) + } + return + } + } else { + hostinfo.logger(f.l).Info("Unexpected target relay state", + "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], + "targetRelayState", targetRelay.State, + ) + return + } + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type) + } + } } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote @@ -300,23 +300,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { } -// handleEncrypted returns true if a packet should be processed, false otherwise -func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { - // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect - if ci == nil { - if !via.IsRelayed { - f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) - } - return false - } - // If the window check fails, refuse to process the packet, but don't send a recv error - if !ci.window.Check(f.l, h.MessageCounter) { - return false - } - - return true -} - var ( ErrPacketTooShort = errors.New("packet is too short") ErrUnknownIPVersion = errors.New("packet is an unknown ip version") @@ -523,38 +506,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h) - } - return nil, errors.New("out of window packet") + return nil, ErrOutOfWindow } return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { - var err error - - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) - return false - } - - err = newPacket(out, true, fwPacket) +func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { + err := newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", "error", err, "packet", out, ) - return false - } - - if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) - } - return false + return } dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) @@ -568,15 +533,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out "reason", dropReason, ) } - return false + return } - f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { f.l.Error("Failed to write to tun", "error", err) } - return true } func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {