Tighten up the inside handlers with a bit of DRY

This commit is contained in:
Dave Russell 2020-09-27 22:37:20 +10:00
parent 2c931d5691
commit 55d72ac46f
2 changed files with 26 additions and 36 deletions

View File

@ -1,22 +1,30 @@
package nebula package nebula
func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { func (f *Interface) encrypted(h InsideHandler) InsideHandler {
if !f.handleEncrypted(ci, addr, header) { return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
return if !f.handleEncrypted(ci, addr, header) {
return
}
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
} }
}
func (f *Interface) rxMetrics(h InsideHandler) InsideHandler {
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
}
}
func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.decryptToTun(hostInfo, header.MessageCounter, out, packet, fwPacket, nb) f.decryptToTun(hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
} }
func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb) d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
if err != nil { if err != nil {
hostInfo.logger().WithError(err).WithField("udpAddr", addr). hostInfo.logger().WithError(err).WithField("udpAddr", addr).
@ -29,17 +37,9 @@ func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionSta
} }
f.lightHouse.HandleRequest(addr, hostInfo.hostId, d, hostInfo.GetCert(), f) f.lightHouse.HandleRequest(addr, hostInfo.hostId, d, hostInfo.GetCert(), f)
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
} }
func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb) d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
if err != nil { if err != nil {
hostInfo.logger().WithError(err).WithField("udpAddr", addr). hostInfo.logger().WithError(err).WithField("udpAddr", addr).
@ -57,28 +57,18 @@ func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, ad
f.handleHostRoaming(hostInfo, addr) f.handleHostRoaming(hostInfo, addr)
f.send(test, testReply, ci, hostInfo, hostInfo.remote, d, nb, out) f.send(test, testReply, ci, hostInfo, hostInfo.remote, d, nb, out)
} }
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
} }
func (f *Interface) handleHandshakePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { func (f *Interface) handleHandshakePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
HandleIncomingHandshake(f, addr, packet, header, hostInfo) HandleIncomingHandshake(f, addr, packet, header, hostInfo)
} }
func (f *Interface) handleRecvErrorPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { func (f *Interface) handleRecvErrorPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
// TODO: Remove this with recv_error deprecation // TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header) f.handleRecvError(addr, header)
} }
func (f *Interface) handleCloseTunnelPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { func (f *Interface) handleCloseTunnelPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
hostInfo.logger().WithField("udpAddr", addr). hostInfo.logger().WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.") Info("Close tunnel received, tearing down.")

View File

@ -111,23 +111,23 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
ifce.handlers = map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler{ ifce.handlers = map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler{
Version: { Version: {
handshake: { handshake: {
handshakeIXPSK0: ifce.handleHandshakePacket, handshakeIXPSK0: ifce.rxMetrics(ifce.handleHandshakePacket),
}, },
message: { message: {
subTypeNone: ifce.handleMessagePacket, subTypeNone: ifce.encrypted(ifce.handleMessagePacket),
}, },
recvError: { recvError: {
subTypeNone: ifce.handleRecvErrorPacket, subTypeNone: ifce.rxMetrics(ifce.handleRecvErrorPacket),
}, },
lightHouse: { lightHouse: {
subTypeNone: ifce.handleLighthousePacket, subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleLighthousePacket)),
}, },
test: { test: {
testRequest: ifce.handleTestPacket, testRequest: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
testReply: ifce.handleTestPacket, testReply: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
}, },
closeTunnel: { closeTunnel: {
subTypeNone: ifce.handleCloseTunnelPacket, subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleCloseTunnelPacket)),
}, },
}, },
} }