Tighten up the inside handlers with a bit of DRY

This commit is contained in:
Dave Russell 2020-09-27 22:37:20 +10:00
parent 2c931d5691
commit 55d72ac46f
2 changed files with 26 additions and 36 deletions

View File

@ -1,22 +1,30 @@
package nebula
func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
func (f *Interface) encrypted(h InsideHandler) InsideHandler {
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
if !f.handleEncrypted(ci, addr, header) {
return
}
f.decryptToTun(hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
}
func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
func (f *Interface) rxMetrics(h InsideHandler) InsideHandler {
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
}
}
func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.decryptToTun(hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
}
func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostInfo.logger().WithError(err).WithField("udpAddr", addr).
@ -29,17 +37,9 @@ func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionSta
}
f.lightHouse.HandleRequest(addr, hostInfo.hostId, d, hostInfo.GetCert(), f)
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
}
func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostInfo.logger().WithError(err).WithField("udpAddr", addr).
@ -57,28 +57,18 @@ func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, ad
f.handleHostRoaming(hostInfo, addr)
f.send(test, testReply, ci, hostInfo, hostInfo.remote, d, nb, out)
}
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
}
func (f *Interface) handleHandshakePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
HandleIncomingHandshake(f, addr, packet, header, hostInfo)
}
func (f *Interface) handleRecvErrorPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header)
}
func (f *Interface) handleCloseTunnelPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
hostInfo.logger().WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")

View File

@ -111,23 +111,23 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
ifce.handlers = map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler{
Version: {
handshake: {
handshakeIXPSK0: ifce.handleHandshakePacket,
handshakeIXPSK0: ifce.rxMetrics(ifce.handleHandshakePacket),
},
message: {
subTypeNone: ifce.handleMessagePacket,
subTypeNone: ifce.encrypted(ifce.handleMessagePacket),
},
recvError: {
subTypeNone: ifce.handleRecvErrorPacket,
subTypeNone: ifce.rxMetrics(ifce.handleRecvErrorPacket),
},
lightHouse: {
subTypeNone: ifce.handleLighthousePacket,
subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleLighthousePacket)),
},
test: {
testRequest: ifce.handleTestPacket,
testReply: ifce.handleTestPacket,
testRequest: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
testReply: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
},
closeTunnel: {
subTypeNone: ifce.handleCloseTunnelPacket,
subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleCloseTunnelPacket)),
},
},
}