Compare commits

..

2 Commits

Author SHA1 Message Date
Nate Brown
0681195da3 Fix testify lint errors 2025-03-13 00:43:08 -05:00
Nate Brown
2c9cc63c1a PSK support for v2 2025-03-12 23:25:28 -05:00
109 changed files with 1803 additions and 4180 deletions

View File

@@ -1,21 +1,13 @@
blank_issues_enabled: true
contact_links:
- name: 💨 Performance Issues
url: https://github.com/slackhq/nebula/discussions/new/choose
about: 'We ask that you create a discussion instead of an issue for performance-related questions. This allows us to have a more open conversation about the issue and helps us to better understand the problem.'
- name: 📄 Documentation Issues
url: https://github.com/definednet/nebula-docs
about: "If you've found an issue with the website documentation, please file it in the nebula-docs repository."
- name: 📱 Mobile Nebula Issues
url: https://github.com/definednet/mobile_nebula
about: "If you're using the mobile Nebula app and have found an issue, please file it in the mobile_nebula repository."
- name: 📘 Documentation
url: https://nebula.defined.net/docs/
about: 'The documentation is the best place to start if you are new to Nebula.'
about: Review documentation.
- name: 💁 Support/Chat
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
about: 'For faster support, join us on Slack for assistance!'
url: https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU
about: 'This issue tracker is not for support questions. Join us on Slack for assistance!'
- name: 📱 Mobile Nebula
url: https://github.com/definednet/mobile_nebula
about: 'This issue tracker is not for mobile support. Try the Mobile Nebula repo instead!'

View File

@@ -1,11 +0,0 @@
<!--
Thank you for taking the time to submit a pull request!
Please be sure to provide a clear description of what you're trying to achieve with the change.
- If you're submitting a new feature, please explain how to use it and document any new config options in the example config.
- If you're submitting a bugfix, please link the related issue or describe the circumstances surrounding the issue.
- If you're changing a default, explain why you believe the new default is appropriate for most users.
P.S. If you're only updating the README or other docs, please file a pull request here instead: https://github.com/DefinedNet/nebula-docs
-->

View File

@@ -16,9 +16,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.23'
check-latest: true
- name: Install goimports

View File

@@ -12,9 +12,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.23'
check-latest: true
- name: Build
@@ -35,9 +35,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.23'
check-latest: true
- name: Build
@@ -68,14 +68,14 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.23'
check-latest: true
- name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true'
uses: Apple-Actions/import-codesign-certs@v5
uses: Apple-Actions/import-codesign-certs@v3
with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}

View File

@@ -22,9 +22,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version-file: 'go.mod'
check-latest: true
- name: add hashicorp source

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.23'
check-latest: true
- name: build

View File

@@ -20,9 +20,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.23'
check-latest: true
- name: Build
@@ -32,9 +32,9 @@ jobs:
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
uses: golangci/golangci-lint-action@v6
with:
version: v2.5
version: v1.64
- name: Test
run: make test
@@ -58,9 +58,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.23'
check-latest: true
- name: Build
@@ -79,9 +79,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.22'
check-latest: true
- name: Build
@@ -100,9 +100,9 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
- uses: actions/setup-go@v5
with:
go-version: '1.25'
go-version: '1.23'
check-latest: true
- name: Build nebula
@@ -115,9 +115,9 @@ jobs:
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
uses: golangci/golangci-lint-action@v6
with:
version: v2.5
version: v1.64
- name: Test
run: make test

View File

@@ -1,23 +1,9 @@
version: "2"
# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json
linters:
default: none
# Disable all linters.
# Default: false
disable-all: true
# Enable specific linter
# https://golangci-lint.run/usage/linters/#enabled-by-default
enable:
- testifylint
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
paths:
- third_party$
- builtin$
- examples$
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -7,13 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Changed
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
intended to target an `unsafe_routes` entry must explicitly declare it via the
`local_cidr` field. This is almost always the intended behavior. This flag is
deprecated and will be removed in a future release.
## [1.9.4] - 2024-09-09
### Added

View File

@@ -4,7 +4,7 @@ It lets you seamlessly connect computers anywhere in the world. Nebula is portab
It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
and tunneling.
and tunneling, and each of those individual pieces existed before Nebula in various forms.
What makes Nebula different to existing offerings is that it brings all of these ideas together,
resulting in a sum that is greater than its individual parts.
@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
You can read more about Nebula [here](https://medium.com/p/884110a5579).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
## Supported Platforms
@@ -28,33 +28,33 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
#### Distribution Packages
- [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/)
```sh
sudo pacman -S nebula
```
$ sudo pacman -S nebula
```
- [Fedora Linux](https://src.fedoraproject.org/rpms/nebula)
```sh
sudo dnf install nebula
```
$ sudo dnf install nebula
```
- [Debian Linux](https://packages.debian.org/source/stable/nebula)
```sh
sudo apt install nebula
```
$ sudo apt install nebula
```
- [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula)
```sh
sudo apk add nebula
```
$ sudo apk add nebula
```
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
```sh
brew install nebula
```
$ brew install nebula
```
- [Docker](https://hub.docker.com/r/nebulaoss/nebula)
```sh
docker pull nebulaoss/nebula
```
$ docker pull nebulaoss/nebula
```
#### Mobile
@@ -64,10 +64,10 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
## Technical Overview
Nebula is a mutually authenticated peer-to-peer software-defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups.
Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes.
Discovery nodes (aka lighthouses) allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
Discovery nodes allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme.
Nebula uses Elliptic-curve Diffie-Hellman (`ECDH`) key exchange and `AES-256-GCM` in its default configuration.
@@ -82,34 +82,28 @@ To set up a Nebula network, you'll need:
#### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse.
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $6/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $5/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
#### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network.
```sh
./nebula-cert ca -name "Myorganization, Inc"
```
This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
**Be aware!** By default, certificate authorities have a 1-year lifetime before expiration. See [this guide](https://nebula.defined.net/docs/guides/rotating-certificate-authority/) for details on rotating a CA.
```
./nebula-cert ca -name "Myorganization, Inc"
```
This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
#### 4. Nebula host keys and certificates generated from that certificate authority
This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network.
```sh
```
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh"
./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers"
./nebula-cert sign -name "host3" -ip "192.168.100.10/24"
```
By default, host certificates will expire 1 second before the CA expires. Use the `-duration` flag to specify a shorter lifetime.
#### 5. Configuration files for each host
Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yml).
* On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set.
@@ -124,13 +118,10 @@ For each host, copy the nebula binary to the host, along with `config.yml` from
**DO NOT COPY `ca.key` TO INDIVIDUAL NODES.**
#### 7. Run nebula on each host
```sh
```
./nebula -config /path/to/config.yml
```
For more detailed instructions, [find the full documentation here](https://nebula.defined.net/docs/).
## Building Nebula from source
Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory.
@@ -149,10 +140,8 @@ The default curve used for cryptographic handshakes and signatures is Curve25519
In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
```sh
make bin-boringcrypto
make release-boringcrypto
```
make bin-boringcrypto
make release-boringcrypto
This is not the recommended default deployment, but may be useful based on your compliance requirements.
@@ -160,3 +149,5 @@ This is not the recommended default deployment, but may be useful based on your
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.

View File

@@ -36,7 +36,7 @@ type AllowListNameRule struct {
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
var nameRules []AllowListNameRule
handleKey := func(key string, value any) (bool, error) {
handleKey := func(key string, value interface{}) (bool, error) {
if key == "interfaces" {
var err error
nameRules, err = getAllowListInterfaces(k, value)
@@ -70,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
r := c.Get(k)
if r == nil {
return nil, nil
@@ -81,8 +81,8 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
rawMap, ok := raw.(map[string]any)
func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
rawMap, ok := raw.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}
@@ -100,7 +100,12 @@ func newAllowList(k string, raw any, handleKey func(key string, value any) (bool
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
for rawCIDR, rawValue := range rawMap {
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
if handleKey != nil {
handled, err := handleKey(rawCIDR, rawValue)
if err != nil {
@@ -111,7 +116,7 @@ func newAllowList(k string, raw any, handleKey func(key string, value any) (bool
}
}
value, ok := config.AsBool(rawValue)
value, ok := rawValue.(bool)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
@@ -168,18 +173,22 @@ func newAllowList(k string, raw any, handleKey func(key string, value any) (bool
return &AllowList{cidrTree: tree}, nil
}
func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) {
func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
var nameRules []AllowListNameRule
rawRules, ok := v.(map[string]any)
rawRules, ok := v.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
}
firstEntry := true
var allValues bool
for name, rawAllow := range rawRules {
allow, ok := config.AsBool(rawAllow)
for rawName, rawAllow := range rawRules {
name, ok := rawName.(string)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
}
allow, ok := rawAllow.(bool)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
}
@@ -215,11 +224,16 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
remoteAllowRanges := new(bart.Table[*AllowList])
rawMap, ok := value.(map[string]any)
rawMap, ok := value.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawCIDR, rawValue := range rawMap {
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
if err != nil {
return nil, err

View File

@@ -15,27 +15,27 @@ import (
func TestNewAllowListFromConfig(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
c.Settings["allowlist"] = map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
r, err := newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
assert.Nil(t, r)
c.Settings["allowlist"] = map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc",
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
@@ -45,7 +45,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
r, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
@@ -55,7 +55,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
assert.NotNil(t, r)
}
c.Settings["allowlist"] = map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
@@ -70,16 +70,16 @@ func TestNewAllowListFromConfig(t *testing.T) {
// Test interface names
c.Settings["allowlist"] = map[string]any{
"interfaces": map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: "foo",
},
}
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[string]any{
"interfaces": map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
`eth.*`: true,
},
@@ -87,8 +87,8 @@ func TestNewAllowListFromConfig(t *testing.T) {
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[string]any{
"interfaces": map[string]any{
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}

View File

@@ -5,7 +5,6 @@ import (
"github.com/sirupsen/logrus"
)
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
type Bits struct {
length uint64
current uint64
@@ -44,7 +43,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
}
// Not within the window
l.Error("rejected a packet (top) %d %d\n", b.current, i)
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
return false
}

View File

@@ -84,11 +84,16 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[string]any)
rawMap, ok := value.(map[any]any)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawCIDR, rawValue := range rawMap {
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
@@ -124,7 +129,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
}
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[string]any)
rawMap, ok := raw.(map[any]any)
if !ok {
return nil, fmt.Errorf("invalid type: %T", raw)
}

View File

@@ -58,9 +58,6 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve
@@ -138,7 +135,8 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
case Version2:
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
default:
return nil, ErrUnknownVersion
//TODO: CERT-V2 make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
}
if err != nil {

View File

@@ -41,7 +41,7 @@ type detailsV1 struct {
curve Curve
}
type m = map[string]any
type m map[string]interface{}
func (c *certificateV1) Version() Version {
return Version1
@@ -83,10 +83,6 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey
}
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte {
return c.signature
}
@@ -114,10 +110,8 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -1,7 +1,6 @@
package cert
import (
"crypto/ed25519"
"fmt"
"net/netip"
"testing"
@@ -14,7 +13,6 @@ import (
)
func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -62,58 +60,6 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{

View File

@@ -114,10 +114,6 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
@@ -153,10 +149,8 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
if err != nil {
return false
}
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:

View File

@@ -15,7 +15,6 @@ import (
)
func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -76,58 +75,6 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{
details: detailsV2{

View File

@@ -10,14 +10,14 @@ import (
func TestNewArgon2Parameters(t *testing.T) {
p := NewArgon2Parameters(64*1024, 4, 3)
assert.Equal(t, &Argon2Parameters{
assert.EqualValues(t, &Argon2Parameters{
version: argon2.Version,
Memory: 64 * 1024,
Parallelism: 4,
Iterations: 3,
}, p)
p = NewArgon2Parameters(2*1024*1024, 2, 1)
assert.Equal(t, &Argon2Parameters{
assert.EqualValues(t, &Argon2Parameters{
version: argon2.Version,
Memory: 2 * 1024 * 1024,
Parallelism: 2,
@@ -26,21 +26,21 @@ func TestNewArgon2Parameters(t *testing.T) {
}
func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
passphrase := []byte("DO NOT USE")
passphrase := []byte("DO NOT USE THIS KEY")
privKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiCPoDfGQiosxNPTbPn5EsMlc2MI
c0Bt4oz6gTrFQhX3aBJcimhHKeAuhyTGvllD0Z19fe+DFPcLH3h5VrdjVfIAajg0
KrbV3n9UHif/Au5skWmquNJzoW1E4MTdRbvpti6o+WdQ49DxjBFhx0YH8LBqrbPU
0BGkUHmIO7daP24=
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
qrlJ69wer3ZUHFXA
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
`)
shortKey := []byte(`# A key which, once decrypted, is too short
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiAVJwdfl3r+eqi/vF6S7OMdpjfo
hAzmTCRnr58Su4AqmBJbCv3zleYCEKYJP6UI3S8ekLMGISsgO4hm5leukCCyqT0Z
cQ76yrberpzkJKoPLGisX8f+xdy4aXSZl7oEYWQte1+vqbtl/eY9PGZhxUQdcyq7
hqzIyrRqfUgVuA==
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
rQr3bdH3Oy/WiYU=
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
`)
invalidBanner := []byte(`# Invalid banner (not encrypted)

View File

@@ -20,7 +20,6 @@ var (
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")

View File

@@ -1,34 +1,25 @@
package cert
import (
"encoding/hex"
"encoding/pem"
"fmt"
"time"
"golang.org/x/crypto/ed25519"
)
const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
)
const (
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
const ( //key-agreement-key banners
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
)
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -60,16 +51,6 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
@@ -81,19 +62,6 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
}
}
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
@@ -105,7 +73,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32
curve = Curve_CURVE25519
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
case P256PublicKeyBanner:
// Uncompressed
expectedLen = 65
curve = Curve_P256
@@ -140,101 +108,6 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
}
}
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
issuerBytes, err := func() ([]byte, error) {
issuer := c.Issuer()
if issuer == "" {
return nil, nil
}
decoded, err := hex.DecodeString(issuer)
if err != nil {
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
return decoded, nil
}()
if err != nil {
return nil, rest, err
}
pubKey := c.PublicKey()
if pubKey != nil {
pubKey = append([]byte(nil), pubKey...)
}
sig := c.Signature()
if sig != nil {
sig = append([]byte(nil), sig...)
}
return &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: pubKey,
IsCA: c.IsCA(),
Issuer: issuerBytes,
Curve: c.Curve(),
},
Signature: sig,
cert: c,
}, rest, nil
}
// IssuerString returns the issuer in hex format for compatibility
func (n *NebulaCertificate) IssuerString() string {
if n.Details.Issuer == nil {
return ""
}
return hex.EncodeToString(n.Details.Issuer)
}
// Certificate returns the underlying certificate (read-only)
func (n *NebulaCertificate) Certificate() Certificate {
return n.cert
}
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
// consumed data or an error on failure
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {

View File

@@ -177,7 +177,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -231,7 +230,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
}
func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -242,12 +240,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -264,22 +256,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)

View File

@@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/sha256"
"fmt"
"math/big"
"net/netip"
"time"
)
@@ -54,10 +55,15 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
if err != nil {
return nil, err
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1

View File

@@ -90,26 +90,26 @@ func Test_ca(t *testing.T) {
assertHelpError(t, ca(
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
), "-name is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// ipv4 only ips
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// ipv4 only subnets
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// failed key write
ob.Reset()
eb.Reset()
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp key file
keyF, err := os.CreateTemp("", "test.key")
@@ -121,8 +121,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp cert file
crtF, err := os.CreateTemp("", "test.crt")
@@ -135,8 +135,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, nopw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// read cert and key files
rb, _ := os.ReadFile(keyF.Name())
@@ -158,7 +158,7 @@ func Test_ca(t *testing.T) {
assert.Empty(t, lCrt.UnsafeNetworks())
assert.Len(t, lCrt.PublicKey(), 32)
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
assert.Empty(t, lCrt.Issuer())
assert.Equal(t, "", lCrt.Issuer())
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
// test encrypted key
@@ -169,7 +169,7 @@ func Test_ca(t *testing.T) {
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", eb.String())
// read encrypted key file and verify default params
rb, _ = os.ReadFile(keyF.Name())
@@ -197,7 +197,7 @@ func Test_ca(t *testing.T) {
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", eb.String())
// test when user fails to enter a password
os.Remove(keyF.Name())
@@ -207,7 +207,7 @@ func Test_ca(t *testing.T) {
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Empty(t, eb.String())
assert.Equal(t, "", eb.String())
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
@@ -222,8 +222,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// test that we won't overwrite existing key file
os.Remove(keyF.Name())
@@ -231,8 +231,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
os.Remove(keyF.Name())
}

View File

@@ -37,20 +37,20 @@ func Test_keygen(t *testing.T) {
// required args
assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// failed key write
ob.Reset()
eb.Reset()
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp key file
keyF, err := os.CreateTemp("", "test.key")
@@ -62,8 +62,8 @@ func Test_keygen(t *testing.T) {
eb.Reset()
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp pub file
pubF, err := os.CreateTemp("", "test.pub")
@@ -75,8 +75,8 @@ func Test_keygen(t *testing.T) {
eb.Reset()
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
require.NoError(t, keygen(args, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// read cert and key files
rb, _ := os.ReadFile(keyF.Name())

View File

@@ -17,7 +17,7 @@ func (he *helpError) Error() string {
return he.s
}
func newHelpErrorf(s string, v ...any) error {
func newHelpErrorf(s string, v ...interface{}) error {
return &helpError{s: fmt.Sprintf(s, v...)}
}

View File

@@ -43,16 +43,16 @@ func Test_printCert(t *testing.T) {
// no path
err := printCert([]string{}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assertHelpError(t, err, "-path is required")
// no cert at path
ob.Reset()
eb.Reset()
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
// invalid cert at path
@@ -64,8 +64,8 @@ func Test_printCert(t *testing.T) {
tf.WriteString("-----BEGIN NOPE-----")
err = printCert([]string{"-path", tf.Name()}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
// test multiple certs
@@ -155,7 +155,7 @@ func Test_printCert(t *testing.T) {
`,
ob.String(),
)
assert.Empty(t, eb.String())
assert.Equal(t, "", eb.String())
// test json
ob.Reset()
@@ -177,7 +177,7 @@ func Test_printCert(t *testing.T) {
`,
ob.String(),
)
assert.Empty(t, eb.String())
assert.Equal(t, "", eb.String())
}
// NewTestCaCert will generate a CA cert

View File

@@ -38,19 +38,19 @@ func Test_verify(t *testing.T) {
// required args
assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// no ca at path
ob.Reset()
eb.Reset()
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
// invalid ca at path
@@ -62,8 +62,8 @@ func Test_verify(t *testing.T) {
caFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
// make a ca for later
@@ -76,8 +76,8 @@ func Test_verify(t *testing.T) {
// no crt at path
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
// invalid crt at path
@@ -89,8 +89,8 @@ func Test_verify(t *testing.T) {
certFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
// unverifiable cert at path
@@ -106,8 +106,8 @@ func Test_verify(t *testing.T) {
certFile.Write(b)
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
require.ErrorIs(t, err, cert.ErrSignatureMismatch)
// verified cert at path
@@ -118,7 +118,7 @@ func Test_verify(t *testing.T) {
certFile.Write(b)
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
require.NoError(t, err)
}

View File

@@ -65,16 +65,8 @@ func main() {
}
if !*configTest {
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
ctrl.Start()
ctrl.ShutdownBlock()
}
os.Exit(0)

View File

@@ -3,9 +3,6 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"github.com/sirupsen/logrus"
@@ -61,22 +58,10 @@ func main() {
os.Exit(1)
}
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest {
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
ctrl.Start()
notifyReady(l)
wait()
l.Info("Goodbye")
ctrl.ShutdownBlock()
}
os.Exit(0)

View File

@@ -17,14 +17,14 @@ import (
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
"gopkg.in/yaml.v2"
)
type C struct {
path string
files []string
Settings map[string]any
oldSettings map[string]any
Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{}
callbacks []func(*C)
l *logrus.Logger
reloadLock sync.Mutex
@@ -32,7 +32,7 @@ type C struct {
func NewC(l *logrus.Logger) *C {
return &C{
Settings: make(map[string]any),
Settings: make(map[interface{}]interface{}),
l: l,
}
}
@@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool {
}
var (
nv any
ov any
nv interface{}
ov interface{}
)
if k == "" {
@@ -147,7 +147,7 @@ func (c *C) ReloadConfig() {
c.reloadLock.Lock()
defer c.reloadLock.Unlock()
c.oldSettings = make(map[string]any)
c.oldSettings = make(map[interface{}]interface{})
for k, v := range c.Settings {
c.oldSettings[k] = v
}
@@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error {
c.reloadLock.Lock()
defer c.reloadLock.Unlock()
c.oldSettings = make(map[string]any)
c.oldSettings = make(map[interface{}]interface{})
for k, v := range c.Settings {
c.oldSettings[k] = v
}
@@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string {
return d
}
rv, ok := r.([]any)
rv, ok := r.([]interface{})
if !ok {
return d
}
@@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string {
}
// GetMap will get the map for k or return the default d if not found or invalid
func (c *C) GetMap(k string, d map[string]any) map[string]any {
func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
r := c.Get(k)
if r == nil {
return d
}
v, ok := r.(map[string]any)
v, ok := r.(map[interface{}]interface{})
if !ok {
return d
}
@@ -243,7 +243,7 @@ func (c *C) GetInt(k string, d int) int {
// GetUint32 will get the uint32 for k or return the default d if not found or invalid
func (c *C) GetUint32(k string, d uint32) uint32 {
r := c.GetInt(k, int(d))
if r < 0 || uint64(r) > uint64(math.MaxUint32) {
if uint64(r) > uint64(math.MaxUint32) {
return d
}
return uint32(r)
@@ -266,22 +266,6 @@ func (c *C) GetBool(k string, d bool) bool {
return v
}
func AsBool(v any) (value bool, ok bool) {
switch x := v.(type) {
case bool:
return x, true
case string:
switch x {
case "y", "yes":
return true, true
case "n", "no":
return false, true
}
}
return false, false
}
// GetDuration will get the duration for k or return the default d if not found or invalid
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
r := c.GetString(k, "")
@@ -292,7 +276,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration {
return v
}
func (c *C) Get(k string) any {
func (c *C) Get(k string) interface{} {
return c.get(k, c.Settings)
}
@@ -300,10 +284,10 @@ func (c *C) IsSet(k string) bool {
return c.get(k, c.Settings) != nil
}
func (c *C) get(k string, v any) any {
func (c *C) get(k string, v interface{}) interface{} {
parts := strings.Split(k, ".")
for _, p := range parts {
m, ok := v.(map[string]any)
m, ok := v.(map[interface{}]interface{})
if !ok {
return nil
}
@@ -362,7 +346,7 @@ func (c *C) addFile(path string, direct bool) error {
}
func (c *C) parseRaw(b []byte) error {
var m map[string]any
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
@@ -374,7 +358,7 @@ func (c *C) parseRaw(b []byte) error {
}
func (c *C) parse() error {
var m map[string]any
var m map[interface{}]interface{}
for _, path := range c.files {
b, err := os.ReadFile(path)
@@ -382,7 +366,7 @@ func (c *C) parse() error {
return err
}
var nm map[string]any
var nm map[interface{}]interface{}
err = yaml.Unmarshal(b, &nm)
if err != nil {
return err

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
"gopkg.in/yaml.v2"
)
func TestConfig_Load(t *testing.T) {
@@ -19,7 +19,7 @@ func TestConfig_Load(t *testing.T) {
// invalid yaml
c := NewC(l)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}")
require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge
c = NewC(l)
@@ -31,8 +31,8 @@ func TestConfig_Load(t *testing.T) {
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
require.NoError(t, c.Load(dir))
expected := map[string]any{
"outer": map[string]any{
expected := map[interface{}]interface{}{
"outer": map[interface{}]interface{}{
"inner": "override",
},
"new": "hi",
@@ -44,12 +44,12 @@ func TestConfig_Get(t *testing.T) {
l := test.NewLogger()
// test simple type
c := NewC(l)
c.Settings["firewall"] = map[string]any{"outbound": "hi"}
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
assert.Equal(t, "hi", c.Get("firewall.outbound"))
// test complex type
inner := []map[string]any{{"port": "1", "code": "2"}}
c.Settings["firewall"] = map[string]any{"outbound": inner}
inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
assert.EqualValues(t, inner, c.Get("firewall.outbound"))
// test missing
@@ -59,7 +59,7 @@ func TestConfig_Get(t *testing.T) {
func TestConfig_GetStringSlice(t *testing.T) {
l := test.NewLogger()
c := NewC(l)
c.Settings["slice"] = []any{"one", "two"}
c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
}
@@ -101,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) {
// Test key change
c = NewC(l)
c.Settings["test"] = "hi"
c.oldSettings = map[string]any{"test": "no"}
c.oldSettings = map[interface{}]interface{}{"test": "no"}
assert.True(t, c.HasChanged("test"))
assert.True(t, c.HasChanged(""))
// No key change
c = NewC(l)
c.Settings["test"] = "hi"
c.oldSettings = map[string]any{"test": "hi"}
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
assert.False(t, c.HasChanged("test"))
assert.False(t, c.HasChanged(""))
}
@@ -184,11 +184,11 @@ firewall:
`),
}
var m map[string]any
var m map[any]any
// merge the same way config.parse() merges
for _, b := range configs {
var nm map[string]any
var nm map[any]any
err := yaml.Unmarshal(b, &nm)
require.NoError(t, err)
@@ -205,15 +205,15 @@ firewall:
t.Logf("Merged Config as YAML:\n%s", mYaml)
// If a bug is present, some items might be replaced instead of merged like we expect
expected := map[string]any{
"firewall": map[string]any{
expected := map[any]any{
"firewall": map[any]any{
"inbound": []any{
map[string]any{"host": "any", "port": "any", "proto": "icmp"},
map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
map[any]any{"host": "any", "port": "any", "proto": "icmp"},
map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
"outbound": []any{
map[string]any{"host": "any", "port": "any", "proto": "any"}}},
"listen": map[string]any{
map[any]any{"host": "any", "port": "any", "proto": "any"}}},
"listen": map[any]any{
"host": "0.0.0.0",
"port": 4242,
},

View File

@@ -4,16 +4,13 @@ import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
)
@@ -30,124 +27,130 @@ const (
)
type connectionManager struct {
in map[uint32]struct{}
inLock *sync.RWMutex
out map[uint32]struct{}
outLock *sync.RWMutex
// relayUsed holds which relay localIndexs are in use
relayUsed map[uint32]struct{}
relayUsedLock *sync.RWMutex
hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32]
intf *Interface
punchy *Punchy
// Configuration settings
hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32]
intf *Interface
pendingDeletion map[uint32]struct{}
punchy *Punchy
checkInterval time.Duration
pendingDeletionInterval time.Duration
inactivityTimeout atomic.Int64
dropInactive atomic.Bool
metricsTxPunchy metrics.Counter
metricsTxPunchy metrics.Counter
l *logrus.Logger
}
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
cm := &connectionManager{
hostMap: hm,
l: l,
punchy: p,
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
var max time.Duration
if checkInterval < pendingDeletionInterval {
max = pendingDeletionInterval
} else {
max = checkInterval
}
cm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
cm.reload(c, false)
})
return cm
}
func (cm *connectionManager) reload(c *config.C, initial bool) {
if initial {
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
// pretty close to their configured duration.
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
inLock: &sync.RWMutex{},
out: make(map[uint32]struct{}),
outLock: &sync.RWMutex{},
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
intf: intf,
pendingDeletion: make(map[uint32]struct{}),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
punchy: punchy,
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
l: l,
}
if initial || c.HasChanged("tunnels.inactivity_timeout") {
old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial {
cm.l.WithField("oldDuration", old).
WithField("newDuration", cm.getInactivityTimeout()).
Info("Inactivity timeout has changed")
}
}
if initial || c.HasChanged("tunnels.drop_inactive") {
old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial {
cm.l.WithField("oldBool", old).
WithField("newBool", cm.dropInactive.Load()).
Info("Drop inactive setting has changed")
}
}
nc.Start(ctx)
return nc
}
func (cm *connectionManager) getInactivityTimeout() time.Duration {
return (time.Duration)(cm.inactivityTimeout.Load())
}
func (cm *connectionManager) In(h *HostInfo) {
h.in.Store(true)
}
func (cm *connectionManager) Out(h *HostInfo) {
h.out.Store(true)
}
func (cm *connectionManager) RelayUsed(localIndex uint32) {
cm.relayUsedLock.RLock()
func (n *connectionManager) In(localIndex uint32) {
n.inLock.RLock()
// If this already exists, return
if _, ok := cm.relayUsed[localIndex]; ok {
cm.relayUsedLock.RUnlock()
if _, ok := n.in[localIndex]; ok {
n.inLock.RUnlock()
return
}
cm.relayUsedLock.RUnlock()
cm.relayUsedLock.Lock()
cm.relayUsed[localIndex] = struct{}{}
cm.relayUsedLock.Unlock()
n.inLock.RUnlock()
n.inLock.Lock()
n.in[localIndex] = struct{}{}
n.inLock.Unlock()
}
func (n *connectionManager) Out(localIndex uint32) {
n.outLock.RLock()
// If this already exists, return
if _, ok := n.out[localIndex]; ok {
n.outLock.RUnlock()
return
}
n.outLock.RUnlock()
n.outLock.Lock()
n.out[localIndex] = struct{}{}
n.outLock.Unlock()
}
func (n *connectionManager) RelayUsed(localIndex uint32) {
n.relayUsedLock.RLock()
// If this already exists, return
if _, ok := n.relayUsed[localIndex]; ok {
n.relayUsedLock.RUnlock()
return
}
n.relayUsedLock.RUnlock()
n.relayUsedLock.Lock()
n.relayUsed[localIndex] = struct{}{}
n.relayUsedLock.Unlock()
}
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
// resets the state for this local index
func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
in := h.in.Swap(false)
out := h.out.Swap(false)
if in || out {
h.lastUsed = now
}
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
n.inLock.Lock()
n.outLock.Lock()
_, in := n.in[localIndex]
_, out := n.out[localIndex]
delete(n.in, localIndex)
delete(n.out, localIndex)
n.inLock.Unlock()
n.outLock.Unlock()
return in, out
}
// AddTrafficWatch must be called for every new HostInfo.
// We will continue to monitor the HostInfo until the tunnel is dropped.
func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
if h.out.Swap(true) == false {
cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
n.outLock.Lock()
if _, ok := n.out[localIndex]; ok {
n.outLock.Unlock()
return
}
n.out[localIndex] = struct{}{}
n.trafficTimer.Add(localIndex, n.checkInterval)
n.outLock.Unlock()
}
func (cm *connectionManager) Start(ctx context.Context) {
clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
func (n *connectionManager) Start(ctx context.Context) {
go n.Run(ctx)
}
func (n *connectionManager) Run(ctx context.Context) {
//TODO: this tick should be based on the min wheel tick? Check firewall
clockSource := time.NewTicker(500 * time.Millisecond)
defer clockSource.Stop()
p := []byte("")
@@ -160,61 +163,61 @@ func (cm *connectionManager) Start(ctx context.Context) {
return
case now := <-clockSource.C:
cm.trafficTimer.Advance(now)
n.trafficTimer.Advance(now)
for {
localIndex, has := cm.trafficTimer.Purge()
localIndex, has := n.trafficTimer.Purge()
if !has {
break
}
cm.doTrafficCheck(localIndex, p, nb, out, now)
n.doTrafficCheck(localIndex, p, nb, out, now)
}
}
}
}
func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
switch decision {
case deleteTunnel:
if cm.hostMap.DeleteHostInfo(hostinfo) {
if n.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
}
case closeTunnel:
cm.intf.sendCloseTunnel(hostinfo)
cm.intf.closeTunnel(hostinfo)
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo)
case swapPrimary:
cm.swapPrimary(hostinfo, primary)
n.swapPrimary(hostinfo, primary)
case migrateRelays:
cm.migrateRelayUsed(hostinfo, primary)
n.migrateRelayUsed(hostinfo, primary)
case tryRehandshake:
cm.tryRehandshake(hostinfo)
n.tryRehandshake(hostinfo)
case sendTestPacket:
cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
}
cm.resetRelayTrafficCheck(hostinfo)
n.resetRelayTrafficCheck(hostinfo)
}
func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
if hostinfo != nil {
cm.relayUsedLock.Lock()
defer cm.relayUsedLock.Unlock()
n.relayUsedLock.Lock()
defer n.relayUsedLock.Unlock()
// No need to migrate any relays, delete usage info now.
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(cm.relayUsed, idx)
delete(n.relayUsed, idx)
}
}
}
func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor {
@@ -224,51 +227,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
var relayFrom netip.Addr
var relayTo netip.Addr
switch {
case ok:
switch existing.State {
case Established, PeerRequested, Disestablished:
// This relay already exists in newhostinfo, then do nothing.
continue
case Requested:
// The relay exists in a Requested state; re-send the request
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = cm.intf.myVpnAddrs[0]
relayTo = existing.PeerAddr
case ForwardingType:
relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
}
case ok && existing.State == Established:
// This relay already exists in newhostinfo, then do nothing.
continue
case ok && existing.State == Requested:
// The relay exists in a Requested state; re-send the request
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = n.intf.myVpnAddrs[0]
relayTo = existing.PeerAddr
case ForwardingType:
relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
}
case !ok:
cm.relayUsedLock.RLock()
if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
n.relayUsedLock.RLock()
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
// The relay hasn't been used; don't migrate it.
cm.relayUsedLock.RUnlock()
n.relayUsedLock.RUnlock()
continue
}
cm.relayUsedLock.RUnlock()
n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request.
var err error
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil {
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue
}
switch r.Type {
case TerminalType:
relayFrom = cm.intf.myVpnAddrs[0]
relayFrom = n.intf.myVpnAddrs[0]
relayTo = r.PeerAddr
case ForwardingType:
relayFrom = r.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
}
}
@@ -281,12 +279,12 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
switch newhostinfo.GetCert().Certificate.Version() {
case cert.Version1:
if !relayFrom.Is4() {
cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !relayTo.Is4() {
cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
@@ -298,16 +296,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayToAddr = netAddrToProtoAddr(relayTo)
default:
newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
continue
}
msg, err := req.Marshal()
if err != nil {
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else {
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
cm.l.WithFields(logrus.Fields{
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex,
@@ -318,45 +316,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
}
}
func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
cm.hostMap.RLock()
defer cm.hostMap.RUnlock()
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
n.hostMap.RLock()
defer n.hostMap.RUnlock()
hostinfo := cm.hostMap.Indexes[localIndex]
hostinfo := n.hostMap.Indexes[localIndex]
if hostinfo == nil {
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
delete(n.pendingDeletion, localIndex)
return doNothing, nil, nil
}
if cm.isInvalidCertificate(now, hostinfo) {
if n.isInvalidCertificate(now, hostinfo) {
delete(n.pendingDeletion, hostinfo.localIndexId)
return closeTunnel, hostinfo, nil
}
primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
}
// Check for traffic on this hostinfo
inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
decision := doNothing
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
hostinfo.pendingDeletion.Store(false)
delete(n.pendingDeletion, hostinfo.localIndexId)
if mainHostInfo {
decision = tryRehandshake
} else {
if cm.shouldSwapPrimary(hostinfo) {
if n.shouldSwapPrimary(hostinfo, primary) {
decision = swapPrimary
} else {
// migrate the relays to the primary, if in use.
@@ -364,55 +363,46 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
}
}
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
if !outTraffic {
// Send a punch packet to keep the NAT state alive
cm.sendPunch(hostinfo)
n.sendPunch(hostinfo)
}
return decision, hostinfo, primary
}
if hostinfo.pendingDeletion.Load() {
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
// We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(cm.l).
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")
delete(n.pendingDeletion, hostinfo.localIndexId)
return deleteTunnel, hostinfo, nil
}
decision := doNothing
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic {
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
if isInactive {
// Tunnel is inactive, tear it down
hostinfo.logger(cm.l).
WithField("inactiveDuration", inactiveFor).
WithField("primary", mainHostInfo).
Info("Dropping tunnel due to inactivity")
return closeTunnel, hostinfo, primary
}
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so.
cm.sendPunch(hostinfo)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
n.sendPunch(hostinfo)
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
return doNothing, nil, nil
}
if cm.punchy.GetTargetEverything() {
if n.punchy.GetTargetEverything() {
// This is similar to the old punchy behavior with a slight optimization.
// We aren't receiving traffic but we are sending it, punch on all known
// ips in case we need to re-prime NAT state
cm.sendPunch(hostinfo)
n.sendPunch(hostinfo)
}
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
}
@@ -421,33 +411,17 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
decision = sendTestPacket
} else {
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
}
}
hostinfo.pendingDeletion.Store(true)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
return decision, hostinfo, nil
}
func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
if cm.dropInactive.Load() == false {
// We aren't configured to drop inactive tunnels
return 0, false
}
inactiveDuration := now.Sub(hostinfo.lastUsed)
if inactiveDuration < cm.getInactivityTimeout() {
// It's not considered inactive
return inactiveDuration, false
}
// The tunnel is inactive
return inactiveDuration, true
}
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
@@ -455,90 +429,83 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
// vpn addr is static across all tunnels for this host pair so lets
// use that to determine if we should consider swapping.
if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Their primary vpn addr is less than mine. Do not swap.
return false
}
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
}
func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
cm.hostMap.Lock()
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
cm.hostMap.unlockedMakePrimary(current)
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
n.hostMap.unlockedMakePrimary(current)
}
cm.hostMap.Unlock()
n.hostMap.Unlock()
}
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
// check and return true.
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert()
if remoteCert == nil {
return false
}
caPool := cm.intf.pki.GetCAPool()
caPool := n.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil {
return false
}
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
// Block listed certificates should always be disconnected
return false
}
hostinfo.logger(cm.l).WithError(err).
hostinfo.logger(n.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
}
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if !cm.punchy.GetPunch() {
func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
if !n.punchy.GetPunch() {
// Punching is disabled
return
}
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
// would lose the ability to notify us and punchy.respond would become unreliable.
return
}
if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
if n.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr)
})
} else if hostinfo.remote.IsValid() {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return
}
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
}

View File

@@ -1,6 +1,7 @@
package nebula
import (
"context"
"crypto/ed25519"
"crypto/rand"
"net/netip"
@@ -22,7 +23,7 @@ func newTestLighthouse() *LightHouse {
addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan netip.Addr, 10),
}
lighthouses := []netip.Addr{}
lighthouses := map[netip.Addr]struct{}{}
staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses)
@@ -43,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
@@ -63,10 +64,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
@@ -84,33 +85,32 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp
nc.Out(hostinfo)
nc.In(hostinfo)
assert.False(t, hostinfo.pendingDeletion.Load())
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
assert.Contains(t, nc.out, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
// Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo)
assert.True(t, hostinfo.out.Load())
nc.Out(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.True(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
}
@@ -126,10 +126,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
@@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
@@ -167,129 +167,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp
nc.Out(hostinfo)
nc.In(hostinfo)
assert.True(t, hostinfo.in.Load())
assert.True(t, hostinfo.out.Load())
assert.False(t, hostinfo.pendingDeletion.Load())
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
// Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo)
nc.Out(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.True(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// We saw traffic, should no longer be pending deletion
nc.In(hostinfo)
nc.In(hostinfo.localIndexId)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnAddrs: vpnAddrs,
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
nc.Out(hostinfo)
nc.In(hostinfo)
assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
now := time.Now()
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
assert.Equal(t, tryRehandshake, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, should still not be pending deletion
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Finally advance beyond the inactivity timeout
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
assert.Equal(t, closeTunnel, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
@@ -360,10 +264,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc.intf = ifce
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
ifce.connectionManager = nc
hostinfo := &HostInfo{
@@ -446,10 +350,6 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey
}
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte {
return d.signature
}

View File

@@ -13,8 +13,6 @@ import (
"github.com/slackhq/nebula/noiseutil"
)
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
// 4092 should be sufficient for 5Gbps
const ReplayWindow = 1024
type ConnectionState struct {
@@ -29,7 +27,7 @@ type ConnectionState struct {
writeLock sync.Mutex
}
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern, psk []byte) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch crt.Curve() {
case cert.Curve_CURVE25519:
@@ -58,13 +56,12 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
//NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKey: []byte{},
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
PresharedKey: psk,
PresharedKeyPlacement: 0,
})
if err != nil {

View File

@@ -2,11 +2,9 @@ package nebula
import (
"context"
"errors"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
@@ -15,16 +13,6 @@ import (
"github.com/slackhq/nebula/overlay"
)
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -38,18 +26,14 @@ type controlHostLister interface {
}
type Control struct {
stateLock sync.Mutex
state RunState
f *Interface
l *logrus.Logger
ctx context.Context
cancel context.CancelFunc
sshStart func()
statsStart func()
dnsStart func()
lighthouseStart func()
connectionManagerStart func(context.Context)
f *Interface
l *logrus.Logger
ctx context.Context
cancel context.CancelFunc
sshStart func()
statsStart func()
dnsStart func()
lighthouseStart func()
}
type ControlHostInfo struct {
@@ -64,21 +48,10 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call.
// The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Activate the interface
err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
c.f.activate()
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
@@ -90,41 +63,20 @@ func (c *Control) Start() (func(), error) {
if c.dnsStart != nil {
go c.dnsStart()
}
if c.connectionManagerStart != nil {
go c.connectionManagerStart(c.ctx)
}
if c.lighthouseStart != nil {
c.lighthouseStart()
}
// Start reading packets.
c.state = Started
c.stateLock.Unlock()
return c.f.run(c.ctx)
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
c.f.run()
}
func (c *Control) Context() context.Context {
return c.ctx
}
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
c.cancel()
@@ -133,7 +85,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
}
c.state = Stopped
c.l.Info("Goodbye")
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
@@ -179,7 +131,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
if c.f.myVpnAddrsTable.Contains(vpnIp) {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
if found {
// Only returning the default certificate since its impossible
// for any other host but ourselves to have more than 1
return c.f.pki.getCertState().GetDefaultCertificate().Copy()

View File

@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
localIndexId: 201,
vpnAddrs: []netip.Addr{vpnIp},
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
localIndexId: 201,
vpnAddrs: []netip.Addr{vpnIp2},
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
@@ -101,7 +101,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assert.Equal(t, &expectedInfo, thi)
assert.EqualValues(t, &expectedInfo, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
@@ -110,7 +110,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
})
}
func assertFields(t *testing.T, expected []string, actualStruct any) {
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField())
for i := 0; i < val.NumField(); i++ {

View File

@@ -26,7 +26,7 @@ type dnsRecords struct {
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Lite
myVpnAddrsTable *bart.Table[struct{}]
}
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
@@ -112,8 +112,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
return true
}
//if we found it in this table, it's good
return d.myVpnAddrsTable.Contains(b)
_, found := d.myVpnAddrsTable.Lookup(b)
return found //if we found it in this table, it's good
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {

View File

@@ -38,24 +38,24 @@ func TestParsequery(t *testing.T) {
func Test_getDnsServerAddr(t *testing.T) {
c := config.NewC(nil)
c.Settings["lighthouse"] = map[string]any{
"dns": map[string]any{
c.Settings["lighthouse"] = map[interface{}]interface{}{
"dns": map[interface{}]interface{}{
"host": "0.0.0.0",
"port": "1",
},
}
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
c.Settings["lighthouse"] = map[string]any{
"dns": map[string]any{
c.Settings["lighthouse"] = map[interface{}]interface{}{
"dns": map[interface{}]interface{}{
"host": "::",
"port": "1",
},
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
c.Settings["lighthouse"] = map[string]any{
"dns": map[string]any{
c.Settings["lighthouse"] = map[interface{}]interface{}{
"dns": map[interface{}]interface{}{
"host": "[::]",
"port": "1",
},
@@ -63,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) {
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
// Make sure whitespace doesn't mess us up
c.Settings["lighthouse"] = map[string]any{
"dns": map[string]any{
c.Settings["lighthouse"] = map[interface{}]interface{}{
"dns": map[interface{}]interface{}{
"host": "[::] ",
"port": "1",
},

View File

@@ -20,7 +20,7 @@ import (
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
"gopkg.in/yaml.v2"
)
func BenchmarkHotPath(b *testing.B) {
@@ -506,7 +506,7 @@ func TestReestablishRelays(t *testing.T) {
curIndexes := len(myControl.GetHostmap().Indexes)
for curIndexes >= start {
curIndexes = len(myControl.GetHostmap().Indexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -991,7 +991,7 @@ func TestRehandshaking(t *testing.T) {
require.NoError(t, err)
var theirNewConfig m
require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
theirFirewall := theirNewConfig["firewall"].(map[string]any)
theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
theirFirewall["inbound"] = []m{{
"proto": "any",
"port": "any",
@@ -1052,9 +1052,6 @@ func TestRehandshakingLoser(t *testing.T) {
t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew their certificate and spin until mine sees it")
@@ -1090,7 +1087,7 @@ func TestRehandshakingLoser(t *testing.T) {
require.NoError(t, err)
var myNewConfig m
require.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
theirFirewall := myNewConfig["firewall"].(map[string]any)
theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
theirFirewall["inbound"] = []m{{
"proto": "any",
"port": "any",
@@ -1227,3 +1224,135 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
myControl.Stop()
theirControl.Stop()
}
func TestPSK(t *testing.T) {
tests := []struct {
name string
myPskMode nebula.PskMode
theirPskMode nebula.PskMode
}{
// All accepting
{
name: "both accepting",
myPskMode: nebula.PskAccepting,
theirPskMode: nebula.PskAccepting,
},
// accepting and sending both ways
{
name: "accepting to sending",
myPskMode: nebula.PskAccepting,
theirPskMode: nebula.PskSending,
},
{
name: "sending to accepting",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskAccepting,
},
// All sending
{
name: "sending to sending",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskSending,
},
// enforced and sending both ways
{
name: "enforced to sending",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskSending,
},
{
name: "sending to enforced",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskEnforced,
},
// All enforced
{
name: "both enforced",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskEnforced,
},
// Enforced can technically handshake with an accepting node, but it is bad to be in this state
{
name: "enforced to accepting",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskAccepting,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var myPskSettings, theirPskSettings m
switch test.myPskMode {
case nebula.PskAccepting:
myPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage0", "this is a key"}}}
case nebula.PskSending:
myPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage1"}}}
case nebula.PskEnforced:
myPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage2"}}}
}
switch test.theirPskMode {
case nebula.PskAccepting:
theirPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage3", "this is a key"}}}
case nebula.PskSending:
theirPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage4"}}}
case nebula.PskEnforced:
theirPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage5"}}}
}
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
myControl, myVpnIp, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.0.0.1/24", myPskSettings)
theirControl, theirVpnIp, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.0.0.2/24", theirPskSettings)
myControl.InjectLightHouseAddr(theirVpnIp[0].Addr(), theirUdpAddr)
r := router.NewR(t, myControl, theirControl)
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Route until we see our cached packet flow")
myControl.InjectTunUDPPacket(theirVpnIp[0].Addr(), 80, myVpnIp[0].Addr(), 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &header.H{}
err := h.Parse(p.Data)
if err != nil {
panic(err)
}
// If this is the stage 1 handshake packet and I am configured to send with a psk, my cert name should
// not appear. It would likely be more obvious to unmarshal the payload and check but this works fine for now
if test.myPskMode == nebula.PskEnforced || test.myPskMode == nebula.PskSending {
if h.Type == 0 && h.MessageCounter == 1 {
assert.NotContains(t, string(p.Data), "test me")
}
}
if p.To == theirUdpAddr && h.Type == 1 {
return router.RouteAndExit
}
return router.KeepRouting
})
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), 80, 80)
t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
assertTunnel(t, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
//TODO: assert hostmaps
})
}
}

View File

@@ -22,10 +22,10 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
"gopkg.in/yaml.v2"
)
type m = map[string]any
type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {

View File

@@ -111,10 +111,6 @@ type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
func NewR(t testing.TB, controls ...*nebula.Control) *R {
ctx, cancel := context.WithCancel(context.Background())
if err := os.MkdirAll("mermaid", 0755); err != nil {
panic(err)
}
r := &R{
controls: make(map[netip.AddrPort]*nebula.Control),
vpnControls: make(map[netip.Addr]*nebula.Control),
@@ -194,6 +190,9 @@ func (r *R) renderFlow() {
return
}
if err := os.MkdirAll(filepath.Dir(r.fn), 0755); err != nil {
panic(err)
}
f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644)
if err != nil {
panic(err)
@@ -700,7 +699,6 @@ func (r *R) FlushAll() {
r.Unlock()
panic("Can't FlushAll for host: " + p.To.String())
}
receiver.InjectUDPPacket(p)
r.Unlock()
}
}

View File

@@ -1,57 +0,0 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
)
func TestDropInactiveTunnels(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Go inactive and wait for the tunnels to get dropped")
waitStart := time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 && theirIndexes == 0 {
break
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*30 {
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
}
time.Sleep(1 * time.Second)
r.FlushAll()
}
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
myControl.Stop()
theirControl.Stop()
}

View File

@@ -13,11 +13,43 @@ pki:
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
#disconnect_invalid: true
# initiating_version controls which certificate version is used when initiating handshakes.
# default_version controls which certificate version is used in handshakes.
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
# initiating_version: 1
# default_version: 1
# psk can be used to mask the contents of handshakes.
psk:
# `mode` defines how the pre shared keys can be used in a handshake.
# `accepting` (the default) will initiate handshakes using an empty key and will try to use any keys provided when
# receiving handshakes, including an empty key.
# `sending` will initiate handshakes with the first key provided and will try to use any keys provided when
# receiving handshakes, including an empty key.
# `enforced` will initiate handshakes with the first psk key provided and will try to use any keys provided when
# responding to handshakes. An empty key will not be allowed.
#
# To change a mesh from not using a psk to enforcing psk:
# 1. Leave `mode` as `accepting` and configure `psk.keys` to match on all nodes in the mesh and reload.
# 2. Change `mode` to `sending` on all nodes in the mesh and reload.
# 3. Change `mode` to `enforced` on all nodes in the mesh and reload.
#mode: accepting
# The keys provided are sent through hkdf to ensure the shared secret used in the noise protocol is the
# correct byte length.
#
# Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
# incoming handshakes. This is to allow for psk rotation.
#
# To rotate a primary key:
# 1. Put the new key in the 2nd slot on every node in the mesh and reload.
# 2. Move the key from the 2nd slot to the 1st slot, the old primary key is now in the 2nd slot, reload.
# 3. Remove the old primary key once it is no longer in use on every node in the mesh and reload.
#keys:
# - shared secret string, this one is used in all outbound handshakes # This is the primary key used when sending handshakes
# - this is a fallback key, received handshakes can use this
# - another fallback, received handshakes can use this one too
# - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
@@ -239,28 +271,7 @@ tun:
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
# Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
# NOTES:
# * You will only see a single gateway in the routing table if you are not on linux
# * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
#
# unsafe_routes:
# # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
# - route: 192.168.87.0/24
# via:
# - gateway: 10.0.0.1
# - gateway: 10.0.0.2
# - gateway: 10.0.0.3
# # Multiple gateways with a weight, this will balance traffic accordingly
# - route: 192.168.87.0/24
# via:
# - gateway: 10.0.0.1
# weight: 10
# - gateway: 10.0.0.2
# weight: 5
#
# NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
# `via`: single node or list of gateways to use for this route
# NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
# `mtu`: will default to tun mtu if this option is not specified
# `metric`: will default to 0 if this option is not specified
# `install`: will default to true, controls whether this route is installed in the systems routing table.
@@ -275,10 +286,6 @@ tun:
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
# in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false
# Buffer size for reading routes updates. 0 means default system buffer size. (/proc/sys/net/core/rmem_default).
# If using massive routes updates, for example BGP, you may need to increase this value to avoid packet loss.
# SO_RCVBUFFORCE is used to avoid having to raise the system wide max
#use_system_route_table_buffer_size: 0
# Configure logging level
logging:
@@ -338,19 +345,6 @@ logging:
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# Tunnel manager settings
#tunnels:
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
# elapsed.
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
# This setting is reloadable
#drop_inactive: false
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
# inactive and eligible to be dropped.
# This setting is reloadable
#inactivity_timeout: 10m
# Nebula security group configuration
firewall:
# Action to take when a packet is not allowed by the firewall rules.
@@ -362,11 +356,11 @@ firewall:
outbound_action: drop
inbound_action: drop
# THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.)
# This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a
# `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule
# will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr`
# is explicitly defined. This is usually not the desired behavior and should be avoided!
# Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
# This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
# unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
# of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
# if the intention is to allow traffic to flow to an unsafe route.
#default_local_cidr_any: false
conntrack:
@@ -384,9 +378,11 @@ firewall:
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
# Otherwise the default is any vpn network assigned to via the certificate.
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
# ca_name: An issuing CA name
# ca_sha: An issuing CA shasum

View File

@@ -5,12 +5,8 @@ import (
"fmt"
"log"
"net"
"os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/service"
)
@@ -63,16 +59,7 @@ pki:
if err := cfg.LoadString(configStr); err != nil {
return err
}
logger := logrus.New()
logger.Out = os.Stdout
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return err
}
svc, err := service.New(ctrl)
svc, err := service.New(&cfg)
if err != nil {
return err
}

View File

@@ -53,7 +53,7 @@ type Firewall struct {
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Lite
routableNetworks *bart.Table[struct{}]
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
@@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any bool
LocalCIDR *bart.Lite
LocalCIDR *bart.Table[struct{}]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout
}
routableNetworks := new(bart.Lite)
routableNetworks := new(bart.Table[struct{}])
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix)
routableNetworks.Insert(nprefix, struct{}{})
assignedNetworks = append(assignedNetworks, network)
}
hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
routableNetworks.Insert(n)
routableNetworks.Insert(n, struct{}{})
hasUnsafeNetworks = true
}
@@ -331,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return nil
}
rs, ok := r.([]any)
rs, ok := r.([]interface{})
if !ok {
return fmt.Errorf("%s failed to parse, should be an array of rules", table)
}
@@ -431,7 +431,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
// Make sure remote address matches nebula certificate
if h.networks != nil {
if !h.networks.Contains(fp.RemoteAddr) {
_, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
@@ -444,7 +445,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// Make sure we are supposed to be handling this local ip address
if !f.routableNetworks.Contains(fp.LocalAddr) {
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
if !ok {
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
@@ -750,7 +752,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite),
LocalCIDR: new(bart.Table[struct{}]),
}
}
@@ -877,7 +879,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
}
for _, network := range f.assignedNetworks {
flc.LocalCIDR.Insert(network)
flc.LocalCIDR.Insert(network, struct{}{})
}
return nil
@@ -886,7 +888,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil
}
flc.LocalCIDR.Insert(localIp)
flc.LocalCIDR.Insert(localIp, struct{}{})
return nil
}
@@ -899,7 +901,8 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
return true
}
return flc.LocalCIDR.Contains(p.LocalAddr)
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
return ok
}
type rule struct {
@@ -915,15 +918,15 @@ type rule struct {
CASha string
}
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
r := rule{}
m, ok := p.(map[string]any)
m, ok := p.(map[interface{}]interface{})
if !ok {
return r, errors.New("could not parse rule")
}
toString := func(k string, m map[string]any) string {
toString := func(k string, m map[interface{}]interface{}) string {
v, ok := m[k]
if !ok {
return ""
@@ -941,7 +944,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
r.CASha = toString("ca_sha", m)
// Make sure group isn't an array
if v, ok := m["group"].([]any); ok {
if v, ok := m["group"].([]interface{}); ok {
if len(v) > 1 {
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
}

View File

@@ -6,7 +6,7 @@ import (
"net/netip"
)
type m = map[string]any
type m map[string]interface{}
const (
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever

View File

@@ -68,9 +68,6 @@ func TestFirewall_AddRule(t *testing.T) {
ti, err := netip.ParsePrefix("1.2.3.4/32")
require.NoError(t, err)
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
@@ -95,24 +92,12 @@ func TestFirewall_AddRule(t *testing.T) {
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
@@ -132,13 +117,6 @@ func TestFirewall_AddRule(t *testing.T) {
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -221,82 +199,6 @@ func TestFirewall_Drop(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
c := dummyCert{
name: "host1",
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
}
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
f := &Firewall{}
ft := FirewallTable{
@@ -306,10 +208,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@@ -341,15 +239,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
ip := netip.MustParsePrefix("fd99::99/128")
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -363,18 +252,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -388,18 +265,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -424,17 +289,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
},
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on name", func(b *testing.B) {
c := &cert.CachedCertificate{
@@ -593,42 +447,6 @@ func TestFirewall_Drop3(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
network := netip.MustParsePrefix("fd12::34/120")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{network},
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
@@ -813,53 +631,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
require.NoError(t, err)
conf := config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
@@ -869,28 +687,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding tcp rule
conf := config.NewC(l)
mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
@@ -898,64 +716,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
@@ -963,7 +766,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf = config.NewC(l)
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
}
@@ -973,8 +776,8 @@ func TestFirewall_convertRule(t *testing.T) {
l.SetOutput(ob)
// Ensure group array of 1 is converted and a warning is printed
c := map[string]any{
"group": []any{"group1"},
c := map[interface{}]interface{}{
"group": []interface{}{"group1"},
}
r, err := convertRule(l, c, "test", 1)
@@ -984,17 +787,17 @@ func TestFirewall_convertRule(t *testing.T) {
// Ensure group array of > 1 is errord
ob.Reset()
c = map[string]any{
"group": []any{"group1", "group2"},
c = map[interface{}]interface{}{
"group": []interface{}{"group1", "group2"},
}
r, err = convertRule(l, c, "test", 1)
assert.Empty(t, ob.String())
assert.Equal(t, "", ob.String())
require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
// Make sure a well formed group is alright
ob.Reset()
c = map[string]any{
c = map[interface{}]interface{}{
"group": "group1",
}

36
go.mod
View File

@@ -1,38 +1,40 @@
module github.com/slackhq/nebula
go 1.25
go 1.23.6
toolchain go1.23.7
require (
dario.cat/mergo v1.0.2
dario.cat/mergo v1.0.1
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.20.4
github.com/gaissmai/bart v0.18.1
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.65
github.com/miekg/dns v1.1.63
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.22.0
github.com/prometheus/client_golang v1.21.1
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.10.0
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.37.0
github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.39.0
golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0
golang.org/x/term v0.31.0
golang.org/x/net v0.37.0
golang.org/x/sync v0.12.0
golang.org/x/sys v0.31.0
golang.org/x/term v0.30.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.6
gopkg.in/yaml.v3 v3.0.1
google.golang.org/protobuf v1.36.5
gopkg.in/yaml.v2 v2.4.0
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
)
@@ -41,13 +43,15 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/mod v0.23.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.30.0 // indirect
golang.org/x/tools v0.22.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

66
go.sum
View File

@@ -1,6 +1,6 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ=
github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -53,8 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -68,8 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY=
github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -106,8 +106,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -145,10 +145,10 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@@ -156,16 +156,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -185,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -204,11 +204,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -219,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -239,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -251,6 +251,8 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -25,7 +25,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
v := cs.initiatingVersion
v := cs.defaultVersion
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
@@ -50,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
Error("Unable to handshake with host because no certificate handshake bytes is available")
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX, cs.psk.primary)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
@@ -71,8 +71,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("certVersion", v).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false
}
@@ -101,38 +100,57 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if crt == nil {
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
WithField("certVersion", cs.defaultVersion).
Error("Unable to handshake with host because no certificate is available")
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
var (
err error
ci *ConnectionState
msg []byte
)
hs := &NebulaHandshake{}
for _, psk := range cs.psk.keys {
ci, err = NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX, psk)
if err != nil {
//TODO: should be bother logging this, if we have multiple psks and the error is unrelated it will be verbose.
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
continue
}
msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
continue
}
// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
// comes out clean as well
err = hs.Unmarshal(msg)
if err == nil {
// There was no error, we can continue with this handshake
break
}
// The unmarshal failed, try the next psk if we have one
}
// We finished with an error, log it and get out
if err != nil || hs.Details == nil {
// We aren't logging the error here because we can't be sure of the failure when using psk
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
Error("Was unable to decrypt the handshake")
return
}
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return
}
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
@@ -186,16 +204,15 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
for _, network := range remoteCert.Certificate.Networks() {
vpnAddr := network.Addr()
if f.myVpnAddrsTable.Contains(vpnAddr) {
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
if found {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
@@ -203,7 +220,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}
// vpnAddrs outside our vpn networks are of no use to us, filter them out
if !f.myVpnNetworksTable.Contains(vpnAddr) {
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}
@@ -214,7 +231,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
@@ -234,7 +250,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
@@ -249,7 +264,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
@@ -257,7 +272,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -269,7 +283,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -287,7 +300,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
@@ -299,7 +311,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
@@ -307,7 +318,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
@@ -375,7 +385,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
WithField("fingerprint", fingerprint).
@@ -391,7 +400,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -404,7 +412,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -421,7 +428,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if err != nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -430,7 +436,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} else {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -449,7 +454,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -457,9 +461,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake message sent")
}
f.connectionManager.AddTrafficWatch(hostinfo)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
hostinfo.remotes.ResetBlockedRemotes()
return
}
@@ -554,7 +558,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
vpnNetworks := remoteCert.Certificate.Networks()
certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
@@ -578,7 +581,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
for _, network := range vpnNetworks {
// vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddr := network.Addr()
if !f.myVpnNetworksTable.Contains(vpnAddr) {
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}
@@ -589,7 +592,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
@@ -599,9 +601,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// Ensure the right host responded
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
@@ -637,7 +637,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -652,7 +651,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
@@ -667,7 +666,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
hostinfo.remotes.ResetBlockedRemotes()
f.metricHandshakes.Update(duration)
return false

View File

@@ -274,7 +274,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
}
// Don't relay through the host I'm trying to connect to
if hm.f.myVpnAddrsTable.Contains(relay) {
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
if found {
continue
}
@@ -450,7 +451,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},

View File

@@ -10,6 +10,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
@@ -23,11 +24,15 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
lh := newTestLighthouse()
psk, err := NewPsk(PskAccepting, nil)
require.NoError(t, err)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
psk: psk,
}
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -98,5 +103,5 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
}
func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{initiatingVersion: cert.Version2}
return &CertState{defaultVersion: cert.Version2}
}

View File

@@ -19,7 +19,7 @@ import (
// |-----------------------------------------------------------------------|
// | payload... |
type m = map[string]any
type m map[string]interface{}
const (
Version uint8 = 1

View File

@@ -4,7 +4,6 @@ import (
"errors"
"net"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
@@ -17,10 +16,12 @@ import (
"github.com/slackhq/nebula/header"
)
// const ProbeLen = 100
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
const MaxRemotes = 10
const maxRecvError = 4
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
// 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -67,7 +68,7 @@ type HostMap struct {
type RelayState struct {
sync.RWMutex
relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
// the RelayState Lock held)
@@ -78,12 +79,7 @@ type RelayState struct {
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
for idx, val := range rs.relays {
if val == ip {
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
return
}
}
delete(rs.relays, ip)
}
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
@@ -128,16 +124,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
if !slices.Contains(rs.relays, ip) {
rs.relays = append(rs.relays, ip)
}
rs.relays[ip] = struct{}{}
}
func (rs *RelayState) CopyRelayIps() []netip.Addr {
ret := make([]netip.Addr, len(rs.relays))
rs.RLock()
defer rs.RUnlock()
copy(ret, rs.relays)
ret := make([]netip.Addr, 0, len(rs.relays))
for ip := range rs.relays {
ret = append(ret, ip)
}
return ret
}
@@ -223,10 +219,11 @@ type HostInfo struct {
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
// The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
vpnAddrs []netip.Addr
recvError atomic.Uint32
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Lite
networks *bart.Table[struct{}]
relayState RelayState
// HandshakePacket records the packets used to create this hostinfo
@@ -253,14 +250,6 @@ type HostInfo struct {
// Used to track other hostinfos for this vpn ip since only 1 can be primary
// Synchronised via hostmap lock and not the hostinfo lock.
next, prev *HostInfo
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
in, out, pendingDeletion atomic.Bool
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
// This value will be behind against actual tunnel utilization in the hot path.
// This should only be used by the ConnectionManagers ticker routine.
lastUsed time.Time
}
type ViaSender struct {
@@ -730,19 +719,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false
}
func (i *HostInfo) RecvErrorExceeded() bool {
if i.recvError.Add(1) >= maxRecvError {
return true
}
return true
}
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed
return
}
i.networks = new(bart.Lite)
i.networks = new(bart.Table[struct{}])
for _, network := range networks {
i.networks.Insert(network)
i.networks.Insert(network, struct{}{})
}
for _, network := range unsafeNetworks {
i.networks.Insert(network)
i.networks.Insert(network, struct{}{})
}
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostMap_MakePrimary(t *testing.T) {
@@ -211,36 +210,8 @@ func TestHostMap_reload(t *testing.T) {
assert.Empty(t, hm.GetPreferredRanges())
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
}
func TestHostMap_RelayState(t *testing.T) {
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
a1 := netip.MustParseAddr("::1")
a2 := netip.MustParseAddr("2001::1")
h1.relayState.InsertRelayTo(a1)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
h1.relayState.InsertRelayTo(a2)
assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
// Ensure that the first relay added is the first one returned in the copy
currentRelays := h1.relayState.CopyRelayIps()
require.Len(t, currentRelays, 2)
assert.Equal(t, a1, currentRelays[0])
// Deleting the last one in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
// Deleting an element not in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
// Deleting the only element in the list works ok
h1.relayState.DeleteRelay(a1)
assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
}

106
inside.go
View File

@@ -8,7 +8,6 @@ import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/routing"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -22,12 +21,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
return
}
}
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
// Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula addr to the Nebula addr through the Nebula
@@ -48,7 +49,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
@@ -120,93 +121,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
}
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.getOrHandshakeNoRouting(vpnAddr, nil)
f.getOrHandshake(vpnAddr, nil)
}
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
// getOrHandshake returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
if f.myVpnNetworksTable.Contains(vpnAddr) {
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
}
return nil, false
}
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
// Host is inside the mesh, no routing required
if hostinfo != nil {
return hostinfo, ready
}
gateways := f.inside.RoutesFor(destinationAddr)
switch len(gateways) {
case 0:
return nil, false
case 1:
// Single gateway route
return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
default:
// Multi gateway route, perform ECMP categorization
gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
if !balancingOk {
// This happens if the gateway buckets were not calculated, this _should_ never happen
f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
}
var handshakeInfoForChosenGateway *HandshakeHostInfo
var hhReceiver = func(hh *HandshakeHostInfo) {
handshakeInfoForChosenGateway = hh
}
// Store the handshakeHostInfo for later.
// If this node is not reachable we will attempt other nodes, if none are reachable we will
// cache the packet for this gateway.
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
return hostinfo, true
}
// It appears the selected gateway cannot be reached, find another gateway to fallback on.
// The current implementation breaks ECMP but that seems better than no connectivity.
// If ECMP is also required when a gateway is down then connectivity status
// for each gateway needs to be kept and the weights recalculated when they go up or down.
// This would also need to interact with unsafe_route updates through reloading the config or
// use of the use_system_route_table option
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("destination", destinationAddr).
WithField("originalGateway", gatewayAddr).
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
}
for i := range gateways {
// Skip the gateway that failed previously
if gateways[i].Addr() == gatewayAddr {
continue
}
// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
return hostinfo, true
}
}
// No gateways reachable, cache the packet in the originally chosen gateway
cacheCallback(handshakeInfoForChosenGateway)
return hostinfo, false
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
if !found {
vpnAddr = f.inside.RouteFor(vpnAddr)
if !vpnAddr.IsValid() {
return nil, false
}
}
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -233,7 +163,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
})
@@ -288,7 +218,7 @@ func (f *Interface) SendVia(via *HostInfo,
c := via.ConnectionState.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
f.connectionManager.Out(via)
f.connectionManager.Out(via.localIndexId)
// Authenticate the header and payload, but do not encrypt for this message type.
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
@@ -356,7 +286,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
f.connectionManager.Out(hostinfo.localIndexId)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our addrs and enable a faster roaming.

View File

@@ -6,8 +6,8 @@ import (
"fmt"
"io"
"net/netip"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
@@ -18,30 +18,29 @@ import (
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp"
)
const mtu = 9001
type InterfaceConfig struct {
HostMap *HostMap
Outside udp.Conn
Inside overlay.Device
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
HandshakeManager *HandshakeManager
lightHouse *LightHouse
connectionManager *connectionManager
DropLocalBroadcast bool
DropMulticast bool
routines int
MessageMetrics *MessageMetrics
version string
relayManager *relayManager
punchy *Punchy
HostMap *HostMap
Outside udp.Conn
Inside overlay.Device
pki *PKI
Firewall *Firewall
ServeDns bool
HandshakeManager *HandshakeManager
lightHouse *LightHouse
checkInterval time.Duration
pendingDeletionInterval time.Duration
DropLocalBroadcast bool
DropMulticast bool
routines int
MessageMetrics *MessageMetrics
version string
relayManager *relayManager
punchy *Punchy
tryPromoteEvery uint32
reQueryEvery uint32
@@ -62,11 +61,11 @@ type Interface struct {
serveDns bool
createTime time.Time
lightHouse *LightHouse
myBroadcastAddrsTable *bart.Lite
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnAddrsTable *bart.Lite
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
myVpnNetworksTable *bart.Lite
myBroadcastAddrsTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
dropLocalBroadcast bool
dropMulticast bool
routines int
@@ -88,19 +87,12 @@ type Interface struct {
writers []udp.Conn
readers []io.ReadWriteCloser
wg sync.WaitGroup
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger
inPool sync.Pool
inbound chan *packet.Packet
outPool sync.Pool
outbound chan *[]byte
}
type EncWriter interface {
@@ -165,9 +157,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Firewall == nil {
return nil, errors.New("no firewall rules")
}
if c.connectionManager == nil {
return nil, errors.New("no connection manager")
}
cs := c.pki.getCertState()
ifce := &Interface{
@@ -192,7 +181,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
myVpnAddrsTable: cs.myVpnAddrsTable,
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
relayManager: c.relayManager,
connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
@@ -202,27 +191,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
},
//TODO: configurable size
inbound: make(chan *packet.Packet, 1028),
outbound: make(chan *[]byte, 1028),
l: c.l,
}
ifce.inPool = sync.Pool{New: func() any {
return packet.New()
}}
ifce.outPool = sync.Pool{New: func() any {
t := make([]byte, mtu)
return &t
}}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
ifce.connectionManager.intf = ifce
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
return ifce, nil
}
@@ -230,7 +206,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() error {
func (f *Interface) activate() {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
@@ -251,44 +227,33 @@ func (f *Interface) activate() error {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
return err
f.l.Fatal(err)
}
}
f.readers[i] = reader
}
if err = f.inside.Activate(); err != nil {
if err := f.inside.Activate(); err != nil {
f.inside.Close()
return err
f.l.Fatal(err)
}
return nil
}
func (f *Interface) run(c context.Context) (func(), error) {
func (f *Interface) run() {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
// Launch n queues to read packets from udp
f.wg.Add(1)
go f.listenOut(i)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.listenIn(f.readers[i], i)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.workerIn(i, c)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.workerOut(i, c)
}
return f.wg.Wait, nil
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
}
}
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if i > 0 {
li = f.writers[i]
@@ -296,97 +261,41 @@ func (f *Interface) listenOut(i int) {
li = f.outside
}
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
p := f.inPool.Get().(*packet.Packet)
//TODO: have the listener store this in the msgs array after a read instead of doing a copy
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
p.Payload = p.Payload[:mtu]
copy(p.Payload, payload)
p.Payload = p.Payload[:len(payload)]
p.Addr = fromUdpAddr
f.inbound <- p
//select {
//case f.inbound <- p:
//default:
// f.l.Error("Dropped packet from inbound channel")
//}
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
}
f.l.Debugf("underlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
p := f.outPool.Get().(*[]byte)
*p = (*p)[:mtu]
n, err := reader.Read(*p)
n, err := reader.Read(packet)
if err != nil {
if !f.closed.Load() {
f.l.WithError(err).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
break
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
*p = (*p)[:n]
//TODO: nonblocking channel write
f.outbound <- p
//select {
//case f.outbound <- p:
//default:
// f.l.Error("Dropped packet from outbound channel")
//}
}
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) workerIn(i int, ctx context.Context) {
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket2 := &firewall.Packet{}
nb2 := make([]byte, 12, 12)
result2 := make([]byte, mtu)
h := &header.H{}
for {
select {
case p := <-f.inbound:
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
p.Payload = p.Payload[:mtu]
f.inPool.Put(p)
case <-ctx.Done():
f.wg.Done()
return
}
}
}
func (f *Interface) workerOut(i int, ctx context.Context) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket1 := &firewall.Packet{}
nb1 := make([]byte, 12, 12)
result1 := make([]byte, mtu)
for {
select {
case data := <-f.outbound:
f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
*data = (*data)[:mtu]
f.outPool.Put(data)
case <-ctx.Done():
f.wg.Done()
return
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
}
@@ -501,7 +410,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
udpStats := udp.NewUDPStatsEmitter(f.writers)
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
for {
@@ -516,7 +425,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certState := f.pki.getCertState()
defaultCrt := certState.GetDefaultCertificate()
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
certInitiatingVersion.Update(int64(defaultCrt.Version()))
certDefaultVersion.Update(int64(defaultCrt.Version()))
// Report the max certificate version we are capable of using
if certState.v2Cert != nil {
@@ -539,7 +448,6 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error {
f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers {
err := u.Close()
if err != nil {
@@ -547,13 +455,6 @@ func (f *Interface) Close() error {
}
}
// Release the tun readers
for _, u := range f.readers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing tun device")
}
}
return nil
// Release the tun device
return f.inside.Close()
}

View File

@@ -24,7 +24,6 @@ import (
)
var ErrHostNotKnown = errors.New("host not known")
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
@@ -33,7 +32,7 @@ type LightHouse struct {
amLighthouse bool
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Lite
myVpnNetworksTable *bart.Table[struct{}]
punchConn udp.Conn
punchy *Punchy
@@ -57,7 +56,7 @@ type LightHouse struct {
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[[]netip.Addr]
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
interval atomic.Int64
updateCancel context.CancelFunc
@@ -108,7 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
lighthouses := make([]netip.Addr, 0)
lighthouses := make(map[netip.Addr]struct{})
h.lighthouses.Store(&lighthouses)
staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList)
@@ -144,7 +143,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load()
}
func (lh *LightHouse) GetLighthouses() []netip.Addr {
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
return *lh.lighthouses.Load()
}
@@ -202,7 +201,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
addr := addrs[0].Unmap()
if lh.myVpnNetworksTable.Contains(addr) {
_, found := lh.myVpnNetworksTable.Lookup(addr)
if found {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
continue
@@ -307,12 +307,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
if initial || c.HasChanged("lighthouse.hosts") {
lhList, err := lh.parseLighthouses(c)
lhMap := make(map[netip.Addr]struct{})
err := lh.parseLighthouses(c, lhMap)
if err != nil {
return err
}
lh.lighthouses.Store(&lhList)
lh.lighthouses.Store(&lhMap)
if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed")
@@ -346,37 +347,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil
}
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
}
out := make([]netip.Addr, len(lhs))
for i, host := range lhs {
addr, err := netip.ParseAddr(host)
if err != nil {
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
}
if !lh.myVpnNetworksTable.Contains(addr) {
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
_, found := lh.myVpnNetworksTable.Lookup(addr)
if !found {
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
}
out[i] = addr
lhMap[addr] = struct{}{}
}
if !lh.amLighthouse && len(out) == 0 {
if !lh.amLighthouse && len(lhMap) == 0 {
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
}
staticList := lh.GetStaticHostList()
for i := range out {
if _, ok := staticList[out[i]]; !ok {
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
for lhAddr, _ := range lhMap {
if _, ok := staticList[lhAddr]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
}
}
return out, nil
return nil
}
func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -421,7 +422,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
return err
}
shm := c.GetMap("static_host_map", map[string]any{})
shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
i := 0
for k, v := range shm {
@@ -430,13 +431,14 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
}
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
_, found := lh.myVpnNetworksTable.Lookup(vpnAddr)
if !found {
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
}
vals, ok := v.([]any)
vals, ok := v.([]interface{})
if !ok {
vals = []any{v}
vals = []interface{}{v}
}
remoteAddrs := []string{}
for _, v := range vals {
@@ -487,7 +489,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
lh.Lock()
defer lh.Unlock()
// Add an entry if we don't already have one
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
return lh.unlockedGetRemoteList(vpnAddrs)
}
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -520,15 +522,11 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
}
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
staticList := lh.GetStaticHostList()
for _, addr := range allVpnAddrs {
if _, ok := staticList[addr]; ok {
return
}
// First we check the static mapping
// and do nothing if it is there
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
return
}
// None of the VpnAddrs were present. Now we can do the deletes.
lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok {
@@ -570,7 +568,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetAddrs() {
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
continue
}
switch {
@@ -632,37 +630,31 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
return len(calculatedV4) > 0 || len(calculatedV6) > 0
}
// unlockedGetRemoteList assumes you have the lh lock
// unlockedGetRemoteList
// assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
for i, addr := range allAddrs {
am, ok := lh.addrMap[addr]
if ok {
if i != 0 {
lh.addrMap[allAddrs[0]] = am
}
return am
am, ok := lh.addrMap[allAddrs[0]]
if !ok {
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
for _, addr := range allAddrs {
lh.addrMap[addr] = am
}
}
am := NewRemoteList(allAddrs, lh.shouldAdd)
for _, addr := range allAddrs {
lh.addrMap[addr] = am
}
return am
}
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow")
}
if !allow {
return false
}
if lh.myVpnNetworksTable.Contains(to) {
_, found := lh.myVpnNetworksTable.Lookup(to)
if found {
return false
}
@@ -682,7 +674,8 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
return false
}
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
if found {
return false
}
@@ -702,7 +695,8 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
return false
}
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
if found {
return false
}
@@ -710,22 +704,19 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
}
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
l := lh.GetLighthouses()
for i := range l {
if l[i] == vpnAddr {
return true
}
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
return true
}
return false
}
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
l := lh.GetLighthouses()
for i := range vpnAddrs {
for j := range l {
if l[j] == vpnAddrs[i] {
return true
}
for _, a := range vpnAddr {
if _, ok := l[a]; ok {
return true
}
}
return false
@@ -767,12 +758,12 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried := 0
lighthouses := lh.GetLighthouses()
for _, lhVpnAddr := range lighthouses {
for lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
v = hi.ConnectionState.myCert.Version()
} else {
v = lh.ifce.GetCertState().initiatingVersion
v = lh.ifce.GetCertState().defaultVersion
}
if v == cert.Version1 {
@@ -865,7 +856,8 @@ func (lh *LightHouse) SendUpdate() {
lal := lh.GetLocalAllowList()
for _, e := range localAddrs(lh.l, lal) {
if lh.myVpnNetworksTable.Contains(e) {
_, found := lh.myVpnNetworksTable.Lookup(e)
if found {
continue
}
@@ -885,13 +877,13 @@ func (lh *LightHouse) SendUpdate() {
updated := 0
lighthouses := lh.GetLighthouses()
for _, lhVpnAddr := range lighthouses {
for lhVpnAddr := range lighthouses {
var v cert.Version
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
v = hi.ConnectionState.myCert.Version()
} else {
v = lh.ifce.GetCertState().initiatingVersion
v = lh.ifce.GetCertState().defaultVersion
}
if v == cert.Version1 {
if v1Update == nil {
@@ -943,6 +935,7 @@ func (lh *LightHouse) SendUpdate() {
V4AddrPorts: v4,
V6AddrPorts: v6,
RelayVpnAddrs: relays,
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
},
}
@@ -1062,19 +1055,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
return
}
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
useVersion := cert.Version1
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
useVersion = 1
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = 2
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
}
return
}
@@ -1083,6 +1076,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
if useVersion == cert.Version1 {
if !queryVpnAddr.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := queryVpnAddr.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
} else {
@@ -1118,7 +1114,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
var useVersion cert.Version
if targetHI == nil {
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
} else {
crt := targetHI.GetCert().Certificate
useVersion = crt.Version()
@@ -1127,9 +1123,8 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok {
whereToPunch = newDest
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
}
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
//choosing to do nothing for now, but maybe we return an error?
}
}
@@ -1188,17 +1183,19 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if !r.Is4() {
continue
}
b = r.As4()
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
}
} else if v == cert.Version2 {
for _, r := range c.relay.relay {
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
}
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("version", v).Debug("unsupported protocol version")
}
//TODO: CERT-V2 don't panic
panic("unsupported version")
}
}
}
@@ -1208,16 +1205,18 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
return
}
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
}
return
lhh.lh.Lock()
var certVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
certVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
relays := n.Details.GetRelays()
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
am.Lock()
lhh.lh.Unlock()
@@ -1242,24 +1241,27 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
return
}
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
var detailsVpnAddr netip.Addr
var useVersion cert.Version
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
useVersion := cert.Version1
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
detailsVpnAddr = netip.AddrFrom4(b)
useVersion = cert.Version1
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
} else if n.Details.VpnAddr != nil {
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = cert.Version2
} else {
detailsVpnAddr = netip.Addr{}
useVersion = cert.Version2
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
}
return
}
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
//Simple check that the host sent this not someone else
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
}
@@ -1273,24 +1275,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(fromVpnAddrs[0], relays)
am.Unlock()
n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck
switch useVersion {
case cert.Version1:
if useVersion == cert.Version1 {
if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return
}
vpnAddrB := fromVpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
case cert.Version2:
// do nothing, we want to send a blank message
default:
} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
} else {
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
return
}
@@ -1308,20 +1310,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
//maybe one day we'll have a better idea, if it matters.
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return
}
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
}
return
}
empty := []byte{0}
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
punch := func(vpnPeer netip.AddrPort) {
if !vpnPeer.IsValid() {
return
}
@@ -1333,31 +1328,48 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
}()
if lhh.l.Level >= logrus.DebugLevel {
var logVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
logVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
}
}
for _, a := range n.Details.V4AddrPorts {
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
punch(protoV4AddrPortToNetAddrPort(a))
}
for _, a := range n.Details.V6AddrPorts {
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
punch(protoV6AddrPortToNetAddrPort(a))
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchy.GetRespond() {
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
}
@@ -1436,17 +1448,3 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
}
return netip.Addr{}, false
}
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
if d.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
detailsVpnAddr := netip.AddrFrom4(b)
return detailsVpnAddr, cert.Version1, nil
} else if d.VpnAddr != nil {
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
return detailsVpnAddr, cert.Version2, nil
} else {
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
}
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
"gopkg.in/yaml.v2"
)
func TestOldIPv4Only(t *testing.T) {
@@ -31,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) {
func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -40,15 +40,15 @@ func Test_lhStaticMapping(t *testing.T) {
lh1 := "10.128.0.2"
c := config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh2 := "10.128.0.3"
c = config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
}
@@ -56,8 +56,8 @@ func Test_lhStaticMapping(t *testing.T) {
func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -65,12 +65,12 @@ func TestReloadLighthouseInterval(t *testing.T) {
lh1 := "10.128.0.2"
c := config.NewC(l)
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
c.Settings["lighthouse"] = map[interface{}]interface{}{
"hosts": []interface{}{lh1},
"interval": "1s",
}
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
@@ -91,8 +91,8 @@ func TestReloadLighthouseInterval(t *testing.T) {
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -192,12 +192,12 @@ func TestLighthouse_Memory(t *testing.T) {
theirVpnIp := netip.MustParseAddr("10.128.0.3")
c := config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -277,12 +277,12 @@ func TestLighthouse_Memory(t *testing.T) {
func TestLighthouse_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
@@ -291,9 +291,9 @@ func TestLighthouse_reload(t *testing.T) {
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
nc := map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"1.1.1.1:4242"},
nc := map[interface{}]interface{}{
"static_host_map": map[interface{}]interface{}{
"10.128.0.2": []interface{}{"1.1.1.1:4242"},
},
}
rc, err := yaml.Marshal(nc)
@@ -417,7 +417,7 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
}
func (tw *testEncWriter) GetCertState() *CertState {
return &CertState{initiatingVersion: tw.protocolVersion}
return &CertState{defaultVersion: tw.protocolVersion}
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
@@ -493,123 +493,3 @@ func Test_findNetworkUnion(t *testing.T) {
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
testStaticHost := netip.MustParseAddr("10.128.0.42")
//myVpnIp := netip.MustParseAddr("10.128.0.2")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
"10.128.0.42": []any{"1.2.3.4:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//test that we actually have the static entry:
out := lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//bolt on a lower numbered primary IP
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
lh.addrMap[testSameHostNotStatic] = am
out.Rebuild([]netip.Prefix{}) //???
//test that we actually have the static entry:
out = lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
//test that we actually have the static entry for BOTH:
out2 := lh.Query(testSameHostNotStatic)
assert.Same(t, out2, out)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
//verify
out = lh.Query(testSameHostNotStatic)
assert.NotNil(t, out)
if out == nil {
t.Fatal("expected non-nil query for the static host")
}
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
}
func TestLighthouse_DeletesWork(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testHost := netip.MustParseAddr("10.128.0.42")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//insert the host
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
am.vpnAddrs = []netip.Addr{testHost}
am.addrs = []netip.AddrPort{myUdpAddr2}
lh.addrMap[testHost] = am
am.Rebuild([]netip.Prefix{}) //???
//test that we actually have the entry:
out := lh.Query(testHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testHost})
//verify
out = lh.Query(testHost)
assert.Nil(t, out)
}

65
main.go
View File

@@ -13,10 +13,10 @@ import (
"github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"gopkg.in/yaml.v3"
"gopkg.in/yaml.v2"
)
type m = map[string]any
type m map[string]interface{}
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background())
@@ -185,7 +185,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c)
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
@@ -221,26 +220,31 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
Outside: udpConns[0],
pki: pki,
Firewall: fw,
ServeDns: serveDns,
HandshakeManager: handshakeManager,
connectionManager: connManager,
lightHouse: lightHouse,
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
DropMulticast: c.GetBool("tun.drop_multicast", false),
routines: routines,
MessageMetrics: messageMetrics,
version: buildVersion,
relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy,
HostMap: hostMap,
Inside: tun,
Outside: udpConns[0],
pki: pki,
Firewall: fw,
ServeDns: serveDns,
HandshakeManager: handshakeManager,
lightHouse: lightHouse,
checkInterval: time.Second * time.Duration(checkInterval),
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
DropMulticast: c.GetBool("tun.drop_multicast", false),
routines: routines,
MessageMetrics: messageMetrics,
version: buildVersion,
relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,
}
@@ -284,14 +288,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
return &Control{
f: ifce,
l: l,
ctx: ctx,
cancel: cancel,
sshStart: sshStart,
statsStart: statsStart,
dnsStart: dnsStart,
lighthouseStart: lightHouse.StartUpdateWorker,
connectionManagerStart: connManager.Start,
ifce,
l,
ctx,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
}, nil
}

View File

@@ -29,9 +29,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
//f.l.Error("in packet ", h)
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
if found {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
@@ -81,7 +82,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
@@ -213,7 +214,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -245,7 +246,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
return
}
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
@@ -255,18 +255,16 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex)
return false
} else {
return false
}
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
}
return true
@@ -315,11 +313,12 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
next := 0
for {
if protoAt >= dataLen {
if dataLen < offset {
break
}
proto := layers.IPProtocol(data[protoAt])
proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
@@ -367,7 +366,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
case layers.IPProtocolAH:
// Auth headers, used by IPSec, have a different meaning for header length
if dataLen <= offset+1 {
if dataLen < offset+1 {
break
}
@@ -375,7 +374,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
default:
// Normal ipv6 header length processing
if dataLen <= offset+1 {
if dataLen < offset+1 {
break
}
@@ -471,7 +470,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
return false
}
@@ -501,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
f.connectionManager.In(hostinfo)
f.connectionManager.In(hostinfo.localIndexId)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
@@ -540,6 +539,10 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
return
}
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return

View File

@@ -117,45 +117,6 @@ func Test_newPacket_v6(t *testing.T) {
err = newPacket(buffer.Bytes(), true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A v6 packet with a hop-by-hop extension
// ICMPv6 Payload (Echo Request)
icmpLayer := layers.ICMPv6{
TypeCode: layers.ICMPv6TypeEchoRequest,
}
// Hop-by-Hop Extension Header
hopOption := layers.IPv6HopByHopOption{}
hopOption.OptionData = []byte{0, 0, 0, 0}
hopByHop := layers.IPv6HopByHop{}
hopByHop.Options = append(hopByHop.Options, &hopOption)
ip = layers.IPv6{
Version: 6,
HopLimit: 128,
NextHeader: layers.IPProtocolIPv6Destination,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
buffer.Clear()
err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{
ComputeChecksums: false,
FixLengths: true,
}, &ip, &hopByHop, &icmpLayer)
if err != nil {
panic(err)
}
// Ensure buffer length checks during parsing with the next 2 tests.
// A full IPv6 header and 1 byte in the first extension, but missing
// the length byte.
err = newPacket(buffer.Bytes()[:41], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A full IPv6 header plus 1 full extension, but only 1 byte of the
// next layer, missing length byte
err = newPacket(buffer.Bytes()[:49], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good ICMP packet
ip = layers.IPv6{
Version: 6,
@@ -327,10 +288,6 @@ func Test_newPacket_v6(t *testing.T) {
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// Ensure buffer bounds checking during processing
err = newPacket(b[:41], true, p)
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// Invalid AH header
b = buffer.Bytes()
err = newPacket(b, true, p)

View File

@@ -3,8 +3,6 @@ package overlay
import (
"io"
"net/netip"
"github.com/slackhq/nebula/routing"
)
type Device interface {
@@ -12,6 +10,6 @@ type Device interface {
Activate() error
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
RouteFor(netip.Addr) netip.Addr
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

View File

@@ -11,14 +11,13 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
type Route struct {
MTU int
Metric int
Cidr netip.Prefix
Via routing.Gateways
Via netip.Addr
Install bool
}
@@ -48,17 +47,15 @@ func (r Route) String() string {
return s
}
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
routeTree := new(bart.Table[routing.Gateways])
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
routeTree := new(bart.Table[netip.Addr])
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
}
gateways := r.Via
if len(gateways) > 0 {
routing.CalculateBucketsForGateways(gateways)
routeTree.Insert(r.Cidr, gateways)
if r.Via.IsValid() {
routeTree.Insert(r.Cidr, r.Via)
}
}
return routeTree, nil
@@ -72,7 +69,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return []Route{}, nil
}
rawRoutes, ok := r.([]any)
rawRoutes, ok := r.([]interface{})
if !ok {
return nil, fmt.Errorf("tun.routes is not an array")
}
@@ -83,7 +80,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
routes := make([]Route, len(rawRoutes))
for i, r := range rawRoutes {
m, ok := r.(map[string]any)
m, ok := r.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
}
@@ -151,7 +148,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return []Route{}, nil
}
rawRoutes, ok := r.([]any)
rawRoutes, ok := r.([]interface{})
if !ok {
return nil, fmt.Errorf("tun.unsafe_routes is not an array")
}
@@ -162,7 +159,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
routes := make([]Route, len(rawRoutes))
for i, r := range rawRoutes {
m, ok := r.(map[string]any)
m, ok := r.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
}
@@ -204,63 +201,14 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
}
var gateways routing.Gateways
via, ok := rVia.(string)
if !ok {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
}
switch via := rVia.(type) {
case string:
viaIp, err := netip.ParseAddr(via)
if err != nil {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
}
gateways = routing.Gateways{routing.NewGateway(viaIp, 1)}
case []any:
gateways = make(routing.Gateways, len(via))
for ig, v := range via {
gatewayMap, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1)
}
rGateway, ok := gatewayMap["gateway"]
if !ok {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1)
}
parsedGateway, ok := rGateway.(string)
if !ok {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1)
}
gatewayIp, err := netip.ParseAddr(parsedGateway)
if err != nil {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err)
}
rGatewayWeight, ok := gatewayMap["weight"]
if !ok {
rGatewayWeight = 1
}
gatewayWeight, ok := rGatewayWeight.(int)
if !ok {
_, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32)
if err != nil {
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1)
}
}
if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 {
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight)
}
gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight)
}
default:
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia)
viaVpnIp, err := netip.ParseAddr(via)
if err != nil {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
}
rRoute, ok := m["route"]
@@ -278,7 +226,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
}
r := Route{
Via: gateways,
Via: viaVpnIp,
MTU: mtu,
Metric: metric,
Install: install,

View File

@@ -6,7 +6,6 @@ import (
"testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -24,75 +23,75 @@ func Test_parseRoutes(t *testing.T) {
assert.Empty(t, routes)
// not an array
c.Settings["tun"] = map[string]any{"routes": "hi"}
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "tun.routes is not an array")
// no routes
c.Settings["tun"] = map[string]any{"routes": []any{}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
routes, err = parseRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Empty(t, routes)
// weird route
c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1 in tun.routes is invalid")
// no mtu
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
// bad mtu
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
// missing route
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.routes is not present")
// unparsable route
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
// above network range
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
// Not in multiple ranges
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}}
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
// happy case
c.Settings["tun"] = map[string]any{"routes": []any{
map[string]any{"mtu": "9000", "route": "10.0.0.0/29"},
map[string]any{"mtu": "8000", "route": "10.0.0.1/32"},
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
}}
routes, err = parseRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
@@ -129,129 +128,105 @@ func Test_parseUnsafeRoutes(t *testing.T) {
assert.Empty(t, routes)
// not an array
c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Empty(t, routes)
// weird route
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
// invalid via
for _, invalidValue := range []any{
for _, invalidValue := range []interface{}{
127, false, nil, 1.0, []string{"1", "2"},
} {
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue))
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
}
// Unparsable list of via
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string")
// unparsable via
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// unparsable gateway
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP")
// missing gateway element
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present")
// unparsable weight element
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer")
// missing route
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
require.NoError(t, err)
// above network range
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
require.NoError(t, err)
// no mtu
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
assert.Equal(t, 0, routes[0].MTU)
// bad mtu
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// bad install
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
// happy case
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
@@ -288,9 +263,9 @@ func Test_makeRouteTree(t *testing.T) {
n, err := netip.ParsePrefix("10.0.0.0/24")
require.NoError(t, err)
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"},
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
}}
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
@@ -305,7 +280,7 @@ func Test_makeRouteTree(t *testing.T) {
nip, err := netip.ParseAddr("192.168.0.1")
require.NoError(t, err)
assert.Equal(t, nip, r[0].Addr())
assert.Equal(t, nip, r)
ip, err = netip.ParseAddr("1.0.0.1")
require.NoError(t, err)
@@ -314,91 +289,10 @@ func Test_makeRouteTree(t *testing.T) {
nip, err = netip.ParseAddr("192.168.0.2")
require.NoError(t, err)
assert.Equal(t, nip, r[0].Addr())
assert.Equal(t, nip, r)
ip, err = netip.ParseAddr("1.1.0.1")
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.False(t, ok)
}
func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24")
require.NoError(t, err)
c.Settings["tun"] = map[string]any{
"unsafe_routes": []any{
map[string]any{
"route": "192.168.86.0/24",
"via": "192.168.100.10",
},
map[string]any{
"route": "192.168.87.0/24",
"via": []any{
map[string]any{
"gateway": "10.0.0.1",
},
map[string]any{
"gateway": "10.0.0.2",
},
map[string]any{
"gateway": "10.0.0.3",
},
},
},
map[string]any{
"route": "192.168.89.0/24",
"via": []any{
map[string]any{
"gateway": "10.0.0.1",
"weight": 10,
},
map[string]any{
"gateway": "10.0.0.2",
"weight": 5,
},
},
},
},
}
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 3)
routeTree, err := makeRouteTree(l, routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("192.168.86.1")
require.NoError(t, err)
r, ok := routeTree.Lookup(ip)
assert.True(t, ok)
nip, err := netip.ParseAddr("192.168.100.10")
require.NoError(t, err)
assert.Equal(t, nip, r[0].Addr())
ip, err = netip.ParseAddr("192.168.87.1")
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1),
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1),
routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)}
routing.CalculateBucketsForGateways(expectedGateways)
assert.ElementsMatch(t, expectedGateways, r)
ip, err = netip.ParseAddr("192.168.89.1")
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10),
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)}
routing.CalculateBucketsForGateways(expectedGateways)
assert.ElementsMatch(t, expectedGateways, r)
}

View File

@@ -1,7 +1,6 @@
package overlay
import (
"net"
"net/netip"
"github.com/sirupsen/logrus"
@@ -71,13 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@@ -22,7 +21,7 @@ type tun struct {
fd int
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
@@ -57,7 +56,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
return nil, fmt.Errorf("newTun not supported in Android")
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
"sync/atomic"
@@ -16,7 +17,6 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
@@ -28,7 +28,7 @@ type tun struct {
vpnNetworks []netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
@@ -342,12 +342,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, ok := t.routeTree.Load().Lookup(ip)
if ok {
return r
}
return routing.Gateways{}
return netip.Addr{}
}
// Get the LinkAddr for the interface of the given name
@@ -382,7 +382,7 @@ func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
@@ -393,7 +393,7 @@ func (t *tun) addRoutes(logErrors bool) error {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
@@ -553,3 +553,13 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/routing"
)
type disabledTun struct {
@@ -44,8 +43,8 @@ func (*disabledTun) Activate() error {
return nil
}
func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways {
return routing.Gateways{}
func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
return netip.Addr{}
}
func (t *disabledTun) Networks() []netip.Prefix {

View File

@@ -10,28 +10,23 @@ import (
"io"
"io/fs"
"net/netip"
"os"
"os/exec"
"strconv"
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
FIODGNAME = 0x80106678
)
type fiodgnameArg struct {
@@ -41,159 +36,43 @@ type fiodgnameArg struct {
}
type ifreqRename struct {
Name [unix.IFNAMSIZ]byte
Name [16]byte
Data uintptr
}
type ifreqDestroy struct {
Name [unix.IFNAMSIZ]byte
Name [16]byte
pad [16]byte
}
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
}
type ifreqMTU struct {
Name [unix.IFNAMSIZ]byte
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
devFd int
}
func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
io.ReadWriteCloser
}
func (t *tun) Close() error {
if t.devFd >= 0 {
err := syscall.Close(t.devFd)
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
t.l.WithError(err).Error("Error closing device")
return err
}
t.devFd = -1
defer syscall.Close(s)
c := make(chan struct{})
go func() {
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
}
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
}
}()
ifreq := ifreqDestroy{Name: t.deviceBytes()}
// wait up to 1 second so we start blocking at the ioctl
select {
case <-c:
case <-time.After(1 * time.Second):
}
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
@@ -205,37 +84,32 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var fd int
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
}
if err != nil {
return nil, err
}
// Read the name of the interface
rawConn, err := file.SyscallConn()
if err != nil {
return nil, fmt.Errorf("SyscallConn: %v", err)
}
var name [16]byte
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
}
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil {
return nil, err
}
@@ -247,7 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
// If the name doesn't match the desired interface name, rename it now
if ifName != deviceName {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
s, err := syscall.Socket(
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
if err != nil {
return nil, err
}
@@ -270,11 +148,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
t := &tun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
devFd: fd,
ReadWriteCloser: file,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
@@ -293,111 +171,38 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
if cidr.Addr().Is6() {
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
return fmt.Errorf("unknown address type %v", cidr)
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return t.addRoutes(false)
}
func (t *tun) setMTU() error {
// Set the MTU on the device
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
@@ -437,7 +242,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -457,21 +262,20 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
@@ -484,8 +288,9 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
@@ -500,144 +305,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
}
return
}
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@@ -24,7 +23,7 @@ type tun struct {
io.ReadWriteCloser
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
@@ -80,7 +79,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}

View File

@@ -17,7 +17,6 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
@@ -34,11 +33,10 @@ type tun struct {
deviceIndex int
ioctlFd uintptr
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeChan chan struct{}
useSystemRoutes bool
useSystemRoutesBufferSize int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeChan chan struct{}
useSystemRoutes bool
l *logrus.Logger
}
@@ -125,13 +123,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
l: l,
ReadWriteCloser: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l,
}
err := t.reload(c, true)
@@ -234,7 +231,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return file, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -293,6 +290,7 @@ func (t *tun) addIPs(link netlink.Link) error {
//add all new addresses
for i := range newAddrs {
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
@@ -360,11 +358,6 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation")
}
if err = t.addIPs(link); err != nil {
return err
}
@@ -470,7 +463,7 @@ func (t *tun) addRoutes(logErrors bool) error {
err := netlink.RouteReplace(&nr)
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
@@ -537,13 +530,7 @@ func (t *tun) watchRoutes() {
rch := make(chan netlink.RouteUpdate)
doneChan := make(chan struct{})
netlinkOptions := netlink.RouteSubscribeOptions{
ReceiveBufferSize: t.useSystemRoutesBufferSize,
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
}
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
return
}
@@ -553,14 +540,8 @@ func (t *tun) watchRoutes() {
go func() {
for {
select {
case r, ok := <-rch:
if ok {
t.updateRoutes(r)
} else {
// may be should do something here as
// netlink stops sending updates
return
}
case r := <-rch:
t.updateRoutes(r)
case <-doneChan:
// netlink.RouteSubscriber will close the rch for us
return
@@ -569,7 +550,20 @@ func (t *tun) watchRoutes() {
}()
}
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
if r.Gw == nil {
// Not a gateway route, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
return
}
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
return
}
gwAddr = gwAddr.Unmap()
withinNetworks := false
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
@@ -577,73 +571,9 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
break
}
}
return withinNetworks
}
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
}
}
for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
}
}
}
routing.CalculateBucketsForGateways(gateways)
return gateways
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
return
}
if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
if !withinNetworks {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
return
}
@@ -659,12 +589,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
newTree := t.routeTree.Load().Clone()
if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
newTree.Insert(dst, gateways)
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
newTree.Insert(dst, gwAddr)
} else {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
newTree.Delete(dst)
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
}
t.routeTree.Store(newTree)
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@@ -32,7 +31,7 @@ type tun struct {
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@@ -178,7 +177,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -198,7 +197,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
@@ -206,7 +205,7 @@ func (t *tun) addRoutes(logErrors bool) error {
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {

View File

@@ -17,7 +17,6 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@@ -26,7 +25,7 @@ type tun struct {
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@@ -159,7 +158,7 @@ func (t *tun) Activate() error {
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -167,7 +166,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
@@ -175,7 +174,7 @@ func (t *tun) addRoutes(logErrors bool) error {
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {

View File

@@ -13,14 +13,13 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
type TestTun struct {
Device string
vpnNetworks []netip.Prefix
Routes []Route
routeTree *bart.Table[routing.Gateways]
routeTree *bart.Table[netip.Addr]
l *logrus.Logger
closed atomic.Bool
@@ -87,7 +86,7 @@ func (t *TestTun) Get(block bool) []byte {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Lookup(ip)
return r
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
@@ -32,7 +31,7 @@ type winTun struct {
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
tun *wintun.NativeTun
@@ -148,18 +147,15 @@ func (t *winTun) addRoutes(logErrors bool) error {
foundDefault4 := false
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add our unsafe route
// Windows does not support multipath routes natively, so we install only a single route.
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
@@ -202,8 +198,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
continue
}
// See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
@@ -213,7 +208,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
return nil
}
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
@@ -39,13 +38,9 @@ type UserDevice struct {
func (d *UserDevice) Activate() error {
return nil
}
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}

View File

@@ -1,12 +0,0 @@
package packet
import "net/netip"
type Packet struct {
Payload []byte
Addr netip.AddrPort
}
func New() *Packet {
return &Packet{Payload: make([]byte, 9001)}
}

View File

@@ -180,7 +180,6 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
}
// Set up the parameters which include the peer's public key

75
pki.go
View File

@@ -33,16 +33,18 @@ type CertState struct {
v2Cert cert.Certificate
v2HandshakeBytes []byte
initiatingVersion cert.Version
privateKey []byte
pkcs11Backed bool
cipher string
defaultVersion cert.Version
privateKey []byte
pkcs11Backed bool
cipher string
psk *Psk
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Lite
myVpnNetworksTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr
myVpnAddrsTable *bart.Lite
myVpnBroadcastAddrsTable *bart.Lite
myVpnAddrsTable *bart.Table[struct{}]
myVpnBroadcastAddrsTable *bart.Table[struct{}]
}
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@@ -171,8 +173,19 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
}
}
psk, err := NewPskFromConfig(c)
if err != nil {
return util.NewContextualError("Failed to load psk from config", nil, err)
}
if len(psk.keys) > 0 {
p.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.keys)).
Info("pre shared keys are in use")
}
newState.psk = psk
p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else {
@@ -193,7 +206,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
}
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
c := cs.getCertificate(cs.initiatingVersion)
c := cs.getCertificate(cs.defaultVersion)
if c == nil {
panic("No default certificate found")
}
@@ -316,37 +329,37 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
return nil, errors.New("no certificates found in pki.cert")
}
useInitiatingVersion := uint32(1)
useDefaultVersion := uint32(1)
if v1 == nil {
// The only condition that requires v2 as the default is if only a v2 certificate is present
// We do this to avoid having to configure it specifically in the config file
useInitiatingVersion = 2
useDefaultVersion = 2
}
rawInitiatingVersion := c.GetUint32("pki.initiating_version", useInitiatingVersion)
var initiatingVersion cert.Version
switch rawInitiatingVersion {
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
var defaultVersion cert.Version
switch rawDefaultVersion {
case 1:
if v1 == nil {
return nil, fmt.Errorf("can not use pki.initiating_version 1 without a v1 certificate in pki.cert")
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
}
initiatingVersion = cert.Version1
defaultVersion = cert.Version1
case 2:
initiatingVersion = cert.Version2
defaultVersion = cert.Version2
default:
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
}
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
}
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
cs := CertState{
privateKey: privateKey,
pkcs11Backed: pkcs11backed,
myVpnNetworksTable: new(bart.Lite),
myVpnAddrsTable: new(bart.Lite),
myVpnBroadcastAddrsTable: new(bart.Lite),
myVpnNetworksTable: new(bart.Table[struct{}]),
myVpnAddrsTable: new(bart.Table[struct{}]),
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
}
if v1 != nil && v2 != nil {
@@ -358,11 +371,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
}
if v1.Networks()[0] != v2.Networks()[0] {
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
}
//TODO: CERT-V2 make sure v2 has v1s address
cs.initiatingVersion = dv
cs.defaultVersion = dv
}
if v1 != nil {
@@ -381,8 +392,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
cs.v1Cert = v1
cs.v1HandshakeBytes = v1hs
if cs.initiatingVersion == 0 {
cs.initiatingVersion = cert.Version1
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version1
}
}
@@ -402,8 +413,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
cs.v2Cert = v2
cs.v2HandshakeBytes = v2hs
if cs.initiatingVersion == 0 {
cs.initiatingVersion = cert.Version2
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version2
}
}
@@ -416,16 +427,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
for _, network := range crt.Networks() {
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
cs.myVpnNetworksTable.Insert(network)
cs.myVpnNetworksTable.Insert(network, struct{}{})
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()))
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
if network.Addr().Is4() {
addr := network.Masked().Addr().As4()
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
}
}

150
psk.go Normal file
View File

@@ -0,0 +1,150 @@
package nebula
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"golang.org/x/crypto/hkdf"
)
var ErrNotAPskMode = errors.New("not a psk mode")
var ErrKeyTooShort = errors.New("key is too short")
var ErrNotEnoughPskKeys = errors.New("at least 1 key is required")
// MinPskLength is the minimum bytes that we accept for a user defined psk, the choice is arbitrary
const MinPskLength = 8
type PskMode int
const (
PskAccepting PskMode = 0
PskSending PskMode = 1
PskEnforced PskMode = 2
)
func NewPskMode(m string) (PskMode, error) {
switch m {
case "accepting":
return PskAccepting, nil
case "sending":
return PskSending, nil
case "enforced":
return PskEnforced, nil
}
return PskAccepting, ErrNotAPskMode
}
func (p PskMode) String() string {
switch p {
case PskAccepting:
return "accepting"
case PskSending:
return "sending"
case PskEnforced:
return "enforced"
}
return "unknown"
}
func (p PskMode) IsValid() bool {
switch p {
case PskAccepting, PskSending, PskEnforced:
return true
default:
return false
}
}
type Psk struct {
// pskMode sets how psk works, ignored, allowed for incoming, or enforced for all
mode PskMode
// primary is the key to use when sending, it may be nil
primary []byte
// keys holds all pre-computed psk hkdfs
// Handshakes iterate this directly
keys [][]byte
}
// NewPskFromConfig is a helper for initial boot and config reloading.
func NewPskFromConfig(c *config.C) (*Psk, error) {
sMode := c.GetString("psk.mode", "accepting")
mode, err := NewPskMode(sMode)
if err != nil {
return nil, util.NewContextualError("Could not parse psk.mode", m{"mode": mode}, err)
}
return NewPsk(
mode,
c.GetStringSlice("psk.keys", nil),
)
}
// NewPsk creates a new Psk object and handles the caching of all accepted keys
func NewPsk(mode PskMode, keys []string) (*Psk, error) {
if !mode.IsValid() {
return nil, ErrNotAPskMode
}
psk := &Psk{
mode: mode,
}
err := psk.cachePsks(keys)
if err != nil {
return nil, err
}
return psk, nil
}
// cachePsks generates all psks we accept and caches them to speed up handshaking
func (p *Psk) cachePsks(keys []string) error {
if p.mode != PskAccepting && len(keys) < 1 {
return ErrNotEnoughPskKeys
}
p.keys = [][]byte{}
for i, rk := range keys {
k, err := sha256KdfFromString(rk)
if err != nil {
return fmt.Errorf("failed to generate key for position %v: %w", i, err)
}
p.keys = append(p.keys, k)
}
if p.mode != PskAccepting {
// We are either sending or enforcing, the primary key must the first slot
p.primary = p.keys[0]
}
if p.mode != PskEnforced {
// If we are not enforcing psk use then a nil psk is allowed
p.keys = append(p.keys, nil)
}
return nil
}
// sha256KdfFromString generates a useful key to use from a provided secret
func sha256KdfFromString(secret string) ([]byte, error) {
if len(secret) < MinPskLength {
return nil, ErrKeyTooShort
}
hmacKey := make([]byte, sha256.Size)
_, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, nil), hmacKey)
if err != nil {
return nil, err
}
return hmacKey, nil
}

72
psk_test.go Normal file
View File

@@ -0,0 +1,72 @@
package nebula
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewPsk(t *testing.T) {
t.Run("mode accepting", func(t *testing.T) {
p, err := NewPsk(PskAccepting, nil)
require.NoError(t, err)
assert.Equal(t, PskAccepting, p.mode)
assert.Nil(t, p.keys[0])
assert.Nil(t, p.primary)
p, err = NewPsk(PskAccepting, []string{"1234567"})
require.ErrorIs(t, err, ErrKeyTooShort)
p, err = NewPsk(PskAccepting, []string{"hi there friends"})
require.NoError(t, err)
assert.Equal(t, PskAccepting, p.mode)
assert.Nil(t, p.primary)
assert.Len(t, p.keys, 2)
assert.Nil(t, p.keys[1])
expectedCache := []byte{
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
}
assert.Equal(t, expectedCache, p.keys[0])
})
t.Run("mode sending", func(t *testing.T) {
p, err := NewPsk(PskSending, nil)
require.ErrorIs(t, err, ErrNotEnoughPskKeys)
p, err = NewPsk(PskSending, []string{"1234567"})
require.ErrorIs(t, err, ErrKeyTooShort)
p, err = NewPsk(PskSending, []string{"hi there friends"})
require.NoError(t, err)
assert.Equal(t, PskSending, p.mode)
assert.Len(t, p.keys, 2)
assert.Nil(t, p.keys[1])
expectedCache := []byte{
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
}
assert.Equal(t, expectedCache, p.keys[0])
assert.Equal(t, p.keys[0], p.primary)
})
t.Run("mode enforced", func(t *testing.T) {
p, err := NewPsk(PskEnforced, nil)
require.ErrorIs(t, err, ErrNotEnoughPskKeys)
p, err = NewPsk(PskEnforced, []string{"hi there friends"})
require.NoError(t, err)
assert.Equal(t, PskEnforced, p.mode)
assert.Len(t, p.keys, 1)
expectedCache := []byte{
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
}
assert.Equal(t, expectedCache, p.keys[0])
assert.Equal(t, p.keys[0], p.primary)
})
}

View File

@@ -27,7 +27,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
assert.True(t, p.GetPunch())
// punchy.punch
c.Settings["punchy"] = map[string]any{"punch": true}
c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
p = NewPunchyFromConfig(l, c)
assert.True(t, p.GetPunch())
@@ -37,18 +37,18 @@ func TestNewPunchyFromConfig(t *testing.T) {
assert.True(t, p.GetRespond())
// punchy.respond
c.Settings["punchy"] = map[string]any{"respond": true}
c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
c.Settings["punch_back"] = false
p = NewPunchyFromConfig(l, c)
assert.True(t, p.GetRespond())
// punchy.delay
c.Settings["punchy"] = map[string]any{"delay": "1m"}
c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
p = NewPunchyFromConfig(l, c)
assert.Equal(t, time.Minute, p.GetDelay())
// punchy.respond_delay
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"}
p = NewPunchyFromConfig(l, c)
assert.Equal(t, time.Minute, p.GetRespondDelay())
}

View File

@@ -241,13 +241,15 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects.
if f.myVpnAddrsTable.Contains(from) {
_, found := f.myVpnAddrsTable.Lookup(from)
if found {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return
}
// Is the target of the relay me?
if f.myVpnAddrsTable.Contains(target) {
_, found = f.myVpnAddrsTable.Lookup(target)
if found {
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok {
switch existingRelay.State {

View File

@@ -190,7 +190,7 @@ type RemoteList struct {
// The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
// A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -201,10 +201,8 @@ type RemoteList struct {
// For learned addresses, this is the vpnIp that sent the packet
cache map[netip.Addr]*cache
hr *hostnamesResults
// shouldAdd is a nillable function that decides if x should be added to addrs.
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
hr *hostnamesResults
shouldAdd func(netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
@@ -215,7 +213,7 @@ type RemoteList struct {
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0),
@@ -370,15 +368,6 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
return c
}
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
r.Lock()
r.badRemotes = nil
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
copy(r.vpnAddrs, vpnAddrs)
r.Unlock()
}
// ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() {
r.Lock()
@@ -588,7 +577,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
if !r.unlockedIsBad(addr) {
addrs = append(addrs, addr)
}

View File

@@ -1,39 +0,0 @@
package routing
import (
"net/netip"
"github.com/slackhq/nebula/firewall"
)
// Hashes the packet source and destination port and always returns a positive integer
// Based on 'Prospecting for Hash Functions'
// - https://nullprogram.com/blog/2018/07/31/
// - https://github.com/skeeto/hash-prospector
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
func hashPacket(p *firewall.Packet) int {
x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
x ^= x >> 16
x *= 0x21f0aaad
x ^= x >> 15
x *= 0xd35a2d97
x ^= x >> 15
return int(x) & 0x7FFFFFFF
}
// For this function to work correctly it requires that the buckets for the gateways have been calculated
// If the contract is violated balancing will not work properly and the second return value will return false
func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) {
hash := hashPacket(fwPacket)
for i := range gateways {
if hash <= gateways[i].BucketUpperBound() {
return gateways[i].Addr(), true
}
}
// If you land here then the buckets for the gateways are not properly calculated
// Fallback to random routing and let the caller know
return gateways[hash%len(gateways)].Addr(), false
}

View File

@@ -1,144 +0,0 @@
package routing
import (
"net/netip"
"testing"
"github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert"
)
func TestPacketsAreBalancedEqually(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gw3Addr := netip.MustParseAddr("1.0.0.3")
gateways = append(gateways, NewGateway(gw1Addr, 1))
gateways = append(gateways, NewGateway(gw2Addr, 1))
gateways = append(gateways, NewGateway(gw3Addr, 1))
CalculateBucketsForGateways(gateways)
gw1count := 0
gw2count := 0
gw3count := 0
iterationCount := uint16(65535)
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.True(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
case gw3Addr:
gw3count += 1
}
}
// Assert packets are balanced, allow variation of up to 100 packets per gateway
assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
}
func TestPacketsAreBalancedByPriority(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gateways = append(gateways, NewGateway(gw1Addr, 10))
gateways = append(gateways, NewGateway(gw2Addr, 5))
CalculateBucketsForGateways(gateways)
gw1count := 0
gw2count := 0
iterationCount := uint16(65535)
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.True(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
}
}
iterationCountAsFloat := float32(iterationCount)
assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count)
assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count)
}
func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gateways = append(gateways, NewGateway(gw1Addr, 10))
gateways = append(gateways, NewGateway(gw2Addr, 5))
iterationCount := uint16(65535)
gw1count := 0
gw2count := 0
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.False(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
}
}
assert.Equal(t, int(iterationCount), (gw1count + gw2count))
assert.NotEqual(t, 0, gw1count)
assert.NotEqual(t, 0, gw2count)
}

View File

@@ -1,70 +0,0 @@
package routing
import (
"fmt"
"net/netip"
)
const (
// Sentinal value
BucketNotCalculated = -1
)
type Gateways []Gateway
func (g Gateways) String() string {
str := ""
for i, gw := range g {
str += gw.String()
if i < len(g)-1 {
str += ", "
}
}
return str
}
type Gateway struct {
addr netip.Addr
weight int
bucketUpperBound int
}
func NewGateway(addr netip.Addr, weight int) Gateway {
return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated}
}
func (g *Gateway) BucketUpperBound() int {
return g.bucketUpperBound
}
func (g *Gateway) Addr() netip.Addr {
return g.addr
}
func (g *Gateway) String() string {
return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight)
}
// Divide and round to nearest integer
func divideAndRound(v uint64, d uint64) uint64 {
var tmp uint64 = v + d/2
return tmp / d
}
// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel.
// After this function returns each gateway will have a
// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX)
func CalculateBucketsForGateways(gateways []Gateway) {
var totalWeight int = 0
for i := range gateways {
totalWeight += gateways[i].weight
}
var loopWeight int = 0
for i := range gateways {
loopWeight += gateways[i].weight
gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1
}
}

View File

@@ -1,34 +0,0 @@
package routing
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRebalance3_2Split(t *testing.T) {
gateways := []Gateway{}
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5})
CalculateBucketsForGateways(gateways)
assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2
assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX
}
func TestRebalanceEqualSplit(t *testing.T) {
gateways := []Gateway{}
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
CalculateBucketsForGateways(gateways)
assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3
assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2
assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX
}

View File

@@ -54,11 +54,7 @@ func New(config *config.C) (*Service, error) {
if err != nil {
return nil, err
}
wait, err := control.Start()
if err != nil {
return nil, err
}
control.Start()
ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx)
@@ -155,12 +151,6 @@ func New(config *config.C) (*Service, error) {
}
})
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil
}

View File

@@ -5,22 +5,18 @@ import (
"context"
"errors"
"net/netip"
"os"
"testing"
"time"
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
"gopkg.in/yaml.v2"
)
type m = map[string]any
type m map[string]interface{}
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
@@ -75,15 +71,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
panic(err)
}
logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
panic(err)
}
s, err := New(control)
s, err := New(&c)
if err != nil {
panic(err)
}

92
ssh.go
View File

@@ -124,10 +124,10 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
}
rawKeys := c.Get("sshd.authorized_users")
keys, ok := rawKeys.([]any)
keys, ok := rawKeys.([]interface{})
if ok {
for _, rk := range keys {
kDef, ok := rk.(map[string]any)
kDef, ok := rk.(map[interface{}]interface{})
if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
continue
@@ -148,7 +148,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
continue
}
case []any:
case []interface{}:
for _, subK := range v {
sk, ok := subK.(string)
if !ok {
@@ -190,7 +190,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
@@ -198,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListHostMap(f.hostMap, fs, w)
},
})
@@ -206,7 +206,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "list-pending-hostmap",
ShortDescription: "List all handshaking hosts",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
@@ -214,7 +214,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListHostMap(f.handshakeManager, fs, w)
},
})
@@ -222,14 +222,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "list-lighthouse-addrmap",
ShortDescription: "List all lighthouse map entries",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListLighthouseMap(f.lightHouse, fs, w)
},
})
@@ -237,7 +237,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "reload",
ShortDescription: "Reloads configuration from disk, same as sending HUP to the process",
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshReload(c, w)
},
})
@@ -251,7 +251,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "stop-cpu-profile",
ShortDescription: "Stops a cpu profile and writes output to the previously provided file",
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
pprof.StopCPUProfile()
return w.WriteLine("If a CPU profile was running it is now stopped")
},
@@ -278,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "log-level",
ShortDescription: "Gets or sets the current log level",
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshLogLevel(l, fs, a, w)
},
})
@@ -286,7 +286,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "log-format",
ShortDescription: "Gets or sets the current log format",
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshLogFormat(l, fs, a, w)
},
})
@@ -294,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "version",
ShortDescription: "Prints the currently running version of nebula",
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshVersion(f, fs, a, w)
},
})
@@ -302,14 +302,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "device-info",
ShortDescription: "Prints information about the network device.",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshDeviceInfoFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshDeviceInfo(f, fs, w)
},
})
@@ -317,7 +317,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "print-cert",
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintCertFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json")
@@ -325,7 +325,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintCert(f, fs, a, w)
},
})
@@ -333,13 +333,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "print-tunnel",
ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintTunnelFlags{}
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintTunnel(f, fs, a, w)
},
})
@@ -347,13 +347,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "print-relays",
ShortDescription: "Prints json details about all relay info",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintTunnelFlags{}
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintRelays(f, fs, a, w)
},
})
@@ -361,13 +361,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "change-remote",
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshChangeRemoteFlags{}
fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshChangeRemote(f, fs, a, w)
},
})
@@ -375,13 +375,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "close-tunnel",
ShortDescription: "Closes a tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshCloseTunnelFlags{}
fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshCloseTunnel(f, fs, a, w)
},
})
@@ -390,13 +390,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
Name: "create-tunnel",
ShortDescription: "Creates a tunnel for the provided vpn address",
Help: "The lighthouses will be queried for real addresses but you can provide one as well.",
Flags: func() (*flag.FlagSet, any) {
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshCreateTunnelFlags{}
fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ")
return fl, &s
},
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshCreateTunnel(f, fs, a, w)
},
})
@@ -405,13 +405,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
Name: "query-lighthouse",
ShortDescription: "Query the lighthouses for the provided vpn address",
Help: "This command is asynchronous. Only currently known udp addresses will be printed.",
Callback: func(fs any, a []string, w sshd.StringWriter) error {
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshQueryLighthouse(f, fs, a, w)
},
})
}
func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error {
func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags)
if !ok {
return nil
@@ -451,7 +451,7 @@ func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error {
return nil
}
func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error {
func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags)
if !ok {
return nil
@@ -505,7 +505,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) er
return nil
}
func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
err := w.WriteLine("No path to write profile provided")
return err
@@ -527,11 +527,11 @@ func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
return err
}
func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("%s", ifce.version))
}
func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No vpn address was provided")
}
@@ -553,7 +553,7 @@ func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter
return json.NewEncoder(w.GetWriter()).Encode(cm)
}
func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshCloseTunnelFlags)
if !ok {
return nil
@@ -593,7 +593,7 @@ func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
return w.WriteLine("Closed")
}
func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshCreateTunnelFlags)
if !ok {
return nil
@@ -638,7 +638,7 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
return w.WriteLine("Created")
}
func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshChangeRemoteFlags)
if !ok {
return nil
@@ -675,7 +675,7 @@ func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
return w.WriteLine("Changed")
}
func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No path to write profile provided")
}
@@ -696,7 +696,7 @@ func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
return err
}
func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
rate := runtime.SetMutexProfileFraction(-1)
return w.WriteLine(fmt.Sprintf("Current value: %d", rate))
@@ -711,7 +711,7 @@ func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
}
func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No path to write profile provided")
}
@@ -735,7 +735,7 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
}
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
@@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) erro
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
}
@@ -767,7 +767,7 @@ func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) err
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
}
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintCertFlags)
if !ok {
return nil
@@ -822,7 +822,7 @@ func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) erro
return w.WriteLine(cert.String())
}
func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags)
if !ok {
w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
@@ -919,7 +919,7 @@ func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
return nil
}
func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags)
if !ok {
return nil
@@ -951,7 +951,7 @@ func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
}
func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error {
func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
data := struct {
Name string `json:"name"`

View File

@@ -12,7 +12,7 @@ import (
// CommandFlags is a function called before help or command execution to parse command line flags
// It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
type CommandFlags func() (*flag.FlagSet, any)
type CommandFlags func() (*flag.FlagSet, interface{})
// CommandCallback is the function called when your command should execute.
// fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
@@ -21,7 +21,7 @@ type CommandFlags func() (*flag.FlagSet, any)
// w is the writer to use when sending messages back to the client.
// If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
// where appropriate
type CommandCallback func(fs any, a []string, w StringWriter) error
type CommandCallback func(fs interface{}, a []string, w StringWriter) error
type Command struct {
Name string
@@ -34,7 +34,7 @@ type Command struct {
func execCommand(c *Command, args []string, w StringWriter) error {
var (
fl *flag.FlagSet
fs any
fs interface{}
)
if c.Flags != nil {
@@ -85,7 +85,7 @@ func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
func matchCommand(c *radix.Tree, cmd string) []string {
cmds := make([]string, 0)
c.WalkPrefix(cmd, func(found string, v any) bool {
c.WalkPrefix(cmd, func(found string, v interface{}) bool {
cmds = append(cmds, found)
return false
})
@@ -95,7 +95,7 @@ func matchCommand(c *radix.Tree, cmd string) []string {
func allCommands(c *radix.Tree) []*Command {
cmds := make([]*Command, 0)
c.WalkPrefix("", func(found string, v any) bool {
c.WalkPrefix("", func(found string, v interface{}) bool {
cmd, ok := v.(*Command)
if ok {
cmds = append(cmds, cmd)

View File

@@ -86,7 +86,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
s.RegisterCommand(&Command{
Name: "help",
ShortDescription: "prints available commands or help <command> for specific usage info",
Callback: func(a any, args []string, w StringWriter) error {
Callback: func(a interface{}, args []string, w StringWriter) error {
return helpCallback(s.commands, args, w)
},
})

View File

@@ -9,13 +9,13 @@ import (
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
"golang.org/x/crypto/ssh/terminal"
)
type session struct {
l *logrus.Entry
c *ssh.ServerConn
term *term.Terminal
term *terminal.Terminal
commands *radix.Tree
exitChan chan bool
}
@@ -31,7 +31,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New
s.commands.Insert("logout", &Command{
Name: "logout",
ShortDescription: "Ends the current session",
Callback: func(a any, args []string, w StringWriter) error {
Callback: func(a interface{}, args []string, w StringWriter) error {
s.Close()
return nil
},
@@ -106,8 +106,8 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
}
}
func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
term := term.NewTerminal(channel, s.c.User()+"@nebula > ")
func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
// key 9 is tab
if key == 9 {

View File

@@ -13,7 +13,7 @@ import (
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
func AssertDeepCopyEqual(t *testing.T, a any, b any) {
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
v1 := reflect.ValueOf(a)
v2 := reflect.ValueOf(b)

View File

@@ -4,14 +4,12 @@ import (
"errors"
"io"
"net/netip"
"github.com/slackhq/nebula/routing"
)
type NoopTun struct{}
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
return routing.Gateways{}
func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
return netip.Addr{}
}
func (NoopTun) Activate() error {

View File

@@ -16,7 +16,7 @@ type EncReader func(
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) error
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error

Some files were not shown because too many files have changed in this diff Show More