Compare commits

..

1 Commits

Author SHA1 Message Date
Jay Wren
2400e2392b lint
* reduce staticcheck warnings
2025-04-14 14:33:46 -04:00
88 changed files with 2317 additions and 3723 deletions

View File

@@ -1,21 +1,13 @@
blank_issues_enabled: true
contact_links:
- name: 💨 Performance Issues
url: https://github.com/slackhq/nebula/discussions/new/choose
about: 'We ask that you create a discussion instead of an issue for performance-related questions. This allows us to have a more open conversation about the issue and helps us to better understand the problem.'
- name: 📄 Documentation Issues
url: https://github.com/definednet/nebula-docs
about: "If you've found an issue with the website documentation, please file it in the nebula-docs repository."
- name: 📱 Mobile Nebula Issues
url: https://github.com/definednet/mobile_nebula
about: "If you're using the mobile Nebula app and have found an issue, please file it in the mobile_nebula repository."
- name: 📘 Documentation
url: https://nebula.defined.net/docs/
about: 'The documentation is the best place to start if you are new to Nebula.'
about: Review documentation.
- name: 💁 Support/Chat
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
about: 'For faster support, join us on Slack for assistance!'
url: https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU
about: 'This issue tracker is not for support questions. Join us on Slack for assistance!'
- name: 📱 Mobile Nebula
url: https://github.com/definednet/mobile_nebula
about: 'This issue tracker is not for mobile support. Try the Mobile Nebula repo instead!'

View File

@@ -1,11 +0,0 @@
<!--
Thank you for taking the time to submit a pull request!
Please be sure to provide a clear description of what you're trying to achieve with the change.
- If you're submitting a new feature, please explain how to use it and document any new config options in the example config.
- If you're submitting a bugfix, please link the related issue or describe the circumstances surrounding the issue.
- If you're changing a default, explain why you believe the new default is appropriate for most users.
P.S. If you're only updating the README or other docs, please file a pull request here instead: https://github.com/DefinedNet/nebula-docs
-->

View File

@@ -16,9 +16,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Install goimports

View File

@@ -12,9 +12,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build
@@ -35,9 +35,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build
@@ -68,9 +68,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Import certificates

View File

@@ -22,9 +22,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version-file: 'go.mod'
check-latest: true
- name: add hashicorp source

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: build

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build
@@ -32,9 +32,9 @@ jobs:
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
uses: golangci/golangci-lint-action@v7
with:
version: v2.5
version: v2.0
- name: Test
run: make test
@@ -58,9 +58,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build
@@ -79,9 +79,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.22'
check-latest: true
- name: Build
@@ -100,9 +100,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.24'
check-latest: true
- name: Build nebula
@@ -115,9 +115,9 @@ jobs:
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
uses: golangci/golangci-lint-action@v7
with:
version: v2.5
version: v2.0
- name: Test
run: make test

View File

@@ -4,7 +4,7 @@ It lets you seamlessly connect computers anywhere in the world. Nebula is portab
It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
and tunneling.
and tunneling, and each of those individual pieces existed before Nebula in various forms.
What makes Nebula different to existing offerings is that it brings all of these ideas together,
resulting in a sum that is greater than its individual parts.
@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
You can read more about Nebula [here](https://medium.com/p/884110a5579).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
## Supported Platforms
@@ -28,33 +28,33 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
#### Distribution Packages
- [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/)
```sh
sudo pacman -S nebula
```
$ sudo pacman -S nebula
```
- [Fedora Linux](https://src.fedoraproject.org/rpms/nebula)
```sh
sudo dnf install nebula
```
$ sudo dnf install nebula
```
- [Debian Linux](https://packages.debian.org/source/stable/nebula)
```sh
sudo apt install nebula
```
$ sudo apt install nebula
```
- [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula)
```sh
sudo apk add nebula
```
$ sudo apk add nebula
```
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
```sh
brew install nebula
```
$ brew install nebula
```
- [Docker](https://hub.docker.com/r/nebulaoss/nebula)
```sh
docker pull nebulaoss/nebula
```
$ docker pull nebulaoss/nebula
```
#### Mobile
@@ -64,10 +64,10 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
## Technical Overview
Nebula is a mutually authenticated peer-to-peer software-defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups.
Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes.
Discovery nodes (aka lighthouses) allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
Discovery nodes allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme.
Nebula uses Elliptic-curve Diffie-Hellman (`ECDH`) key exchange and `AES-256-GCM` in its default configuration.
@@ -82,34 +82,28 @@ To set up a Nebula network, you'll need:
#### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse.
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $6/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $5/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
#### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network.
```sh
```
./nebula-cert ca -name "Myorganization, Inc"
```
This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
**Be aware!** By default, certificate authorities have a 1-year lifetime before expiration. See [this guide](https://nebula.defined.net/docs/guides/rotating-certificate-authority/) for details on rotating a CA.
#### 4. Nebula host keys and certificates generated from that certificate authority
This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network.
```sh
```
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh"
./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers"
./nebula-cert sign -name "host3" -ip "192.168.100.10/24"
```
By default, host certificates will expire 1 second before the CA expires. Use the `-duration` flag to specify a shorter lifetime.
#### 5. Configuration files for each host
Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yml).
* On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set.
@@ -124,13 +118,10 @@ For each host, copy the nebula binary to the host, along with `config.yml` from
**DO NOT COPY `ca.key` TO INDIVIDUAL NODES.**
#### 7. Run nebula on each host
```sh
```
./nebula -config /path/to/config.yml
```
For more detailed instructions, [find the full documentation here](https://nebula.defined.net/docs/).
## Building Nebula from source
Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory.
@@ -149,10 +140,8 @@ The default curve used for cryptographic handshakes and signatures is Curve25519
In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
```sh
make bin-boringcrypto
make release-boringcrypto
```
This is not the recommended default deployment, but may be useful based on your compliance requirements.
@@ -160,3 +149,5 @@ This is not the recommended default deployment, but may be useful based on your
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.

View File

@@ -25,14 +25,14 @@ func TestNewAllowListFromConfig(t *testing.T) {
c.Settings["allowlist"] = map[string]any{
"192.168.0.0/16": "abc",
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
_, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[string]any{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
_, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[string]any{
@@ -42,7 +42,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
_, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[string]any{
@@ -75,7 +75,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
`docker.*`: "foo",
},
}
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
_, err = NewLocalAllowListFromConfig(c, "allowlist")
require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[string]any{
@@ -84,7 +84,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
`eth.*`: true,
},
}
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
_, err = NewLocalAllowListFromConfig(c, "allowlist")
require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[string]any{
@@ -92,7 +92,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
`docker.*`: false,
},
}
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
if assert.NoError(t, err) {
assert.NotNil(t, lr)
}

10
bits.go
View File

@@ -18,7 +18,7 @@ type Bits struct {
func NewBits(bits uint64) *Bits {
return &Bits{
length: bits,
bits: make([]bool, bits, bits),
bits: make([]bool, bits),
current: 0,
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
@@ -28,7 +28,7 @@ func NewBits(bits uint64) *Bits {
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
if i > b.current || (i == 0 && !b.firstSeen && b.current < b.length) {
return true
}
@@ -51,7 +51,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through
if i > b.length && b.bits[i%b.length] == false {
if i > b.length && !b.bits[i%b.length] {
b.lostCounter.Inc(1)
}
b.bits[i%b.length] = true
@@ -104,7 +104,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
}
// Allow for the 0 packet to come in within the first window
if i == 0 && b.firstSeen == false && b.current < b.length {
if i == 0 && !b.firstSeen && b.current < b.length {
b.firstSeen = true
b.bits[i%b.length] = true
return true
@@ -122,7 +122,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false
}
if b.bits[i%b.length] == true {
if b.bits[i%b.length] {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")

View File

@@ -84,11 +84,16 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[string]any)
rawMap, ok := value.(map[any]any)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawCIDR, rawValue := range rawMap {
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
@@ -124,7 +129,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
}
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[string]any)
rawMap, ok := raw.(map[any]any)
if !ok {
return nil, fmt.Errorf("invalid type: %T", raw)
}

View File

@@ -58,9 +58,6 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve
@@ -138,7 +135,8 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
case Version2:
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
default:
return nil, ErrUnknownVersion
//TODO: CERT-V2 make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
}
if err != nil {

View File

@@ -20,8 +20,6 @@ import (
"google.golang.org/protobuf/proto"
)
const publicKeyLen = 32
type certificateV1 struct {
details detailsV1
signature []byte
@@ -83,10 +81,6 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey
}
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte {
return c.signature
}
@@ -114,10 +108,8 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -1,7 +1,6 @@
package cert
import (
"crypto/ed25519"
"fmt"
"net/netip"
"testing"
@@ -14,7 +13,6 @@ import (
)
func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -62,58 +60,6 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{

View File

@@ -114,10 +114,6 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
@@ -153,10 +149,8 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -15,7 +15,6 @@ import (
)
func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -76,58 +75,6 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{
details: detailsV2{
@@ -166,14 +113,14 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
}
b, err := nc.MarshalJSON()
_, err := nc.MarshalJSON()
require.ErrorIs(t, err, ErrMissingDetails)
rd, err := nc.details.Marshal()
require.NoError(t, err)
nc.rawDetails = rd
b, err = nc.MarshalJSON()
b, err := nc.MarshalJSON()
require.NoError(t, err)
assert.JSONEq(
t,
@@ -227,8 +174,9 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
require.ErrorIs(t, err, ErrInvalidPrivateKey)
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
_, _, curve, err = UnmarshalPrivateKeyFromPEM(priv)
assert.Equal(t, err, nil)
assert.Equal(t, curve, Curve_P256)
err = c.VerifyPrivateKey(Curve_P256, priv[:16])
require.ErrorIs(t, err, ErrInvalidPrivateKey)
@@ -314,6 +262,7 @@ func TestCertificateV2_marshalForSigningStability(t *testing.T) {
assert.Equal(t, expectedRawDetails, db)
expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
require.NoError(t, err)
b, err := nc.marshalForSigning()
require.NoError(t, err)
assert.Equal(t, expectedForSigning, b)

View File

@@ -227,6 +227,9 @@ func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
}
func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
// Are we testing the compilers types here?
// No value of int32 is lewss than math.MinInt32.
// By definition these checks can never be true.
if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
}

View File

@@ -26,21 +26,21 @@ func TestNewArgon2Parameters(t *testing.T) {
}
func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
passphrase := []byte("DO NOT USE")
passphrase := []byte("DO NOT USE THIS KEY")
privKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiCPoDfGQiosxNPTbPn5EsMlc2MI
c0Bt4oz6gTrFQhX3aBJcimhHKeAuhyTGvllD0Z19fe+DFPcLH3h5VrdjVfIAajg0
KrbV3n9UHif/Au5skWmquNJzoW1E4MTdRbvpti6o+WdQ49DxjBFhx0YH8LBqrbPU
0BGkUHmIO7daP24=
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
qrlJ69wer3ZUHFXA
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
`)
shortKey := []byte(`# A key which, once decrypted, is too short
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiAVJwdfl3r+eqi/vF6S7OMdpjfo
hAzmTCRnr58Su4AqmBJbCv3zleYCEKYJP6UI3S8ekLMGISsgO4hm5leukCCyqT0Z
cQ76yrberpzkJKoPLGisX8f+xdy4aXSZl7oEYWQte1+vqbtl/eY9PGZhxUQdcyq7
hqzIyrRqfUgVuA==
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
rQr3bdH3Oy/WiYU=
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
`)
invalidBanner := []byte(`# Invalid banner (not encrypted)
@@ -72,12 +72,14 @@ qrlJ69wer3ZUHFXA
require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to invalid banner
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
@@ -85,12 +87,14 @@ qrlJ69wer3ZUHFXA
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to invalid passphrase
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
require.EqualError(t, err, "invalid passphrase or corrupt private key")
assert.Nil(t, k)
assert.Equal(t, []byte{}, rest)
assert.Equal(t, curve, Curve_CURVE25519)
}
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {

View File

@@ -20,7 +20,6 @@ var (
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")

View File

@@ -21,6 +21,9 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
switch curve {
case Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
case Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {

View File

@@ -7,26 +7,19 @@ import (
"golang.org/x/crypto/ed25519"
)
const ( //cert banners
const (
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
)
const ( //key-agreement-key banners
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
)
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -58,16 +51,6 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
@@ -79,19 +62,6 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
}
}
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
@@ -103,7 +73,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32
curve = Curve_CURVE25519
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
case P256PublicKeyBanner:
// Uncompressed
expectedLen = 65
curve = Curve_P256

View File

@@ -97,12 +97,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
// Fail due to short key
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
// Fail due to invalid banner
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
@@ -110,6 +112,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
// it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
@@ -159,12 +162,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// Fail due to short key
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
// Fail due to invalid banner
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "bytes did not contain a proper private key banner")
@@ -172,12 +177,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -231,7 +236,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -242,12 +246,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -264,22 +262,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
@@ -290,12 +281,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// Fail due to short key
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
// Fail due to invalid banner
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
require.EqualError(t, err, "bytes did not contain a proper public key banner")
assert.Equal(t, rest, invalidPem)
@@ -303,6 +296,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
}

View File

@@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/sha256"
"fmt"
"math/big"
"net/netip"
"time"
)
@@ -54,10 +55,15 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
if err != nil {
return nil, err
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1

View File

@@ -37,6 +37,7 @@ func TestCertificateV1_Sign(t *testing.T) {
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
require.NoError(t, err)
assert.NotNil(t, c)

View File

@@ -22,6 +22,9 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
switch curve {
case cert.Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
case cert.Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {

View File

@@ -81,7 +81,7 @@ func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert
return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil
}
func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
func ca(args []string, out io.Writer, _ io.Writer, pr PasswordReader) error {
cf := newCaFlags()
err := cf.set.Parse(args)
if err != nil {

View File

@@ -29,7 +29,7 @@ func newKeygenFlags() *keygenFlags {
return &cf
}
func keygen(args []string, out io.Writer, errOut io.Writer) error {
func keygen(args []string, _ io.Writer, _ io.Writer) error {
cf := newKeygenFlags()
err := cf.set.Parse(args)
if err != nil {

View File

@@ -3,7 +3,6 @@ package main
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"testing"
@@ -77,7 +76,7 @@ func assertHelpError(t *testing.T, err error, msg string) {
case *helpError:
// good
default:
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
t.Fatalf("err was not a helpError: %q, expected %q", err, msg)
}
require.EqualError(t, err, msg)

View File

@@ -10,7 +10,7 @@ func p11Supported() bool {
return false
}
func p11Flag(set *flag.FlagSet) *string {
func p11Flag(_ *flag.FlagSet) *string {
var ret = ""
return &ret
}

View File

@@ -1,12 +1,12 @@
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"os"
"strings"
"github.com/skip2/go-qrcode"
"github.com/slackhq/nebula/cert"
@@ -29,7 +29,7 @@ func newPrintFlags() *printFlags {
return &pf
}
func printCert(args []string, out io.Writer, errOut io.Writer) error {
func printCert(args []string, out io.Writer, _ io.Writer) error {
pf := newPrintFlags()
err := pf.set.Parse(args)
if err != nil {
@@ -72,7 +72,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
qrBytes = append(qrBytes, b...)
}
if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
if len(rawCert) == 0 || len(bytes.TrimSpace(rawCert)) == 0 {
break
}

View File

@@ -1,12 +1,12 @@
package main
import (
"bytes"
"errors"
"flag"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/slackhq/nebula/cert"
@@ -52,7 +52,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
if len(rawCACert) == 0 || len(bytes.TrimSpace(rawCACert)) == 0 {
break
}
}

View File

@@ -97,7 +97,7 @@ func Test_verify(t *testing.T) {
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
// Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature
pub := crt.PublicKey()
for i, _ := range pub {
for i := range pub {
pub[i] = 0
}
b, _ = crt.MarshalPEM()

View File

@@ -51,10 +51,7 @@ func (p *program) Stop(s service.Service) error {
func fileExists(filename string) bool {
_, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return true
return !os.IsNotExist(err)
}
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {

View File

@@ -63,7 +63,7 @@ func (c *C) Load(path string) error {
func (c *C) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
return errors.New("empty configuration")
}
return c.parseRaw([]byte(raw))
}

View File

@@ -4,16 +4,13 @@ import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
)
@@ -30,6 +27,12 @@ const (
)
type connectionManager struct {
in map[uint32]struct{}
inLock *sync.RWMutex
out map[uint32]struct{}
outLock *sync.RWMutex
// relayUsed holds which relay localIndexs are in use
relayUsed map[uint32]struct{}
relayUsedLock *sync.RWMutex
@@ -37,121 +40,121 @@ type connectionManager struct {
hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32]
intf *Interface
pendingDeletion map[uint32]struct{}
punchy *Punchy
// Configuration settings
checkInterval time.Duration
pendingDeletionInterval time.Duration
inactivityTimeout atomic.Int64
dropInactive atomic.Bool
metricsTxPunchy metrics.Counter
l *logrus.Logger
}
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
cm := &connectionManager{
hostMap: hm,
l: l,
punchy: p,
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
var max time.Duration
if checkInterval < pendingDeletionInterval {
max = pendingDeletionInterval
} else {
max = checkInterval
}
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
inLock: &sync.RWMutex{},
out: make(map[uint32]struct{}),
outLock: &sync.RWMutex{},
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
intf: intf,
pendingDeletion: make(map[uint32]struct{}),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
punchy: punchy,
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
l: l,
}
cm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
cm.reload(c, false)
})
return cm
nc.Start(ctx)
return nc
}
func (cm *connectionManager) reload(c *config.C, initial bool) {
if initial {
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
// pretty close to their configured duration.
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
}
if initial || c.HasChanged("tunnels.inactivity_timeout") {
old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial {
cm.l.WithField("oldDuration", old).
WithField("newDuration", cm.getInactivityTimeout()).
Info("Inactivity timeout has changed")
}
}
if initial || c.HasChanged("tunnels.drop_inactive") {
old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial {
cm.l.WithField("oldBool", old).
WithField("newBool", cm.dropInactive.Load()).
Info("Drop inactive setting has changed")
}
}
}
func (cm *connectionManager) getInactivityTimeout() time.Duration {
return (time.Duration)(cm.inactivityTimeout.Load())
}
func (cm *connectionManager) In(h *HostInfo) {
h.in.Store(true)
}
func (cm *connectionManager) Out(h *HostInfo) {
h.out.Store(true)
}
func (cm *connectionManager) RelayUsed(localIndex uint32) {
cm.relayUsedLock.RLock()
func (n *connectionManager) In(localIndex uint32) {
n.inLock.RLock()
// If this already exists, return
if _, ok := cm.relayUsed[localIndex]; ok {
cm.relayUsedLock.RUnlock()
if _, ok := n.in[localIndex]; ok {
n.inLock.RUnlock()
return
}
cm.relayUsedLock.RUnlock()
cm.relayUsedLock.Lock()
cm.relayUsed[localIndex] = struct{}{}
cm.relayUsedLock.Unlock()
n.inLock.RUnlock()
n.inLock.Lock()
n.in[localIndex] = struct{}{}
n.inLock.Unlock()
}
func (n *connectionManager) Out(localIndex uint32) {
n.outLock.RLock()
// If this already exists, return
if _, ok := n.out[localIndex]; ok {
n.outLock.RUnlock()
return
}
n.outLock.RUnlock()
n.outLock.Lock()
n.out[localIndex] = struct{}{}
n.outLock.Unlock()
}
func (n *connectionManager) RelayUsed(localIndex uint32) {
n.relayUsedLock.RLock()
// If this already exists, return
if _, ok := n.relayUsed[localIndex]; ok {
n.relayUsedLock.RUnlock()
return
}
n.relayUsedLock.RUnlock()
n.relayUsedLock.Lock()
n.relayUsed[localIndex] = struct{}{}
n.relayUsedLock.Unlock()
}
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
// resets the state for this local index
func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
in := h.in.Swap(false)
out := h.out.Swap(false)
if in || out {
h.lastUsed = now
}
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
n.inLock.Lock()
n.outLock.Lock()
_, in := n.in[localIndex]
_, out := n.out[localIndex]
delete(n.in, localIndex)
delete(n.out, localIndex)
n.inLock.Unlock()
n.outLock.Unlock()
return in, out
}
// AddTrafficWatch must be called for every new HostInfo.
// We will continue to monitor the HostInfo until the tunnel is dropped.
func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
if h.out.Swap(true) == false {
cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
n.outLock.Lock()
if _, ok := n.out[localIndex]; ok {
n.outLock.Unlock()
return
}
n.out[localIndex] = struct{}{}
n.trafficTimer.Add(localIndex, n.checkInterval)
n.outLock.Unlock()
}
func (cm *connectionManager) Start(ctx context.Context) {
clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
func (n *connectionManager) Start(ctx context.Context) {
go n.Run(ctx)
}
func (n *connectionManager) Run(ctx context.Context) {
//TODO: this tick should be based on the min wheel tick? Check firewall
clockSource := time.NewTicker(500 * time.Millisecond)
defer clockSource.Stop()
p := []byte("")
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
out := make([]byte, mtu)
for {
@@ -160,61 +163,61 @@ func (cm *connectionManager) Start(ctx context.Context) {
return
case now := <-clockSource.C:
cm.trafficTimer.Advance(now)
n.trafficTimer.Advance(now)
for {
localIndex, has := cm.trafficTimer.Purge()
localIndex, has := n.trafficTimer.Purge()
if !has {
break
}
cm.doTrafficCheck(localIndex, p, nb, out, now)
n.doTrafficCheck(localIndex, p, nb, out, now)
}
}
}
}
func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
switch decision {
case deleteTunnel:
if cm.hostMap.DeleteHostInfo(hostinfo) {
if n.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
}
case closeTunnel:
cm.intf.sendCloseTunnel(hostinfo)
cm.intf.closeTunnel(hostinfo)
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo)
case swapPrimary:
cm.swapPrimary(hostinfo, primary)
n.swapPrimary(hostinfo, primary)
case migrateRelays:
cm.migrateRelayUsed(hostinfo, primary)
n.migrateRelayUsed(hostinfo, primary)
case tryRehandshake:
cm.tryRehandshake(hostinfo)
n.tryRehandshake(hostinfo)
case sendTestPacket:
cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
}
cm.resetRelayTrafficCheck(hostinfo)
n.resetRelayTrafficCheck(hostinfo)
}
func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
if hostinfo != nil {
cm.relayUsedLock.Lock()
defer cm.relayUsedLock.Unlock()
n.relayUsedLock.Lock()
defer n.relayUsedLock.Unlock()
// No need to migrate any relays, delete usage info now.
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(cm.relayUsed, idx)
delete(n.relayUsed, idx)
}
}
}
func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor {
@@ -224,51 +227,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
var relayFrom netip.Addr
var relayTo netip.Addr
switch {
case ok:
switch existing.State {
case Established, PeerRequested, Disestablished:
case ok && existing.State == Established:
// This relay already exists in newhostinfo, then do nothing.
continue
case Requested:
case ok && existing.State == Requested:
// The relay exists in a Requested state; re-send the request
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = cm.intf.myVpnAddrs[0]
relayFrom = n.intf.myVpnAddrs[0]
relayTo = existing.PeerAddr
case ForwardingType:
relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
}
}
case !ok:
cm.relayUsedLock.RLock()
if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
n.relayUsedLock.RLock()
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
// The relay hasn't been used; don't migrate it.
cm.relayUsedLock.RUnlock()
n.relayUsedLock.RUnlock()
continue
}
cm.relayUsedLock.RUnlock()
n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request.
var err error
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil {
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue
}
switch r.Type {
case TerminalType:
relayFrom = cm.intf.myVpnAddrs[0]
relayFrom = n.intf.myVpnAddrs[0]
relayTo = r.PeerAddr
case ForwardingType:
relayFrom = r.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
}
}
@@ -281,12 +279,12 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
switch newhostinfo.GetCert().Certificate.Version() {
case cert.Version1:
if !relayFrom.Is4() {
cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !relayTo.Is4() {
cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
@@ -298,16 +296,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayToAddr = netAddrToProtoAddr(relayTo)
default:
newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
continue
}
msg, err := req.Marshal()
if err != nil {
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else {
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
cm.l.WithFields(logrus.Fields{
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex,
@@ -318,45 +316,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
}
}
func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
cm.hostMap.RLock()
defer cm.hostMap.RUnlock()
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
n.hostMap.RLock()
defer n.hostMap.RUnlock()
hostinfo := cm.hostMap.Indexes[localIndex]
hostinfo := n.hostMap.Indexes[localIndex]
if hostinfo == nil {
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
delete(n.pendingDeletion, localIndex)
return doNothing, nil, nil
}
if cm.isInvalidCertificate(now, hostinfo) {
if n.isInvalidCertificate(now, hostinfo) {
delete(n.pendingDeletion, hostinfo.localIndexId)
return closeTunnel, hostinfo, nil
}
primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
}
// Check for traffic on this hostinfo
inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
decision := doNothing
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
hostinfo.pendingDeletion.Store(false)
delete(n.pendingDeletion, hostinfo.localIndexId)
if mainHostInfo {
decision = tryRehandshake
} else {
if cm.shouldSwapPrimary(hostinfo) {
if n.shouldSwapPrimary(hostinfo) {
decision = swapPrimary
} else {
// migrate the relays to the primary, if in use.
@@ -364,55 +363,46 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
}
}
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
if !outTraffic {
// Send a punch packet to keep the NAT state alive
cm.sendPunch(hostinfo)
n.sendPunch(hostinfo)
}
return decision, hostinfo, primary
}
if hostinfo.pendingDeletion.Load() {
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
// We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(cm.l).
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")
delete(n.pendingDeletion, hostinfo.localIndexId)
return deleteTunnel, hostinfo, nil
}
decision := doNothing
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic {
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
if isInactive {
// Tunnel is inactive, tear it down
hostinfo.logger(cm.l).
WithField("inactiveDuration", inactiveFor).
WithField("primary", mainHostInfo).
Info("Dropping tunnel due to inactivity")
return closeTunnel, hostinfo, primary
}
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so.
cm.sendPunch(hostinfo)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
n.sendPunch(hostinfo)
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
return doNothing, nil, nil
}
if cm.punchy.GetTargetEverything() {
if n.punchy.GetTargetEverything() {
// This is similar to the old punchy behavior with a slight optimization.
// We aren't receiving traffic but we are sending it, punch on all known
// ips in case we need to re-prime NAT state
cm.sendPunch(hostinfo)
n.sendPunch(hostinfo)
}
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
}
@@ -421,33 +411,17 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
decision = sendTestPacket
} else {
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
}
}
hostinfo.pendingDeletion.Store(true)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
return decision, hostinfo, nil
}
func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
if cm.dropInactive.Load() == false {
// We aren't configured to drop inactive tunnels
return 0, false
}
inactiveDuration := now.Sub(hostinfo.lastUsed)
if inactiveDuration < cm.getInactivityTimeout() {
// It's not considered inactive
return inactiveDuration, false
}
// The tunnel is inactive
return inactiveDuration, true
}
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
func (n *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
@@ -455,90 +429,83 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
// vpn addr is static across all tunnels for this host pair so lets
// use that to determine if we should consider swapping.
if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Their primary vpn addr is less than mine. Do not swap.
return false
}
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
}
func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
cm.hostMap.Lock()
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
cm.hostMap.unlockedMakePrimary(current)
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
n.hostMap.unlockedMakePrimary(current)
}
cm.hostMap.Unlock()
n.hostMap.Unlock()
}
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
// check and return true.
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert()
if remoteCert == nil {
return false
}
caPool := cm.intf.pki.GetCAPool()
caPool := n.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil {
return false
}
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
// Block listed certificates should always be disconnected
return false
}
hostinfo.logger(cm.l).WithError(err).
hostinfo.logger(n.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
}
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if !cm.punchy.GetPunch() {
func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
if !n.punchy.GetPunch() {
// Punching is disabled
return
}
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
// would lose the ability to notify us and punchy.respond would become unreliable.
return
}
if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
if n.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr)
})
} else if hostinfo.remote.IsValid() {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return
}
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
}

View File

@@ -1,6 +1,7 @@
package nebula
import (
"context"
"crypto/ed25519"
"crypto/rand"
"net/netip"
@@ -22,7 +23,7 @@ func newTestLighthouse() *LightHouse {
addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan netip.Addr, 10),
}
lighthouses := []netip.Addr{}
lighthouses := map[netip.Addr]struct{}{}
staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses)
@@ -63,12 +64,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("")
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
out := make([]byte, mtu)
// Add an ip we have established a connection w/ to hostmap
@@ -84,33 +85,32 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp
nc.Out(hostinfo)
nc.In(hostinfo)
assert.False(t, hostinfo.pendingDeletion.Load())
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
assert.Contains(t, nc.out, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
// Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo)
assert.True(t, hostinfo.out.Load())
nc.Out(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.True(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
}
@@ -146,12 +146,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("")
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
out := make([]byte, mtu)
// Add an ip we have established a connection w/ to hostmap
@@ -167,129 +167,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp
nc.Out(hostinfo)
nc.In(hostinfo)
assert.True(t, hostinfo.in.Load())
assert.True(t, hostinfo.out.Load())
assert.False(t, hostinfo.pendingDeletion.Load())
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
// Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo)
nc.Out(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.True(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// We saw traffic, should no longer be pending deletion
nc.In(hostinfo)
nc.In(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnAddrs: vpnAddrs,
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
nc.Out(hostinfo)
nc.In(hostinfo)
assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
now := time.Now()
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
assert.Equal(t, tryRehandshake, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, should still not be pending deletion
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Finally advance beyond the inactivity timeout
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
assert.Equal(t, closeTunnel, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
@@ -337,7 +241,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
require.NoError(t, err)
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
require.NoError(t, err)
cs := &CertState{
privateKey: []byte{},
v1Cert: &dummyCert{},
@@ -360,10 +264,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
ifce.connectionManager = nc
hostinfo := &HostInfo{
@@ -446,10 +350,6 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey
}
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte {
return d.signature
}

View File

@@ -34,7 +34,6 @@ type Control struct {
statsStart func()
dnsStart func()
lighthouseStart func()
connectionManagerStart func(context.Context)
}
type ControlHostInfo struct {
@@ -64,9 +63,6 @@ func (c *Control) Start() {
if c.dnsStart != nil {
go c.dnsStart()
}
if c.connectionManagerStart != nil {
go c.connectionManagerStart(c.ctx)
}
if c.lighthouseStart != nil {
c.lighthouseStart()
}
@@ -135,7 +131,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
if c.f.myVpnAddrsTable.Contains(vpnIp) {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
if found {
// Only returning the default certificate since its impossible
// for any other host but ourselves to have more than 1
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
@@ -218,7 +215,7 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo.ConnectionState,
hostInfo,
[]byte{},
make([]byte, 12, 12),
make([]byte, 12),
make([]byte, mtu),
)
}
@@ -234,7 +231,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
return
}
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12), make([]byte, mtu))
c.f.closeTunnel(h)
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
@@ -285,9 +282,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
CurrentRemote: h.remote,
}
for i, a := range h.vpnAddrs {
chi.VpnAddrs[i] = a
}
copy(chi.VpnAddrs, h.vpnAddrs)
if h.ConnectionState != nil {
chi.MessageCounter = h.ConnectionState.messageCounter.Load()

View File

@@ -27,12 +27,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
ipNet := net.IPNet{
IP: remote1.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
}
ipNet2 := net.IPNet{
IP: remote2.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
}
remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
@@ -53,7 +51,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
localIndexId: 201,
vpnAddrs: []netip.Addr{vpnIp},
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
@@ -72,7 +70,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
localIndexId: 201,
vpnAddrs: []netip.Addr{vpnIp2},
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},

View File

@@ -26,7 +26,7 @@ type dnsRecords struct {
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Lite
myVpnAddrsTable *bart.Table[struct{}]
}
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
@@ -112,8 +112,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
return true
}
//if we found it in this table, it's good
return d.myVpnAddrsTable.Contains(b)
_, found := d.myVpnAddrsTable.Lookup(b)
return found //if we found it in this table, it's good
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {

View File

@@ -506,7 +506,7 @@ func TestReestablishRelays(t *testing.T) {
curIndexes := len(myControl.GetHostmap().Indexes)
for curIndexes >= start {
curIndexes = len(myControl.GetHostmap().Indexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -1052,9 +1052,6 @@ func TestRehandshakingLoser(t *testing.T) {
t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew their certificate and spin until mine sees it")

View File

@@ -700,7 +700,6 @@ func (r *R) FlushAll() {
r.Unlock()
panic("Can't FlushAll for host: " + p.To.String())
}
receiver.InjectUDPPacket(p)
r.Unlock()
}
}

View File

@@ -1,57 +0,0 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
)
func TestDropInactiveTunnels(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Go inactive and wait for the tunnels to get dropped")
waitStart := time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 && theirIndexes == 0 {
break
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*30 {
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
}
time.Sleep(1 * time.Second)
r.FlushAll()
}
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
myControl.Stop()
theirControl.Stop()
}

View File

@@ -275,10 +275,6 @@ tun:
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
# in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false
# Buffer size for reading routes updates. 0 means default system buffer size. (/proc/sys/net/core/rmem_default).
# If using massive routes updates, for example BGP, you may need to increase this value to avoid packet loss.
# SO_RCVBUFFORCE is used to avoid having to raise the system wide max
#use_system_route_table_buffer_size: 0
# Configure logging level
logging:
@@ -338,18 +334,6 @@ logging:
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# Tunnel manager settings
#tunnels:
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
# elapsed.
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
# This setting is reloadable
#drop_inactive: false
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
# inactive and eligible to be dropped.
# This setting is reloadable
#inactivity_timeout: 10m
# Nebula security group configuration
firewall:

View File

@@ -5,12 +5,8 @@ import (
"fmt"
"log"
"net"
"os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/service"
)
@@ -63,16 +59,7 @@ pki:
if err := cfg.LoadString(configStr); err != nil {
return err
}
logger := logrus.New()
logger.Out = os.Stdout
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return err
}
svc, err := service.New(ctrl)
svc, err := service.New(&cfg)
if err != nil {
return err
}

View File

@@ -53,7 +53,7 @@ type Firewall struct {
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Lite
routableNetworks *bart.Table[struct{}]
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
@@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any bool
LocalCIDR *bart.Lite
LocalCIDR *bart.Table[struct{}]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout
}
routableNetworks := new(bart.Lite)
routableNetworks := new(bart.Table[struct{}])
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix)
routableNetworks.Insert(nprefix, struct{}{})
assignedNetworks = append(assignedNetworks, network)
}
hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
routableNetworks.Insert(n)
routableNetworks.Insert(n, struct{}{})
hasUnsafeNetworks = true
}
@@ -431,7 +431,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
// Make sure remote address matches nebula certificate
if h.networks != nil {
if !h.networks.Contains(fp.RemoteAddr) {
_, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
@@ -444,7 +445,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// Make sure we are supposed to be handling this local ip address
if !f.routableNetworks.Contains(fp.LocalAddr) {
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
if !ok {
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
@@ -604,7 +606,7 @@ func (f *Firewall) evict(p firewall.Packet) {
return
}
newT := t.Expires.Sub(time.Now())
newT := time.Until(t.Expires)
// Timeout is in the future, re-add the timer
if newT > 0 {
@@ -750,7 +752,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite),
LocalCIDR: new(bart.Table[struct{}]),
}
}
@@ -830,7 +832,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
}
// Shortcut path for if groups, hosts, or cidr contained an `any`
if fr.Any.match(p, c) {
if fr.Any.match(p) {
return true
}
@@ -847,21 +849,21 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
found = true
}
if found && sg.LocalCIDR.match(p, c) {
if found && sg.LocalCIDR.match(p) {
return true
}
}
if fr.Hosts != nil {
if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
if flc.match(p, c) {
if flc.match(p) {
return true
}
}
}
for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
if v.match(p, c) {
if v.match(p) {
return true
}
}
@@ -877,7 +879,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
}
for _, network := range f.assignedNetworks {
flc.LocalCIDR.Insert(network)
flc.LocalCIDR.Insert(network, struct{}{})
}
return nil
@@ -886,11 +888,11 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil
}
flc.LocalCIDR.Insert(localIp)
flc.LocalCIDR.Insert(localIp, struct{}{})
return nil
}
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool {
func (flc *firewallLocalCIDR) match(p firewall.Packet) bool {
if flc == nil {
return false
}
@@ -899,7 +901,8 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
return true
}
return flc.LocalCIDR.Contains(p.LocalAddr)
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
return ok
}
type rule struct {

View File

@@ -35,22 +35,27 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
}
@@ -68,9 +73,6 @@ func TestFirewall_AddRule(t *testing.T) {
ti, err := netip.ParsePrefix("1.2.3.4/32")
require.NoError(t, err)
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
@@ -95,24 +97,12 @@ func TestFirewall_AddRule(t *testing.T) {
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
@@ -132,13 +122,6 @@ func TestFirewall_AddRule(t *testing.T) {
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -221,82 +204,6 @@ func TestFirewall_Drop(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
c := dummyCert{
name: "host1",
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
}
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
f := &Firewall{}
ft := FirewallTable{
@@ -306,10 +213,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@@ -341,15 +244,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
ip := netip.MustParsePrefix("fd99::99/128")
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -363,18 +257,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -388,18 +270,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -424,17 +294,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
},
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on name", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -593,42 +452,6 @@ func TestFirewall_Drop3(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
network := netip.MustParsePrefix("fd12::34/120")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{network},
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
@@ -692,50 +515,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
},
}
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
},
}
h1 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c1,
},
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
}
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.0.2.1"),
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
ml := func(m map[string]struct{}, a [][]string) {
for n := 0; n < b.N; n++ {
@@ -953,21 +732,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}

50
go.mod
View File

@@ -1,39 +1,41 @@
module github.com/slackhq/nebula
go 1.25
go 1.23.6
toolchain go1.24.1
require (
dario.cat/mergo v1.0.2
dario.cat/mergo v1.0.1
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.25.0
github.com/gaissmai/bart v0.20.1
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
github.com/miekg/dns v1.1.68
github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.65
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.23.2
github.com/prometheus/client_golang v1.21.1
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.43.0
github.com/stretchr/testify v1.10.0
github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.45.0
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
golang.org/x/term v0.36.0
golang.org/x/net v0.38.0
golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0
golang.org/x/term v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.8
google.golang.org/protobuf v1.36.6
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
)
require (
@@ -41,14 +43,14 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/time v0.7.0 // indirect
golang.org/x/tools v0.33.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/mod v0.23.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.30.0 // indirect
)

101
go.sum
View File

@@ -1,6 +1,6 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo=
github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -53,8 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -64,12 +64,12 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -106,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -143,33 +143,29 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -180,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -189,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -201,37 +197,38 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
@@ -242,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -257,5 +254,5 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g=
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU=

View File

@@ -192,7 +192,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
for _, network := range remoteCert.Certificate.Networks() {
vpnAddr := network.Addr()
if f.myVpnAddrsTable.Contains(vpnAddr) {
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
if found {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
@@ -203,7 +204,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}
// vpnAddrs outside our vpn networks are of no use to us, filter them out
if !f.myVpnNetworksTable.Contains(vpnAddr) {
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}
@@ -249,7 +250,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
@@ -342,7 +343,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12), make([]byte, mtu))
}
msg = existing.HandshakePacket[2]
@@ -385,7 +386,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
@@ -457,11 +458,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake message sent")
}
f.connectionManager.AddTrafficWatch(hostinfo)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
return
hostinfo.remotes.ResetBlockedRemotes()
}
func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
@@ -578,7 +577,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
for _, network := range vpnNetworks {
// vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddr := network.Addr()
if !f.myVpnNetworksTable.Contains(vpnAddr) {
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}
@@ -652,14 +651,14 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
}
if len(hh.packetStore) > 0 {
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
out := make([]byte, mtu)
for _, cp := range hh.packetStore {
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
@@ -667,7 +666,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
hostinfo.remotes.ResetBlockedRemotes()
f.metricHandshakes.Update(duration)
return false

View File

@@ -274,7 +274,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
}
// Don't relay through the host I'm trying to connect to
if hm.f.myVpnAddrsTable.Contains(relay) {
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
if found {
continue
}
@@ -450,7 +451,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},

View File

@@ -65,30 +65,16 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.NotContains(t, blah.vpnIps, ip)
}
func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
for _, i := range tw.t.wheel {
n := i.Head
for n != nil {
c++
n = n.Next
}
}
return c
}
type mockEncWriter struct {
}
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
return
}
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
return
}
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
return
}
func (mw *mockEncWriter) Handshake(_ netip.Addr) {}

View File

@@ -23,7 +23,7 @@ type m = map[string]any
const (
Version uint8 = 1
Len = 16
Len int = 16
)
type MessageType uint8

View File

@@ -4,7 +4,6 @@ import (
"errors"
"net"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
@@ -17,10 +16,12 @@ import (
"github.com/slackhq/nebula/header"
)
// const ProbeLen = 100
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
const MaxRemotes = 10
const maxRecvError = 4
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
// 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -67,7 +68,7 @@ type HostMap struct {
type RelayState struct {
sync.RWMutex
relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
// the RelayState Lock held)
@@ -78,12 +79,7 @@ type RelayState struct {
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
for idx, val := range rs.relays {
if val == ip {
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
return
}
}
delete(rs.relays, ip)
}
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
@@ -128,16 +124,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
if !slices.Contains(rs.relays, ip) {
rs.relays = append(rs.relays, ip)
}
rs.relays[ip] = struct{}{}
}
func (rs *RelayState) CopyRelayIps() []netip.Addr {
ret := make([]netip.Addr, len(rs.relays))
rs.RLock()
defer rs.RUnlock()
copy(ret, rs.relays)
ret := make([]netip.Addr, 0, len(rs.relays))
for ip := range rs.relays {
ret = append(ret, ip)
}
return ret
}
@@ -224,9 +220,10 @@ type HostInfo struct {
// The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
recvError atomic.Uint32
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Lite
networks *bart.Table[struct{}]
relayState RelayState
// HandshakePacket records the packets used to create this hostinfo
@@ -253,14 +250,6 @@ type HostInfo struct {
// Used to track other hostinfos for this vpn ip since only 1 can be primary
// Synchronised via hostmap lock and not the hostinfo lock.
next, prev *HostInfo
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
in, out, pendingDeletion atomic.Bool
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
// This value will be behind against actual tunnel utilization in the hot path.
// This should only be used by the ConnectionManagers ticker routine.
lastUsed time.Time
}
type ViaSender struct {
@@ -579,7 +568,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
}
for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
hm.unlockedInnerAddHostInfo(addr, hostinfo)
}
hm.Indexes[hostinfo.localIndexId] = hostinfo
@@ -592,7 +581,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}
}
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo) {
existing := hm.Hosts[vpnAddr]
hm.Hosts[vpnAddr] = hostinfo
@@ -659,7 +648,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
// Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes
ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12), make([]byte, mtu))
})
}
@@ -730,20 +719,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false
}
func (i *HostInfo) RecvErrorExceeded() bool {
if i.recvError.Add(1) >= maxRecvError {
return true
}
return true
}
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed
return
}
i.networks = new(bart.Lite)
i.networks = new(bart.Table[struct{}])
for _, network := range networks {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
i.networks.Insert(nprefix)
i.networks.Insert(network, struct{}{})
}
for _, network := range unsafeNetworks {
i.networks.Insert(network)
i.networks.Insert(network, struct{}{})
}
}
@@ -799,7 +794,7 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
}
addr = addr.Unmap()
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel {
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")

View File

@@ -7,7 +7,6 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostMap_MakePrimary(t *testing.T) {
@@ -216,31 +215,3 @@ func TestHostMap_reload(t *testing.T) {
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
}
func TestHostMap_RelayState(t *testing.T) {
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
a1 := netip.MustParseAddr("::1")
a2 := netip.MustParseAddr("2001::1")
h1.relayState.InsertRelayTo(a1)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
h1.relayState.InsertRelayTo(a2)
assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
// Ensure that the first relay added is the first one returned in the copy
currentRelays := h1.relayState.CopyRelayIps()
require.Len(t, currentRelays, 2)
assert.Equal(t, a1, currentRelays[0])
// Deleting the last one in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
// Deleting an element not in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
// Deleting the only element in the list works ok
h1.relayState.DeleteRelay(a1)
assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
}

View File

@@ -22,18 +22,21 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
return
}
}
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
// Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula addr to the Nebula addr through the Nebula
// TUN device.
if immediatelyForwardToSelf {
if err := f.writeTun(q, packet); err != nil {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
}
}
@@ -90,7 +93,8 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
return
}
if err := f.writeTun(q, out); err != nil {
_, err := f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
}
}
@@ -126,7 +130,8 @@ func (f *Interface) Handshake(vpnAddr netip.Addr) {
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
if f.myVpnNetworksTable.Contains(vpnAddr) {
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
if found {
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
}
@@ -286,7 +291,7 @@ func (f *Interface) SendVia(via *HostInfo,
c := via.ConnectionState.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
f.connectionManager.Out(via)
f.connectionManager.Out(via.localIndexId)
// Authenticate the header and payload, but do not encrypt for this message type.
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
@@ -354,7 +359,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
f.connectionManager.Out(hostinfo.localIndexId)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our addrs and enable a faster roaming.

View File

@@ -28,12 +28,12 @@ type InterfaceConfig struct {
Outside udp.Conn
Inside overlay.Device
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
HandshakeManager *HandshakeManager
lightHouse *LightHouse
connectionManager *connectionManager
checkInterval time.Duration
pendingDeletionInterval time.Duration
DropLocalBroadcast bool
DropMulticast bool
routines int
@@ -47,7 +47,6 @@ type InterfaceConfig struct {
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
batchSize int
l *logrus.Logger
}
@@ -62,11 +61,11 @@ type Interface struct {
serveDns bool
createTime time.Time
lightHouse *LightHouse
myBroadcastAddrsTable *bart.Lite
myBroadcastAddrsTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnAddrsTable *bart.Lite
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
myVpnNetworksTable *bart.Lite
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
dropLocalBroadcast bool
dropMulticast bool
routines int
@@ -85,7 +84,6 @@ type Interface struct {
version string
conntrackCacheTimeout time.Duration
batchSize int
writers []udp.Conn
readers []io.ReadWriteCloser
@@ -112,16 +110,6 @@ type EncWriter interface {
GetCertState() *CertState
}
// BatchReader is an interface for readers that support vectorized packet reading
type BatchReader interface {
BatchRead(buffers [][]byte, sizes []int) (int, error)
}
// BatchWriter is an interface for writers that support vectorized packet writing
type BatchWriter interface {
BatchWrite([][]byte) (int, error)
}
type sendRecvErrorConfig uint8
const (
@@ -169,9 +157,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Firewall == nil {
return nil, errors.New("no firewall rules")
}
if c.connectionManager == nil {
return nil, errors.New("no connection manager")
}
cs := c.pki.getCertState()
ifce := &Interface{
@@ -196,9 +181,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
myVpnAddrsTable: cs.myVpnAddrsTable,
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
relayManager: c.relayManager,
connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
batchSize: c.batchSize,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
@@ -214,7 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
ifce.connectionManager.intf = ifce
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
return ifce, nil
}
@@ -292,20 +276,10 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
// Check if reader supports batch operations
if batchReader, ok := reader.(BatchReader); ok {
err := f.listenInBatch(batchReader, i)
if err != nil {
f.l.WithError(err).Error("Fatal error in batch packet reader, exiting goroutine")
}
return
}
// Fall back to single-packet mode
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
@@ -316,85 +290,15 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
return
}
f.l.WithError(err).Error("Fatal error while reading outbound packet, exiting goroutine")
return
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
}
// listenInBatch handles vectorized packet reading for improved performance
func (f *Interface) listenInBatch(reader BatchReader, i int) error {
// Allocate per-packet state and buffers for batch reading
batchSize := f.batchSize
if batchSize <= 0 {
batchSize = 64 // Fallback to default if not configured
}
fwPackets := make([]*firewall.Packet, batchSize)
outBuffers := make([][]byte, batchSize)
nbBuffers := make([][]byte, batchSize)
packets := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for j := 0; j < batchSize; j++ {
fwPackets[j] = &firewall.Packet{}
outBuffers[j] = make([]byte, mtu)
nbBuffers[j] = make([]byte, 12)
packets[j] = make([]byte, mtu)
}
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.BatchRead(packets, sizes)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return nil
}
return fmt.Errorf("error while batch reading outbound packets: %w", err)
}
// Process each packet in the batch
cache := conntrackCache.Get(f.l)
for idx := 0; idx < n; idx++ {
if sizes[idx] > 0 {
// Use modulo to reuse fw packet state if batch is larger than our pre-allocated state
stateIdx := idx % len(fwPackets)
f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache)
}
}
}
}
// writeTunBatch attempts to write multiple packets to the TUN device using batch operations if supported
func (f *Interface) writeTunBatch(q int, packets [][]byte) error {
if len(packets) == 0 {
return nil
}
// Check if the reader/writer supports batch operations
if batchWriter, ok := f.readers[q].(BatchWriter); ok {
_, err := batchWriter.BatchWrite(packets)
return err
}
// Fall back to writing packets individually
for _, packet := range packets {
if _, err := f.readers[q].Write(packet); err != nil {
return err
}
}
return nil
}
// writeTun writes a single packet to the TUN device
func (f *Interface) writeTun(q int, packet []byte) error {
_, err := f.readers[q].Write(packet)
return err
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError)
@@ -418,7 +322,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
if !c.HasChanged("firewall") {
f.l.Debug("No firewall config change detected")
return
}
@@ -520,7 +424,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certState := f.pki.getCertState()
defaultCrt := certState.GetDefaultCertificate()
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
certExpirationGauge.Update(int64(time.Until(defaultCrt.NotAfter()) / time.Second))
certInitiatingVersion.Update(int64(defaultCrt.Version()))
// Report the max certificate version we are capable of using

View File

@@ -24,7 +24,6 @@ import (
)
var ErrHostNotKnown = errors.New("host not known")
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
@@ -33,7 +32,7 @@ type LightHouse struct {
amLighthouse bool
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Lite
myVpnNetworksTable *bart.Table[struct{}]
punchConn udp.Conn
punchy *Punchy
@@ -57,7 +56,7 @@ type LightHouse struct {
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[[]netip.Addr]
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
interval atomic.Int64
updateCancel context.CancelFunc
@@ -108,7 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
lighthouses := make([]netip.Addr, 0)
lighthouses := make(map[netip.Addr]struct{})
h.lighthouses.Store(&lighthouses)
staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList)
@@ -144,7 +143,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load()
}
func (lh *LightHouse) GetLighthouses() []netip.Addr {
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
return *lh.lighthouses.Load()
}
@@ -202,7 +201,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
addr := addrs[0].Unmap()
if lh.myVpnNetworksTable.Contains(addr) {
_, found := lh.myVpnNetworksTable.Lookup(addr)
if found {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
continue
@@ -307,12 +307,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
if initial || c.HasChanged("lighthouse.hosts") {
lhList, err := lh.parseLighthouses(c)
lhMap := make(map[netip.Addr]struct{})
err := lh.parseLighthouses(c, lhMap)
if err != nil {
return err
}
lh.lighthouses.Store(&lhList)
lh.lighthouses.Store(&lhMap)
if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed")
@@ -346,37 +347,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil
}
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
}
out := make([]netip.Addr, len(lhs))
for i, host := range lhs {
addr, err := netip.ParseAddr(host)
if err != nil {
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
}
if !lh.myVpnNetworksTable.Contains(addr) {
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
_, found := lh.myVpnNetworksTable.Lookup(addr)
if !found {
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
}
out[i] = addr
lhMap[addr] = struct{}{}
}
if !lh.amLighthouse && len(out) == 0 {
if !lh.amLighthouse && len(lhMap) == 0 {
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
}
staticList := lh.GetStaticHostList()
for i := range out {
if _, ok := staticList[out[i]]; !ok {
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
for lhAddr := range lhMap {
if _, ok := staticList[lhAddr]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
}
}
return out, nil
return nil
}
func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -430,7 +431,8 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
}
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
_, found := lh.myVpnNetworksTable.Lookup(vpnAddr)
if !found {
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
}
@@ -487,7 +489,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
lh.Lock()
defer lh.Unlock()
// Add an entry if we don't already have one
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
return lh.unlockedGetRemoteList(vpnAddrs)
}
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -520,15 +522,11 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
}
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
staticList := lh.GetStaticHostList()
for _, addr := range allVpnAddrs {
if _, ok := staticList[addr]; ok {
// First we check the static mapping
// and do nothing if it is there
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
return
}
}
// None of the VpnAddrs were present. Now we can do the deletes.
lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok {
@@ -570,7 +568,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetAddrs() {
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
continue
}
switch {
@@ -632,41 +630,32 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
return len(calculatedV4) > 0 || len(calculatedV6) > 0
}
// unlockedGetRemoteList assumes you have the lh lock
// unlockedGetRemoteList
// assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
for i, addr := range allAddrs {
am, ok := lh.addrMap[addr]
if ok {
if i != 0 {
lh.addrMap[allAddrs[0]] = am
}
return am
}
}
am := NewRemoteList(allAddrs, lh.shouldAdd)
am, ok := lh.addrMap[allAddrs[0]]
if !ok {
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
for _, addr := range allAddrs {
lh.addrMap[addr] = am
}
}
return am
}
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow")
}
if !allow {
return false
}
if lh.myVpnNetworksTable.Contains(to) {
return false
}
_, found := lh.myVpnNetworksTable.Lookup(to)
return true
return !found
}
// unlockedShouldAddV4 checks if to is allowed by our allow list
@@ -682,11 +671,8 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
return false
}
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
return false
}
return true
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
return !found
}
// unlockedShouldAddV6 checks if to is allowed by our allow list
@@ -702,32 +688,27 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
return false
}
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
return false
}
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
return true
return !found
}
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
l := lh.GetLighthouses()
for i := range l {
if l[i] == vpnAddr {
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
return true
}
}
return false
}
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
l := lh.GetLighthouses()
for i := range vpnAddrs {
for j := range l {
if l[j] == vpnAddrs[i] {
for _, a := range vpnAddr {
if _, ok := l[a]; ok {
return true
}
}
}
return false
}
@@ -737,7 +718,7 @@ func (lh *LightHouse) startQueryWorker() {
}
go func() {
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
out := make([]byte, mtu)
for {
@@ -767,7 +748,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried := 0
lighthouses := lh.GetLighthouses()
for _, lhVpnAddr := range lighthouses {
for lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
v = hi.ConnectionState.myCert.Version()
@@ -865,7 +846,8 @@ func (lh *LightHouse) SendUpdate() {
lal := lh.GetLocalAllowList()
for _, e := range localAddrs(lh.l, lal) {
if lh.myVpnNetworksTable.Contains(e) {
_, found := lh.myVpnNetworksTable.Lookup(e)
if found {
continue
}
@@ -877,7 +859,7 @@ func (lh *LightHouse) SendUpdate() {
}
}
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
out := make([]byte, mtu)
var v1Update, v2Update []byte
@@ -885,7 +867,7 @@ func (lh *LightHouse) SendUpdate() {
updated := 0
lighthouses := lh.GetLighthouses()
for _, lhVpnAddr := range lighthouses {
for lhVpnAddr := range lighthouses {
var v cert.Version
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
@@ -943,6 +925,7 @@ func (lh *LightHouse) SendUpdate() {
V4AddrPorts: v4,
V6AddrPorts: v6,
RelayVpnAddrs: relays,
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
},
}
@@ -978,7 +961,7 @@ type LightHouseHandler struct {
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
lhh := &LightHouseHandler{
lh: lh,
nb: make([]byte, 12, 12),
nb: make([]byte, 12),
out: make([]byte, mtu),
l: lh.l,
pb: make([]byte, mtu),
@@ -1062,19 +1045,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
return
}
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
useVersion := cert.Version1
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
useVersion = 1
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = 2
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
}
return
}
@@ -1083,6 +1066,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
if useVersion == cert.Version1 {
if !queryVpnAddr.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := queryVpnAddr.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
} else {
@@ -1127,9 +1113,8 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok {
whereToPunch = newDest
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
}
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
//choosing to do nothing for now, but maybe we return an error?
}
}
@@ -1167,7 +1152,7 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if c.v4.learned != nil {
n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned)
}
if c.v4.reported != nil && len(c.v4.reported) > 0 {
if len(c.v4.reported) > 0 {
n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...)
}
}
@@ -1176,7 +1161,7 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if c.v6.learned != nil {
n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned)
}
if c.v6.reported != nil && len(c.v6.reported) > 0 {
if len(c.v6.reported) > 0 {
n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...)
}
}
@@ -1188,17 +1173,19 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if !r.Is4() {
continue
}
b = r.As4()
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
}
} else if v == cert.Version2 {
for _, r := range c.relay.relay {
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
}
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("version", v).Debug("unsupported protocol version")
}
//TODO: CERT-V2 don't panic
panic("unsupported version")
}
}
}
@@ -1208,16 +1195,18 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
return
}
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
}
return
lhh.lh.Lock()
var certVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
certVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
relays := n.Details.GetRelays()
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
am.Lock()
lhh.lh.Unlock()
@@ -1242,24 +1231,27 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
return
}
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
var detailsVpnAddr netip.Addr
var useVersion cert.Version
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
useVersion := cert.Version1
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
detailsVpnAddr = netip.AddrFrom4(b)
useVersion = cert.Version1
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
} else if n.Details.VpnAddr != nil {
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = cert.Version2
} else {
detailsVpnAddr = netip.Addr{}
useVersion = cert.Version2
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
}
return
}
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
//Simple check that the host sent this not someone else
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
}
@@ -1273,24 +1265,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(fromVpnAddrs[0], relays)
am.Unlock()
n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck
switch useVersion {
case cert.Version1:
if useVersion == cert.Version1 {
if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return
}
vpnAddrB := fromVpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
case cert.Version2:
// do nothing, we want to send a blank message
default:
} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
} else {
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
return
}
@@ -1308,20 +1300,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
//maybe one day we'll have a better idea, if it matters.
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return
}
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
}
return
}
empty := []byte{0}
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
punch := func(vpnPeer netip.AddrPort) {
if !vpnPeer.IsValid() {
return
}
@@ -1333,31 +1318,48 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
}()
if lhh.l.Level >= logrus.DebugLevel {
var logVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
logVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
}
}
for _, a := range n.Details.V4AddrPorts {
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
punch(protoV4AddrPortToNetAddrPort(a))
}
for _, a := range n.Details.V6AddrPorts {
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
punch(protoV6AddrPortToNetAddrPort(a))
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchy.GetRespond() {
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12), make([]byte, mtu))
}()
}
}
@@ -1436,17 +1438,3 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
}
return netip.Addr{}, false
}
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
if d.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
detailsVpnAddr := netip.AddrFrom4(b)
return detailsVpnAddr, cert.Version1, nil
} else if d.VpnAddr != nil {
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
return detailsVpnAddr, cert.Version2, nil
} else {
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
}
}

View File

@@ -31,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) {
func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -56,8 +56,8 @@ func Test_lhStaticMapping(t *testing.T) {
func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -91,8 +91,8 @@ func TestReloadLighthouseInterval(t *testing.T) {
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -196,8 +196,8 @@ func TestLighthouse_Memory(t *testing.T) {
c.Settings["listen"] = map[string]any{"port": 4242}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -281,8 +281,8 @@ func TestLighthouse_reload(t *testing.T) {
c.Settings["listen"] = map[string]any{"port": 4242}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -484,132 +484,12 @@ func Test_findNetworkUnion(t *testing.T) {
assert.Equal(t, out, afe81)
//falsey cases
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
_, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
_, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
_, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
_, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
testStaticHost := netip.MustParseAddr("10.128.0.42")
//myVpnIp := netip.MustParseAddr("10.128.0.2")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
"10.128.0.42": []any{"1.2.3.4:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//test that we actually have the static entry:
out := lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//bolt on a lower numbered primary IP
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
lh.addrMap[testSameHostNotStatic] = am
out.Rebuild([]netip.Prefix{}) //???
//test that we actually have the static entry:
out = lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
//test that we actually have the static entry for BOTH:
out2 := lh.Query(testSameHostNotStatic)
assert.Same(t, out2, out)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
//verify
out = lh.Query(testSameHostNotStatic)
assert.NotNil(t, out)
if out == nil {
t.Fatal("expected non-nil query for the static host")
}
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
}
func TestLighthouse_DeletesWork(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testHost := netip.MustParseAddr("10.128.0.42")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//insert the host
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
am.vpnAddrs = []netip.Addr{testHost}
am.addrs = []netip.AddrPort{myUdpAddr2}
lh.addrMap[testHost] = am
am.Rebuild([]netip.Prefix{}) //???
//test that we actually have the entry:
out := lh.Query(testHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testHost})
//verify
out = lh.Query(testHost)
assert.Nil(t, out)
}

10
main.go
View File

@@ -185,7 +185,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c)
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
@@ -221,6 +220,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
@@ -229,8 +231,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Firewall: fw,
ServeDns: serveDns,
HandshakeManager: handshakeManager,
connectionManager: connManager,
lightHouse: lightHouse,
checkInterval: time.Second * time.Duration(checkInterval),
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
@@ -241,8 +244,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
version: buildVersion,
relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout,
batchSize: c.GetInt("tun.batch_size", 64),
l: l,
}
@@ -293,6 +296,5 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
connManager.Start,
}, nil
}

View File

@@ -17,7 +17,7 @@ type MessageMetrics struct {
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
if int(t) < len(m.rx) && int(s) < len(m.rx[t]) {
m.rx[t][s].Inc(i)
} else if m.rxUnknown != nil {
m.rxUnknown.Inc(i)
@@ -26,7 +26,7 @@ func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int
}
func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
if int(t) < len(m.tx) && int(s) < len(m.tx[t]) {
m.tx[t][s].Inc(i)
} else if m.txUnknown != nil {
m.txUnknown.Inc(i)

View File

@@ -31,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
if found {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
@@ -81,7 +82,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
@@ -213,7 +214,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -227,7 +228,7 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
@@ -254,18 +255,16 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex)
}
return false
} else {
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
}
return true
@@ -314,11 +313,12 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
next := 0
for {
if protoAt >= dataLen {
if dataLen < offset {
break
}
proto := layers.IPProtocol(data[protoAt])
proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
@@ -333,13 +333,12 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
}
fp.Protocol = uint8(proto)
ports := data[offset : offset+4]
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(ports[0:2])
fp.LocalPort = binary.BigEndian.Uint16(ports[2:4])
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(ports[0:2])
fp.RemotePort = binary.BigEndian.Uint16(ports[2:4])
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
@@ -367,7 +366,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
case layers.IPProtocolAH:
// Auth headers, used by IPSec, have a different meaning for header length
if dataLen <= offset+1 {
if dataLen < offset+1 {
break
}
@@ -375,7 +374,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
default:
// Normal ipv6 header length processing
if dataLen <= offset+1 {
if dataLen < offset+1 {
break
}
@@ -501,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
@@ -540,6 +539,10 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
return
}
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return

View File

@@ -117,45 +117,6 @@ func Test_newPacket_v6(t *testing.T) {
err = newPacket(buffer.Bytes(), true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A v6 packet with a hop-by-hop extension
// ICMPv6 Payload (Echo Request)
icmpLayer := layers.ICMPv6{
TypeCode: layers.ICMPv6TypeEchoRequest,
}
// Hop-by-Hop Extension Header
hopOption := layers.IPv6HopByHopOption{}
hopOption.OptionData = []byte{0, 0, 0, 0}
hopByHop := layers.IPv6HopByHop{}
hopByHop.Options = append(hopByHop.Options, &hopOption)
ip = layers.IPv6{
Version: 6,
HopLimit: 128,
NextHeader: layers.IPProtocolIPv6Destination,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
buffer.Clear()
err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{
ComputeChecksums: false,
FixLengths: true,
}, &ip, &hopByHop, &icmpLayer)
if err != nil {
panic(err)
}
// Ensure buffer length checks during parsing with the next 2 tests.
// A full IPv6 header and 1 byte in the first extension, but missing
// the length byte.
err = newPacket(buffer.Bytes()[:41], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A full IPv6 header plus 1 full extension, but only 1 byte of the
// next layer, missing length byte
err = newPacket(buffer.Bytes()[:49], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good ICMP packet
ip = layers.IPv6{
Version: 6,
@@ -327,10 +288,6 @@ func Test_newPacket_v6(t *testing.T) {
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// Ensure buffer bounds checking during processing
err = newPacket(b[:41], true, p)
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// Invalid AH header
b = buffer.Bytes()
err = newPacket(b, true, p)

View File

@@ -1,5 +1,5 @@
//go:build darwin && !ios && !e2e_testing
// +build darwin,!ios,!e2e_testing
//go:build !ios && !e2e_testing
// +build !ios,!e2e_testing
package overlay
@@ -7,28 +7,50 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
"sync/atomic"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
)
type tun struct {
io.ReadWriteCloser
Device string
vpnNetworks []netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
// ioctl structures for Darwin network configuration
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
pad [8]byte
}
const (
_SIOCAIFADDR_IN6 = 2155899162
_UTUN_OPT_IFNAME = 2
_IN6_IFF_NODAD = 0x0020
_IN6_IFF_SECURED = 0x0400
utunControlName = "com.apple.net.utun_control"
)
type ifreqMTU struct {
Name [16]byte
MTU int32
@@ -58,61 +80,60 @@ type ifreqAlias6 struct {
Lifetime addrLifetime
}
const (
_SIOCAIFADDR_IN6 = 2155899162
_IN6_IFF_NODAD = 0x0020
)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported on Darwin")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
deviceName := "utun"
// Parse device name to handle utun[0-9]+ format
if name != "" && name != "utun" {
ifIndex := -1
if name != "" && name != "utun" {
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
if err != nil || ifIndex < 0 {
// NOTE: we don't make this error so we don't break existing
// configs that set a name before it was used.
l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring")
} else {
deviceName = name
ifIndex = -1
}
}
mtu := c.GetInt("tun.mtu", DefaultMTU)
// Create WireGuard TUN device
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device: %w", err)
return nil, fmt.Errorf("system socket: %v", err)
}
// Get the actual device name
actualName, err := tunDevice.Name()
var ctlInfo = &unix.CtlInfo{}
copy(ctlInfo.Name[:], utunControlName)
err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil {
tunDevice.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
}
t := &wgTun{
tunDevice: tunDevice,
err = unix.Connect(fd, &unix.SockaddrCtl{
ID: ctlInfo.Id,
Unit: uint32(ifIndex) + 1,
})
if err != nil {
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
}
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
if err != nil {
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
}
err = unix.SetNonblock(fd, true)
if err != nil {
return nil, fmt.Errorf("SetNonblock: %v", err)
}
t := &tun{
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
Device: name,
vpnNetworks: vpnNetworks,
MaxMTU: mtu,
DefaultMTU: mtu,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
// Create Darwin-specific route manager
t.routeManager = &tun{}
err = t.reload(c, true)
if err != nil {
tunDevice.Close()
return nil, err
}
@@ -123,251 +144,216 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
})
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil
}
func (rm *tun) Activate(t *wgTun) error {
name, err := t.tunDevice.Name()
if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}
// Set the MTU
rm.SetMTU(t, t.MaxMTU)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
// Add IP addresses
for _, network := range t.vpnNetworks {
if err := rm.addIP(t, name, network); err != nil {
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func (t *tun) Activate() error {
devName := t.deviceBytes()
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
fd := uintptr(s)
// Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
return fmt.Errorf("failed to set tun mtu: %v", err)
}
// Bring up the interface using ioctl
if err := rm.bringUpInterface(name); err != nil {
return fmt.Errorf("failed to bring up interface: %w", err)
// Get the device flags
ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to get tun flags: %s", err)
}
// Get the link address for routing
linkAddr, err := getLinkAddr(name)
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return fmt.Errorf("failed to get link address: %w", err)
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
rm.linkAddr = linkAddr
t.linkAddr = linkAddr
// Set the routes
if err := rm.AddRoutes(t, false); err != nil {
for _, network := range t.vpnNetworks {
if network.Addr().Is4() {
err = t.activate4(network)
if err != nil {
return err
}
} else {
err = t.activate6(network)
if err != nil {
return err
}
}
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) activate4(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(network).As4(),
},
}
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun v4 address: %s", err)
}
err = addRoute(network, t.linkAddr)
if err != nil {
return err
}
return nil
}
func (rm *tun) bringUpInterface(name string) error {
// Open a socket for ioctl
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
func (t *tun) activate6(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return fmt.Errorf("failed to create socket: %w", err)
}
defer unix.Close(fd)
// Get current flags
var ifrf ifReq
copy(ifrf.Name[:], name)
if err := ioctl(uintptr(fd), unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to get interface flags: %w", err)
}
// Set IFF_UP and IFF_RUNNING flags
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err := ioctl(uintptr(fd), unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set interface flags: %w", err)
}
return nil
}
func (rm *tun) SetMTU(t *wgTun, mtu int) {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get device name for MTU set")
return
}
// Open a socket for ioctl
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
if err != nil {
t.l.WithError(err).Error("Failed to create socket for MTU set")
return
}
defer unix.Close(fd)
// Prepare the ioctl request
var ifr ifreqMTU
copy(ifr.Name[:], name)
ifr.MTU = int32(mtu)
// Set the MTU using ioctl
if err := ioctl(uintptr(fd), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifr))); err != nil {
t.l.WithError(err).Error("Failed to set tun mtu via ioctl")
}
}
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
// On Darwin, routes are set via ifconfig and route commands
return nil
}
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Install {
continue
}
err := rm.addRoute(r.Cidr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
err := rm.delRoute(r.Cidr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
// Darwin doesn't support multi-queue TUN devices in the same way as Linux
// Return a reader that wraps the same device
return &wgTunReader{
parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
}
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
addr := network.Addr()
if addr.Is4() {
return rm.addIPv4(name, network)
} else {
return rm.addIPv6(name, network)
}
}
func (rm *tun) addIPv4(name string, network netip.Prefix) error {
// Open an IPv4 socket for ioctl
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return fmt.Errorf("failed to create IPv4 socket: %w", err)
return err
}
defer unix.Close(s)
var ifr ifreqAlias4
copy(ifr.Name[:], name)
// Set the address
ifr.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
}
// Set the destination address (same as address for point-to-point)
ifr.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
}
// Set the netmask
ifr.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(network).As4(),
}
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set IPv4 address via ioctl: %w", err)
}
return nil
}
func (rm *tun) addIPv6(name string, network netip.Prefix) error {
// Open an IPv6 socket for ioctl
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return fmt.Errorf("failed to create IPv6 socket: %w", err)
}
defer unix.Close(s)
var ifr ifreqAlias6
copy(ifr.Name[:], name)
// Set the address
ifr.Addr = unix.RawSockaddrInet6{
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: network.Addr().As16(),
}
// Set the prefix mask
ifr.PrefixMask = unix.RawSockaddrInet6{
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(network).As16(),
}
// Set lifetime (never expires)
ifr.Lifetime = addrLifetime{
},
Lifetime: addrLifetime{
// never expires
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD,
}
// Set flags (no DAD - Duplicate Address Detection)
ifr.Flags = _IN6_IFF_NODAD
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set IPv6 address via ioctl: %w", err)
return fmt.Errorf("failed to set tun address: %s", err)
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, ok := t.routeTree.Load().Lookup(ip)
if ok {
return r
}
return routing.Gateways{}
}
// Get the LinkAddr for the interface of the given name
// Is there an easier way to fetch this when we create the interface?
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
@@ -393,7 +379,53 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
return nil, nil
}
func (rm *tun) addRoute(prefix netip.Prefix) error {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
@@ -411,13 +443,13 @@ func (rm *tun) addRoute(prefix netip.Prefix) error {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: rm.linkAddr,
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: rm.linkAddr,
unix.RTAX_GATEWAY: gateway,
}
}
@@ -434,7 +466,7 @@ func (rm *tun) addRoute(prefix netip.Prefix) error {
return nil
}
func (rm *tun) delRoute(prefix netip.Prefix) error {
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
@@ -451,13 +483,13 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: rm.linkAddr,
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: rm.linkAddr,
unix.RTAX_GATEWAY: gateway,
}
}
@@ -465,7 +497,6 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@@ -474,34 +505,62 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
return nil
}
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
return nil
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
bits := prefix.Bits()
pLen := 128
if prefix.Addr().Is4() {
// Create IPv4 netmask from prefix length
mask := ^uint32(0) << (32 - bits)
return netip.AddrFrom4([4]byte{
byte(mask >> 24),
byte(mask >> 16),
byte(mask >> 8),
byte(mask),
})
} else {
// Create IPv6 netmask from prefix length
var mask [16]byte
for i := 0; i < bits/8; i++ {
mask[i] = 0xff
}
if bits%8 != 0 {
mask[bits/8] = ^byte(0) << (8 - bits%8)
}
return netip.AddrFrom16(mask)
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -1,77 +1,163 @@
//go:build freebsd && !e2e_testing
// +build freebsd,!e2e_testing
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"net/netip"
"os"
"os/exec"
"strconv"
"strings"
"sync/atomic"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
)
type tun struct{}
const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678
)
type fiodgnameArg struct {
length int32
pad [4]byte
buf unsafe.Pointer
}
// ifreqRename is used for renaming network interfaces on FreeBSD
type ifreqRename struct {
Name [unix.IFNAMSIZ]byte
Name [16]byte
Data uintptr
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported on FreeBSD")
type ifreqDestroy struct {
Name [16]byte
pad [16]byte
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
deviceName := c.GetString("tun.dev", "tun")
mtu := c.GetInt("tun.mtu", DefaultMTU)
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
// Create WireGuard TUN device
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device: %w", err)
io.ReadWriteCloser
}
// Get the actual device name
actualName, err := tunDevice.Name()
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
tunDevice.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
return err
}
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
}
if err != nil {
return nil, err
}
rawConn, err := file.SyscallConn()
if err != nil {
return nil, fmt.Errorf("SyscallConn: %v", err)
}
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil {
return nil, err
}
ifName := string(bytes.TrimRight(name[:], "\x00"))
if deviceName == "" {
deviceName = ifName
}
// If the name doesn't match the desired interface name, rename it now
if actualName != deviceName && deviceName != "" && deviceName != "tun" {
if err := renameInterface(actualName, deviceName); err != nil {
tunDevice.Close()
return nil, fmt.Errorf("failed to rename interface from %s to %s: %w", actualName, deviceName, err)
if ifName != deviceName {
s, err := syscall.Socket(
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
if err != nil {
return nil, err
}
actualName = deviceName
defer syscall.Close(s)
fd := uintptr(s)
var fromName [16]byte
var toName [16]byte
copy(fromName[:], ifName)
copy(toName[:], deviceName)
ifrr := ifreqRename{
Name: fromName,
Data: uintptr(unsafe.Pointer(&toName)),
}
t := &wgTun{
tunDevice: tunDevice,
// Set the device name
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
}
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MaxMTU: mtu,
DefaultMTU: mtu,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
// Create FreeBSD-specific route manager
t.routeManager = &tun{}
err = t.reload(c, true)
if err != nil {
tunDevice.Close()
return nil, err
}
@@ -82,194 +168,141 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
})
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil
}
func (rm *tun) Activate(t *wgTun) error {
name, err := t.tunDevice.Name()
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
}
// Set the MTU
rm.SetMTU(t, t.MaxMTU)
// Add IP addresses
for _, network := range t.vpnNetworks {
if err := rm.addIP(t, name, network); err != nil {
return err
}
}
// Bring up the interface
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
return fmt.Errorf("failed to bring up interface: %w", err)
return nil
}
// Set the routes
if err := rm.AddRoutes(t, false); err != nil {
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (rm *tun) SetMTU(t *wgTun, mtu int) {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get device name for MTU set")
return
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
t.l.WithError(err).Error("Failed to set tun mtu")
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
// On FreeBSD, routes are set via ifconfig and route commands
return nil
func (t *tun) Name() string {
return t.Device
}
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
name, err := t.tunDevice.Name()
if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Install {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add route using route command
args := []string{"add"}
if r.Cidr.Addr().Is6() {
args = append(args, "-inet6")
} else {
args = append(args, "-inet")
}
args = append(args, r.Cidr.String(), "-interface", name)
if r.Metric > 0 {
// FreeBSD doesn't support route metrics directly like Linux
t.l.WithField("route", r).Warn("Route metrics are not fully supported on FreeBSD")
}
err := runCommandBSD("route", args...)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get device name for route removal")
return
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
args := []string{"delete"}
if r.Cidr.Addr().Is6() {
args = append(args, "-inet6")
} else {
args = append(args, "-inet")
}
args = append(args, r.Cidr.String(), "-interface", name)
err := runCommandBSD("route", args...)
if err != nil {
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
// FreeBSD doesn't support multi-queue TUN devices in the same way as Linux
// Return a reader that wraps the same device
return &wgTunReader{
parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
}
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
addr := network.Addr()
if addr.Is4() {
// For IPv4: ifconfig tun0 10.0.0.1/24
if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
return fmt.Errorf("failed to add IPv4 address: %w", err)
}
} else {
// For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
return fmt.Errorf("failed to add IPv6 address: %w", err)
}
}
return nil
}
func runCommandBSD(name string, args ...string) error {
cmd := exec.Command(name, args...)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return nil
}
func renameInterface(fromName, toName string) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return fmt.Errorf("failed to create socket: %w", err)
}
defer syscall.Close(s)
fd := uintptr(s)
var fromNameBytes [unix.IFNAMSIZ]byte
var toNameBytes [unix.IFNAMSIZ]byte
copy(fromNameBytes[:], fromName)
copy(toNameBytes[:], toName)
ifrr := ifreqRename{
Name: fromNameBytes,
Data: uintptr(unsafe.Pointer(&toNameBytes)),
}
// Set the device name using SIOCSIFNAME ioctl
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
if errno != 0 {
return fmt.Errorf("SIOCSIFNAME ioctl failed: %w", errno)
}
return nil
return
}

View File

@@ -1,5 +1,5 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
//go:build !android && !e2e_testing
// +build !android,!e2e_testing
package overlay
@@ -9,105 +9,131 @@ import (
"net"
"net/netip"
"os"
"strings"
"sync/atomic"
"time"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
)
type tun struct {
io.ReadWriteCloser
fd int
Device string
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int
ioctlFd uintptr
txQueueLen int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeChan chan struct{}
useSystemRoutes bool
useSystemRoutesBufferSize int
l *logrus.Logger
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*wgTun, error) {
deviceName := c.GetString("tun.dev", "")
mtu := c.GetInt("tun.mtu", DefaultMTU)
// Create WireGuard TUN device
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device: %w", err)
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
// Get the actual device name
actualName, err := tunDevice.Name()
if err != nil {
tunDevice.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
type ifReq struct {
Name [16]byte
Flags uint16
pad [8]byte
}
t := &wgTun{
tunDevice: tunDevice,
vpnNetworks: vpnNetworks,
MaxMTU: mtu,
DefaultMTU: mtu,
l: l,
type ifreqMTU struct {
Name [16]byte
MTU int32
pad [8]byte
}
// Create Linux-specific route manager
routeManager := &tun{
txQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
}
t.routeManager = routeManager
err = t.reload(c, true)
if err != nil {
tunDevice.Close()
return nil, err
type ifreqQLEN struct {
Name [16]byte
Value int32
pad [8]byte
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*wgTun, error) {
// Create TUN device from file descriptor
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
mtu := c.GetInt("tun.mtu", DefaultMTU)
tunDevice, err := wgtun.CreateTUNFromFile(file, mtu)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device from fd: %w", err)
return nil, err
}
t := &wgTun{
tunDevice: tunDevice,
t.Device = "tun0"
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
if os.IsNotExist(err) {
err = os.MkdirAll("/dev/net", 0755)
if err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
if err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
}
} else {
return nil, err
}
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
t.Device = name
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks,
MaxMTU: mtu,
DefaultMTU: mtu,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l,
}
// Create Linux-specific route manager
routeManager := &tun{
txQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
}
t.routeManager = routeManager
err = t.reload(c, true)
err := t.reload(c, true)
if err != nil {
tunDevice.Close()
return nil, err
}
@@ -121,105 +147,269 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func (rm *tun) Activate(t *wgTun) error {
name, err := t.tunDevice.Name()
func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
return err
}
if t.routeManager.useSystemRoutes {
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, true)
if err != nil {
return err
}
oldDefaultMTU := t.DefaultMTU
oldMaxMTU := t.MaxMTU
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
newMaxMTU := newDefaultMTU
for i, r := range routes {
if r.MTU == 0 {
routes[i].MTU = newDefaultMTU
}
if r.MTU > t.MaxMTU {
newMaxMTU = r.MTU
}
}
t.MaxMTU = newMaxMTU
t.DefaultMTU = newDefaultMTU
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
if oldMaxMTU != newMaxMTU {
t.setMTU()
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
}
if oldDefaultMTU != newDefaultMTU {
for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
}
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// This should never be called since addRoutes should log its own errors in a reload condition
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
}
}
return nil
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return file, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *tun) Write(b []byte) (int, error) {
var nn int
maximum := len(b)
for {
n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 {
nn += n
}
if nn == len(b) {
return nn, err
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, io.ErrUnexpectedEOF
}
}
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
func (t *tun) addIPs(link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
newAddrs[i] = &netlink.Addr{
IPNet: &net.IPNet{
IP: t.vpnNetworks[i].Addr().AsSlice(),
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
},
Label: t.vpnNetworks[i].Addr().Zone(),
}
}
//add all new addresses
for i := range newAddrs {
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
}
//iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}
for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}
return nil
}
func (t *tun) Activate() error {
devName := t.deviceBytes()
if t.useSystemRoutes {
t.watchRoutes()
}
// Get the netlink device
link, err := netlink.LinkByName(name)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
rm.deviceIndex = link.Attrs().Index
// Open socket for ioctl operations
s, err := unix.Socket(
unix.AF_INET,
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
rm.ioctlFd = uintptr(s)
t.ioctlFd = uintptr(s)
rm.SetMTU(t, t.MaxMTU)
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
// Setup our default MTU
t.setMTU()
// Set the transmit queue length
devName := deviceBytes(name)
ifrq := ifreqQLEN{Name: devName, Value: int32(rm.txQueueLen)}
if err = ioctl(t.routeManager.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
// Disable IPv6 link-local address generation
const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation")
}
// Add IP addresses
if err = t.routeManager.addIPs(t, link); err != nil {
if err = t.addIPs(link); err != nil {
return err
}
// Bring up the interface
if err = netlink.LinkSetUp(link); err != nil {
ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
}
// Set route MTU
//set route MTU
for i := range t.vpnNetworks {
if err = t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i]); err != nil {
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
return fmt.Errorf("failed to set default route MTU: %w", err)
}
}
// Set the routes
if err = t.routeManager.AddRoutes(t, false); err != nil {
if err = t.addRoutes(false); err != nil {
return err
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
return nil
}
func (rm *tun) SetMTU(t *wgTun, mtu int) {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get device name for MTU set")
return
}
link, err := netlink.LinkByName(name)
if err != nil {
t.l.WithError(err).Error("Failed to get link for MTU set")
return
}
if err := netlink.LinkSetMTU(link, mtu); err != nil {
func (t *tun) setMTU() {
// Set the MTU on the device
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
t.l.WithError(err).Error("Failed to set tun mtu")
}
}
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
dr := &net.IPNet{
IP: cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
}
nr := netlink.Route{
LinkIndex: t.routeManager.deviceIndex,
LinkIndex: t.deviceIndex,
Dst: dr,
MTU: t.DefaultMTU,
AdvMSS: advMSS(Route{}, t.DefaultMTU, t.MaxMTU),
AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
Src: net.IP(cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL,
@@ -229,7 +419,7 @@ func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
err := netlink.RouteReplace(&nr)
if err != nil {
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
// Retry twice more
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
for i := 0; i < 2; i++ {
time.Sleep(100 * time.Millisecond)
err = netlink.RouteReplace(&nr)
@@ -247,7 +437,8 @@ func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
return nil
}
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
func (t *tun) addRoutes(logErrors bool) error {
// Path routes
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Install {
@@ -260,10 +451,10 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
}
nr := netlink.Route{
LinkIndex: t.routeManager.deviceIndex,
LinkIndex: t.deviceIndex,
Dst: dr,
MTU: r.MTU,
AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
}
@@ -287,7 +478,7 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
return nil
}
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
func (t *tun) removeRoutes(routes []Route) {
for _, r := range routes {
if !r.Install {
continue
@@ -299,10 +490,10 @@ func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
}
nr := netlink.Route{
LinkIndex: t.routeManager.deviceIndex,
LinkIndex: t.deviceIndex,
Dst: dr,
MTU: r.MTU,
AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
}
@@ -319,109 +510,28 @@ func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
}
}
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
// For Linux with WireGuard TUN, we can reuse the same device
// The vectorized I/O will handle batching
return &wgTunReader{
parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
func (t *tun) Name() string {
return t.Device
}
func deviceBytes(name string) [16]byte {
var o [16]byte
for i, c := range name {
if i >= 16 {
break
}
o[i] = byte(c)
}
return o
}
func advMSS(r Route, defaultMTU, maxMTU int) int {
func (t *tun) advMSS(r Route) int {
mtu := r.MTU
if r.MTU == 0 {
mtu = defaultMTU
mtu = t.DefaultMTU
}
// We only need to set advmss if the route MTU does not match the device MTU
if mtu != maxMTU {
if mtu != t.MaxMTU {
return mtu - 40
}
return 0
}
type ifreqQLEN struct {
Name [16]byte
Value int32
pad [8]byte
}
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}
func (rm *tun) addIPs(t *wgTun, link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
newAddrs[i] = &netlink.Addr{
IPNet: &net.IPNet{
IP: t.vpnNetworks[i].Addr().AsSlice(),
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
},
Label: t.vpnNetworks[i].Addr().Zone(),
}
}
// Add all new addresses
for i := range newAddrs {
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
}
// Iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}
for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}
return nil
}
// watchRoutes monitors system route changes
func (t *wgTun) watchRoutes() {
func (t *tun) watchRoutes() {
rch := make(chan netlink.RouteUpdate)
doneChan := make(chan struct{})
netlinkOptions := netlink.RouteSubscribeOptions{
ReceiveBufferSize: t.routeManager.useSystemRoutesBufferSize,
ReceiveBufferForceSize: t.routeManager.useSystemRoutesBufferSize != 0,
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
}
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
return
}
@@ -431,29 +541,86 @@ func (t *wgTun) watchRoutes() {
go func() {
for {
select {
case r, ok := <-rch:
if ok {
case r := <-rch:
t.updateRoutes(r)
} else {
return
}
case <-doneChan:
// netlink.RouteSubscriber will close the rch for us
return
}
}
}()
}
func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route, t.routeManager.deviceIndex)
if len(gateways) == 0 {
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
return
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
withinNetworks := false
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
withinNetworks = true
break
}
}
if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
return withinNetworks
}
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
}
}
for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
}
}
}
routing.CalculateBucketsForGateways(gateways)
return gateways
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
return
}
@@ -471,6 +638,7 @@ func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
newTree.Insert(dst, gateways)
} else {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
newTree.Delete(dst)
@@ -478,71 +646,18 @@ func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
t.routeTree.Store(newTree)
}
func (t *wgTun) getGatewaysFromRoute(r *netlink.Route, deviceIndex int) routing.Gateways {
var gateways routing.Gateways
name, err := t.tunDevice.Name()
if err != nil {
t.l.Error("Ignoring route update: failed to get device name")
return gateways
func (t *tun) Close() error {
if t.routeChan != nil {
close(t.routeChan)
}
link, err := netlink.LinkByName(name)
if err != nil {
t.l.WithField("DeviceName", name).Error("Ignoring route update: failed to get link by name")
return gateways
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
}
if t.ioctlFd > 0 {
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
}
for _, p := range r.MultiPath {
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
}
}
}
routing.CalculateBucketsForGateways(gateways)
return gateways
}
func (t *wgTun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
return true
}
}
return false
}
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}

View File

@@ -7,26 +7,25 @@ import "testing"
var runAdvMSSTests = []struct {
name string
defaultMTU int
maxMTU int
tun *tun
r Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", 1440, 1440, Route{}, 0},
{"default-min", 1440, 1440, Route{MTU: 1440}, 0},
{"default-low", 1440, 1440, Route{MTU: 1200}, 1160},
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", 1440, 8941, Route{}, 1400},
{"route-min", 1440, 8941, Route{MTU: 1440}, 1400},
{"route-high", 1440, 8941, Route{MTU: 8941}, 0},
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := advMSS(tt.r, tt.defaultMTU, tt.maxMTU)
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}

View File

@@ -4,12 +4,13 @@
package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
@@ -19,42 +20,11 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
SIOCAIFADDR_IN6 = 0x8080696b
TUNSIFHEAD = 0x80047442
TUNSIFMODE = 0x80047458
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
type ifreqDestroy struct {
Name [16]byte
pad [16]byte
}
type tun struct {
@@ -64,18 +34,40 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
f *os.File
fd int
io.ReadWriteCloser
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
@@ -85,19 +77,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{
f: os.NewFile(uintptr(fd), ""),
fd: fd,
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -119,225 +105,40 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil
}
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifr := ifreq{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
return err
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
rc, err := t.f.SyscallConn()
if err != nil {
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
}
var errno syscall.Errno
var n uintptr
err = rc.Read(func(fd uintptr) bool {
// first 4 bytes is protocol family, in network byte order
head := [4]byte{}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
if errno.Temporary() {
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
return false
}
return true
})
if err != nil {
if err == syscall.EBADF || err.Error() == "use of closed file" {
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
return 0, os.ErrClosed
}
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
}
if errno != 0 {
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, nil
} else if bytesRead < 4 {
return 0, nil
} else {
return bytesRead - 4, nil
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head[3] = syscall.AF_INET
} else if ipVer == 6 {
head[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var errno syscall.Errno
var n uintptr
err = rc.Write(func(fd uintptr) bool {
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
// According to NetBSD documentation for TUN, writes will only return errors in which
// this packet will never be delivered so just go on living life.
return true
})
if err != nil {
return 0, err
}
if errno != 0 {
return 0, errno
}
return int(n) - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
return nil
}
if cidr.Addr().Is6() {
var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime = addrLifetime{
Vltime: 0xffffffff,
Pltime: 0xffffffff,
}
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
mode := int32(unix.IFF_BROADCAST)
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
if err != nil {
return fmt.Errorf("failed to set tun device mode: %w", err)
}
v := 1
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
if err != nil {
return fmt.Errorf("failed to set tun device head: %w", err)
}
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks {
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -396,23 +197,21 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := addRoute(r.Cidr, t.vpnNetworks)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -425,8 +224,10 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -441,147 +242,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
bits := prefix.Bits()
if prefix.Addr().Is4() {
mask := ^uint32(0) << (32 - bits)
return netip.AddrFrom4([4]byte{
byte(mask >> 24),
byte(mask >> 16),
byte(mask >> 8),
byte(mask),
})
}
var mask [16]byte
for i := 0; i < bits/8; i++ {
mask[i] = 0xff
}
if bits%8 != 0 {
mask[bits/8] = ^byte(0) << (8 - bits%8)
}
return netip.AddrFrom16(mask)
}
func selectGateway(prefix netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
for _, gw := range gateways {
if prefix.Addr().Is4() == gw.Addr().Is4() {
return gw, nil
}
}
return netip.Prefix{}, fmt.Errorf("no suitable gateway found for prefix %v", prefix)
}

14
overlay/tun_notwin.go Normal file
View File

@@ -0,0 +1,14 @@
//go:build !windows
// +build !windows
package overlay
import "syscall"
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}

View File

@@ -1,5 +1,5 @@
//go:build openbsd && !e2e_testing
// +build openbsd,!e2e_testing
//go:build !e2e_testing
// +build !e2e_testing
package overlay
@@ -7,53 +7,73 @@ import (
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
"sync/atomic"
"syscall"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
wgtun "golang.zx2c4.com/wireguard/tun"
)
type tun struct{}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported on OpenBSD")
io.ReadWriteCloser
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
deviceName := c.GetString("tun.dev", "tun")
mtu := c.GetInt("tun.mtu", DefaultMTU)
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
// Create WireGuard TUN device
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
}
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device: %w", err)
return nil, err
}
// Get the actual device name
actualName, err := tunDevice.Name()
if err != nil {
tunDevice.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
}
t := &wgTun{
tunDevice: tunDevice,
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MaxMTU: mtu,
DefaultMTU: mtu,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
// Create OpenBSD-specific route manager
t.routeManager = &tun{}
err = t.reload(c, true)
if err != nil {
tunDevice.Close()
return nil, err
}
@@ -64,166 +84,172 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
})
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil
}
func (rm *tun) Activate(t *wgTun) error {
name, err := t.tunDevice.Name()
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
}
// Set the MTU
rm.SetMTU(t, t.MaxMTU)
// Add IP addresses
for _, network := range t.vpnNetworks {
if err := rm.addIP(t, name, network); err != nil {
return err
}
if !initial && !change {
return nil
}
// Bring up the interface
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
return fmt.Errorf("failed to bring up interface: %w", err)
}
// Set the routes
if err := rm.AddRoutes(t, false); err != nil {
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (rm *tun) SetMTU(t *wgTun, mtu int) {
name, err := t.tunDevice.Name()
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
t.l.WithError(err).Error("Failed to get device name for MTU set")
return
}
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
t.l.WithError(err).Error("Failed to set tun mtu")
return err
}
}
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
// On OpenBSD, routes are set via ifconfig and route commands
return nil
}
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
name, err := t.tunDevice.Name()
if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Install {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add route using route command
args := []string{"add"}
if r.Cidr.Addr().Is6() {
args = append(args, "-inet6")
} else {
args = append(args, "-inet")
}
args = append(args, r.Cidr.String(), "-interface", name)
if r.Metric > 0 {
// OpenBSD doesn't support route metrics directly like Linux
t.l.WithField("route", r).Warn("Route metrics are not fully supported on OpenBSD")
}
err := runCommandBSD("route", args...)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get device name for route removal")
return
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
args := []string{"delete"}
if r.Cidr.Addr().Is6() {
args = append(args, "-inet6")
} else {
args = append(args, "-inet")
}
args = append(args, r.Cidr.String(), "-interface", name)
err := runCommandBSD("route", args...)
if err != nil {
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
// OpenBSD doesn't support multi-queue TUN devices in the same way as Linux
// Return a reader that wraps the same device
return &wgTunReader{
parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
addr := network.Addr()
if addr.Is4() {
// For IPv4: ifconfig tun0 10.0.0.1/24
if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
return fmt.Errorf("failed to add IPv4 address: %w", err)
func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
// For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
return fmt.Errorf("failed to add IPv6 address: %w", err)
}
return 0, fmt.Errorf("unable to determine IP version from packet")
}
return nil
}
copy(buf[4:], from)
func runCommandBSD(name string, args ...string) error {
cmd := exec.Command(name, args...)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
}
return nil
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err
}

View File

@@ -1,242 +0,0 @@
//go:build !android && !netbsd && !e2e_testing
// +build !android,!netbsd,!e2e_testing
package overlay
import (
"fmt"
"io"
"net/netip"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
wgtun "golang.zx2c4.com/wireguard/tun"
)
// wgTun wraps a WireGuard TUN device and implements the overlay.Device interface
type wgTun struct {
tunDevice wgtun.Device
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeChan chan struct{}
// Platform-specific route management
routeManager *tun
l *logrus.Logger
}
// BatchReader interface for readers that support vectorized I/O
type BatchReader interface {
BatchRead(buffers [][]byte, sizes []int) (int, error)
}
// BatchWriter interface for writers that support vectorized I/O
type BatchWriter interface {
BatchWrite(packets [][]byte) (int, error)
}
// wgTunReader wraps a single TUN queue for multi-queue support
type wgTunReader struct {
parent *wgTun
tunDevice wgtun.Device
offset int
l *logrus.Logger
}
func (t *wgTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *wgTun) Name() string {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get TUN device name")
return "unknown"
}
return name
}
func (t *wgTun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *wgTun) Activate() error {
if t.routeManager == nil {
return fmt.Errorf("route manager not initialized")
}
return t.routeManager.Activate(t)
}
// Read implements single-packet read for backward compatibility
func (t *wgTun) Read(b []byte) (int, error) {
bufs := [][]byte{b}
sizes := []int{0}
n, err := t.tunDevice.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrNoProgress
}
return sizes[0], nil
}
// Write implements single-packet write for backward compatibility
func (t *wgTun) Write(b []byte) (int, error) {
bufs := [][]byte{b}
offset := 0
// WireGuard TUN expects the packet data to start at offset 0
n, err := t.tunDevice.Write(bufs, offset)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
func (t *wgTun) Close() error {
if t.routeChan != nil {
close(t.routeChan)
}
if t.tunDevice != nil {
return t.tunDevice.Close()
}
return nil
}
func (t *wgTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
// For WireGuard TUN, we need to create separate TUN device instances for multi-queue
// The platform-specific implementation will handle this
if t.routeManager == nil {
return nil, fmt.Errorf("route manager not initialized for multi-queue reader")
}
return t.routeManager.NewMultiQueueReader(t)
}
func (t *wgTun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, true)
if err != nil {
return err
}
oldDefaultMTU := t.DefaultMTU
oldMaxMTU := t.MaxMTU
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
newMaxMTU := newDefaultMTU
for i, r := range routes {
if r.MTU == 0 {
routes[i].MTU = newDefaultMTU
}
if r.MTU > t.MaxMTU {
newMaxMTU = r.MTU
}
}
t.MaxMTU = newMaxMTU
t.DefaultMTU = newDefaultMTU
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial && t.routeManager != nil {
if oldMaxMTU != newMaxMTU {
t.routeManager.SetMTU(t, t.MaxMTU)
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
}
if oldDefaultMTU != newDefaultMTU {
for i := range t.vpnNetworks {
err := t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
}
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.routeManager.RemoveRoutes(t, findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.routeManager.AddRoutes(t, true)
if err != nil {
// This should never be called since AddRoutes should log its own errors in a reload condition
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
}
}
return nil
}
// BatchRead reads multiple packets from the TUN device using vectorized I/O
// The caller provides buffers and sizes slices, and this function returns the number of packets read.
func (r *wgTunReader) BatchRead(buffers [][]byte, sizes []int) (int, error) {
return r.tunDevice.Read(buffers, sizes, r.offset)
}
// Read implements io.Reader for wgTunReader (single packet for compatibility)
func (r *wgTunReader) Read(b []byte) (int, error) {
bufs := [][]byte{b}
sizes := []int{0}
n, err := r.tunDevice.Read(bufs, sizes, r.offset)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrNoProgress
}
return sizes[0], nil
}
// Write implements io.Writer for wgTunReader
func (r *wgTunReader) Write(b []byte) (int, error) {
bufs := [][]byte{b}
n, err := r.tunDevice.Write(bufs, r.offset)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
// BatchWrite writes multiple packets to the TUN device using vectorized I/O
func (r *wgTunReader) BatchWrite(packets [][]byte) (int, error) {
return r.tunDevice.Write(packets, r.offset)
}
func (r *wgTunReader) Close() error {
if r.tunDevice != nil {
return r.tunDevice.Close()
}
return nil
}

View File

@@ -1,77 +1,84 @@
//go:build windows && !e2e_testing
// +build windows,!e2e_testing
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"crypto"
"encoding/binary"
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
wgtun "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type tun struct {
luid winipcfg.LUID
type winTun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
deviceName := c.GetString("tun.dev", "Nebula")
mtu := c.GetInt("tun.mtu", DefaultMTU)
// Create WireGuard TUN device
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
err := checkWinTunExists()
if err != nil {
return nil, fmt.Errorf("failed to create TUN device: %w", err)
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
}
// Get the actual device name
actualName, err := tunDevice.Name()
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
tunDevice.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
t := &wgTun{
tunDevice: tunDevice,
t := &winTun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MaxMTU: mtu,
DefaultMTU: mtu,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
// Create Windows-specific route manager
rm := &tun{}
// Get LUID from the TUN device
// The WireGuard TUN device on Windows should provide a LUID() method
if nativeTun, ok := tunDevice.(interface{ LUID() uint64 }); ok {
rm.luid = winipcfg.LUID(nativeTun.LUID())
} else {
tunDevice.Close()
return nil, fmt.Errorf("failed to get LUID from TUN device")
}
t.routeManager = rm
err = t.reload(c, true)
if err != nil {
tunDevice.Close()
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
@@ -79,140 +86,206 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
})
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil
}
func (rm *tun) Activate(t *wgTun) error {
// Set MTU
err := rm.setMTU(t, t.MaxMTU)
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return fmt.Errorf("failed to set MTU: %w", err)
}
// Add IP addresses
for _, network := range t.vpnNetworks {
if err := rm.addIP(t, network); err != nil {
return err
}
if !initial && !change {
return nil
}
// Add routes
if err := rm.AddRoutes(t, false); err != nil {
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses(t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (rm *tun) SetMTU(t *wgTun, mtu int) {
if err := rm.setMTU(t, mtu); err != nil {
t.l.WithError(err).Error("Failed to set MTU")
}
}
func (rm *tun) setMTU(t *wgTun, mtu int) error {
// Set MTU using winipcfg
// Note: MTU setting on Windows TUN devices may be handled by the driver
// For now, we'll skip explicit MTU setting as the WireGuard TUN handles it
return nil
}
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
// On Windows, routes are managed differently
return nil
}
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if !r.Install {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
if r.MTU > 0 {
// Windows route MTU is not directly supported
t.l.WithField("route", r).Debug("Route MTU is not supported on Windows")
}
// Use winipcfg to add the route
// The rm.luid should have the AddRoute method from winipcfg
if len(r.Via) == 0 {
t.l.WithField("route", r).Warn("Route has no via address, skipping")
continue
}
err := rm.luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
// Add our unsafe route
// Windows does not support multipath routes natively, so we install only a single route.
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 {
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
}
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil
}
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
if len(r.Via) == 0 {
continue
}
err := rm.luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
// See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
// Windows doesn't support multi-queue TUN devices
// Return a reader that wraps the same device
return &wgTunReader{
parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
}
func (rm *tun) addIP(t *wgTun, network netip.Prefix) error {
// Add IP address using winipcfg
// SetIPAddresses expects a slice of prefixes
err := rm.luid.SetIPAddresses([]netip.Prefix{network})
if err != nil {
return fmt.Errorf("failed to add IP address %s: %w", network, err)
}
return nil
}
// generateGUIDByDeviceName generates a GUID based on the device name
func generateGUIDByDeviceName(deviceName string) (*windows.GUID, error) {
// Hash the device name to create a deterministic GUID
h := crypto.SHA256.New()
h.Write([]byte(tunGUIDLabel))
h.Write([]byte(deviceName))
sum := h.Sum(nil)
guid := &windows.GUID{
Data1: binary.LittleEndian.Uint32(sum[0:4]),
Data2: binary.LittleEndian.Uint16(sum[4:6]),
Data3: binary.LittleEndian.Uint16(sum[6:8]),
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
copy(guid.Data4[:], sum[8:16])
return guid, nil
func (t *winTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *winTun) Name() string {
return t.Device
}
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
_ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6)
return t.tun.Close()
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
_, err := hash.Write([]byte(tunGUIDLabel))
if err != nil {
return nil, err
}
_, err = hash.Write([]byte(name))
if err != nil {
return nil, err
}
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func checkWinTunExists() error {
myPath, err := os.Executable()
if err != nil {
return err
}
arch := runtime.GOARCH
switch arch {
case "386":
//NOTE: wintun bundles 386 as x86
arch = "x86"
}
_, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
return err
}

View File

@@ -1,8 +1,6 @@
package pkclient
import (
"crypto/ecdsa"
"crypto/x509"
"fmt"
"io"
"strconv"
@@ -50,27 +48,6 @@ func FromUrl(pkurl string) (*PKClient, error) {
return New(module, uint(slotid), pin, id, label)
}
func ecKeyToArray(key *ecdsa.PublicKey) []byte {
x := make([]byte, 32)
y := make([]byte, 32)
key.X.FillBytes(x)
key.Y.FillBytes(y)
return append([]byte{0x04}, append(x, y...)...)
}
func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) {
e, err := x509.ParsePKIXPublicKey(d)
if err != nil {
return nil, err
}
switch t := e.(type) {
case *ecdsa.PublicKey:
return ecKeyToArray(e.(*ecdsa.PublicKey)), nil
default:
return nil, fmt.Errorf("unknown public key type: %T", t)
}
}
func (c *PKClient) Test() error {
pub, err := c.GetPubKey()
if err != nil {

View File

@@ -3,6 +3,8 @@
package pkclient
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/asn1"
"errors"
"fmt"
@@ -180,7 +182,6 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
}
// Set up the parameters which include the peer's public key
@@ -228,3 +229,24 @@ func (c *PKClient) GetPubKey() ([]byte, error) {
return nil, fmt.Errorf("unknown public key length: %d", len(d))
}
}
func ecKeyToArray(key *ecdsa.PublicKey) []byte {
x := make([]byte, 32)
y := make([]byte, 32)
key.X.FillBytes(x)
key.Y.FillBytes(y)
return append([]byte{0x04}, append(x, y...)...)
}
func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) {
e, err := x509.ParsePKIXPublicKey(d)
if err != nil {
return nil, err
}
switch t := e.(type) {
case *ecdsa.PublicKey:
return ecKeyToArray(e.(*ecdsa.PublicKey)), nil
default:
return nil, fmt.Errorf("unknown public key type: %T", t)
}
}

View File

@@ -7,10 +7,10 @@ import "errors"
type PKClient struct {
}
var notImplemented = errors.New("not implemented")
var errNotImplemented = errors.New("not implemented")
func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) {
return nil, notImplemented
return nil, errNotImplemented
}
func (c *PKClient) Close() error {
@@ -18,13 +18,13 @@ func (c *PKClient) Close() error {
}
func (c *PKClient) SignASN1(data []byte) ([]byte, error) {
return nil, notImplemented
return nil, errNotImplemented
}
func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) {
return nil, notImplemented
return nil, errNotImplemented
}
func (c *PKClient) GetPubKey() ([]byte, error) {
return nil, notImplemented
return nil, errNotImplemented
}

23
pki.go
View File

@@ -39,10 +39,10 @@ type CertState struct {
cipher string
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Lite
myVpnNetworksTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr
myVpnAddrsTable *bart.Lite
myVpnBroadcastAddrsTable *bart.Lite
myVpnAddrsTable *bart.Table[struct{}]
myVpnBroadcastAddrsTable *bart.Table[struct{}]
}
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@@ -173,6 +173,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else {
@@ -344,9 +345,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
cs := CertState{
privateKey: privateKey,
pkcs11Backed: pkcs11backed,
myVpnNetworksTable: new(bart.Lite),
myVpnAddrsTable: new(bart.Lite),
myVpnBroadcastAddrsTable: new(bart.Lite),
myVpnNetworksTable: new(bart.Table[struct{}]),
myVpnAddrsTable: new(bart.Table[struct{}]),
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
}
if v1 != nil && v2 != nil {
@@ -358,9 +359,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
}
if v1.Networks()[0] != v2.Networks()[0] {
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
}
//TODO: CERT-V2 make sure v2 has v1s address
cs.initiatingVersion = dv
}
@@ -416,16 +415,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
for _, network := range crt.Networks() {
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
cs.myVpnNetworksTable.Insert(network)
cs.myVpnNetworksTable.Insert(network, struct{}{})
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()))
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
if network.Addr().Is4() {
addr := network.Masked().Addr().As4()
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
}
}

View File

@@ -241,13 +241,15 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects.
if f.myVpnAddrsTable.Contains(from) {
_, found := f.myVpnAddrsTable.Lookup(from)
if found {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return
}
// Is the target of the relay me?
if f.myVpnAddrsTable.Contains(target) {
_, found = f.myVpnAddrsTable.Lookup(target)
if found {
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok {
switch existingRelay.State {

View File

@@ -190,7 +190,7 @@ type RemoteList struct {
// The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
// A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -202,9 +202,7 @@ type RemoteList struct {
cache map[netip.Addr]*cache
hr *hostnamesResults
// shouldAdd is a nillable function that decides if x should be added to addrs.
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
shouldAdd func(netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
@@ -215,7 +213,7 @@ type RemoteList struct {
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0),
@@ -265,9 +263,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort
r.RLock()
defer r.RUnlock()
c := make([]netip.AddrPort, len(r.addrs))
for i, v := range r.addrs {
c[i] = v
}
copy(c, r.addrs)
return c
}
@@ -328,9 +324,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
}
if mc.relay != nil {
for _, a := range mc.relay.relay {
c.Relay = append(c.Relay, a)
}
c.Relay = append(c.Relay, mc.relay.relay...)
}
}
@@ -364,21 +358,10 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
defer r.RUnlock()
c := make([]netip.AddrPort, len(r.badRemotes))
for i, v := range r.badRemotes {
c[i] = v
}
copy(c, r.badRemotes)
return c
}
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
r.Lock()
r.badRemotes = nil
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
copy(r.vpnAddrs, vpnAddrs)
r.Unlock()
}
// ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() {
r.Lock()
@@ -580,15 +563,13 @@ func (r *RemoteList) unlockedCollect() {
}
if c.relay != nil {
for _, v := range c.relay.relay {
relays = append(relays, v)
}
relays = append(relays, c.relay.relay...)
}
}
dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
if !r.unlockedIsBad(addr) {
addrs = append(addrs, addr)
}
@@ -646,15 +627,15 @@ func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
a4 := a.Addr().Is4()
b4 := b.Addr().Is4()
switch {
case a4 == false && b4 == true:
case !a4 && b4:
// If i is v6 and j is v4, i is less than j
return true
case a4 == true && b4 == false:
case a4 && !b4:
// If j is v6 and i is v4, i is not less than j
return false
case a4 == true && b4 == true:
case a4 && b4:
// i and j are both ipv4
aPrivate := a.Addr().IsPrivate()
bPrivate := b.Addr().IsPrivate()
@@ -702,7 +683,6 @@ func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
}
r.addrs = r.addrs[:a+1]
return
}
// minInt returns the minimum integer of a or b

View File

@@ -9,10 +9,13 @@ import (
"math"
"net"
"net/netip"
"os"
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup"
"gvisor.dev/gvisor/pkg/buffer"
@@ -43,7 +46,14 @@ type Service struct {
}
}
func New(control *nebula.Control) (*Service, error) {
func New(config *config.C) (*Service, error) {
logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return nil, err
}
control.Start()
ctx := control.Context()

View File

@@ -5,17 +5,13 @@ import (
"context"
"errors"
"net/netip"
"os"
"testing"
"time"
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
)
@@ -75,15 +71,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
panic(err)
}
logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
panic(err)
}
s, err := New(control)
s, err := New(&c)
if err != nil {
panic(err)
}

20
ssh.go
View File

@@ -527,11 +527,11 @@ func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
return err
}
func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("%s", ifce.version))
func sshVersion(ifce *Interface, _ any, _ []string, w sshd.StringWriter) error {
return w.WriteLine(ifce.version)
}
func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshQueryLighthouse(ifce *Interface, _ any, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No vpn address was provided")
}
@@ -584,7 +584,7 @@ func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
hostInfo.ConnectionState,
hostInfo,
[]byte{},
make([]byte, 12, 12),
make([]byte, 12),
make([]byte, mtu),
)
}
@@ -614,12 +614,12 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
return w.WriteLine("Tunnel already exists")
}
hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
return w.WriteLine("Tunnel already handshaking")
}
var addr netip.AddrPort
@@ -735,7 +735,7 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
}
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
func sshLogLevel(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
@@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) erro
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
func sshLogFormat(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
}
@@ -822,10 +822,10 @@ func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) erro
return w.WriteLine(cert.String())
}
func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshPrintRelays(ifce *Interface, fs any, _ []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags)
if !ok {
w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
w.WriteLine("sshPrintRelays failed to convert args type")
return nil
}

View File

@@ -23,7 +23,6 @@ type SSHServer struct {
trustedCAs []ssh.PublicKey
// List of available commands
helpCommand *Command
commands *radix.Tree
listener net.Listener
@@ -43,7 +42,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
conns: make(map[int]*session),
}
cc := ssh.CertChecker{
cc := &ssh.CertChecker{
IsUserAuthority: func(auth ssh.PublicKey) bool {
for _, ca := range s.trustedCAs {
if bytes.Equal(ca.Marshal(), auth.Marshal()) {
@@ -77,10 +76,11 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
},
}
s.certChecker = cc
s.config = &ssh.ServerConfig{
PublicKeyCallback: cc.Authenticate,
ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
ServerVersion: "SSH-2.0-Nebula???",
}
s.RegisterCommand(&Command{

View File

@@ -170,7 +170,6 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
}
_ = execCommand(c, args[1:], w)
return
}
func (s *session) Close() {

View File

@@ -30,15 +30,11 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) ListenOut(_ EncReader) {}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
return
}
func (NoopConn) ReloadConfig(_ *config.C) {}
func (NoopConn) Close() error {
return nil
}

View File

@@ -1,5 +0,0 @@
package udp
import "errors"
var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")

View File

@@ -3,62 +3,20 @@
package udp
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"golang.org/x/sys/unix"
)
type StdConn struct {
*net.UDPConn
isV4 bool
sysFd uintptr
l *logrus.Logger
}
var _ Conn = &StdConn{}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
c := &StdConn{UDPConn: uc, l: l}
rc, err := uc.SyscallConn()
if err != nil {
return nil, fmt.Errorf("failed to open udp socket: %w", err)
}
err = rc.Control(func(fd uintptr) {
c.sysFd = fd
})
if err != nil {
return nil, fmt.Errorf("failed to get udp fd: %w", err)
}
la, err := c.LocalAddr()
if err != nil {
return nil, err
}
c.isV4 = la.Addr().Is4()
return c, nil
}
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
@@ -85,116 +43,16 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
//go:linkname sendto golang.org/x/sys/unix.sendto
//go:noescape
func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
var sa unsafe.Pointer
var addrLen int32
if u.isV4 {
if ap.Addr().Is6() {
return ErrInvalidIPv6RemoteForSocket
}
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4
} else {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet6
}
// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
// See https://github.com/golang/go/issues/73919
for {
//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
err := sendto(int(u.sysFd), b, 0, sa, addrLen)
if err == nil {
// Written, get out before the error handling
return nil
}
if errors.Is(err, syscall.EINTR) {
// Write was interrupted, retry
continue
}
if errors.Is(err, syscall.EAGAIN) {
return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
}
if errors.Is(err, syscall.EBADF) {
return net.ErrClosed
}
return &net.OpError{Op: "sendto", Err: err}
}
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
addr, ok := netip.AddrFromSlice(v.IP)
if !ok {
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
}
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
default:
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
}
}
func (u *StdConn) ReloadConfig(c *config.C) {
// TODO
}
func NewUDPStatsEmitter(udpConns []Conn) func() {
// No UDP stats for non-linux
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer)
func (u *GenericConn) Rebind() error {
rc, err := u.UDPConn.SyscallConn()
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
} else {
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
return err
}
return rc.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
if err != nil {
u.l.WithError(err).Error("Failed to rebind udp socket")
}
return nil
})
}

View File

@@ -1,7 +1,6 @@
//go:build (!linux || android) && !e2e_testing && !darwin
//go:build (!linux || android) && !e2e_testing
// +build !linux android
// +build !e2e_testing
// +build !darwin
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows.
@@ -34,7 +33,7 @@ func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, b
if uc, ok := pc.(*net.UDPConn); ok {
return &GenericConn{UDPConn: uc, l: l}, nil
}
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
}
func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
@@ -67,10 +66,6 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {}
}
type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)

View File

@@ -221,7 +221,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
if !ip.Addr().Is4() {
return ErrInvalidIPv6RemoteForSocket
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var rsa unix.RawSockaddrInet4

View File

@@ -92,25 +92,6 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
// Enable v4 for this socket
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
// Disable reporting of PORT_UNREACHABLE and NET_UNREACHABLE errors from the UDP socket receive call.
// These errors are returned on Windows during UDP receives based on the receipt of ICMP packets. Disable
// the UDP receive error returns with these ioctl calls.
ret := uint32(0)
flag := uint32(0)
size := uint32(unsafe.Sizeof(flag))
err = syscall.WSAIoctl(syscall.Handle(u.sock), syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
if err != nil {
return err
}
ret = 0
flag = 0
size = uint32(unsafe.Sizeof(flag))
SIO_UDP_NETRESET := uint32(syscall.IOC_IN | syscall.IOC_VENDOR | 15)
err = syscall.WSAIoctl(syscall.Handle(u.sock), SIO_UDP_NETRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
if err != nil {
return err
}
err = u.rx.Open()
if err != nil {
return err
@@ -141,13 +122,9 @@ func (u *RIOConn) ListenOut(r EncReader) {
// Just read one packet at a time
n, rua, err := u.receive(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
continue
}
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
}