mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Compare commits
14 Commits
cross-stac
...
channels-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c6f81c224 | ||
|
|
ad37749c5e | ||
|
|
a0f8cb2098 | ||
|
|
d18d1aea67 | ||
|
|
f5ff534671 | ||
|
|
2ea8a72d5c | ||
|
|
663232e1fc | ||
|
|
2f48529e8b | ||
|
|
f3e1ad64cd | ||
|
|
1d8112a329 | ||
|
|
31eea0cc94 | ||
|
|
dbba4a4c77 | ||
|
|
194fde45da | ||
|
|
f46b83f2c4 |
3
bits.go
3
bits.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
|
||||||
type Bits struct {
|
type Bits struct {
|
||||||
length uint64
|
length uint64
|
||||||
current uint64
|
current uint64
|
||||||
@@ -43,7 +44,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
l.Error("rejected a packet (top) %d %d\n", b.current, i)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
97
cert/pem.go
97
cert/pem.go
@@ -1,8 +1,10 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/hex"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
@@ -138,6 +140,101 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backward compatibility functions for older API
|
||||||
|
func MarshalX25519PublicKey(b []byte) []byte {
|
||||||
|
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalX25519PrivateKey(b []byte) []byte {
|
||||||
|
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPublicKey(curve Curve, b []byte) []byte {
|
||||||
|
return MarshalPublicKeyToPEM(curve, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPrivateKey(curve Curve, b []byte) []byte {
|
||||||
|
return MarshalPrivateKeyToPEM(curve, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NebulaCertificate is a compatibility wrapper for the old API
|
||||||
|
type NebulaCertificate struct {
|
||||||
|
Details NebulaCertificateDetails
|
||||||
|
Signature []byte
|
||||||
|
cert Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// NebulaCertificateDetails is a compatibility wrapper for certificate details
|
||||||
|
type NebulaCertificateDetails struct {
|
||||||
|
Name string
|
||||||
|
NotBefore time.Time
|
||||||
|
NotAfter time.Time
|
||||||
|
PublicKey []byte
|
||||||
|
IsCA bool
|
||||||
|
Issuer []byte
|
||||||
|
Curve Curve
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
|
||||||
|
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
|
||||||
|
c, rest, err := UnmarshalCertificateFromPEM(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, rest, err
|
||||||
|
}
|
||||||
|
|
||||||
|
issuerBytes, err := func() ([]byte, error) {
|
||||||
|
issuer := c.Issuer()
|
||||||
|
if issuer == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
decoded, err := hex.DecodeString(issuer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return nil, rest, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := c.PublicKey()
|
||||||
|
if pubKey != nil {
|
||||||
|
pubKey = append([]byte(nil), pubKey...)
|
||||||
|
}
|
||||||
|
|
||||||
|
sig := c.Signature()
|
||||||
|
if sig != nil {
|
||||||
|
sig = append([]byte(nil), sig...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &NebulaCertificate{
|
||||||
|
Details: NebulaCertificateDetails{
|
||||||
|
Name: c.Name(),
|
||||||
|
NotBefore: c.NotBefore(),
|
||||||
|
NotAfter: c.NotAfter(),
|
||||||
|
PublicKey: pubKey,
|
||||||
|
IsCA: c.IsCA(),
|
||||||
|
Issuer: issuerBytes,
|
||||||
|
Curve: c.Curve(),
|
||||||
|
},
|
||||||
|
Signature: sig,
|
||||||
|
cert: c,
|
||||||
|
}, rest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssuerString returns the issuer in hex format for compatibility
|
||||||
|
func (n *NebulaCertificate) IssuerString() string {
|
||||||
|
if n.Details.Issuer == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(n.Details.Issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Certificate returns the underlying certificate (read-only)
|
||||||
|
func (n *NebulaCertificate) Certificate() Certificate {
|
||||||
|
return n.cert
|
||||||
|
}
|
||||||
|
|
||||||
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
||||||
// consumed data or an error on failure
|
// consumed data or an error on failure
|
||||||
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||||
|
|||||||
@@ -65,8 +65,16 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
ctrl.ShutdownBlock()
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
|
wait()
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -58,10 +61,22 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
||||||
|
}()
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
ctrl.ShutdownBlock()
|
wait()
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
|
||||||
|
// 4092 should be sufficient for 5Gbps
|
||||||
const ReplayWindow = 1024
|
const ReplayWindow = 1024
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
|
|||||||
56
control.go
56
control.go
@@ -2,9 +2,11 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -13,6 +15,16 @@ import (
|
|||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RunState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Stopped RunState = 0 // The control has yet to be started
|
||||||
|
Started RunState = 1 // The control has been started
|
||||||
|
Stopping RunState = 2 // The control is stopping
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrAlreadyStarted = errors.New("nebula is already started")
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
@@ -26,6 +38,9 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
|
stateLock sync.Mutex
|
||||||
|
state RunState
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
|
|||||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
// Start actually runs nebula, this is a nonblocking call.
|
||||||
func (c *Control) Start() {
|
// The returned function can be used to wait for nebula to fully stop.
|
||||||
|
func (c *Control) Start() (func(), error) {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
if c.state != Stopped {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
return nil, ErrAlreadyStarted
|
||||||
|
}
|
||||||
|
|
||||||
// Activate the interface
|
// Activate the interface
|
||||||
c.f.activate()
|
err := c.f.activate()
|
||||||
|
if err != nil {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||||
if c.sshStart != nil {
|
if c.sshStart != nil {
|
||||||
@@ -72,15 +98,33 @@ func (c *Control) Start() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start reading packets.
|
// Start reading packets.
|
||||||
c.f.run()
|
c.state = Started
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
return c.f.run(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Control) State() RunState {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
defer c.stateLock.Unlock()
|
||||||
|
return c.state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Context() context.Context {
|
func (c *Control) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
||||||
func (c *Control) Stop() {
|
func (c *Control) Stop() {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
if c.state != Started {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
// We are stopping or stopped already
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = Stopping
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
|
||||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||||
// being created while we're shutting them all down.
|
// being created while we're shutting them all down.
|
||||||
c.cancel()
|
c.cancel()
|
||||||
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
|
|||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.WithError(err).Error("Close interface failed")
|
||||||
}
|
}
|
||||||
c.l.Info("Goodbye")
|
c.state = Stopped
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ type m = map[string]any
|
|||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -54,25 +56,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnNetworks) == 0 {
|
|
||||||
panic("no vpn networks")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -56,50 +55,3 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCrossStackRelaysWork(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
|
||||||
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
//myVpnV4 := myVpnIpNet[0]
|
|
||||||
myVpnV6 := myVpnIpNet[1]
|
|
||||||
relayVpnV4 := relayVpnIpNet[0]
|
|
||||||
relayVpnV6 := relayVpnIpNet[1]
|
|
||||||
theirVpnV6 := theirVpnIpNet[0]
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
r.Log("Assert the tunnel works")
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
|
||||||
|
|
||||||
t.Log("reply?")
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
|
||||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
|
||||||
//t.Log("finish up")
|
|
||||||
//myControl.Stop()
|
|
||||||
//theirControl.Stop()
|
|
||||||
//relayControl.Stop()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -132,6 +132,13 @@ listen:
|
|||||||
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
||||||
# default is 64, does not support reload
|
# default is 64, does not support reload
|
||||||
#batch: 64
|
#batch: 64
|
||||||
|
|
||||||
|
# Control batching between UDP and TUN pipelines
|
||||||
|
#batch:
|
||||||
|
# inbound_size: 32 # packets to queue from UDP before handing to workers
|
||||||
|
# outbound_size: 32 # packets to queue from TUN before handing to workers
|
||||||
|
# flush_interval: 50us # flush partially filled batches after this duration
|
||||||
|
# max_outstanding: 1028 # batches buffered per routine on each channel
|
||||||
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
|
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
|
||||||
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
|
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
|
||||||
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide
|
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide
|
||||||
|
|||||||
@@ -692,50 +692,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
c := cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "host-owner",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "host",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
|
|
||||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
h1 := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &c1,
|
|
||||||
},
|
|
||||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
|
||||||
}
|
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
|
|
||||||
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("192.0.2.1"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
|
|
||||||
LocalPort: 1,
|
|
||||||
RemotePort: 1,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLookup(b *testing.B) {
|
func BenchmarkLookup(b *testing.B) {
|
||||||
ml := func(m map[string]struct{}, a [][]string) {
|
ml := func(m map[string]struct{}, a [][]string) {
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
|
|||||||
@@ -292,7 +292,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m := NebulaControl{
|
m := NebulaControl{
|
||||||
@@ -302,25 +301,37 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
|
|
||||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
err = buildRelayInfoCertV1(&m, hm.f.myVpnNetworks, vpnIp)
|
if !hm.f.myVpnAddrs[0].Is4() {
|
||||||
|
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !vpnIp.Is4() {
|
||||||
|
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
b := hm.f.myVpnAddrs[0].As4()
|
||||||
|
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = vpnIp.As4()
|
||||||
|
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
case cert.Version2:
|
case cert.Version2:
|
||||||
err = buildRelayInfoCertV2(&m, hm.f.myVpnNetworks, vpnIp)
|
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
||||||
|
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||||
default:
|
default:
|
||||||
err = errors.New("unknown certificate version found while creating relay")
|
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(hm.l).WithError(err).Error("Refusing to relay")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := m.Marshal()
|
msg, err := m.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).WithError(err).
|
hostinfo.logger(hm.l).
|
||||||
|
WithError(err).
|
||||||
Error("Failed to marshal Control message to create relay")
|
Error("Failed to marshal Control message to create relay")
|
||||||
} else {
|
} else {
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": m.GetRelayFrom(),
|
"relayFrom": hm.f.myVpnAddrs[0],
|
||||||
"relayTo": vpnIp,
|
"relayTo": vpnIp,
|
||||||
"initiatorRelayIndex": idx,
|
"initiatorRelayIndex": idx,
|
||||||
"relay": relay}).
|
"relay": relay}).
|
||||||
@@ -346,27 +357,39 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
InitiatorRelayIndex: existingRelay.LocalIndex,
|
InitiatorRelayIndex: existingRelay.LocalIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
err = buildRelayInfoCertV1(&m, hm.f.myVpnNetworks, vpnIp)
|
if !hm.f.myVpnAddrs[0].Is4() {
|
||||||
|
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !vpnIp.Is4() {
|
||||||
|
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
b := hm.f.myVpnAddrs[0].As4()
|
||||||
|
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = vpnIp.As4()
|
||||||
|
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
case cert.Version2:
|
case cert.Version2:
|
||||||
err = buildRelayInfoCertV2(&m, hm.f.myVpnNetworks, vpnIp)
|
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
||||||
|
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||||
default:
|
default:
|
||||||
err = errors.New("unknown certificate version found while creating relay")
|
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(hm.l).WithError(err).Error("Refusing to relay")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
msg, err := m.Marshal()
|
msg, err := m.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).WithError(err).Error("Failed to marshal Control message to create relay")
|
hostinfo.logger(hm.l).
|
||||||
|
WithError(err).
|
||||||
|
Error("Failed to marshal Control message to create relay")
|
||||||
} else {
|
} else {
|
||||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": m.GetRelayFrom(),
|
"relayFrom": hm.f.myVpnAddrs[0],
|
||||||
"relayTo": vpnIp,
|
"relayTo": vpnIp,
|
||||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
"initiatorRelayIndex": existingRelay.LocalIndex,
|
||||||
"relay": relay}).
|
"relay": relay}).
|
||||||
@@ -701,32 +724,3 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
|
|||||||
func hsTimeout(tries int64, interval time.Duration) time.Duration {
|
func hsTimeout(tries int64, interval time.Duration) time.Duration {
|
||||||
return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
|
return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
|
||||||
}
|
}
|
||||||
|
|
||||||
var errNoRelayTooOld = errors.New("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
|
||||||
|
|
||||||
func buildRelayInfoCertV1(m *NebulaControl, myVpnNetworks []netip.Prefix, peerVpnIp netip.Addr) error {
|
|
||||||
relayFrom := myVpnNetworks[0].Addr()
|
|
||||||
if !relayFrom.Is4() {
|
|
||||||
return errNoRelayTooOld
|
|
||||||
}
|
|
||||||
if !peerVpnIp.Is4() {
|
|
||||||
return errNoRelayTooOld
|
|
||||||
}
|
|
||||||
|
|
||||||
b := relayFrom.As4()
|
|
||||||
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
|
||||||
b = peerVpnIp.As4()
|
|
||||||
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildRelayInfoCertV2(m *NebulaControl, myVpnNetworks []netip.Prefix, peerVpnIp netip.Addr) error {
|
|
||||||
for i := range myVpnNetworks {
|
|
||||||
if myVpnNetworks[i].Contains(peerVpnIp) {
|
|
||||||
m.RelayFromAddr = netAddrToProtoAddr(myVpnNetworks[i].Addr())
|
|
||||||
m.RelayToAddr = netAddrToProtoAddr(peerVpnIp)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return errors.New("cannot establish relay, no networks in common")
|
|
||||||
}
|
|
||||||
|
|||||||
10
hostmap.go
10
hostmap.go
@@ -512,16 +512,13 @@ func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
|
|||||||
return hm.queryVpnAddr(vpnIp, nil)
|
return hm.queryVpnAddr(vpnIp, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errUnableToFindHost = errors.New("unable to find host")
|
|
||||||
var errUnableToFindHostWithRelay = errors.New("unable to find host with relay")
|
|
||||||
|
|
||||||
func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
|
func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
defer hm.RUnlock()
|
defer hm.RUnlock()
|
||||||
|
|
||||||
h, ok := hm.Hosts[relayHostIp]
|
h, ok := hm.Hosts[relayHostIp]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errUnableToFindHost
|
return nil, nil, errors.New("unable to find host")
|
||||||
}
|
}
|
||||||
|
|
||||||
for h != nil {
|
for h != nil {
|
||||||
@@ -534,7 +531,7 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
|
|||||||
h = h.next
|
h = h.next
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, errUnableToFindHostWithRelay
|
return nil, nil, errors.New("unable to find host with relay")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
|
func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
|
||||||
@@ -741,8 +738,7 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
|||||||
|
|
||||||
i.networks = new(bart.Lite)
|
i.networks = new(bart.Lite)
|
||||||
for _, network := range networks {
|
for _, network := range networks {
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
i.networks.Insert(network)
|
||||||
i.networks.Insert(nprefix)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range unsafeNetworks {
|
for _, network := range unsafeNetworks {
|
||||||
|
|||||||
341
interface.go
341
interface.go
@@ -6,8 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -18,10 +18,18 @@ import (
|
|||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
|
"github.com/slackhq/nebula/packet"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const (
|
||||||
|
mtu = 9001
|
||||||
|
|
||||||
|
inboundBatchSizeDefault = 32
|
||||||
|
outboundBatchSizeDefault = 32
|
||||||
|
batchFlushIntervalDefault = 50 * time.Microsecond
|
||||||
|
maxOutstandingBatchesDefault = 1028
|
||||||
|
)
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
@@ -47,9 +55,17 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
|
BatchConfig BatchConfig
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BatchConfig struct {
|
||||||
|
InboundBatchSize int
|
||||||
|
OutboundBatchSize int
|
||||||
|
FlushInterval time.Duration
|
||||||
|
MaxOutstandingPerChan int
|
||||||
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
outside udp.Conn
|
outside udp.Conn
|
||||||
@@ -87,12 +103,95 @@ type Interface struct {
|
|||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
|
inPool sync.Pool
|
||||||
|
inbound []chan *packetBatch
|
||||||
|
|
||||||
|
outPool sync.Pool
|
||||||
|
outbound []chan *outboundBatch
|
||||||
|
|
||||||
|
packetBatchPool sync.Pool
|
||||||
|
outboundBatchPool sync.Pool
|
||||||
|
|
||||||
|
inboundBatchSize int
|
||||||
|
outboundBatchSize int
|
||||||
|
batchFlushInterval time.Duration
|
||||||
|
maxOutstandingPerChan int
|
||||||
|
}
|
||||||
|
|
||||||
|
type packetBatch struct {
|
||||||
|
packets []*packet.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPacketBatch(capacity int) *packetBatch {
|
||||||
|
return &packetBatch{
|
||||||
|
packets: make([]*packet.Packet, 0, capacity),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *packetBatch) add(p *packet.Packet) {
|
||||||
|
b.packets = append(b.packets, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *packetBatch) reset() {
|
||||||
|
for i := range b.packets {
|
||||||
|
b.packets[i] = nil
|
||||||
|
}
|
||||||
|
b.packets = b.packets[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) getPacketBatch() *packetBatch {
|
||||||
|
if v := f.packetBatchPool.Get(); v != nil {
|
||||||
|
b := v.(*packetBatch)
|
||||||
|
b.reset()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return newPacketBatch(f.inboundBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) releasePacketBatch(b *packetBatch) {
|
||||||
|
b.reset()
|
||||||
|
f.packetBatchPool.Put(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
type outboundBatch struct {
|
||||||
|
payloads []*[]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOutboundBatch(capacity int) *outboundBatch {
|
||||||
|
return &outboundBatch{payloads: make([]*[]byte, 0, capacity)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *outboundBatch) add(buf *[]byte) {
|
||||||
|
b.payloads = append(b.payloads, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *outboundBatch) reset() {
|
||||||
|
for i := range b.payloads {
|
||||||
|
b.payloads[i] = nil
|
||||||
|
}
|
||||||
|
b.payloads = b.payloads[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) getOutboundBatch() *outboundBatch {
|
||||||
|
if v := f.outboundBatchPool.Get(); v != nil {
|
||||||
|
b := v.(*outboundBatch)
|
||||||
|
b.reset()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return newOutboundBatch(f.outboundBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
|
||||||
|
b.reset()
|
||||||
|
f.outboundBatchPool.Put(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
@@ -162,6 +261,20 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
|
|
||||||
|
bc := c.BatchConfig
|
||||||
|
if bc.InboundBatchSize <= 0 {
|
||||||
|
bc.InboundBatchSize = inboundBatchSizeDefault
|
||||||
|
}
|
||||||
|
if bc.OutboundBatchSize <= 0 {
|
||||||
|
bc.OutboundBatchSize = outboundBatchSizeDefault
|
||||||
|
}
|
||||||
|
if bc.FlushInterval <= 0 {
|
||||||
|
bc.FlushInterval = batchFlushIntervalDefault
|
||||||
|
}
|
||||||
|
if bc.MaxOutstandingPerChan <= 0 {
|
||||||
|
bc.MaxOutstandingPerChan = maxOutstandingBatchesDefault
|
||||||
|
}
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
@@ -194,9 +307,39 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
inbound: make([]chan *packetBatch, c.routines),
|
||||||
|
outbound: make([]chan *outboundBatch, c.routines),
|
||||||
|
|
||||||
l: c.l,
|
l: c.l,
|
||||||
|
|
||||||
|
inboundBatchSize: bc.InboundBatchSize,
|
||||||
|
outboundBatchSize: bc.OutboundBatchSize,
|
||||||
|
batchFlushInterval: bc.FlushInterval,
|
||||||
|
maxOutstandingPerChan: bc.MaxOutstandingPerChan,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i := 0; i < c.routines; i++ {
|
||||||
|
ifce.inbound[i] = make(chan *packetBatch, ifce.maxOutstandingPerChan)
|
||||||
|
ifce.outbound[i] = make(chan *outboundBatch, ifce.maxOutstandingPerChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
ifce.inPool = sync.Pool{New: func() any {
|
||||||
|
return packet.New()
|
||||||
|
}}
|
||||||
|
|
||||||
|
ifce.outPool = sync.Pool{New: func() any {
|
||||||
|
t := make([]byte, mtu)
|
||||||
|
return &t
|
||||||
|
}}
|
||||||
|
|
||||||
|
ifce.packetBatchPool = sync.Pool{New: func() any {
|
||||||
|
return newPacketBatch(ifce.inboundBatchSize)
|
||||||
|
}}
|
||||||
|
|
||||||
|
ifce.outboundBatchPool = sync.Pool{New: func() any {
|
||||||
|
return newOutboundBatch(ifce.outboundBatchSize)
|
||||||
|
}}
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
@@ -209,7 +352,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
// activate creates the interface on the host. After the interface is created, any
|
// activate creates the interface on the host. After the interface is created, any
|
||||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||||
// the interface isn't going to process anything until run() is called.
|
// the interface isn't going to process anything until run() is called.
|
||||||
func (f *Interface) activate() {
|
func (f *Interface) activate() error {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
@@ -230,33 +373,44 @@ func (f *Interface) activate() {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.inside.Activate(); err != nil {
|
if err = f.inside.Activate(); err != nil {
|
||||||
f.inside.Close()
|
f.inside.Close()
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) run() {
|
func (f *Interface) run(c context.Context) (func(), error) {
|
||||||
// Launch n queues to read packets from udp
|
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
|
// Launch n queues to read packets from udp
|
||||||
|
f.wg.Add(1)
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
|
|
||||||
|
// Launch n queues to read packets from tun dev
|
||||||
|
f.wg.Add(1)
|
||||||
|
go f.listenIn(f.readers[i], i)
|
||||||
|
|
||||||
|
// Launch n queues to read packets from tun dev
|
||||||
|
f.wg.Add(1)
|
||||||
|
go f.workerIn(i, c)
|
||||||
|
|
||||||
|
// Launch n queues to read packets from tun dev
|
||||||
|
f.wg.Add(1)
|
||||||
|
go f.workerOut(i, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
return f.wg.Wait, nil
|
||||||
for i := 0; i < f.routines; i++ {
|
|
||||||
go f.listenIn(f.readers[i], i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
@@ -264,41 +418,142 @@ func (f *Interface) listenOut(i int) {
|
|||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
batch := f.getPacketBatch()
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lastFlush := time.Now()
|
||||||
plaintext := make([]byte, udp.MTU)
|
|
||||||
h := &header.H{}
|
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
flush := func(force bool) {
|
||||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
if len(batch.packets) == 0 {
|
||||||
|
if force {
|
||||||
|
f.releasePacketBatch(batch)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.inbound[i] <- batch
|
||||||
|
batch = f.getPacketBatch()
|
||||||
|
lastFlush = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
|
p := f.inPool.Get().(*packet.Packet)
|
||||||
|
p.Payload = p.Payload[:mtu]
|
||||||
|
copy(p.Payload, payload)
|
||||||
|
p.Payload = p.Payload[:len(payload)]
|
||||||
|
p.Addr = fromUdpAddr
|
||||||
|
batch.add(p)
|
||||||
|
|
||||||
|
if len(batch.packets) >= f.inboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
|
||||||
|
flush(false)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if len(batch.packets) > 0 {
|
||||||
|
f.inbound[i] <- batch
|
||||||
|
} else {
|
||||||
|
f.releasePacketBatch(batch)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && !f.closed.Load() {
|
||||||
|
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
|
||||||
|
//TODO: Trigger Control to close
|
||||||
|
}
|
||||||
|
|
||||||
|
f.l.Debugf("underlay reader %v is done", i)
|
||||||
|
f.wg.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
packet := make([]byte, mtu)
|
batch := f.getOutboundBatch()
|
||||||
out := make([]byte, mtu)
|
lastFlush := time.Now()
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
flush := func(force bool) {
|
||||||
|
if len(batch.payloads) == 0 {
|
||||||
for {
|
if force {
|
||||||
n, err := reader.Read(packet)
|
f.releaseOutboundBatch(batch)
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
|
||||||
os.Exit(2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.outbound[i] <- batch
|
||||||
|
batch = f.getOutboundBatch()
|
||||||
|
lastFlush = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
p := f.outPool.Get().(*[]byte)
|
||||||
|
*p = (*p)[:mtu]
|
||||||
|
n, err := reader.Read(*p)
|
||||||
|
if err != nil {
|
||||||
|
if !f.closed.Load() {
|
||||||
|
f.l.WithError(err).Error("Error while reading outbound packet, closing")
|
||||||
|
//TODO: Trigger Control to close
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
*p = (*p)[:n]
|
||||||
|
batch.add(p)
|
||||||
|
|
||||||
|
if len(batch.payloads) >= f.outboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
|
||||||
|
flush(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(batch.payloads) > 0 {
|
||||||
|
f.outbound[i] <- batch
|
||||||
|
} else {
|
||||||
|
f.releaseOutboundBatch(batch)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.l.Debugf("overlay reader %v is done", i)
|
||||||
|
f.wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) workerIn(i int, ctx context.Context) {
|
||||||
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
fwPacket2 := &firewall.Packet{}
|
||||||
|
nb2 := make([]byte, 12, 12)
|
||||||
|
result2 := make([]byte, mtu)
|
||||||
|
h := &header.H{}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case batch := <-f.inbound[i]:
|
||||||
|
for _, p := range batch.packets {
|
||||||
|
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
||||||
|
p.Payload = p.Payload[:mtu]
|
||||||
|
f.inPool.Put(p)
|
||||||
|
}
|
||||||
|
f.releasePacketBatch(batch)
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) workerOut(i int, ctx context.Context) {
|
||||||
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
fwPacket1 := &firewall.Packet{}
|
||||||
|
nb1 := make([]byte, 12, 12)
|
||||||
|
result1 := make([]byte, mtu)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case batch := <-f.outbound[i]:
|
||||||
|
for _, data := range batch.payloads {
|
||||||
|
f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
|
||||||
|
*data = (*data)[:mtu]
|
||||||
|
f.outPool.Put(data)
|
||||||
|
}
|
||||||
|
f.releaseOutboundBatch(batch)
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,6 +706,7 @@ func (f *Interface) GetCertState() *CertState {
|
|||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
|
// Release the udp readers
|
||||||
for _, u := range f.writers {
|
for _, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -458,6 +714,13 @@ func (f *Interface) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the tun device
|
// Release the tun readers
|
||||||
return f.inside.Close()
|
for _, u := range f.readers {
|
||||||
|
err := u.Close()
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).Error("Error while closing tun device")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1425,7 +1425,7 @@ func (d *NebulaMetaDetails) GetRelays() []netip.Addr {
|
|||||||
return relays
|
return relays
|
||||||
}
|
}
|
||||||
|
|
||||||
// findNetworkUnion returns the first netip.Addr of addrs contained in the list of provided netip.Prefix, if able
|
// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able
|
||||||
func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
|
func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
|
||||||
for i := range prefixes {
|
for i := range prefixes {
|
||||||
for j := range addrs {
|
for j := range addrs {
|
||||||
@@ -1450,13 +1450,3 @@ func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, er
|
|||||||
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
|
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *NebulaControl) GetRelayFrom() netip.Addr {
|
|
||||||
if d.OldRelayFromAddr != 0 {
|
|
||||||
b := [4]byte{}
|
|
||||||
binary.BigEndian.PutUint32(b[:], d.OldRelayFromAddr)
|
|
||||||
return netip.AddrFrom4(b)
|
|
||||||
} else {
|
|
||||||
return protoAddrToNetAddr(d.RelayFromAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
26
main.go
26
main.go
@@ -221,6 +221,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
batchCfg := BatchConfig{
|
||||||
|
InboundBatchSize: c.GetInt("batch.inbound_size", inboundBatchSizeDefault),
|
||||||
|
OutboundBatchSize: c.GetInt("batch.outbound_size", outboundBatchSizeDefault),
|
||||||
|
FlushInterval: c.GetDuration("batch.flush_interval", batchFlushIntervalDefault),
|
||||||
|
MaxOutstandingPerChan: c.GetInt("batch.max_outstanding", maxOutstandingBatchesDefault),
|
||||||
|
}
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
HostMap: hostMap,
|
HostMap: hostMap,
|
||||||
Inside: tun,
|
Inside: tun,
|
||||||
@@ -242,6 +249,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
|
BatchConfig: batchCfg,
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,14 +292,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
ifce,
|
f: ifce,
|
||||||
l,
|
l: l,
|
||||||
ctx,
|
ctx: ctx,
|
||||||
cancel,
|
cancel: cancel,
|
||||||
sshStart,
|
sshStart: sshStart,
|
||||||
statsStart,
|
statsStart: statsStart,
|
||||||
dnsStart,
|
dnsStart: dnsStart,
|
||||||
lightHouse.StartUpdateWorker,
|
lighthouseStart: lightHouse.StartUpdateWorker,
|
||||||
connManager.Start,
|
connectionManagerStart: connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//f.l.Error("in packet ", h)
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -245,6 +245,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
||||||
Info("Host roamed to new udp ip/port.")
|
Info("Host roamed to new udp ip/port.")
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
@@ -470,7 +471,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
@@ -82,41 +81,3 @@ func prefixToMask(prefix netip.Prefix) netip.Addr {
|
|||||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||||
return addr
|
return addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func flipBytes(b []byte) []byte {
|
|
||||||
for i := 0; i < len(b); i++ {
|
|
||||||
b[i] ^= 0xFF
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
func orBytes(a []byte, b []byte) []byte {
|
|
||||||
ret := make([]byte, len(a))
|
|
||||||
for i := 0; i < len(a); i++ {
|
|
||||||
ret[i] = a[i] | b[i]
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
|
||||||
broadcast, _ := netip.AddrFromSlice(
|
|
||||||
orBytes(
|
|
||||||
cidr.Addr().AsSlice(),
|
|
||||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return broadcast
|
|
||||||
}
|
|
||||||
|
|
||||||
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
|
||||||
for _, gateway := range gateways {
|
|
||||||
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
|
||||||
return gateway, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
|
||||||
return gateway, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ func (t *tun) activate6(network netip.Prefix) error {
|
|||||||
Vltime: 0xffffffff,
|
Vltime: 0xffffffff,
|
||||||
Pltime: 0xffffffff,
|
Pltime: 0xffffffff,
|
||||||
},
|
},
|
||||||
|
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
|
||||||
Flags: _IN6_IFF_NODAD,
|
Flags: _IN6_IFF_NODAD,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -501,6 +501,30 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func flipBytes(b []byte) []byte {
|
||||||
|
for i := 0; i < len(b); i++ {
|
||||||
|
b[i] ^= 0xFF
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
func orBytes(a []byte, b []byte) []byte {
|
||||||
|
ret := make([]byte, len(a))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
ret[i] = a[i] | b[i]
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
||||||
|
broadcast, _ := netip.AddrFromSlice(
|
||||||
|
orBytes(
|
||||||
|
cidr.Addr().AsSlice(),
|
||||||
|
flipBytes(prefixToMask(cidr).AsSlice()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return broadcast
|
||||||
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,12 +4,13 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -19,42 +20,11 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type ifreqDestroy struct {
|
||||||
SIOCAIFADDR_IN6 = 0x8080696b
|
Name [16]byte
|
||||||
TUNSIFHEAD = 0x80047442
|
pad [16]byte
|
||||||
TUNSIFMODE = 0x80047458
|
|
||||||
)
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet6
|
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime addrLifetime
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreq struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
data int
|
|
||||||
}
|
|
||||||
|
|
||||||
type addrLifetime struct {
|
|
||||||
Expire uint64
|
|
||||||
Preferred uint64
|
|
||||||
Vltime uint32
|
|
||||||
Pltime uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
@@ -64,18 +34,40 @@ type tun struct {
|
|||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
f *os.File
|
|
||||||
fd int
|
io.ReadWriteCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
func (t *tun) Close() error {
|
||||||
|
if t.ReadWriteCloser != nil {
|
||||||
|
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||||
|
|
||||||
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
// Try to open tun device
|
||||||
|
var file *os.File
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName == "" {
|
if deviceName == "" {
|
||||||
@@ -85,23 +77,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
|
||||||
if err != nil {
|
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
|
||||||
}
|
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
f: os.NewFile(uintptr(fd), ""),
|
ReadWriteCloser: file,
|
||||||
fd: fd,
|
Device: deviceName,
|
||||||
Device: deviceName,
|
vpnNetworks: vpnNetworks,
|
||||||
vpnNetworks: vpnNetworks,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
l: l,
|
||||||
l: l,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -119,225 +105,40 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
|
||||||
if t.f != nil {
|
|
||||||
if err := t.f.Close(); err != nil {
|
|
||||||
return fmt.Errorf("error closing tun file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// t.f.Close should have handled it for us but let's be extra sure
|
|
||||||
_ = unix.Close(t.fd)
|
|
||||||
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ifr := ifreq{Name: t.deviceBytes()}
|
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
|
||||||
rc, err := t.f.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errno syscall.Errno
|
|
||||||
var n uintptr
|
|
||||||
err = rc.Read(func(fd uintptr) bool {
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
head := [4]byte{}
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&to[0], uint64(len(to))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
if errno.Temporary() {
|
|
||||||
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
if err == syscall.EBADF || err.Error() == "use of closed file" {
|
|
||||||
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
|
|
||||||
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
|
|
||||||
return 0, os.ErrClosed
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
|
|
||||||
}
|
|
||||||
|
|
||||||
// fix bytes read number to exclude header
|
|
||||||
bytesRead := int(n)
|
|
||||||
if bytesRead < 0 {
|
|
||||||
return bytesRead, nil
|
|
||||||
} else if bytesRead < 4 {
|
|
||||||
return 0, nil
|
|
||||||
} else {
|
|
||||||
return bytesRead - 4, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
if len(from) <= 1 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
var head [4]byte
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
if ipVer == 4 {
|
|
||||||
head[3] = syscall.AF_INET
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
head[3] = syscall.AF_INET6
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, err := t.f.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var errno syscall.Errno
|
|
||||||
var n uintptr
|
|
||||||
err = rc.Write(func(fd uintptr) bool {
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&from[0], uint64(len(from))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
// According to NetBSD documentation for TUN, writes will only return errors in which
|
|
||||||
// this packet will never be delivered so just go on living life.
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return 0, errno
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(n) - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
if cidr.Addr().Is4() {
|
var err error
|
||||||
var req ifreqAlias4
|
|
||||||
req.Name = t.deviceBytes()
|
|
||||||
req.Addr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.DstAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.MaskAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
// TODO use syscalls instead of exec.Command
|
||||||
if err != nil {
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
return err
|
t.l.Debug("command: ", cmd.String())
|
||||||
}
|
if err = cmd.Run(); err != nil {
|
||||||
defer syscall.Close(s)
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||||
var req ifreqAlias6
|
t.l.Debug("command: ", cmd.String())
|
||||||
req.Name = t.deviceBytes()
|
if err = cmd.Run(); err != nil {
|
||||||
req.Addr = unix.RawSockaddrInet6{
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
}
|
|
||||||
req.PrefixMask = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
}
|
|
||||||
req.Lifetime = addrLifetime{
|
|
||||||
Vltime: 0xffffffff,
|
|
||||||
Pltime: 0xffffffff,
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||||
}
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
func (t *tun) Activate() error {
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
mode := int32(unix.IFF_BROADCAST)
|
|
||||||
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun device mode: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
v := 1
|
|
||||||
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun device head: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
err = t.addIp(t.vpnNetworks[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
return t.addRoutes(false)
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
func (t *tun) Activate() error {
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
for i := range t.vpnNetworks {
|
||||||
if err != nil {
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
defer syscall.Close(s)
|
return nil
|
||||||
|
|
||||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
|
||||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
@@ -396,23 +197,21 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
if err := cmd.Run(); err != nil {
|
||||||
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,8 +224,10 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
//TODO: CERT-V2 is this right?
|
||||||
if err != nil {
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -441,109 +242,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_ADD,
|
|
||||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,50 +4,23 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
SIOCAIFADDR_IN6 = 0x8080691a
|
|
||||||
)
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet6
|
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime [2]uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreq struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
data int
|
|
||||||
}
|
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
@@ -55,46 +28,48 @@ type tun struct {
|
|||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
f *os.File
|
|
||||||
fd int
|
io.ReadWriteCloser
|
||||||
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
out []byte
|
out []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
func (t *tun) Close() error {
|
||||||
|
if t.ReadWriteCloser != nil {
|
||||||
|
return t.ReadWriteCloser.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
return nil
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
|
||||||
var err error
|
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName == "" {
|
if deviceName == "" {
|
||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||||
}
|
|
||||||
if !deviceNameRE.MatchString(deviceName) {
|
|
||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
if !deviceNameRE.MatchString(deviceName) {
|
||||||
|
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
|
||||||
if err != nil {
|
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
|
||||||
}
|
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
f: os.NewFile(uintptr(fd), ""),
|
ReadWriteCloser: file,
|
||||||
fd: fd,
|
Device: deviceName,
|
||||||
Device: deviceName,
|
vpnNetworks: vpnNetworks,
|
||||||
vpnNetworks: vpnNetworks,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
l: l,
|
||||||
l: l,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -112,154 +87,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
|
||||||
if t.f != nil {
|
|
||||||
if err := t.f.Close(); err != nil {
|
|
||||||
return fmt.Errorf("error closing tun file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// t.f.Close should have handled it for us but let's be extra sure
|
|
||||||
_ = unix.Close(t.fd)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
|
||||||
buf := make([]byte, len(to)+4)
|
|
||||||
|
|
||||||
n, err := t.f.Read(buf)
|
|
||||||
|
|
||||||
copy(to, buf[4:])
|
|
||||||
return n - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
buf := t.out
|
|
||||||
if cap(buf) < len(from)+4 {
|
|
||||||
buf = make([]byte, len(from)+4)
|
|
||||||
t.out = buf
|
|
||||||
}
|
|
||||||
buf = buf[:len(from)+4]
|
|
||||||
|
|
||||||
if len(from) == 0 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the IP Family for the NULL L2 Header
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
if ipVer == 4 {
|
|
||||||
buf[3] = syscall.AF_INET
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
buf[3] = syscall.AF_INET6
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(buf[4:], from)
|
|
||||||
|
|
||||||
n, err := t.f.Write(buf)
|
|
||||||
return n - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
|
||||||
if cidr.Addr().Is4() {
|
|
||||||
var req ifreqAlias4
|
|
||||||
req.Name = t.deviceBytes()
|
|
||||||
req.Addr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.DstAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.MaskAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = addRoute(cidr, t.vpnNetworks)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
|
||||||
var req ifreqAlias6
|
|
||||||
req.Name = t.deviceBytes()
|
|
||||||
req.Addr = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
}
|
|
||||||
req.PrefixMask = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
}
|
|
||||||
req.Lifetime[0] = 0xffffffff
|
|
||||||
req.Lifetime[1] = 0xffffffff
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
|
||||||
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
err = t.addIp(t.vpnNetworks[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
|
||||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -297,42 +124,63 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
|
var err error
|
||||||
|
// TODO use syscalls instead of exec.Command
|
||||||
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
|
return t.addRoutes(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
|
||||||
return t.vpnNetworks
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
|
||||||
return t.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
//TODO: CERT-V2 is this right?
|
||||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
if err := cmd.Run(); err != nil {
|
||||||
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -344,9 +192,10 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
//TODO: CERT-V2 is this right?
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -355,115 +204,52 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
for i, c := range t.Device {
|
return t.vpnNetworks
|
||||||
o[i] = byte(c)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
func (t *tun) Name() string {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
return t.Device
|
||||||
if err != nil {
|
}
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
Version: unix.RTM_VERSION,
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
Type: unix.RTM_ADD,
|
}
|
||||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
|
||||||
Seq: 1,
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
|
buf := make([]byte, len(to)+4)
|
||||||
|
|
||||||
|
n, err := t.ReadWriteCloser.Read(buf)
|
||||||
|
|
||||||
|
copy(to, buf[4:])
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is only valid for single threaded use
|
||||||
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
|
buf := t.out
|
||||||
|
if cap(buf) < len(from)+4 {
|
||||||
|
buf = make([]byte, len(from)+4)
|
||||||
|
t.out = buf
|
||||||
|
}
|
||||||
|
buf = buf[:len(from)+4]
|
||||||
|
|
||||||
|
if len(from) == 0 {
|
||||||
|
return 0, syscall.EIO
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
// Determine the IP Family for the NULL L2 Header
|
||||||
gw, err := selectGateway(prefix, gateways)
|
ipVer := from[0] >> 4
|
||||||
if err != nil {
|
if ipVer == 4 {
|
||||||
return err
|
buf[3] = syscall.AF_INET
|
||||||
}
|
} else if ipVer == 6 {
|
||||||
route.Addrs = []netroute.Addr{
|
buf[3] = syscall.AF_INET6
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
gw, err := selectGateway(prefix, gateways)
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := route.Marshal()
|
copy(buf[4:], from)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
n, err := t.ReadWriteCloser.Write(buf)
|
||||||
if err != nil {
|
return n - 4, err
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
12
packet/packet.go
Normal file
12
packet/packet.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package packet
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
type Packet struct {
|
||||||
|
Payload []byte
|
||||||
|
Addr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Packet {
|
||||||
|
return &Packet{Payload: make([]byte, 9001)}
|
||||||
|
}
|
||||||
@@ -155,8 +155,6 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
"vpnAddrs": h.vpnAddrs}).
|
"vpnAddrs": h.vpnAddrs}).
|
||||||
Info("handleCreateRelayResponse")
|
Info("handleCreateRelayResponse")
|
||||||
|
|
||||||
//peer == relayFrom
|
|
||||||
//target == relayTo
|
|
||||||
target := m.RelayToAddr
|
target := m.RelayToAddr
|
||||||
targetAddr := protoAddrToNetAddr(target)
|
targetAddr := protoAddrToNetAddr(target)
|
||||||
|
|
||||||
@@ -192,12 +190,11 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
InitiatorRelayIndex: peerRelay.RemoteIndex,
|
InitiatorRelayIndex: peerRelay.RemoteIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
relayFrom := h.vpnAddrs[0]
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
peer := peerHostInfo.vpnAddrs[0]
|
peer := peerHostInfo.vpnAddrs[0]
|
||||||
if !peer.Is4() {
|
if !peer.Is4() {
|
||||||
rm.l.WithField("relayFrom", peer).
|
rm.l.WithField("relayFrom", peer).
|
||||||
WithField("relayTo", targetAddr).
|
WithField("relayTo", target).
|
||||||
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
|
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
|
||||||
WithField("responderRelayIndex", resp.ResponderRelayIndex).
|
WithField("responderRelayIndex", resp.ResponderRelayIndex).
|
||||||
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
|
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
|
||||||
@@ -210,22 +207,7 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
b = targetAddr.As4()
|
b = targetAddr.As4()
|
||||||
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
} else {
|
} else {
|
||||||
ok = false
|
resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
|
||||||
peerNetworks := h.GetCert().Certificate.Networks()
|
|
||||||
for i := range peerNetworks {
|
|
||||||
if peerNetworks[i].Contains(targetAddr) {
|
|
||||||
relayFrom = peerNetworks[i].Addr()
|
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
rm.l.WithFields(logrus.Fields{"from": f.myVpnNetworks, "to": targetAddr}).
|
|
||||||
Error("cannot establish relay, no networks in common")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
|
||||||
resp.RelayToAddr = target
|
resp.RelayToAddr = target
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,8 +218,8 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": relayFrom,
|
"relayFrom": resp.RelayFromAddr,
|
||||||
"relayTo": targetAddr,
|
"relayTo": resp.RelayToAddr,
|
||||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||||
"vpnAddrs": peerHostInfo.vpnAddrs}).
|
"vpnAddrs": peerHostInfo.vpnAddrs}).
|
||||||
@@ -331,7 +313,8 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
msg, err := resp.Marshal()
|
msg, err := resp.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
logMsg.
|
||||||
|
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.WithFields(logrus.Fields{
|
||||||
@@ -377,10 +360,10 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
Type: NebulaControl_CreateRelayRequest,
|
Type: NebulaControl_CreateRelayRequest,
|
||||||
InitiatorRelayIndex: index,
|
InitiatorRelayIndex: index,
|
||||||
}
|
}
|
||||||
relayFrom := h.vpnAddrs[0]
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if !relayFrom.Is4() {
|
if !h.vpnAddrs[0].Is4() {
|
||||||
rm.l.WithField("relayFrom", relayFrom).
|
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
|
||||||
WithField("relayTo", target).
|
WithField("relayTo", target).
|
||||||
WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
|
WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
|
||||||
WithField("responderRelayIndex", req.ResponderRelayIndex).
|
WithField("responderRelayIndex", req.ResponderRelayIndex).
|
||||||
@@ -389,37 +372,23 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
b := relayFrom.As4()
|
b := h.vpnAddrs[0].As4()
|
||||||
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
b = target.As4()
|
b = target.As4()
|
||||||
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
} else {
|
} else {
|
||||||
ok = false
|
req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
|
||||||
peerNetworks := h.GetCert().Certificate.Networks()
|
|
||||||
for i := range peerNetworks {
|
|
||||||
if peerNetworks[i].Contains(target) {
|
|
||||||
relayFrom = peerNetworks[i].Addr()
|
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
rm.l.WithFields(logrus.Fields{"from": f.myVpnNetworks, "to": target}).
|
|
||||||
Error("cannot establish relay, no networks in common")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
|
||||||
req.RelayToAddr = netAddrToProtoAddr(target)
|
req.RelayToAddr = netAddrToProtoAddr(target)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.WithError(err).Error("relayManager Failed to marshal Control message to create relay")
|
logMsg.
|
||||||
|
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": relayFrom,
|
"relayFrom": h.vpnAddrs[0],
|
||||||
"relayTo": target,
|
"relayTo": target,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex": req.ResponderRelayIndex,
|
||||||
@@ -432,7 +401,8 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
if !ok {
|
if !ok {
|
||||||
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
|
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.WithError(err).Error("relayManager Failed to allocate a local index for relay")
|
logMsg.
|
||||||
|
WithError(err).Error("relayManager Failed to allocate a local index for relay")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,10 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(control *nebula.Control) (*Service, error) {
|
func New(control *nebula.Control) (*Service, error) {
|
||||||
control.Start()
|
wait, err := control.Start()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
ctx := control.Context()
|
ctx := control.Context()
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Add the nebula wait function to the group
|
||||||
|
eg.Go(func() error {
|
||||||
|
wait()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type EncReader func(
|
|||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader) error
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Close() error
|
Close() error
|
||||||
@@ -30,8 +30,8 @@ func (NoopConn) Rebind() error {
|
|||||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
return netip.AddrPort{}, nil
|
return netip.AddrPort{}, nil
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOut(_ EncReader) {
|
func (NoopConn) ListenOut(_ EncReader) error {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|||||||
return func() {}
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -174,14 +174,17 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
|
|||||||
@@ -71,15 +71,14 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader) {
|
func (u *GenericConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
|
|||||||
499
udp/udp_linux.go
499
udp/udp_linux.go
@@ -5,10 +5,13 @@ package udp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
@@ -17,19 +20,40 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultGSOMaxSegments = 8
|
||||||
|
defaultGSOFlushTimeout = 150 * time.Microsecond
|
||||||
|
defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
|
||||||
|
maxGSOBatchBytes = 0xFFFF
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errGSOFallback = errors.New("udp gso fallback")
|
||||||
|
errGSODisabled = errors.New("udp gso disabled")
|
||||||
|
)
|
||||||
|
|
||||||
type StdConn struct {
|
type StdConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
batch int
|
batch int
|
||||||
}
|
|
||||||
|
|
||||||
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
enableGRO bool
|
||||||
ip4 := ip.To4()
|
enableGSO bool
|
||||||
if ip4 != nil {
|
|
||||||
return ip4, true
|
gsoMu sync.Mutex
|
||||||
}
|
gsoBuf []byte
|
||||||
return ip, false
|
gsoAddr netip.AddrPort
|
||||||
|
gsoSegSize int
|
||||||
|
gsoSegments int
|
||||||
|
gsoMaxSegments int
|
||||||
|
gsoMaxBytes int
|
||||||
|
gsoFlushTimeout time.Duration
|
||||||
|
gsoTimer *time.Timer
|
||||||
|
|
||||||
|
groBufSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
@@ -55,6 +79,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set a read timeout
|
||||||
|
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
var sa unix.Sockaddr
|
var sa unix.Sockaddr
|
||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
sa4 := &unix.SockaddrInet4{Port: port}
|
sa4 := &unix.SockaddrInet4{Port: port}
|
||||||
@@ -69,7 +98,16 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
return &StdConn{
|
||||||
|
sysFd: fd,
|
||||||
|
isV4: ip.Is4(),
|
||||||
|
l: l,
|
||||||
|
batch: batch,
|
||||||
|
gsoMaxSegments: defaultGSOMaxSegments,
|
||||||
|
gsoMaxBytes: MTU * defaultGSOMaxSegments,
|
||||||
|
gsoFlushTimeout: defaultGSOFlushTimeout,
|
||||||
|
groBufSize: MTU,
|
||||||
|
}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
@@ -118,20 +156,46 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) error {
|
||||||
var ip netip.Addr
|
var (
|
||||||
|
ip netip.Addr
|
||||||
|
controls [][]byte
|
||||||
|
)
|
||||||
|
|
||||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
bufSize := u.readBufferSize()
|
||||||
|
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
desired := u.readBufferSize()
|
||||||
|
if len(buffers) == 0 || cap(buffers[0]) < desired {
|
||||||
|
msgs, buffers, names = u.PrepareRawMessages(u.batch, desired)
|
||||||
|
controls = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.enableGRO {
|
||||||
|
if controls == nil {
|
||||||
|
controls = make([][]byte, len(msgs))
|
||||||
|
for i := range controls {
|
||||||
|
controls[i] = make([]byte, unix.CmsgSpace(4))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := range msgs {
|
||||||
|
setRawMessageControl(&msgs[i], controls[i])
|
||||||
|
}
|
||||||
|
} else if controls != nil {
|
||||||
|
for i := range msgs {
|
||||||
|
setRawMessageControl(&msgs[i], nil)
|
||||||
|
}
|
||||||
|
controls = nil
|
||||||
|
}
|
||||||
|
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
@@ -141,11 +205,82 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
} else {
|
} else {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
}
|
}
|
||||||
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
||||||
|
payload := buffers[i][:msgs[i].Len]
|
||||||
|
|
||||||
|
if u.enableGRO && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
||||||
|
ctrlLen := getRawMessageControlLen(&msgs[i])
|
||||||
|
msgFlags := getRawMessageFlags(&msgs[i])
|
||||||
|
u.l.WithFields(logrus.Fields{
|
||||||
|
"tag": "gro-debug",
|
||||||
|
"stage": "recv",
|
||||||
|
"payload_len": len(payload),
|
||||||
|
"ctrl_len": ctrlLen,
|
||||||
|
"msg_flags": msgFlags,
|
||||||
|
}).Debug("gro batch data")
|
||||||
|
if controls != nil && ctrlLen > 0 {
|
||||||
|
maxDump := ctrlLen
|
||||||
|
if maxDump > 16 {
|
||||||
|
maxDump = 16
|
||||||
|
}
|
||||||
|
u.l.WithFields(logrus.Fields{
|
||||||
|
"tag": "gro-debug",
|
||||||
|
"stage": "control-bytes",
|
||||||
|
"control_hex": fmt.Sprintf("%x", controls[i][:maxDump]),
|
||||||
|
"datalen": ctrlLen,
|
||||||
|
}).Debug("gro control dump")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sawControl := false
|
||||||
|
if controls != nil {
|
||||||
|
if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
|
||||||
|
if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segSize > 0 {
|
||||||
|
sawControl = true
|
||||||
|
if u.l.IsLevelEnabled(logrus.DebugLevel) {
|
||||||
|
u.l.WithFields(logrus.Fields{
|
||||||
|
"tag": "gro-debug",
|
||||||
|
"stage": "control",
|
||||||
|
"seg_size": segSize,
|
||||||
|
"seg_count": segCount,
|
||||||
|
"payloadLen": len(payload),
|
||||||
|
}).Debug("gro control parsed")
|
||||||
|
}
|
||||||
|
segSize = normalizeGROSegSize(segSize, segCount, len(payload))
|
||||||
|
if segSize > 0 && segSize < len(payload) {
|
||||||
|
if u.emitGROSegments(r, addr, payload, segSize) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.enableGRO && len(payload) > MTU {
|
||||||
|
if !sawControl && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
||||||
|
u.l.WithFields(logrus.Fields{
|
||||||
|
"tag": "gro-debug",
|
||||||
|
"stage": "fallback",
|
||||||
|
"payload_len": len(payload),
|
||||||
|
}).Debug("gro control missing; splitting payload by MTU")
|
||||||
|
}
|
||||||
|
if u.emitGROSegments(r, addr, payload, MTU) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r(addr, payload)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) readBufferSize() int {
|
||||||
|
if u.enableGRO && u.groBufSize > MTU {
|
||||||
|
return u.groBufSize
|
||||||
|
}
|
||||||
|
return MTU
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||||
for {
|
for {
|
||||||
n, _, err := unix.Syscall6(
|
n, _, err := unix.Syscall6(
|
||||||
@@ -159,6 +294,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
|
if err == unix.EAGAIN || err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +318,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
|
if err == unix.EAGAIN || err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,6 +329,14 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||||
|
if u.enableGSO && ip.IsValid() {
|
||||||
|
if err := u.queueGSOPacket(b, ip); err == nil {
|
||||||
|
return nil
|
||||||
|
} else if !errors.Is(err, errGSOFallback) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
return u.writeTo4(b, ip)
|
return u.writeTo4(b, ip)
|
||||||
}
|
}
|
||||||
@@ -221,7 +370,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
|||||||
|
|
||||||
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
||||||
if !ip.Addr().Is4() {
|
if !ip.Addr().Is4() {
|
||||||
return ErrInvalidIPv6RemoteForSocket
|
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||||
}
|
}
|
||||||
|
|
||||||
var rsa unix.RawSockaddrInet4
|
var rsa unix.RawSockaddrInet4
|
||||||
@@ -294,6 +443,94 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
u.configureGRO(c)
|
||||||
|
u.configureGSO(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) configureGRO(c *config.C) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
enable := c.GetBool("listen.enable_gro", false)
|
||||||
|
if enable == u.enableGRO {
|
||||||
|
if enable {
|
||||||
|
if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 {
|
||||||
|
u.setGROBufferSize(size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if enable {
|
||||||
|
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
||||||
|
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.enableGRO = true
|
||||||
|
u.setGROBufferSize(c.GetInt("listen.gro_read_buffer", defaultGROReadBufferSize))
|
||||||
|
u.l.WithField("buffer_size", u.groBufSize).Info("UDP GRO enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
|
||||||
|
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
||||||
|
}
|
||||||
|
u.enableGRO = false
|
||||||
|
u.groBufSize = MTU
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) configureGSO(c *config.C) {
|
||||||
|
enable := c.GetBool("listen.enable_gso", false)
|
||||||
|
if !enable {
|
||||||
|
u.disableGSO()
|
||||||
|
} else {
|
||||||
|
u.enableGSO = true
|
||||||
|
}
|
||||||
|
|
||||||
|
segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
||||||
|
if segments < 1 {
|
||||||
|
segments = 1
|
||||||
|
}
|
||||||
|
u.gsoMaxSegments = segments
|
||||||
|
|
||||||
|
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = MTU * segments
|
||||||
|
}
|
||||||
|
if maxBytes > maxGSOBatchBytes {
|
||||||
|
u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
|
||||||
|
maxBytes = maxGSOBatchBytes
|
||||||
|
}
|
||||||
|
u.gsoMaxBytes = maxBytes
|
||||||
|
|
||||||
|
timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
|
||||||
|
if timeout < 0 {
|
||||||
|
timeout = 0
|
||||||
|
}
|
||||||
|
u.gsoFlushTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) setGROBufferSize(size int) {
|
||||||
|
if size < MTU {
|
||||||
|
size = defaultGROReadBufferSize
|
||||||
|
}
|
||||||
|
if size > maxGSOBatchBytes {
|
||||||
|
size = maxGSOBatchBytes
|
||||||
|
}
|
||||||
|
u.groBufSize = size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) disableGSO() {
|
||||||
|
u.gsoMu.Lock()
|
||||||
|
defer u.gsoMu.Unlock()
|
||||||
|
u.enableGSO = false
|
||||||
|
_ = u.flushGSOlocked()
|
||||||
|
u.gsoBuf = nil
|
||||||
|
u.gsoSegments = 0
|
||||||
|
u.gsoSegSize = 0
|
||||||
|
u.stopGSOTimerLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||||
@@ -305,7 +542,239 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) queueGSOPacket(b []byte, addr netip.AddrPort) error {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
u.gsoMu.Lock()
|
||||||
|
defer u.gsoMu.Unlock()
|
||||||
|
|
||||||
|
if !u.enableGSO || !addr.IsValid() || len(b) > u.gsoMaxBytes {
|
||||||
|
if err := u.flushGSOlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errGSOFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.gsoSegments == 0 {
|
||||||
|
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
||||||
|
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
||||||
|
}
|
||||||
|
u.gsoAddr = addr
|
||||||
|
u.gsoSegSize = len(b)
|
||||||
|
} else if addr != u.gsoAddr || len(b) != u.gsoSegSize {
|
||||||
|
if err := u.flushGSOlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
||||||
|
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
||||||
|
}
|
||||||
|
u.gsoAddr = addr
|
||||||
|
u.gsoSegSize = len(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(u.gsoBuf)+len(b) > u.gsoMaxBytes {
|
||||||
|
if err := u.flushGSOlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
||||||
|
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
||||||
|
}
|
||||||
|
u.gsoAddr = addr
|
||||||
|
u.gsoSegSize = len(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
u.gsoBuf = append(u.gsoBuf, b...)
|
||||||
|
u.gsoSegments++
|
||||||
|
|
||||||
|
if u.gsoSegments >= u.gsoMaxSegments || u.gsoFlushTimeout <= 0 {
|
||||||
|
return u.flushGSOlocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
u.scheduleGSOFlushLocked()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) flushGSOlocked() error {
|
||||||
|
if u.gsoSegments == 0 {
|
||||||
|
u.stopGSOTimerLocked()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := append([]byte(nil), u.gsoBuf...)
|
||||||
|
addr := u.gsoAddr
|
||||||
|
segSize := u.gsoSegSize
|
||||||
|
|
||||||
|
u.gsoBuf = u.gsoBuf[:0]
|
||||||
|
u.gsoSegments = 0
|
||||||
|
u.gsoSegSize = 0
|
||||||
|
u.stopGSOTimerLocked()
|
||||||
|
|
||||||
|
if segSize <= 0 {
|
||||||
|
return errGSOFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
err := u.sendSegmented(payload, addr, segSize)
|
||||||
|
if errors.Is(err, errGSODisabled) {
|
||||||
|
u.l.WithField("addr", addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
|
||||||
|
u.enableGSO = false
|
||||||
|
return u.sendSegmentsIndividually(payload, addr, segSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int) error {
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
control := make([]byte, unix.CmsgSpace(2))
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.SOL_UDP
|
||||||
|
hdr.Type = unix.UDP_SEGMENT
|
||||||
|
setCmsgLen(hdr, unix.CmsgLen(2))
|
||||||
|
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
|
||||||
|
|
||||||
|
var sa unix.Sockaddr
|
||||||
|
if addr.Addr().Is4() {
|
||||||
|
var sa4 unix.SockaddrInet4
|
||||||
|
sa4.Port = int(addr.Port())
|
||||||
|
sa4.Addr = addr.Addr().As4()
|
||||||
|
sa = &sa4
|
||||||
|
} else {
|
||||||
|
var sa6 unix.SockaddrInet6
|
||||||
|
sa6.Port = int(addr.Port())
|
||||||
|
sa6.Addr = addr.Addr().As16()
|
||||||
|
sa = &sa6
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := unix.SendmsgN(u.sysFd, payload, control, sa, 0); err != nil {
|
||||||
|
if errno, ok := err.(syscall.Errno); ok && (errno == unix.EINVAL || errno == unix.ENOTSUP || errno == unix.EOPNOTSUPP) {
|
||||||
|
return errGSODisabled
|
||||||
|
}
|
||||||
|
return &net.OpError{Op: "sendmsg", Err: err}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) sendSegmentsIndividually(buf []byte, addr netip.AddrPort, segSize int) error {
|
||||||
|
if segSize <= 0 {
|
||||||
|
return errGSOFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
for offset := 0; offset < len(buf); offset += segSize {
|
||||||
|
end := offset + segSize
|
||||||
|
if end > len(buf) {
|
||||||
|
end = len(buf)
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
if u.isV4 {
|
||||||
|
err = u.writeTo4(buf[offset:end], addr)
|
||||||
|
} else {
|
||||||
|
err = u.writeTo6(buf[offset:end], addr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) scheduleGSOFlushLocked() {
|
||||||
|
if u.gsoTimer == nil {
|
||||||
|
u.gsoTimer = time.AfterFunc(u.gsoFlushTimeout, u.gsoFlushTimer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.gsoTimer.Reset(u.gsoFlushTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) stopGSOTimerLocked() {
|
||||||
|
if u.gsoTimer != nil {
|
||||||
|
u.gsoTimer.Stop()
|
||||||
|
u.gsoTimer = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) gsoFlushTimer() {
|
||||||
|
u.gsoMu.Lock()
|
||||||
|
defer u.gsoMu.Unlock()
|
||||||
|
_ = u.flushGSOlocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGROControl(control []byte) (int, int) {
|
||||||
|
if len(control) == 0 {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
cmsgs, err := unix.ParseSocketControlMessage(control)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cmsgs {
|
||||||
|
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
||||||
|
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
|
||||||
|
segCount := 0
|
||||||
|
if len(c.Data) >= 4 {
|
||||||
|
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
|
||||||
|
}
|
||||||
|
return segSize, segCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
|
||||||
|
if segSize <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for offset := 0; offset < len(payload); offset += segSize {
|
||||||
|
end := offset + segSize
|
||||||
|
if end > len(payload) {
|
||||||
|
end = len(payload)
|
||||||
|
}
|
||||||
|
segment := make([]byte, end-offset)
|
||||||
|
copy(segment, payload[offset:end])
|
||||||
|
r(addr, segment)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGROSegSize(segSize, segCount, total int) int {
|
||||||
|
if segSize <= 0 || total <= 0 {
|
||||||
|
return segSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if segSize > total && segCount > 0 {
|
||||||
|
segSize = total / segCount
|
||||||
|
if segSize == 0 {
|
||||||
|
segSize = total
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if segCount <= 1 && segSize > 0 && total > segSize {
|
||||||
|
calculated := total / segSize
|
||||||
|
if calculated <= 1 {
|
||||||
|
calculated = (total + segSize - 1) / segSize
|
||||||
|
}
|
||||||
|
if calculated > 1 {
|
||||||
|
segCount = calculated
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if segSize > MTU {
|
||||||
|
return MTU
|
||||||
|
}
|
||||||
|
|
||||||
|
return segSize
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) Close() error {
|
func (u *StdConn) Close() error {
|
||||||
|
u.disableGSO()
|
||||||
return syscall.Close(u.sysFd)
|
return syscall.Close(u.sysFd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,13 +30,16 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
|
if bufSize <= 0 {
|
||||||
|
bufSize = MTU
|
||||||
|
}
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, MTU)
|
buffers[i] = make([]byte, bufSize)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
@@ -52,3 +55,25 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
|||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
||||||
|
if len(buf) == 0 {
|
||||||
|
msg.Hdr.Control = nil
|
||||||
|
msg.Hdr.Controllen = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg.Hdr.Control = &buf[0]
|
||||||
|
msg.Hdr.Controllen = uint32(len(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRawMessageControlLen(msg *rawMessage) int {
|
||||||
|
return int(msg.Hdr.Controllen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRawMessageFlags(msg *rawMessage) int {
|
||||||
|
return int(msg.Hdr.Flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||||
|
h.Len = uint32(l)
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,13 +33,16 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
|
if bufSize <= 0 {
|
||||||
|
bufSize = MTU
|
||||||
|
}
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, MTU)
|
buffers[i] = make([]byte, bufSize)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
@@ -55,3 +58,25 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
|||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
||||||
|
if len(buf) == 0 {
|
||||||
|
msg.Hdr.Control = nil
|
||||||
|
msg.Hdr.Controllen = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg.Hdr.Control = &buf[0]
|
||||||
|
msg.Hdr.Controllen = uint64(len(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRawMessageControlLen(msg *rawMessage) int {
|
||||||
|
return int(msg.Hdr.Controllen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRawMessageFlags(msg *rawMessage) int {
|
||||||
|
return int(msg.Hdr.Flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||||
|
h.Len = uint64(l)
|
||||||
|
}
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader) {
|
func (u *RIOConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|||||||
Reference in New Issue
Block a user