refactor readOutsidePackets (#1642)

* refactor readOutsidePackets

They layout of this method is confusing and relys on certain parts to
return early for things to work correctly.

Change the ordering of the logic so that we do this:

- Handle unencrypted packets
- Decrypt packet
- Handle encrypted packets

This way, nothing can sneak through unencrypted to where it shouldn't
be.

* fix comment

* code review comments

* check for expected type/subtype

* check header version

* log header

* need to handle TestReply

* clean roaming / connectionManager

* dont need to roam here now, we do it earlier

* cleanup metrics and errors

* rxInvalid

* debug logger checks

* ErrOutOfWindow
This commit is contained in:
Wade Simmons
2026-05-06 12:23:27 -04:00
committed by GitHub
parent ff91c37529
commit 4fb5cdb4fa
3 changed files with 210 additions and 225 deletions

View File

@@ -174,6 +174,10 @@ func (h *H) SubTypeName() string {
return SubTypeName(h.Type, h.Subtype)
}
func (h *H) IsValidSubType() bool {
return IsValidSubType(h.Type, h.Subtype)
}
// SubTypeName will transform a nebula message sub type into a human string
func SubTypeName(t MessageType, s MessageSubType) string {
if n, ok := subTypeMap[t]; ok {
@@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string {
return "unknown"
}
func IsValidSubType(t MessageType, s MessageSubType) bool {
if n, ok := subTypeMap[t]; ok {
if _, ok := (*n)[s]; ok {
return true
}
}
return false
}
// NewHeader turns bytes into a header
func NewHeader(b []byte) (*H, error) {
h := new(H)

View File

@@ -13,6 +13,8 @@ type MessageMetrics struct {
rxUnknown metrics.Counter
txUnknown metrics.Counter
rxInvalid metrics.Counter
}
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
@@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int
}
}
}
func (m *MessageMetrics) RxInvalid(i int64) {
if m != nil && m.rxInvalid != nil {
m.rxInvalid.Inc(i)
}
}
func newMessageMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
@@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics {
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil),
}
}

View File

@@ -20,23 +20,46 @@ const (
minFwPacketLen = 4
)
var ErrOutOfWindow = errors.New("out of window packet")
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
// TODO: record metrics for rx holepunch/punchy packets?
if len(packet) > 1 {
f.l.Info("Error while parsing inbound packet",
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while parsing inbound packet",
"from", via,
"error", err,
"packet", packet,
)
}
}
return
}
if h.Version != header.Version {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Unexpected header version received", "from", via)
}
return
}
// Check before processing to see if this is a expected type/subtype
if !h.IsValidSubType() {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Unexpected packet received", "from", via)
}
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Refusing to process double encrypted packet", "from", via)
}
@@ -44,31 +67,108 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
}
}
// don't keep Rx metrics for message type, since you can see those in the tun metrics
if h.Type != header.Message {
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
}
// Unencrypted packets
switch h.Type {
case header.Handshake:
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.handleRecvError(via.UdpAddr, h)
return
}
// Relay packets are special
isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay)
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
if isMessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
// At this point we should have a valid existing tunnel, verify and send
// recvError if necessary
if hostinfo == nil || hostinfo.ConnectionState == nil {
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return
}
// All remaining packets are encrypted
ci := hostinfo.ConnectionState
if !ci.window.Check(f.l, h.MessageCounter) {
return
}
// Relay packets are special
if isMessageRelay {
f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache)
return
}
out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Failed to decrypt packet",
"error", err,
"from", via,
"header", h,
)
}
return
}
// Roam before we respond
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, via, h) {
switch h.Subtype {
case header.MessageNone:
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
return
}
case header.LightHouse:
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f)
case header.Test:
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
case header.TestReply:
// No-op, useful for the Roaming and connectionManager side-effects above
case header.TestRequest:
f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h)
return
}
case header.MessageRelay:
case header.CloseTunnel:
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
case header.Control:
f.relayManager.HandleControlMsg(hostinfo, out, f)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h)
}
}
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
@@ -76,6 +176,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// which will gracefully fail in the DecryptDanger call.
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return
@@ -111,7 +212,6 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
IsRelayed: true,
}
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
@@ -131,9 +231,14 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
return
default:
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type)
}
return
}
} else {
hostinfo.logger(f.l).Info("Unexpected target relay state",
@@ -143,116 +248,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
)
return
}
}
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt test packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(via.UdpAddr, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via)
hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type)
}
return
}
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -300,23 +300,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
}
return true
}
var (
ErrPacketTooShort = errors.New("packet is too short")
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
@@ -523,38 +506,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
}
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
}
return nil, errors.New("out of window packet")
return nil, ErrOutOfWindow
}
return out, nil
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
return false
}
err = newPacket(out, true, fwPacket)
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
"error", err,
"packet", out,
)
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
}
return false
return
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
@@ -568,15 +533,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
"reason", dropReason,
)
}
return false
return
}
f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.Error("Failed to write to tun", "error", err)
}
return true
}
func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {