mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
6 Commits
no-exit
...
cross-stac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f597aa71e3 | ||
|
|
20b7219fbe | ||
|
|
3b53c27170 | ||
|
|
526236c5fa | ||
|
|
0ab2882b78 | ||
|
|
889d49ff82 |
@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
|||||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
|
||||||
nc := &cert.TBSCertificate{
|
|
||||||
Version: v,
|
|
||||||
Curve: c.Curve(),
|
|
||||||
Name: c.Name(),
|
|
||||||
Networks: c.Networks(),
|
|
||||||
UnsafeNetworks: c.UnsafeNetworks(),
|
|
||||||
Groups: c.Groups(),
|
|
||||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
|
||||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
|
||||||
PublicKey: c.PublicKey(),
|
|
||||||
IsCA: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pem, err := c.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, pem
|
|
||||||
}
|
|
||||||
|
|
||||||
func X25519Keypair() ([]byte, []byte) {
|
func X25519Keypair() ([]byte, []byte) {
|
||||||
privkey := make([]byte, 32)
|
privkey := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||||
|
|||||||
@@ -65,16 +65,8 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
wait, err := ctrl.Start()
|
ctrl.Start()
|
||||||
if err != nil {
|
ctrl.ShutdownBlock()
|
||||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go ctrl.ShutdownBlock()
|
|
||||||
wait()
|
|
||||||
|
|
||||||
l.Info("Goodbye")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -3,9 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
_ "net/http/pprof"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -61,22 +58,10 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
wait, err := ctrl.Start()
|
ctrl.Start()
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go ctrl.ShutdownBlock()
|
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
wait()
|
ctrl.ShutdownBlock()
|
||||||
|
|
||||||
l.Info("Goodbye")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -354,6 +354,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if mainHostInfo {
|
if mainHostInfo {
|
||||||
decision = tryRehandshake
|
decision = tryRehandshake
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.shouldSwapPrimary(hostinfo) {
|
if cm.shouldSwapPrimary(hostinfo) {
|
||||||
decision = swapPrimary
|
decision = swapPrimary
|
||||||
@@ -460,10 +461,6 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||||
if crt == nil {
|
|
||||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||||
// settle down.
|
// settle down.
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
@@ -478,34 +475,31 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
|||||||
cm.hostMap.Unlock()
|
cm.hostMap.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
// check and return true.
|
||||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||||
remoteCert := hostinfo.GetCert()
|
remoteCert := hostinfo.GetCert()
|
||||||
if remoteCert == nil {
|
if remoteCert == nil {
|
||||||
return false //don't tear down tunnels for handshakes in progress
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
caPool := cm.intf.pki.GetCAPool()
|
caPool := cm.intf.pki.GetCAPool()
|
||||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false //cert is still valid! yay!
|
|
||||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
|
||||||
// Block listed certificates should always be disconnected
|
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
|
||||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
|
||||||
return true
|
|
||||||
} else if cm.intf.disconnectInvalid.Load() {
|
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
|
||||||
return true
|
|
||||||
} else {
|
|
||||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||||
|
// Block listed certificates should always be disconnected
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||||
@@ -536,45 +530,15 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
cs := cm.intf.pki.getCertState()
|
cs := cm.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
curCrtVersion := curCrt.Version()
|
myCrt := cs.getCertificate(curCrt.Version())
|
||||||
myCrt := cs.getCertificate(curCrtVersion)
|
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||||
if myCrt == nil {
|
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("version", curCrtVersion).
|
|
||||||
WithField("reason", "local certificate removed").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerCrt := hostinfo.ConnectionState.peerCert
|
|
||||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
|
||||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
|
||||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("version", curCrtVersion).
|
|
||||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
|
||||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
|
||||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("reason", "local certificate is not current").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
return
|
WithField("reason", "local certificate is not current").
|
||||||
}
|
Info("Re-handshaking with remote")
|
||||||
if curCrtVersion < cs.initiatingVersion {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
56
control.go
56
control.go
@@ -2,11 +2,9 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -15,16 +13,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RunState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
Stopped RunState = 0 // The control has yet to be started
|
|
||||||
Started RunState = 1 // The control has been started
|
|
||||||
Stopping RunState = 2 // The control is stopping
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrAlreadyStarted = errors.New("nebula is already started")
|
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
@@ -38,9 +26,6 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
stateLock sync.Mutex
|
|
||||||
state RunState
|
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -64,21 +49,10 @@ type ControlHostInfo struct {
|
|||||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call.
|
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||||
// The returned function can be used to wait for nebula to fully stop.
|
func (c *Control) Start() {
|
||||||
func (c *Control) Start() (func(), error) {
|
|
||||||
c.stateLock.Lock()
|
|
||||||
if c.state != Stopped {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
return nil, ErrAlreadyStarted
|
|
||||||
}
|
|
||||||
|
|
||||||
// Activate the interface
|
// Activate the interface
|
||||||
err := c.f.activate()
|
c.f.activate()
|
||||||
if err != nil {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||||
if c.sshStart != nil {
|
if c.sshStart != nil {
|
||||||
@@ -98,33 +72,15 @@ func (c *Control) Start() (func(), error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start reading packets.
|
// Start reading packets.
|
||||||
c.state = Started
|
c.f.run()
|
||||||
c.stateLock.Unlock()
|
|
||||||
return c.f.run()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) State() RunState {
|
|
||||||
c.stateLock.Lock()
|
|
||||||
defer c.stateLock.Unlock()
|
|
||||||
return c.state
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Context() context.Context {
|
func (c *Control) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
||||||
func (c *Control) Stop() {
|
func (c *Control) Stop() {
|
||||||
c.stateLock.Lock()
|
|
||||||
if c.state != Started {
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
// We are stopping or stopped already
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.state = Stopping
|
|
||||||
c.stateLock.Unlock()
|
|
||||||
|
|
||||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||||
// being created while we're shutting them all down.
|
// being created while we're shutting them all down.
|
||||||
c.cancel()
|
c.cancel()
|
||||||
@@ -133,7 +89,7 @@ func (c *Control) Stop() {
|
|||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.WithError(err).Error("Close interface failed")
|
||||||
}
|
}
|
||||||
c.state = Stopped
|
c.l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
|
|||||||
@@ -29,8 +29,6 @@ type m = map[string]any
|
|||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
|
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
|
var vpnNetworks []netip.Prefix
|
||||||
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnNetworks) == 0 {
|
||||||
|
panic("no vpn networks")
|
||||||
|
}
|
||||||
|
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
@@ -129,109 +146,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
return control, vpnNetworks, udpAddr, c
|
return control, vpnNetworks, udpAddr, c
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServer creates a nebula instance with fewer assumptions
|
|
||||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
vpnNetworks := certs[len(certs)-1].Networks()
|
|
||||||
|
|
||||||
var udpAddr netip.AddrPort
|
|
||||||
if vpnNetworks[0].Addr().Is4() {
|
|
||||||
budpIp := vpnNetworks[0].Addr().As4()
|
|
||||||
budpIp[1] -= 128
|
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
|
||||||
} else {
|
|
||||||
budpIp := vpnNetworks[0].Addr().As16()
|
|
||||||
// beef for funsies
|
|
||||||
budpIp[2] = 190
|
|
||||||
budpIp[3] = 239
|
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
|
||||||
}
|
|
||||||
|
|
||||||
caStr := ""
|
|
||||||
for _, ca := range caCrt {
|
|
||||||
x, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr += string(x)
|
|
||||||
}
|
|
||||||
certStr := ""
|
|
||||||
for _, c := range certs {
|
|
||||||
x, err := c.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
certStr += string(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": certStr,
|
|
||||||
"key": string(key),
|
|
||||||
},
|
|
||||||
//"tun": m{"disabled": true},
|
|
||||||
"firewall": m{
|
|
||||||
"outbound": []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
"inbound": []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
//"handshakes": m{
|
|
||||||
// "try_interval": "1s",
|
|
||||||
//},
|
|
||||||
"listen": m{
|
|
||||||
"host": udpAddr.Addr().String(),
|
|
||||||
"port": udpAddr.Port(),
|
|
||||||
},
|
|
||||||
"logging": m{
|
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
|
||||||
"timers": m{
|
|
||||||
"pending_deletion_interval": 2,
|
|
||||||
"connection_alive_interval": 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if overrides != nil {
|
|
||||||
final := m{}
|
|
||||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
mc = final
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
|
||||||
cStr := string(cb)
|
|
||||||
c.LoadString(cStr)
|
|
||||||
|
|
||||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return control, vpnNetworks, udpAddr, c
|
|
||||||
}
|
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
|
|
||||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,8 +11,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDropInactiveTunnels(t *testing.T) {
|
func TestDropInactiveTunnels(t *testing.T) {
|
||||||
@@ -60,261 +57,49 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertUpgrade(t *testing.T) {
|
func TestCrossStackRelaysWork(t *testing.T) {
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
// under ideal conditions
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||||
caB, err := ca.MarshalPEM()
|
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||||
if err != nil {
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
ca2B, err := ca2.MarshalPEM()
|
//myVpnV4 := myVpnIpNet[0]
|
||||||
if err != nil {
|
myVpnV6 := myVpnIpNet[1]
|
||||||
panic(err)
|
relayVpnV4 := relayVpnIpNet[0]
|
||||||
}
|
relayVpnV6 := relayVpnIpNet[1]
|
||||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
theirVpnV6 := theirVpnIpNet[0]
|
||||||
|
|
||||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
// Build a router so we don't have to reason who gets which packet
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
myControl.Start()
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
defer r.RenderFlow()
|
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
r.Log("Assert the tunnel works")
|
||||||
r.Log("yay")
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
mc := m{
|
t.Log("reply?")
|
||||||
"pki": m{
|
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||||
"ca": caStr,
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
"cert": string(myCert2Pem),
|
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||||
"key": string(myPrivKey),
|
|
||||||
},
|
|
||||||
//"tun": m{"disabled": true},
|
|
||||||
"firewall": myC.Settings["firewall"],
|
|
||||||
//"handshakes": m{
|
|
||||||
// "try_interval": "1s",
|
|
||||||
//},
|
|
||||||
"listen": myC.Settings["listen"],
|
|
||||||
"logging": myC.Settings["logging"],
|
|
||||||
"timers": myC.Settings["timers"],
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
if err != nil {
|
//t.Log("finish up")
|
||||||
panic(err)
|
//myControl.Stop()
|
||||||
}
|
//theirControl.Stop()
|
||||||
|
//relayControl.Stop()
|
||||||
r.Logf("reload new v2-only config")
|
|
||||||
err = myC.ReloadConfigString(string(cb))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r.Log("yay, spin until their sees it")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
r.Logf("version %d", version)
|
|
||||||
if version == cert.Version2 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*10 {
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertDowngrade(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
caB, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
ca2B, err := ca2.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
//r.Log("yay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": string(myCertPem),
|
|
||||||
"key": string(myPrivKey),
|
|
||||||
},
|
|
||||||
"firewall": myC.Settings["firewall"],
|
|
||||||
"listen": myC.Settings["listen"],
|
|
||||||
"logging": myC.Settings["logging"],
|
|
||||||
"timers": myC.Settings["timers"],
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("reload new v1-only config")
|
|
||||||
err = myC.ReloadConfigString(string(cb))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r.Log("yay, spin until their sees it")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil || c2 == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
theirVersion := c2.Cert.Version()
|
|
||||||
r.Logf("version %d,%d", version, theirVersion)
|
|
||||||
if version == cert.Version1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*5 {
|
|
||||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
|
||||||
}
|
|
||||||
if since > time.Second*10 {
|
|
||||||
r.Log("wtf")
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertMismatchCorrection(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
//r.Log("yay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil || c2 == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
theirVersion := c2.Cert.Version()
|
|
||||||
r.Logf("version %d,%d", version, theirVersion)
|
|
||||||
if version == theirVersion {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*5 {
|
|
||||||
r.Log("wtf")
|
|
||||||
}
|
|
||||||
if since > time.Second*10 {
|
|
||||||
r.Log("wtf")
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
}
|
||||||
|
|||||||
29
firewall.go
29
firewall.go
@@ -417,6 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrUnknownNetworkType = errors.New("unknown network type")
|
||||||
|
var ErrPeerRejected = errors.New("remote IP is not within a subnet that we handle")
|
||||||
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||||
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
@@ -429,18 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate
|
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||||
if h.networks != nil {
|
if h.networks == nil {
|
||||||
if !h.networks.Contains(fp.RemoteAddr) {
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrInvalidRemoteIP
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Simple case: Certificate has one address and no unsafe networks
|
// Simple case: Certificate has one address and no unsafe networks
|
||||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||||
|
if !ok {
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrInvalidRemoteIP
|
||||||
|
}
|
||||||
|
switch nwType {
|
||||||
|
case NetworkTypeVPN:
|
||||||
|
break // nothing special
|
||||||
|
case NetworkTypeVPNPeer:
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||||
|
case NetworkTypeUnsafe:
|
||||||
|
break // nothing special, one day this may have different FW rules
|
||||||
|
default:
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrUnknownNetworkType //should never happen
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -149,7 +150,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -174,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -226,6 +228,9 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -250,7 +255,7 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -453,6 +458,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -478,7 +485,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
c1 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -493,7 +500,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -510,6 +517,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -541,7 +550,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c2 := cert.CachedCertificate{
|
c2 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -556,7 +565,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c3 := cert.CachedCertificate{
|
c3 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -571,7 +580,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -597,6 +606,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -620,7 +631,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
@@ -633,6 +644,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -659,7 +672,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -696,6 +709,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||||
|
|
||||||
c := cert.CachedCertificate{
|
c := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -717,7 +732,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
|
||||||
|
|||||||
149
handshake_ix.go
149
handshake_ix.go
@@ -23,17 +23,13 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we're connecting to a v6 address we must use a v2 cert
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
v := cs.initiatingVersion
|
v := cs.initiatingVersion
|
||||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
v = hh.initiatingVersionOverride
|
if a.Is6() {
|
||||||
} else if v < cert.Version2 {
|
v = cert.Version2
|
||||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
break
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
|
||||||
if a.Is6() {
|
|
||||||
v = cert.Version2
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", v).
|
WithField("certVersion", v).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||||
@@ -108,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
WithField("certVersion", cs.initiatingVersion).
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||||
@@ -149,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fp, fperr := rc.Fingerprint()
|
fp, err := rc.Fingerprint()
|
||||||
if fperr != nil {
|
if err != nil {
|
||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if myCertOtherVersion == nil {
|
if rc == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
f.l.WithError(err).WithFields(m{
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||||
"udpAddr": addr,
|
Info("Unable to handshake with host due to missing certificate version")
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
return
|
||||||
"cert": remoteCert,
|
|
||||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Record the certificate we are actually using
|
|
||||||
ci.myCert = myCertOtherVersion
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record the certificate we are actually using
|
||||||
|
ci.myCert = rc
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
@@ -192,17 +183,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
|
||||||
var filteredNetworks []netip.Prefix
|
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
certVersion := remoteCert.Certificate.Version()
|
certVersion := remoteCert.Certificate.Version()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
|
vpnNetworks := remoteCert.Certificate.Networks()
|
||||||
|
|
||||||
for _, network := range remoteCert.Certificate.Networks() {
|
anyVpnAddrsInCommon := false
|
||||||
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
|
for i, network := range vpnNetworks {
|
||||||
vpnAddr := network.Addr()
|
vpnAddr := network.Addr()
|
||||||
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -210,24 +202,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
vpnAddrs[i] = network.Addr()
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
if f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
anyVpnAddrsInCommon = true
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filteredNetworks = append(filteredNetworks, network)
|
|
||||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnAddrs) == 0 {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
@@ -264,26 +242,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL := f.l.WithFields(m{
|
||||||
WithField("certName", certName).
|
"vpnAddrs": vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"udpAddr": addr,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName": certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion": certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint": fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"issuer": issuer,
|
||||||
Info("Handshake message received")
|
"initiatorIndex": hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex": hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex": h.RemoteIndex,
|
||||||
|
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
})
|
||||||
|
|
||||||
|
if anyVpnAddrsInCommon {
|
||||||
|
msgRxL.Info("Handshake message received")
|
||||||
|
} else {
|
||||||
|
//todo warn if not lighthouse or relay?
|
||||||
|
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||||
|
}
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
WithField("certVersion", ci.myCert.Version()).
|
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -341,7 +323,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -582,30 +564,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
anyVpnAddrsInCommon := false
|
||||||
var filteredNetworks []netip.Prefix
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
for _, network := range vpnNetworks {
|
for i, network := range vpnNetworks {
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
vpnAddrs[i] = network.Addr()
|
||||||
vpnAddr := network.Addr()
|
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
anyVpnAddrsInCommon = true
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filteredNetworks = append(filteredNetworks, network)
|
|
||||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnAddrs) == 0 {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
|
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
||||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).
|
||||||
@@ -618,6 +587,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
|
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
@@ -644,7 +614,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -652,12 +622,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("durationNs", duration).
|
WithField("durationNs", duration).
|
||||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
WithField("sentCachedPackets", len(hh.packetStore))
|
||||||
Info("Handshake message received")
|
if anyVpnAddrsInCommon {
|
||||||
|
msgRxL.Info("Handshake message received")
|
||||||
|
} else {
|
||||||
|
//todo warn if not lighthouse or relay?
|
||||||
|
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||||
|
}
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
// Build up the radix for the firewall if we have subnets in the cert
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
hostinfo.vpnAddrs = vpnAddrs
|
||||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
|
|||||||
@@ -68,12 +68,11 @@ type HandshakeManager struct {
|
|||||||
type HandshakeHostInfo struct {
|
type HandshakeHostInfo struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
startTime time.Time // Time that we first started trying with this handshake
|
startTime time.Time // Time that we first started trying with this handshake
|
||||||
ready bool // Is the handshake ready
|
ready bool // Is the handshake ready
|
||||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
counter int64 // How many attempts have we made so far
|
||||||
counter int64 // How many attempts have we made so far
|
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
|
||||||
|
|
||||||
hostinfo *HostInfo
|
hostinfo *HostInfo
|
||||||
}
|
}
|
||||||
|
|||||||
39
hostmap.go
39
hostmap.go
@@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
|||||||
rs.relayForByIdx[idx] = r
|
rs.relayForByIdx[idx] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NetworkType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
NetworkTypeUnknown NetworkType = iota
|
||||||
|
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
||||||
|
NetworkTypeVPN
|
||||||
|
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
||||||
|
NetworkTypeVPNPeer
|
||||||
|
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||||
|
NetworkTypeUnsafe
|
||||||
|
)
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
remote netip.AddrPort
|
remote netip.AddrPort
|
||||||
remotes *RemoteList
|
remotes *RemoteList
|
||||||
@@ -220,13 +232,11 @@ type HostInfo struct {
|
|||||||
remoteIndexId uint32
|
remoteIndexId uint32
|
||||||
localIndexId uint32
|
localIndexId uint32
|
||||||
|
|
||||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
// vpnAddrs is a list of vpn addresses assigned to this host
|
||||||
// The host may have other vpn addresses that are outside our
|
|
||||||
// vpn networks but were removed because they are not usable
|
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// networks are both all vpn and unsafe networks assigned to this host
|
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
||||||
networks *bart.Lite
|
networks *bart.Table[NetworkType]
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@@ -730,20 +740,27 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, networks, unsafeNetworks []netip.Prefix) {
|
||||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||||
// Simple case, no CIDRTree needed
|
if myVpnNetworksTable.Contains(networks[0].Addr()) {
|
||||||
return
|
return // Simple case, no CIDRTree needed
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
i.networks = new(bart.Lite)
|
i.networks = new(bart.Table[NetworkType])
|
||||||
for _, network := range networks {
|
for _, network := range networks {
|
||||||
|
var nwType NetworkType
|
||||||
|
if myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
|
nwType = NetworkTypeVPN
|
||||||
|
} else {
|
||||||
|
nwType = NetworkTypeVPNPeer
|
||||||
|
}
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
i.networks.Insert(nprefix)
|
i.networks.Insert(nprefix, nwType)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range unsafeNetworks {
|
for _, network := range unsafeNetworks {
|
||||||
i.networks.Insert(network)
|
i.networks.Insert(network, NetworkTypeUnsafe)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
56
interface.go
56
interface.go
@@ -6,8 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -87,7 +87,6 @@ type Interface struct {
|
|||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
@@ -210,7 +209,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
// activate creates the interface on the host. After the interface is created, any
|
// activate creates the interface on the host. After the interface is created, any
|
||||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||||
// the interface isn't going to process anything until run() is called.
|
// the interface isn't going to process anything until run() is called.
|
||||||
func (f *Interface) activate() error {
|
func (f *Interface) activate() {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
@@ -231,38 +230,33 @@ func (f *Interface) activate() error {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
f.l.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = f.inside.Activate(); err != nil {
|
if err := f.inside.Activate(); err != nil {
|
||||||
f.inside.Close()
|
f.inside.Close()
|
||||||
return err
|
f.l.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) run() (func(), error) {
|
func (f *Interface) run() {
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
f.wg.Add(1)
|
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
f.wg.Add(1)
|
|
||||||
go f.listenIn(f.readers[i], i)
|
go f.listenIn(f.readers[i], i)
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.wg.Wait, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
@@ -277,21 +271,14 @@ func (f *Interface) listenOut(i int) {
|
|||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil && !f.closed.Load() {
|
|
||||||
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
|
|
||||||
//TODO: Trigger Control to close
|
|
||||||
}
|
|
||||||
|
|
||||||
f.l.Debugf("underlay reader %v is done", i)
|
|
||||||
f.wg.Done()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
@@ -302,18 +289,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !f.closed.Load() {
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet, closing")
|
return
|
||||||
//TODO: Trigger Control to close
|
|
||||||
}
|
}
|
||||||
break
|
|
||||||
|
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||||
|
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||||
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.Debugf("overlay reader %v is done", i)
|
|
||||||
f.wg.Done()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
@@ -465,7 +451,6 @@ func (f *Interface) GetCertState() *CertState {
|
|||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
// Release the udp readers
|
|
||||||
for _, u := range f.writers {
|
for _, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -473,13 +458,6 @@ func (f *Interface) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the tun readers
|
// Release the tun device
|
||||||
for _, u := range f.readers {
|
return f.inside.Close()
|
||||||
err := u.Close()
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Error while closing tun device")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1017,17 +1017,17 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
|
|||||||
return lhh.meta
|
return lhh.meta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) {
|
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostinfo *HostInfo, p []byte, w EncWriter) {
|
||||||
n := lhh.resetMeta()
|
n := lhh.resetMeta()
|
||||||
err := n.Unmarshal(p)
|
err := n.Unmarshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
|
||||||
Error("Failed to unmarshal lighthouse packet")
|
Error("Failed to unmarshal lighthouse packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Details == nil {
|
if n.Details == nil {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
|
||||||
Error("Invalid lighthouse update")
|
Error("Invalid lighthouse update")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1036,24 +1036,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
|
|
||||||
switch n.Type {
|
switch n.Type {
|
||||||
case NebulaMeta_HostQuery:
|
case NebulaMeta_HostQuery:
|
||||||
lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w)
|
lhh.handleHostQuery(n, hostinfo, rAddr, w)
|
||||||
|
|
||||||
case NebulaMeta_HostQueryReply:
|
case NebulaMeta_HostQueryReply:
|
||||||
lhh.handleHostQueryReply(n, fromVpnAddrs)
|
lhh.handleHostQueryReply(n, hostinfo.vpnAddrs)
|
||||||
|
|
||||||
case NebulaMeta_HostUpdateNotification:
|
case NebulaMeta_HostUpdateNotification:
|
||||||
lhh.handleHostUpdateNotification(n, fromVpnAddrs, w)
|
lhh.handleHostUpdateNotification(n, hostinfo, w)
|
||||||
|
|
||||||
case NebulaMeta_HostMovedNotification:
|
case NebulaMeta_HostMovedNotification:
|
||||||
case NebulaMeta_HostPunchNotification:
|
case NebulaMeta_HostPunchNotification:
|
||||||
lhh.handleHostPunchNotification(n, fromVpnAddrs, w)
|
lhh.handleHostPunchNotification(n, hostinfo.vpnAddrs, w)
|
||||||
|
|
||||||
case NebulaMeta_HostUpdateNotificationAck:
|
case NebulaMeta_HostUpdateNotificationAck:
|
||||||
// noop
|
// noop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo, addr netip.AddrPort, w EncWriter) {
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
@@ -1065,7 +1065,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
lhh.l.WithField("from", hostinfo.vpnAddrs).WithField("details", n.Details).
|
||||||
Debugln("Dropping malformed HostQuery")
|
Debugln("Dropping malformed HostQuery")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -1073,7 +1073,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -1099,14 +1099,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||||
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||||
|
|
||||||
lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
|
lhh.sendHostPunchNotification(n, hostinfo.vpnAddrs, queryVpnAddr, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
|
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
|
||||||
@@ -1115,20 +1115,34 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
|
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
|
||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostPunchNotification
|
n.Type = NebulaMeta_HostPunchNotification
|
||||||
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
punchNotifDestHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
||||||
var useVersion cert.Version
|
var useVersion cert.Version
|
||||||
if targetHI == nil {
|
if punchNotifDestHI == nil {
|
||||||
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
||||||
} else {
|
} else {
|
||||||
crt := targetHI.GetCert().Certificate
|
|
||||||
useVersion = crt.Version()
|
|
||||||
// we can only retarget if we have a hostinfo
|
// we can only retarget if we have a hostinfo
|
||||||
newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
|
punchNotifDestCrt := punchNotifDestHI.GetCert().Certificate
|
||||||
|
useVersion = punchNotifDestCrt.Version()
|
||||||
|
punchNotifDestNetworks := punchNotifDestCrt.Networks()
|
||||||
|
|
||||||
|
//if we (the lighthouse) don't have a network in common with punchNotifDest, try to find one
|
||||||
|
if !lhh.lh.myVpnNetworksTable.Contains(punchNotifDest) {
|
||||||
|
newPunchNotifDest, ok := findNetworkUnion(lhh.lh.myVpnNetworks, punchNotifDestHI.vpnAddrs)
|
||||||
|
if ok {
|
||||||
|
punchNotifDest = newPunchNotifDest
|
||||||
|
} else {
|
||||||
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
|
lhh.l.WithField("to", punchNotifDestNetworks).Debugln("unable to notify host to host, no addresses in common")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newWhereToPunch, ok := findNetworkUnion(punchNotifDestNetworks, fromVpnAddrs)
|
||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newWhereToPunch
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
lhh.l.WithFields(m{"from": fromVpnAddrs, "to": punchNotifDestNetworks}).Debugln("unable to punch to host, no addresses in common with requestor")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1234,7 +1248,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, hostinfo *HostInfo, w EncWriter) {
|
||||||
|
fromVpnAddrs := hostinfo.vpnAddrs
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
||||||
@@ -1302,7 +1317,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
|
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
|
||||||
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
|
|||||||
@@ -132,8 +132,13 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
|
hostinfo := &HostInfo{
|
||||||
hi := []netip.Addr{vpnIp2}
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{vpnIp2},
|
||||||
|
}
|
||||||
b.Run("notfound", func(b *testing.B) {
|
b.Run("notfound", func(b *testing.B) {
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
@@ -146,7 +151,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hostinfo, p, mw)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("found", func(b *testing.B) {
|
b.Run("found", func(b *testing.B) {
|
||||||
@@ -162,7 +167,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hostinfo, p, mw)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -326,7 +331,14 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
|
|||||||
w := &testEncWriter{
|
w := &testEncWriter{
|
||||||
metaFilter: &filter,
|
metaFilter: &filter,
|
||||||
}
|
}
|
||||||
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
|
hostinfo := &HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{myVpnIp},
|
||||||
|
}
|
||||||
|
lhh.HandleRequest(fromAddr, hostinfo, b, w)
|
||||||
return w.lastReply
|
return w.lastReply
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +367,15 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
hostinfo := &HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
|
}
|
||||||
w := &testEncWriter{}
|
w := &testEncWriter{}
|
||||||
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
|
lhh.HandleRequest(fromAddr, hostinfo, b, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testLhReply struct {
|
type testLhReply struct {
|
||||||
|
|||||||
18
main.go
18
main.go
@@ -284,14 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
f: ifce,
|
ifce,
|
||||||
l: l,
|
l,
|
||||||
ctx: ctx,
|
ctx,
|
||||||
cancel: cancel,
|
cancel,
|
||||||
sshStart: sshStart,
|
sshStart,
|
||||||
statsStart: statsStart,
|
statsStart,
|
||||||
dnsStart: dnsStart,
|
dnsStart,
|
||||||
lighthouseStart: lightHouse.StartUpdateWorker,
|
lightHouse.StartUpdateWorker,
|
||||||
connectionManagerStart: connManager.Start,
|
connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//f.l.Error("in packet ", h)
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -138,7 +138,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
lhf.HandleRequest(ip, hostinfo, d, f)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
|
|||||||
83
pki.go
83
pki.go
@@ -100,62 +100,55 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
currentState := p.cs.Load()
|
currentState := p.cs.Load()
|
||||||
if newState.v1Cert != nil {
|
if newState.v1Cert != nil {
|
||||||
if currentState.v1Cert == nil {
|
if currentState.v1Cert == nil {
|
||||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||||
} else {
|
|
||||||
// did IP in cert change? if so, don't set
|
|
||||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Networks in new cert was different from old",
|
|
||||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Curve in new v1 cert was different from old",
|
|
||||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// did IP in cert change? if so, don't set
|
||||||
|
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Networks in new cert was different from old",
|
||||||
|
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Curve in new cert was different from old",
|
||||||
|
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if currentState.v1Cert != nil {
|
||||||
|
//TODO: CERT-V2 we should be able to tear this down
|
||||||
|
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if newState.v2Cert != nil {
|
if newState.v2Cert != nil {
|
||||||
if currentState.v2Cert == nil {
|
if currentState.v2Cert == nil {
|
||||||
//adding certs is fine, actually
|
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||||
} else {
|
|
||||||
// did IP in cert change? if so, don't set
|
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Networks in new cert was different from old",
|
|
||||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Curve in new cert was different from old",
|
|
||||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if currentState.v2Cert != nil {
|
// did IP in cert change? if so, don't set
|
||||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||||
if newState.v1Cert == nil {
|
|
||||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
|
||||||
}
|
|
||||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
"Networks in new cert was different from old",
|
||||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Curve in new cert was different from old",
|
||||||
|
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if currentState.v2Cert != nil {
|
||||||
|
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cipher cant be hot swapped so just leave it at what it was before
|
// Cipher cant be hot swapped so just leave it at what it was before
|
||||||
|
|||||||
@@ -44,10 +44,7 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(control *nebula.Control) (*Service, error) {
|
func New(control *nebula.Control) (*Service, error) {
|
||||||
wait, err := control.Start()
|
control.Start()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := control.Context()
|
ctx := control.Context()
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
@@ -144,12 +141,6 @@ func New(control *nebula.Control) (*Service, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the nebula wait function to the group
|
|
||||||
eg.Go(func() error {
|
|
||||||
wait()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type EncReader func(
|
|||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader) error
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Close() error
|
Close() error
|
||||||
@@ -30,8 +30,8 @@ func (NoopConn) Rebind() error {
|
|||||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
return netip.AddrPort{}, nil
|
return netip.AddrPort{}, nil
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOut(_ EncReader) error {
|
func (NoopConn) ListenOut(_ EncReader) {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|||||||
return func() {}
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) error {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -173,7 +173,8 @@ func (u *StdConn) ListenOut(r EncReader) error {
|
|||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
return err
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||||
|
|||||||
@@ -71,14 +71,15 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader) error {
|
func (u *GenericConn) ListenOut(r EncReader) {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
@@ -18,8 +17,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
|
|
||||||
|
|
||||||
type StdConn struct {
|
type StdConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
isV4 bool
|
isV4 bool
|
||||||
@@ -27,6 +24,14 @@ type StdConn struct {
|
|||||||
batch int
|
batch int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||||
|
ip4 := ip.To4()
|
||||||
|
if ip4 != nil {
|
||||||
|
return ip4, true
|
||||||
|
}
|
||||||
|
return ip, false
|
||||||
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
af := unix.AF_INET6
|
af := unix.AF_INET6
|
||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
@@ -50,11 +55,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set a read timeout
|
|
||||||
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var sa unix.Sockaddr
|
var sa unix.Sockaddr
|
||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
sa4 := &unix.SockaddrInet4{Port: port}
|
sa4 := &unix.SockaddrInet4{Port: port}
|
||||||
@@ -118,7 +118,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) error {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
|
|
||||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
@@ -130,7 +130,8 @@ func (u *StdConn) ListenOut(r EncReader) error {
|
|||||||
for {
|
for {
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
@@ -158,9 +159,6 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
if err == unix.EAGAIN || err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,9 +180,6 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
if err == unix.EAGAIN || err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader) error {
|
func (u *RIOConn) ListenOut(r EncReader) {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -142,7 +142,8 @@ func (u *RIOConn) ListenOut(r EncReader) error {
|
|||||||
n, rua, err := u.receive(buffer)
|
n, rua, err := u.receive(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
return err
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package udp
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -107,11 +106,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) ListenOut(r EncReader) error {
|
func (u *TesterConn) ListenOut(r EncReader) {
|
||||||
for {
|
for {
|
||||||
p, ok := <-u.RxPackets
|
p, ok := <-u.RxPackets
|
||||||
if !ok {
|
if !ok {
|
||||||
return os.ErrClosed
|
return
|
||||||
}
|
}
|
||||||
r(p.From, p.Data)
|
r(p.From, p.Data)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user