mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03ab9a1208 | ||
|
|
4258c1388c | ||
|
|
2a2b6424ed | ||
|
|
f896e2a863 | ||
|
|
4db6049684 | ||
|
|
8f1dc12618 |
2
.github/ISSUE_TEMPLATE/config.yml
vendored
2
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -17,5 +17,5 @@ contact_links:
|
|||||||
about: 'The documentation is the best place to start if you are new to Nebula.'
|
about: 'The documentation is the best place to start if you are new to Nebula.'
|
||||||
|
|
||||||
- name: 💁 Support/Chat
|
- name: 💁 Support/Chat
|
||||||
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
|
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ
|
||||||
about: 'For faster support, join us on Slack for assistance!'
|
about: 'For faster support, join us on Slack for assistance!'
|
||||||
|
|||||||
6
.github/workflows/gofmt.yml
vendored
6
.github/workflows/gofmt.yml
vendored
@@ -14,11 +14,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Install goimports
|
- name: Install goimports
|
||||||
|
|||||||
32
.github/workflows/release.yml
vendored
32
.github/workflows/release.yml
vendored
@@ -10,11 +10,11 @@ jobs:
|
|||||||
name: Build Linux/BSD All
|
name: Build Linux/BSD All
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -24,7 +24,7 @@ jobs:
|
|||||||
mv build/*.tar.gz release
|
mv build/*.tar.gz release
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-latest
|
name: linux-latest
|
||||||
path: release
|
path: release
|
||||||
@@ -33,11 +33,11 @@ jobs:
|
|||||||
name: Build Windows
|
name: Build Windows
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
mv dist\windows\wintun build\dist\windows\
|
mv dist\windows\wintun build\dist\windows\
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-latest
|
name: windows-latest
|
||||||
path: build
|
path: build
|
||||||
@@ -66,11 +66,11 @@ jobs:
|
|||||||
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Import certificates
|
- name: Import certificates
|
||||||
@@ -104,7 +104,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: darwin-latest
|
name: darwin-latest
|
||||||
path: ./release/*
|
path: ./release/*
|
||||||
@@ -124,11 +124,11 @@ jobs:
|
|||||||
# be overwritten
|
# be overwritten
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Download artifacts
|
- name: Download artifacts
|
||||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||||
uses: actions/download-artifact@v6
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-latest
|
name: linux-latest
|
||||||
path: artifacts
|
path: artifacts
|
||||||
@@ -160,10 +160,10 @@ jobs:
|
|||||||
needs: [build-linux, build-darwin, build-windows]
|
needs: [build-linux, build-darwin, build-windows]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Download artifacts
|
- name: Download artifacts
|
||||||
uses: actions/download-artifact@v6
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
path: artifacts
|
path: artifacts
|
||||||
|
|
||||||
|
|||||||
6
.github/workflows/smoke-extra.yml
vendored
6
.github/workflows/smoke-extra.yml
vendored
@@ -20,11 +20,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version-file: 'go.mod'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: add hashicorp source
|
- name: add hashicorp source
|
||||||
|
|||||||
6
.github/workflows/smoke.yml
vendored
6
.github/workflows/smoke.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: build
|
- name: build
|
||||||
|
|||||||
36
.github/workflows/test.yml
vendored
36
.github/workflows/test.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -32,9 +32,9 @@ jobs:
|
|||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v7
|
||||||
with:
|
with:
|
||||||
version: v2.5
|
version: v2.0
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: make test
|
run: make test
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Build test mobile
|
- name: Build test mobile
|
||||||
run: make build-test-mobile
|
run: make build-test-mobile
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v5
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: e2e packet flow linux-latest
|
name: e2e packet flow linux-latest
|
||||||
path: e2e/mermaid/linux-latest
|
path: e2e/mermaid/linux-latest
|
||||||
@@ -56,11 +56,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -77,11 +77,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -98,11 +98,11 @@ jobs:
|
|||||||
os: [windows-latest, macos-latest]
|
os: [windows-latest, macos-latest]
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.25'
|
go-version: '1.24'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build nebula
|
- name: Build nebula
|
||||||
@@ -115,9 +115,9 @@ jobs:
|
|||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v7
|
||||||
with:
|
with:
|
||||||
version: v2.5
|
version: v2.0
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: make test
|
run: make test
|
||||||
@@ -125,7 +125,7 @@ jobs:
|
|||||||
- name: End 2 end
|
- name: End 2 end
|
||||||
run: make e2evv
|
run: make e2evv
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v5
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: e2e packet flow ${{ matrix.os }}
|
name: e2e packet flow ${{ matrix.os }}
|
||||||
path: e2e/mermaid/${{ matrix.os }}
|
path: e2e/mermaid/${{ matrix.os }}
|
||||||
|
|||||||
3
Makefile
3
Makefile
@@ -61,7 +61,7 @@ ALL = $(ALL_LINUX) \
|
|||||||
windows-arm64
|
windows-arm64
|
||||||
|
|
||||||
e2e:
|
e2e:
|
||||||
$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
|
$(TEST_ENV) go test -tags=synctrace,e2e_testing -count=1 $(TEST_FLAGS) ./e2e
|
||||||
|
|
||||||
e2ev: TEST_FLAGS += -v
|
e2ev: TEST_FLAGS += -v
|
||||||
e2ev: e2e
|
e2ev: e2e
|
||||||
@@ -215,6 +215,7 @@ ifeq ($(words $(MAKECMDGOALS)),1)
|
|||||||
@$(MAKE) service ${.DEFAULT_GOAL} --no-print-directory
|
@$(MAKE) service ${.DEFAULT_GOAL} --no-print-directory
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
bin-docker: BUILD_ARGS = -tags=synctrace
|
||||||
bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert
|
bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert
|
||||||
|
|
||||||
smoke-docker: bin-docker
|
smoke-docker: bin-docker
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
|
|||||||
|
|
||||||
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
||||||
|
|
||||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
|
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
|
||||||
|
|
||||||
## Supported Platforms
|
## Supported Platforms
|
||||||
|
|
||||||
|
|||||||
2
bits.go
2
bits.go
@@ -43,7 +43,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
l.Debugf("rejected a packet (top) %d %d delta %d\n", b.current, i, b.current-i)
|
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -84,11 +84,16 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
|||||||
|
|
||||||
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
||||||
|
|
||||||
rawMap, ok := value.(map[string]any)
|
rawMap, ok := value.(map[any]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||||
}
|
}
|
||||||
for rawCIDR, rawValue := range rawMap {
|
for rawKey, rawValue := range rawMap {
|
||||||
|
rawCIDR, ok := rawKey.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||||
|
}
|
||||||
|
|
||||||
cidr, err := netip.ParsePrefix(rawCIDR)
|
cidr, err := netip.ParsePrefix(rawCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||||
@@ -124,7 +129,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
||||||
rawMap, ok := raw.(map[string]any)
|
rawMap, ok := raw.(map[any]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid type: %T", raw)
|
return nil, fmt.Errorf("invalid type: %T", raw)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,9 +58,6 @@ type Certificate interface {
|
|||||||
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
||||||
PublicKey() []byte
|
PublicKey() []byte
|
||||||
|
|
||||||
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
|
|
||||||
MarshalPublicKeyPEM() []byte
|
|
||||||
|
|
||||||
// Curve identifies which curve was used for the PublicKey and Signature.
|
// Curve identifies which curve was used for the PublicKey and Signature.
|
||||||
Curve() Curve
|
Curve() Curve
|
||||||
|
|
||||||
@@ -138,7 +135,8 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
|
|||||||
case Version2:
|
case Version2:
|
||||||
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
||||||
default:
|
default:
|
||||||
return nil, ErrUnknownVersion
|
//TODO: CERT-V2 make a static var
|
||||||
|
return nil, fmt.Errorf("unknown certificate version %d", v)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -83,10 +83,6 @@ func (c *certificateV1) PublicKey() []byte {
|
|||||||
return c.details.publicKey
|
return c.details.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
|
|
||||||
return marshalCertPublicKeyToPEM(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certificateV1) Signature() []byte {
|
func (c *certificateV1) Signature() []byte {
|
||||||
return c.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
@@ -114,10 +110,8 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
|
|||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
return ed25519.Verify(key, b, c.signature)
|
return ed25519.Verify(key, b, c.signature)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||||
if err != nil {
|
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||||
return false
|
|
||||||
}
|
|
||||||
hashed := sha256.Sum256(b)
|
hashed := sha256.Sum256(b)
|
||||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ed25519"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -14,7 +13,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV1_Marshal(t *testing.T) {
|
func TestCertificateV1_Marshal(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||||
@@ -62,58 +60,6 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
|||||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_PublicKeyPem(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
|
||||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
|
||||||
|
|
||||||
nc := certificateV1{
|
|
||||||
details: detailsV1{
|
|
||||||
name: "testing",
|
|
||||||
networks: []netip.Prefix{},
|
|
||||||
unsafeNetworks: []netip.Prefix{},
|
|
||||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
|
||||||
notBefore: before,
|
|
||||||
notAfter: after,
|
|
||||||
publicKey: pubKey,
|
|
||||||
isCA: false,
|
|
||||||
issuer: "1234567890abcedfghij1234567890ab",
|
|
||||||
},
|
|
||||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, Version1, nc.Version())
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = true
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
|
||||||
`)
|
|
||||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
|
||||||
require.NoError(t, err)
|
|
||||||
nc.details.curve = Curve_P256
|
|
||||||
nc.details.publicKey = pubP256Key
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = false
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateV1_Expired(t *testing.T) {
|
func TestCertificateV1_Expired(t *testing.T) {
|
||||||
nc := certificateV1{
|
nc := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
|
|||||||
@@ -114,10 +114,6 @@ func (c *certificateV2) PublicKey() []byte {
|
|||||||
return c.publicKey
|
return c.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
|
|
||||||
return marshalCertPublicKeyToPEM(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certificateV2) Signature() []byte {
|
func (c *certificateV2) Signature() []byte {
|
||||||
return c.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
@@ -153,10 +149,8 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
|
|||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
return ed25519.Verify(key, b, c.signature)
|
return ed25519.Verify(key, b, c.signature)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||||
if err != nil {
|
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||||
return false
|
|
||||||
}
|
|
||||||
hashed := sha256.Sum256(b)
|
hashed := sha256.Sum256(b)
|
||||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV2_Marshal(t *testing.T) {
|
func TestCertificateV2_Marshal(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||||
@@ -76,58 +75,6 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
|||||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_PublicKeyPem(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
|
||||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
|
||||||
|
|
||||||
nc := certificateV2{
|
|
||||||
details: detailsV2{
|
|
||||||
name: "testing",
|
|
||||||
networks: []netip.Prefix{},
|
|
||||||
unsafeNetworks: []netip.Prefix{},
|
|
||||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
|
||||||
notBefore: before,
|
|
||||||
notAfter: after,
|
|
||||||
isCA: false,
|
|
||||||
issuer: "1234567890abcedfghij1234567890ab",
|
|
||||||
},
|
|
||||||
publicKey: pubKey,
|
|
||||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, Version2, nc.Version())
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = true
|
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
|
||||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
|
||||||
`)
|
|
||||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
|
||||||
require.NoError(t, err)
|
|
||||||
nc.curve = Curve_P256
|
|
||||||
nc.publicKey = pubP256Key
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.True(t, nc.IsCA())
|
|
||||||
|
|
||||||
nc.details.isCA = false
|
|
||||||
assert.Equal(t, Curve_P256, nc.Curve())
|
|
||||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
|
||||||
assert.False(t, nc.IsCA())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateV2_Expired(t *testing.T) {
|
func TestCertificateV2_Expired(t *testing.T) {
|
||||||
nc := certificateV2{
|
nc := certificateV2{
|
||||||
details: detailsV2{
|
details: detailsV2{
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ var (
|
|||||||
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
||||||
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
||||||
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
||||||
ErrUnknownVersion = errors.New("certificate version unrecognized")
|
|
||||||
|
|
||||||
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
||||||
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
||||||
|
|||||||
52
cert/pem.go
52
cert/pem.go
@@ -7,26 +7,19 @@ import (
|
|||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ( //cert banners
|
const (
|
||||||
CertificateBanner = "NEBULA CERTIFICATE"
|
CertificateBanner = "NEBULA CERTIFICATE"
|
||||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||||
)
|
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||||
|
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||||
|
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||||
|
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||||
|
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||||
|
|
||||||
const ( //key-agreement-key banners
|
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
|
||||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
|
||||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
|
|
||||||
const ( //signing key banners
|
|
||||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||||
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
|
|
||||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
|
||||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
|
||||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
||||||
@@ -58,16 +51,6 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
|
||||||
if c.IsCA() {
|
|
||||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
|
||||||
} else {
|
|
||||||
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
|
|
||||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
|
||||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||||
switch curve {
|
switch curve {
|
||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
@@ -79,19 +62,6 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
|
|
||||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
|
||||||
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
|
|
||||||
switch curve {
|
|
||||||
case Curve_CURVE25519:
|
|
||||||
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
|
|
||||||
case Curve_P256:
|
|
||||||
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||||
k, r := pem.Decode(b)
|
k, r := pem.Decode(b)
|
||||||
if k == nil {
|
if k == nil {
|
||||||
@@ -103,7 +73,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
|||||||
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
||||||
expectedLen = 32
|
expectedLen = 32
|
||||||
curve = Curve_CURVE25519
|
curve = Curve_CURVE25519
|
||||||
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
|
case P256PublicKeyBanner:
|
||||||
// Uncompressed
|
// Uncompressed
|
||||||
expectedLen = 65
|
expectedLen = 65
|
||||||
curve = Curve_P256
|
curve = Curve_P256
|
||||||
|
|||||||
@@ -177,7 +177,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
pubKey := []byte(`# A good key
|
pubKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
@@ -231,7 +230,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
pubKey := []byte(`# A good key
|
pubKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
@@ -242,12 +240,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
-----END NEBULA P256 PUBLIC KEY-----
|
||||||
`)
|
|
||||||
oldPubP256Key := []byte(`# A good key
|
|
||||||
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
|
||||||
-----END NEBULA ECDSA P256 PUBLIC KEY-----
|
|
||||||
`)
|
`)
|
||||||
shortKey := []byte(`# A short key
|
shortKey := []byte(`# A short key
|
||||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||||
@@ -264,22 +256,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||||
|
|
||||||
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
|
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
|
|
||||||
// Success test case
|
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
|
||||||
assert.Len(t, k, 65)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
|
||||||
assert.Equal(t, Curve_P256, curve)
|
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Len(t, k, 65)
|
assert.Len(t, k, 65)
|
||||||
|
|||||||
12
cert/sign.go
12
cert/sign.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -54,10 +55,15 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
|
|||||||
}
|
}
|
||||||
return t.SignWith(signer, curve, sp)
|
return t.SignWith(signer, curve, sp)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
|
pk := &ecdsa.PrivateKey{
|
||||||
if err != nil {
|
PublicKey: ecdsa.PublicKey{
|
||||||
return nil, err
|
Curve: elliptic.P256(),
|
||||||
|
},
|
||||||
|
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
||||||
|
D: new(big.Int).SetBytes(key),
|
||||||
}
|
}
|
||||||
|
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
||||||
|
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
|
||||||
sp := func(certBytes []byte) ([]byte, error) {
|
sp := func(certBytes []byte) ([]byte, error) {
|
||||||
// We need to hash first for ECDSA
|
// We need to hash first for ECDSA
|
||||||
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
||||||
|
|||||||
@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
|||||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
|
||||||
nc := &cert.TBSCertificate{
|
|
||||||
Version: v,
|
|
||||||
Curve: c.Curve(),
|
|
||||||
Name: c.Name(),
|
|
||||||
Networks: c.Networks(),
|
|
||||||
UnsafeNetworks: c.UnsafeNetworks(),
|
|
||||||
Groups: c.Groups(),
|
|
||||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
|
||||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
|
||||||
PublicKey: c.PublicKey(),
|
|
||||||
IsCA: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pem, err := c.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, pem
|
|
||||||
}
|
|
||||||
|
|
||||||
func X25519Keypair() ([]byte, []byte) {
|
func X25519Keypair() ([]byte, []byte) {
|
||||||
privkey := make([]byte, 32)
|
privkey := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||||
|
|||||||
@@ -3,9 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
_ "net/http/pprof"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -61,10 +58,6 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
ctrl.Start()
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"go.yaml.in/yaml/v3"
|
"github.com/wadey/synctrace"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type C struct {
|
type C struct {
|
||||||
@@ -27,13 +27,14 @@ type C struct {
|
|||||||
oldSettings map[string]any
|
oldSettings map[string]any
|
||||||
callbacks []func(*C)
|
callbacks []func(*C)
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
reloadLock sync.Mutex
|
reloadLock synctrace.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewC(l *logrus.Logger) *C {
|
func NewC(l *logrus.Logger) *C {
|
||||||
return &C{
|
return &C{
|
||||||
Settings: make(map[string]any),
|
Settings: make(map[string]any),
|
||||||
l: l,
|
l: l,
|
||||||
|
reloadLock: synctrace.NewMutex("config-reload"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
|
|||||||
@@ -4,17 +4,14 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
type trafficDecision int
|
type trafficDecision int
|
||||||
@@ -30,124 +27,130 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type connectionManager struct {
|
type connectionManager struct {
|
||||||
|
in map[uint32]struct{}
|
||||||
|
inLock synctrace.RWMutex
|
||||||
|
|
||||||
|
out map[uint32]struct{}
|
||||||
|
outLock synctrace.RWMutex
|
||||||
|
|
||||||
// relayUsed holds which relay localIndexs are in use
|
// relayUsed holds which relay localIndexs are in use
|
||||||
relayUsed map[uint32]struct{}
|
relayUsed map[uint32]struct{}
|
||||||
relayUsedLock *sync.RWMutex
|
relayUsedLock synctrace.RWMutex
|
||||||
|
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
trafficTimer *LockingTimerWheel[uint32]
|
trafficTimer *LockingTimerWheel[uint32]
|
||||||
intf *Interface
|
intf *Interface
|
||||||
punchy *Punchy
|
pendingDeletion map[uint32]struct{}
|
||||||
|
punchy *Punchy
|
||||||
// Configuration settings
|
|
||||||
checkInterval time.Duration
|
checkInterval time.Duration
|
||||||
pendingDeletionInterval time.Duration
|
pendingDeletionInterval time.Duration
|
||||||
inactivityTimeout atomic.Int64
|
metricsTxPunchy metrics.Counter
|
||||||
dropInactive atomic.Bool
|
|
||||||
|
|
||||||
metricsTxPunchy metrics.Counter
|
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
|
||||||
cm := &connectionManager{
|
var max time.Duration
|
||||||
hostMap: hm,
|
if checkInterval < pendingDeletionInterval {
|
||||||
l: l,
|
max = pendingDeletionInterval
|
||||||
punchy: p,
|
} else {
|
||||||
relayUsed: make(map[uint32]struct{}),
|
max = checkInterval
|
||||||
relayUsedLock: &sync.RWMutex{},
|
|
||||||
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.reload(c, true)
|
nc := &connectionManager{
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
hostMap: intf.hostMap,
|
||||||
cm.reload(c, false)
|
in: make(map[uint32]struct{}),
|
||||||
})
|
inLock: synctrace.NewRWMutex("connection-manager-in"),
|
||||||
|
out: make(map[uint32]struct{}),
|
||||||
return cm
|
outLock: synctrace.NewRWMutex("connection-manager-out"),
|
||||||
}
|
relayUsed: make(map[uint32]struct{}),
|
||||||
|
relayUsedLock: synctrace.NewRWMutex("connection-manager-relay-used"),
|
||||||
func (cm *connectionManager) reload(c *config.C, initial bool) {
|
trafficTimer: NewLockingTimerWheel[uint32]("traffic-timer", time.Millisecond*500, max),
|
||||||
if initial {
|
intf: intf,
|
||||||
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
|
pendingDeletion: make(map[uint32]struct{}),
|
||||||
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
|
checkInterval: checkInterval,
|
||||||
|
pendingDeletionInterval: pendingDeletionInterval,
|
||||||
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
|
punchy: punchy,
|
||||||
// pretty close to their configured duration.
|
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
||||||
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
|
l: l,
|
||||||
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
|
|
||||||
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
|
|
||||||
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("tunnels.inactivity_timeout") {
|
nc.Start(ctx)
|
||||||
old := cm.getInactivityTimeout()
|
return nc
|
||||||
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
|
||||||
if !initial {
|
|
||||||
cm.l.WithField("oldDuration", old).
|
|
||||||
WithField("newDuration", cm.getInactivityTimeout()).
|
|
||||||
Info("Inactivity timeout has changed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if initial || c.HasChanged("tunnels.drop_inactive") {
|
|
||||||
old := cm.dropInactive.Load()
|
|
||||||
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
|
||||||
if !initial {
|
|
||||||
cm.l.WithField("oldBool", old).
|
|
||||||
WithField("newBool", cm.dropInactive.Load()).
|
|
||||||
Info("Drop inactive setting has changed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) getInactivityTimeout() time.Duration {
|
func (n *connectionManager) In(localIndex uint32) {
|
||||||
return (time.Duration)(cm.inactivityTimeout.Load())
|
n.inLock.RLock()
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) In(h *HostInfo) {
|
|
||||||
h.in.Store(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) Out(h *HostInfo) {
|
|
||||||
h.out.Store(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) RelayUsed(localIndex uint32) {
|
|
||||||
cm.relayUsedLock.RLock()
|
|
||||||
// If this already exists, return
|
// If this already exists, return
|
||||||
if _, ok := cm.relayUsed[localIndex]; ok {
|
if _, ok := n.in[localIndex]; ok {
|
||||||
cm.relayUsedLock.RUnlock()
|
n.inLock.RUnlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cm.relayUsedLock.RUnlock()
|
n.inLock.RUnlock()
|
||||||
cm.relayUsedLock.Lock()
|
n.inLock.Lock()
|
||||||
cm.relayUsed[localIndex] = struct{}{}
|
n.in[localIndex] = struct{}{}
|
||||||
cm.relayUsedLock.Unlock()
|
n.inLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *connectionManager) Out(localIndex uint32) {
|
||||||
|
n.outLock.RLock()
|
||||||
|
// If this already exists, return
|
||||||
|
if _, ok := n.out[localIndex]; ok {
|
||||||
|
n.outLock.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.outLock.RUnlock()
|
||||||
|
n.outLock.Lock()
|
||||||
|
n.out[localIndex] = struct{}{}
|
||||||
|
n.outLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *connectionManager) RelayUsed(localIndex uint32) {
|
||||||
|
n.relayUsedLock.RLock()
|
||||||
|
// If this already exists, return
|
||||||
|
if _, ok := n.relayUsed[localIndex]; ok {
|
||||||
|
n.relayUsedLock.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.relayUsedLock.RUnlock()
|
||||||
|
n.relayUsedLock.Lock()
|
||||||
|
n.relayUsed[localIndex] = struct{}{}
|
||||||
|
n.relayUsedLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
|
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
|
||||||
// resets the state for this local index
|
// resets the state for this local index
|
||||||
func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
|
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
|
||||||
in := h.in.Swap(false)
|
n.inLock.Lock()
|
||||||
out := h.out.Swap(false)
|
n.outLock.Lock()
|
||||||
if in || out {
|
_, in := n.in[localIndex]
|
||||||
h.lastUsed = now
|
_, out := n.out[localIndex]
|
||||||
}
|
delete(n.in, localIndex)
|
||||||
|
delete(n.out, localIndex)
|
||||||
|
n.inLock.Unlock()
|
||||||
|
n.outLock.Unlock()
|
||||||
return in, out
|
return in, out
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTrafficWatch must be called for every new HostInfo.
|
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
|
||||||
// We will continue to monitor the HostInfo until the tunnel is dropped.
|
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
|
||||||
func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
|
n.outLock.Lock()
|
||||||
if h.out.Swap(true) == false {
|
if _, ok := n.out[localIndex]; ok {
|
||||||
cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
|
n.outLock.Unlock()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
n.out[localIndex] = struct{}{}
|
||||||
|
n.trafficTimer.Add(localIndex, n.checkInterval)
|
||||||
|
n.outLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) Start(ctx context.Context) {
|
func (n *connectionManager) Start(ctx context.Context) {
|
||||||
clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
|
go n.Run(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *connectionManager) Run(ctx context.Context) {
|
||||||
|
//TODO: this tick should be based on the min wheel tick? Check firewall
|
||||||
|
clockSource := time.NewTicker(500 * time.Millisecond)
|
||||||
defer clockSource.Stop()
|
defer clockSource.Stop()
|
||||||
|
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
@@ -160,61 +163,61 @@ func (cm *connectionManager) Start(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
|
|
||||||
case now := <-clockSource.C:
|
case now := <-clockSource.C:
|
||||||
cm.trafficTimer.Advance(now)
|
n.trafficTimer.Advance(now)
|
||||||
for {
|
for {
|
||||||
localIndex, has := cm.trafficTimer.Purge()
|
localIndex, has := n.trafficTimer.Purge()
|
||||||
if !has {
|
if !has {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.doTrafficCheck(localIndex, p, nb, out, now)
|
n.doTrafficCheck(localIndex, p, nb, out, now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
|
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
|
||||||
decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
|
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
|
||||||
|
|
||||||
switch decision {
|
switch decision {
|
||||||
case deleteTunnel:
|
case deleteTunnel:
|
||||||
if cm.hostMap.DeleteHostInfo(hostinfo) {
|
if n.hostMap.DeleteHostInfo(hostinfo) {
|
||||||
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
||||||
cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
case closeTunnel:
|
case closeTunnel:
|
||||||
cm.intf.sendCloseTunnel(hostinfo)
|
n.intf.sendCloseTunnel(hostinfo)
|
||||||
cm.intf.closeTunnel(hostinfo)
|
n.intf.closeTunnel(hostinfo)
|
||||||
|
|
||||||
case swapPrimary:
|
case swapPrimary:
|
||||||
cm.swapPrimary(hostinfo, primary)
|
n.swapPrimary(hostinfo, primary)
|
||||||
|
|
||||||
case migrateRelays:
|
case migrateRelays:
|
||||||
cm.migrateRelayUsed(hostinfo, primary)
|
n.migrateRelayUsed(hostinfo, primary)
|
||||||
|
|
||||||
case tryRehandshake:
|
case tryRehandshake:
|
||||||
cm.tryRehandshake(hostinfo)
|
n.tryRehandshake(hostinfo)
|
||||||
|
|
||||||
case sendTestPacket:
|
case sendTestPacket:
|
||||||
cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
|
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.resetRelayTrafficCheck(hostinfo)
|
n.resetRelayTrafficCheck(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
|
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
|
||||||
if hostinfo != nil {
|
if hostinfo != nil {
|
||||||
cm.relayUsedLock.Lock()
|
n.relayUsedLock.Lock()
|
||||||
defer cm.relayUsedLock.Unlock()
|
defer n.relayUsedLock.Unlock()
|
||||||
// No need to migrate any relays, delete usage info now.
|
// No need to migrate any relays, delete usage info now.
|
||||||
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
|
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
|
||||||
delete(cm.relayUsed, idx)
|
delete(n.relayUsed, idx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
|
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
|
||||||
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
||||||
|
|
||||||
for _, r := range relayFor {
|
for _, r := range relayFor {
|
||||||
@@ -224,51 +227,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
var relayFrom netip.Addr
|
var relayFrom netip.Addr
|
||||||
var relayTo netip.Addr
|
var relayTo netip.Addr
|
||||||
switch {
|
switch {
|
||||||
case ok:
|
case ok && existing.State == Established:
|
||||||
switch existing.State {
|
// This relay already exists in newhostinfo, then do nothing.
|
||||||
case Established, PeerRequested, Disestablished:
|
continue
|
||||||
// This relay already exists in newhostinfo, then do nothing.
|
case ok && existing.State == Requested:
|
||||||
continue
|
// The relay exists in a Requested state; re-send the request
|
||||||
case Requested:
|
index = existing.LocalIndex
|
||||||
// The relay exists in a Requested state; re-send the request
|
switch r.Type {
|
||||||
index = existing.LocalIndex
|
case TerminalType:
|
||||||
switch r.Type {
|
relayFrom = n.intf.myVpnAddrs[0]
|
||||||
case TerminalType:
|
relayTo = existing.PeerAddr
|
||||||
relayFrom = cm.intf.myVpnAddrs[0]
|
case ForwardingType:
|
||||||
relayTo = existing.PeerAddr
|
relayFrom = existing.PeerAddr
|
||||||
case ForwardingType:
|
relayTo = newhostinfo.vpnAddrs[0]
|
||||||
relayFrom = existing.PeerAddr
|
default:
|
||||||
relayTo = newhostinfo.vpnAddrs[0]
|
// should never happen
|
||||||
default:
|
|
||||||
// should never happen
|
|
||||||
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case !ok:
|
case !ok:
|
||||||
cm.relayUsedLock.RLock()
|
n.relayUsedLock.RLock()
|
||||||
if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
|
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
|
||||||
// The relay hasn't been used; don't migrate it.
|
// The relay hasn't been used; don't migrate it.
|
||||||
cm.relayUsedLock.RUnlock()
|
n.relayUsedLock.RUnlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
cm.relayUsedLock.RUnlock()
|
n.relayUsedLock.RUnlock()
|
||||||
// The relay doesn't exist at all; create some relay state and send the request.
|
// The relay doesn't exist at all; create some relay state and send the request.
|
||||||
var err error
|
var err error
|
||||||
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
relayFrom = cm.intf.myVpnAddrs[0]
|
relayFrom = n.intf.myVpnAddrs[0]
|
||||||
relayTo = r.PeerAddr
|
relayTo = r.PeerAddr
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
relayFrom = r.PeerAddr
|
relayFrom = r.PeerAddr
|
||||||
relayTo = newhostinfo.vpnAddrs[0]
|
relayTo = newhostinfo.vpnAddrs[0]
|
||||||
default:
|
default:
|
||||||
// should never happen
|
// should never happen
|
||||||
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,12 +279,12 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
switch newhostinfo.GetCert().Certificate.Version() {
|
switch newhostinfo.GetCert().Certificate.Version() {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
if !relayFrom.Is4() {
|
if !relayFrom.Is4() {
|
||||||
cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !relayTo.Is4() {
|
if !relayTo.Is4() {
|
||||||
cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,16 +296,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||||
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
||||||
default:
|
default:
|
||||||
newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
|
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
||||||
} else {
|
} else {
|
||||||
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
cm.l.WithFields(logrus.Fields{
|
n.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": req.RelayFromAddr,
|
"relayFrom": req.RelayFromAddr,
|
||||||
"relayTo": req.RelayToAddr,
|
"relayTo": req.RelayToAddr,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||||
@@ -318,44 +316,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
|
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
|
||||||
// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
|
n.hostMap.RLock()
|
||||||
cm.hostMap.RLock()
|
defer n.hostMap.RUnlock()
|
||||||
defer cm.hostMap.RUnlock()
|
|
||||||
|
|
||||||
hostinfo := cm.hostMap.Indexes[localIndex]
|
hostinfo := n.hostMap.Indexes[localIndex]
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
|
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
|
||||||
|
delete(n.pendingDeletion, localIndex)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.isInvalidCertificate(now, hostinfo) {
|
if n.isInvalidCertificate(now, hostinfo) {
|
||||||
|
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||||
return closeTunnel, hostinfo, nil
|
return closeTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||||
mainHostInfo := true
|
mainHostInfo := true
|
||||||
if primary != nil && primary != hostinfo {
|
if primary != nil && primary != hostinfo {
|
||||||
mainHostInfo = false
|
mainHostInfo = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for traffic on this hostinfo
|
// Check for traffic on this hostinfo
|
||||||
inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
|
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
|
||||||
|
|
||||||
// A hostinfo is determined alive if there is incoming traffic
|
// A hostinfo is determined alive if there is incoming traffic
|
||||||
if inTraffic {
|
if inTraffic {
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
}
|
}
|
||||||
hostinfo.pendingDeletion.Store(false)
|
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||||
|
|
||||||
if mainHostInfo {
|
if mainHostInfo {
|
||||||
decision = tryRehandshake
|
decision = tryRehandshake
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.shouldSwapPrimary(hostinfo) {
|
if n.shouldSwapPrimary(hostinfo, primary) {
|
||||||
decision = swapPrimary
|
decision = swapPrimary
|
||||||
} else {
|
} else {
|
||||||
// migrate the relays to the primary, if in use.
|
// migrate the relays to the primary, if in use.
|
||||||
@@ -363,55 +363,46 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
||||||
|
|
||||||
if !outTraffic {
|
if !outTraffic {
|
||||||
// Send a punch packet to keep the NAT state alive
|
// Send a punch packet to keep the NAT state alive
|
||||||
cm.sendPunch(hostinfo)
|
n.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
return decision, hostinfo, primary
|
return decision, hostinfo, primary
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostinfo.pendingDeletion.Load() {
|
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
|
||||||
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
||||||
Info("Tunnel status")
|
Info("Tunnel status")
|
||||||
|
|
||||||
|
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||||
return deleteTunnel, hostinfo, nil
|
return deleteTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
|
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
|
||||||
if !outTraffic {
|
if !outTraffic {
|
||||||
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
|
||||||
if isInactive {
|
|
||||||
// Tunnel is inactive, tear it down
|
|
||||||
hostinfo.logger(cm.l).
|
|
||||||
WithField("inactiveDuration", inactiveFor).
|
|
||||||
WithField("primary", mainHostInfo).
|
|
||||||
Info("Dropping tunnel due to inactivity")
|
|
||||||
|
|
||||||
return closeTunnel, hostinfo, primary
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
||||||
// Just maintain NAT state if configured to do so.
|
// Just maintain NAT state if configured to do so.
|
||||||
cm.sendPunch(hostinfo)
|
n.sendPunch(hostinfo)
|
||||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.punchy.GetTargetEverything() {
|
if n.punchy.GetTargetEverything() {
|
||||||
// This is similar to the old punchy behavior with a slight optimization.
|
// This is similar to the old punchy behavior with a slight optimization.
|
||||||
// We aren't receiving traffic but we are sending it, punch on all known
|
// We aren't receiving traffic but we are sending it, punch on all known
|
||||||
// ips in case we need to re-prime NAT state
|
// ips in case we need to re-prime NAT state
|
||||||
cm.sendPunch(hostinfo)
|
n.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(n.l).
|
||||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
}
|
}
|
||||||
@@ -420,33 +411,17 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
decision = sendTestPacket
|
decision = sendTestPacket
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if n.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
|
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.pendingDeletion.Store(true)
|
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
|
||||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
|
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
|
||||||
return decision, hostinfo, nil
|
return decision, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
|
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
||||||
if cm.dropInactive.Load() == false {
|
|
||||||
// We aren't configured to drop inactive tunnels
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
inactiveDuration := now.Sub(hostinfo.lastUsed)
|
|
||||||
if inactiveDuration < cm.getInactivityTimeout() {
|
|
||||||
// It's not considered inactive
|
|
||||||
return inactiveDuration, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// The tunnel is inactive
|
|
||||||
return inactiveDuration, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|
||||||
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
||||||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||||
// Let's sort this out.
|
// Let's sort this out.
|
||||||
@@ -454,127 +429,83 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|||||||
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
|
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
|
||||||
// vpn addr is static across all tunnels for this host pair so lets
|
// vpn addr is static across all tunnels for this host pair so lets
|
||||||
// use that to determine if we should consider swapping.
|
// use that to determine if we should consider swapping.
|
||||||
if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
|
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
|
||||||
// Their primary vpn addr is less than mine. Do not swap.
|
// Their primary vpn addr is less than mine. Do not swap.
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||||
if crt == nil {
|
|
||||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||||
// settle down.
|
// settle down.
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||||
cm.hostMap.Lock()
|
n.hostMap.Lock()
|
||||||
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
||||||
if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
||||||
cm.hostMap.unlockedMakePrimary(current)
|
n.hostMap.unlockedMakePrimary(current)
|
||||||
}
|
}
|
||||||
cm.hostMap.Unlock()
|
n.hostMap.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
// check and return true.
|
||||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||||
remoteCert := hostinfo.GetCert()
|
remoteCert := hostinfo.GetCert()
|
||||||
if remoteCert == nil {
|
if remoteCert == nil {
|
||||||
return false //don't tear down tunnels for handshakes in progress
|
|
||||||
}
|
|
||||||
|
|
||||||
caPool := cm.intf.pki.GetCAPool()
|
|
||||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
|
||||||
if err == nil {
|
|
||||||
return false //cert is still valid! yay!
|
|
||||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
|
||||||
// Block listed certificates should always be disconnected
|
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
|
||||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
|
||||||
return true
|
|
||||||
} else if cm.intf.disconnectInvalid.Load() {
|
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
|
||||||
return true
|
|
||||||
} else {
|
|
||||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
caPool := n.intf.pki.GetCAPool()
|
||||||
|
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||||
|
// Block listed certificates should always be disconnected
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo.logger(n.l).WithError(err).
|
||||||
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||||
if !cm.punchy.GetPunch() {
|
if !n.punchy.GetPunch() {
|
||||||
// Punching is disabled
|
// Punching is disabled
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
if n.punchy.GetTargetEverything() {
|
||||||
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
|
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||||
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
|
n.metricsTxPunchy.Inc(1)
|
||||||
// would lose the ability to notify us and punchy.respond would become unreliable.
|
n.intf.outside.WriteTo([]byte{1}, addr)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if cm.punchy.GetTargetEverything() {
|
|
||||||
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
|
||||||
cm.metricsTxPunchy.Inc(1)
|
|
||||||
cm.intf.outside.WriteTo([]byte{1}, addr)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
} else if hostinfo.remote.IsValid() {
|
} else if hostinfo.remote.IsValid() {
|
||||||
cm.metricsTxPunchy.Inc(1)
|
n.metricsTxPunchy.Inc(1)
|
||||||
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
cs := cm.intf.pki.getCertState()
|
cs := n.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
curCrtVersion := curCrt.Version()
|
myCrt := cs.getCertificate(curCrt.Version())
|
||||||
myCrt := cs.getCertificate(curCrtVersion)
|
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||||
if myCrt == nil {
|
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("version", curCrtVersion).
|
|
||||||
WithField("reason", "local certificate removed").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerCrt := hostinfo.ConnectionState.peerCert
|
|
||||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
|
||||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
|
||||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("version", curCrtVersion).
|
|
||||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
|
||||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
|
||||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("reason", "local certificate is not current").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
return
|
WithField("reason", "local certificate is not current").
|
||||||
}
|
Info("Re-handshaking with remote")
|
||||||
if curCrtVersion < cs.initiatingVersion {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -22,7 +23,7 @@ func newTestLighthouse() *LightHouse {
|
|||||||
addrMap: map[netip.Addr]*RemoteList{},
|
addrMap: map[netip.Addr]*RemoteList{},
|
||||||
queryChan: make(chan netip.Addr, 10),
|
queryChan: make(chan netip.Addr, 10),
|
||||||
}
|
}
|
||||||
lighthouses := []netip.Addr{}
|
lighthouses := map[netip.Addr]struct{}{}
|
||||||
staticList := map[netip.Addr]struct{}{}
|
staticList := map[netip.Addr]struct{}{}
|
||||||
|
|
||||||
lh.lighthouses.Store(&lighthouses)
|
lh.lighthouses.Store(&lighthouses)
|
||||||
@@ -63,10 +64,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
defer cancel()
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||||
nc.intf = ifce
|
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
@@ -84,33 +85,32 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
// We saw traffic out to vpnIp
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(hostinfo)
|
nc.Out(hostinfo.localIndexId)
|
||||||
nc.In(hostinfo)
|
nc.In(hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.True(t, hostinfo.out.Load())
|
assert.Contains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.True(t, hostinfo.in.Load())
|
|
||||||
|
|
||||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
|
|
||||||
// Do another traffic check tick, this host should be pending deletion now
|
// Do another traffic check tick, this host should be pending deletion now
|
||||||
nc.Out(hostinfo)
|
nc.Out(hostinfo.localIndexId)
|
||||||
assert.True(t, hostinfo.out.Load())
|
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.True(t, hostinfo.pendingDeletion.Load())
|
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
// Do a final traffic check tick, the host should now be removed
|
// Do a final traffic check tick, the host should now be removed
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
|
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
defer cancel()
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||||
nc.intf = ifce
|
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
@@ -167,129 +167,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
// We saw traffic out to vpnIp
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(hostinfo)
|
nc.Out(hostinfo.localIndexId)
|
||||||
nc.In(hostinfo)
|
nc.In(hostinfo.localIndexId)
|
||||||
assert.True(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
|
||||||
assert.True(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
|
|
||||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
|
|
||||||
// Do another traffic check tick, this host should be pending deletion now
|
// Do another traffic check tick, this host should be pending deletion now
|
||||||
nc.Out(hostinfo)
|
nc.Out(hostinfo.localIndexId)
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.True(t, hostinfo.pendingDeletion.Load())
|
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
// We saw traffic, should no longer be pending deletion
|
// We saw traffic, should no longer be pending deletion
|
||||||
nc.In(hostinfo)
|
nc.In(hostinfo.localIndexId)
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.out.Load())
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.False(t, hostinfo.in.Load())
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
|
||||||
vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
|
|
||||||
preferredRanges := []netip.Prefix{localrange}
|
|
||||||
|
|
||||||
// Very incomplete mock objects
|
|
||||||
hostMap := newHostMap(l)
|
|
||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
|
||||||
|
|
||||||
cs := &CertState{
|
|
||||||
initiatingVersion: cert.Version1,
|
|
||||||
privateKey: []byte{},
|
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
|
||||||
v1HandshakeBytes: []byte{},
|
|
||||||
}
|
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
|
||||||
ifce := &Interface{
|
|
||||||
hostMap: hostMap,
|
|
||||||
inside: &test.NoopTun{},
|
|
||||||
outside: &udp.NoopConn{},
|
|
||||||
firewall: &Firewall{},
|
|
||||||
lightHouse: lh,
|
|
||||||
pki: &PKI{},
|
|
||||||
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
|
||||||
l: l,
|
|
||||||
}
|
|
||||||
ifce.pki.cs.Store(cs)
|
|
||||||
|
|
||||||
// Create manager
|
|
||||||
conf := config.NewC(l)
|
|
||||||
conf.Settings["tunnels"] = map[string]any{
|
|
||||||
"drop_inactive": true,
|
|
||||||
}
|
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
|
||||||
assert.True(t, nc.dropInactive.Load())
|
|
||||||
nc.intf = ifce
|
|
||||||
|
|
||||||
// Add an ip we have established a connection w/ to hostmap
|
|
||||||
hostinfo := &HostInfo{
|
|
||||||
vpnAddrs: vpnAddrs,
|
|
||||||
localIndexId: 1099,
|
|
||||||
remoteIndexId: 9901,
|
|
||||||
}
|
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
|
||||||
myCert: &dummyCert{version: cert.Version1},
|
|
||||||
H: &noise.HandshakeState{},
|
|
||||||
}
|
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
|
||||||
|
|
||||||
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
|
|
||||||
nc.Out(hostinfo)
|
|
||||||
nc.In(hostinfo)
|
|
||||||
assert.True(t, hostinfo.out.Load())
|
|
||||||
assert.True(t, hostinfo.in.Load())
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
|
|
||||||
assert.Equal(t, tryRehandshake, decision)
|
|
||||||
assert.Equal(t, now, hostinfo.lastUsed)
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.False(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.in.Load())
|
|
||||||
|
|
||||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
|
|
||||||
assert.Equal(t, doNothing, decision)
|
|
||||||
assert.Equal(t, now, hostinfo.lastUsed)
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.False(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.in.Load())
|
|
||||||
|
|
||||||
// Do another traffic check tick, should still not be pending deletion
|
|
||||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
|
|
||||||
assert.Equal(t, doNothing, decision)
|
|
||||||
assert.Equal(t, now, hostinfo.lastUsed)
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.False(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.in.Load())
|
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
|
||||||
|
|
||||||
// Finally advance beyond the inactivity timeout
|
|
||||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
|
|
||||||
assert.Equal(t, closeTunnel, decision)
|
|
||||||
assert.Equal(t, now, hostinfo.lastUsed)
|
|
||||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
|
||||||
assert.False(t, hostinfo.out.Load())
|
|
||||||
assert.False(t, hostinfo.in.Load())
|
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
@@ -360,10 +264,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
ifce.disconnectInvalid.Store(true)
|
ifce.disconnectInvalid.Store(true)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
defer cancel()
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||||
nc.intf = ifce
|
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
|
|
||||||
hostinfo := &HostInfo{
|
hostinfo := &HostInfo{
|
||||||
@@ -446,10 +350,6 @@ func (d *dummyCert) PublicKey() []byte {
|
|||||||
return d.publicKey
|
return d.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
|
|
||||||
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dummyCert) Signature() []byte {
|
func (d *dummyCert) Signature() []byte {
|
||||||
return d.signature
|
return d.signature
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,16 +4,16 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ReplayWindow = 4096
|
const ReplayWindow = 1024
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
eKey *NebulaCipherState
|
eKey *NebulaCipherState
|
||||||
@@ -24,7 +24,7 @@ type ConnectionState struct {
|
|||||||
initiator bool
|
initiator bool
|
||||||
messageCounter atomic.Uint64
|
messageCounter atomic.Uint64
|
||||||
window *Bits
|
window *Bits
|
||||||
writeLock sync.Mutex
|
writeLock synctrace.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||||
@@ -76,6 +76,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
|||||||
initiator: initiator,
|
initiator: initiator,
|
||||||
window: b,
|
window: b,
|
||||||
myCert: crt,
|
myCert: crt,
|
||||||
|
writeLock: synctrace.NewMutex("connection-state"),
|
||||||
}
|
}
|
||||||
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
||||||
ci.messageCounter.Add(2)
|
ci.messageCounter.Add(2)
|
||||||
|
|||||||
20
control.go
20
control.go
@@ -26,15 +26,14 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
sshStart func()
|
sshStart func()
|
||||||
statsStart func()
|
statsStart func()
|
||||||
dnsStart func()
|
dnsStart func()
|
||||||
lighthouseStart func()
|
lighthouseStart func()
|
||||||
connectionManagerStart func(context.Context)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ControlHostInfo struct {
|
type ControlHostInfo struct {
|
||||||
@@ -64,9 +63,6 @@ func (c *Control) Start() {
|
|||||||
if c.dnsStart != nil {
|
if c.dnsStart != nil {
|
||||||
go c.dnsStart()
|
go c.dnsStart()
|
||||||
}
|
}
|
||||||
if c.connectionManagerStart != nil {
|
|
||||||
go c.connectionManagerStart(c.ctx)
|
|
||||||
}
|
|
||||||
if c.lighthouseStart != nil {
|
if c.lighthouseStart != nil {
|
||||||
c.lighthouseStart()
|
c.lighthouseStart()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
vpnAddrs: []netip.Addr{vpnIp},
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: nil,
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
vpnAddrs: []netip.Addr{vpnIp2},
|
vpnAddrs: []netip.Addr{vpnIp2},
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: nil,
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
|
|||||||
return c.f.hostMap
|
return c.f.hostMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetF() *Interface {
|
|
||||||
return c.f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) GetCertState() *CertState {
|
func (c *Control) GetCertState() *CertState {
|
||||||
return c.f.pki.getCertState()
|
return c.f.pki.getCertState()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This whole thing should be rewritten to use context
|
// This whole thing should be rewritten to use context
|
||||||
@@ -21,7 +21,7 @@ var dnsServer *dns.Server
|
|||||||
var dnsAddr string
|
var dnsAddr string
|
||||||
|
|
||||||
type dnsRecords struct {
|
type dnsRecords struct {
|
||||||
sync.RWMutex
|
synctrace.RWMutex
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
dnsMap4 map[string]netip.Addr
|
dnsMap4 map[string]netip.Addr
|
||||||
dnsMap6 map[string]netip.Addr
|
dnsMap6 map[string]netip.Addr
|
||||||
@@ -31,6 +31,7 @@ type dnsRecords struct {
|
|||||||
|
|
||||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||||
return &dnsRecords{
|
return &dnsRecords{
|
||||||
|
RWMutex: synctrace.NewRWMutex("dns-records"),
|
||||||
l: l,
|
l: l,
|
||||||
dnsMap4: make(map[string]netip.Addr),
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
dnsMap6: make(map[string]netip.Addr),
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkHotPath(b *testing.B) {
|
func BenchmarkHotPath(b *testing.B) {
|
||||||
@@ -97,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGoodHandshakeNoOverlap(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
|
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
empty := []byte{}
|
|
||||||
t.Log("do something to cause a handshake")
|
|
||||||
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
|
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
|
||||||
|
|
||||||
t.Log("Get their stage 1 packet")
|
|
||||||
stage1Packet := theirControl.GetFromUDP(true)
|
|
||||||
|
|
||||||
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
|
|
||||||
myControl.InjectUDPPacket(stage1Packet)
|
|
||||||
|
|
||||||
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
|
|
||||||
myControl.WaitForType(header.Test, 0, theirControl)
|
|
||||||
|
|
||||||
t.Log("Make sure our host infos are correct")
|
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWrongResponderHandshake(t *testing.T) {
|
func TestWrongResponderHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
@@ -499,35 +464,6 @@ func TestRelays(t *testing.T) {
|
|||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRelaysDontCareAboutIps(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
r.Log("Assert the tunnel works")
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReestablishRelays(t *testing.T) {
|
func TestReestablishRelays(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
@@ -570,7 +506,7 @@ func TestReestablishRelays(t *testing.T) {
|
|||||||
curIndexes := len(myControl.GetHostmap().Indexes)
|
curIndexes := len(myControl.GetHostmap().Indexes)
|
||||||
for curIndexes >= start {
|
for curIndexes >= start {
|
||||||
curIndexes = len(myControl.GetHostmap().Indexes)
|
curIndexes = len(myControl.GetHostmap().Indexes)
|
||||||
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
|
r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
|
||||||
|
|
||||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
@@ -1116,9 +1052,6 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
t.Log("Stand up a tunnel between me and them")
|
t.Log("Stand up a tunnel between me and them")
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
r.Log("Renew their certificate and spin until mine sees it")
|
r.Log("Renew their certificate and spin until mine sees it")
|
||||||
@@ -1291,109 +1224,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
|
||||||
|
|
||||||
o := m{
|
|
||||||
"static_host_map": m{
|
|
||||||
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
|
|
||||||
},
|
|
||||||
"lighthouse": m{
|
|
||||||
"hosts": []string{lhVpnIpNet[0].Addr().String()},
|
|
||||||
"local_allow_list": m{
|
|
||||||
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
|
|
||||||
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
|
|
||||||
"10.0.0.0/24": true,
|
|
||||||
"::/0": false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
|
|
||||||
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, lhControl, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
lhControl.Start()
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Stand up an ipv6 tunnel between me and them")
|
|
||||||
assert.True(t, myVpnIpNet[1].Addr().Is6())
|
|
||||||
assert.True(t, theirVpnIpNet[1].Addr().Is6())
|
|
||||||
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
|
|
||||||
|
|
||||||
lhControl.Stop()
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
|
||||||
unsafePrefix := "192.168.6.0/24"
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
|
|
||||||
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
|
|
||||||
myCfg := m{
|
|
||||||
"tun": m{
|
|
||||||
"unsafe_routes": []m{route},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
|
|
||||||
t.Logf("my config %v", myConfig)
|
|
||||||
// Put their info in our lighthouse
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
spookyDest := netip.MustParseAddr("192.168.6.4")
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
|
||||||
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
|
||||||
|
|
||||||
t.Log("Get their stage 1 packet so that we can play with it")
|
|
||||||
stage1Packet := theirControl.GetFromUDP(true)
|
|
||||||
|
|
||||||
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
|
|
||||||
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
|
|
||||||
badPacket := stage1Packet.Copy()
|
|
||||||
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
|
|
||||||
myControl.InjectUDPPacket(badPacket)
|
|
||||||
|
|
||||||
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
|
|
||||||
myControl.InjectUDPPacket(stage1Packet)
|
|
||||||
|
|
||||||
t.Log("Wait until we see my cached packet come through")
|
|
||||||
myControl.WaitForType(1, 0, theirControl)
|
|
||||||
|
|
||||||
t.Log("Make sure our host infos are correct")
|
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
|
||||||
|
|
||||||
t.Log("Get that cached packet and make sure it looks right")
|
|
||||||
myCachedPacket := theirControl.GetFromTun(true)
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
|
|
||||||
|
|
||||||
//reply
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
|
|
||||||
//wait for reply
|
|
||||||
theirControl.WaitForType(1, 0, myControl)
|
|
||||||
theirCachedPacket := myControl.GetFromTun(true)
|
|
||||||
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
|
|
||||||
|
|
||||||
t.Log("Do a bidirectional tunnel test")
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,14 +22,15 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"gopkg.in/yaml.v3"
|
||||||
"go.yaml.in/yaml/v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -55,54 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
}
|
|
||||||
|
|
||||||
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnNetworks) == 0 {
|
|
||||||
panic("no vpn networks")
|
|
||||||
}
|
|
||||||
|
|
||||||
firewallInbound := []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}}
|
|
||||||
|
|
||||||
var unsafeNetworks []netip.Prefix
|
|
||||||
if sUnsafeNetworks != "" {
|
|
||||||
firewallInbound = []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
"local_cidr": "0.0.0.0/0",
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
|
|
||||||
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
unsafeNetworks = append(unsafeNetworks, x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
|
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -122,7 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
|||||||
"port": "any",
|
"port": "any",
|
||||||
"host": "any",
|
"host": "any",
|
||||||
}},
|
}},
|
||||||
"inbound": firewallInbound,
|
"inbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
},
|
},
|
||||||
//"handshakes": m{
|
//"handshakes": m{
|
||||||
// "try_interval": "1s",
|
// "try_interval": "1s",
|
||||||
@@ -171,109 +129,6 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
|||||||
return control, vpnNetworks, udpAddr, c
|
return control, vpnNetworks, udpAddr, c
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServer creates a nebula instance with fewer assumptions
|
|
||||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
vpnNetworks := certs[len(certs)-1].Networks()
|
|
||||||
|
|
||||||
var udpAddr netip.AddrPort
|
|
||||||
if vpnNetworks[0].Addr().Is4() {
|
|
||||||
budpIp := vpnNetworks[0].Addr().As4()
|
|
||||||
budpIp[1] -= 128
|
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
|
||||||
} else {
|
|
||||||
budpIp := vpnNetworks[0].Addr().As16()
|
|
||||||
// beef for funsies
|
|
||||||
budpIp[2] = 190
|
|
||||||
budpIp[3] = 239
|
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
|
||||||
}
|
|
||||||
|
|
||||||
caStr := ""
|
|
||||||
for _, ca := range caCrt {
|
|
||||||
x, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr += string(x)
|
|
||||||
}
|
|
||||||
certStr := ""
|
|
||||||
for _, c := range certs {
|
|
||||||
x, err := c.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
certStr += string(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": certStr,
|
|
||||||
"key": string(key),
|
|
||||||
},
|
|
||||||
//"tun": m{"disabled": true},
|
|
||||||
"firewall": m{
|
|
||||||
"outbound": []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
"inbound": []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
//"handshakes": m{
|
|
||||||
// "try_interval": "1s",
|
|
||||||
//},
|
|
||||||
"listen": m{
|
|
||||||
"host": udpAddr.Addr().String(),
|
|
||||||
"port": udpAddr.Port(),
|
|
||||||
},
|
|
||||||
"logging": m{
|
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
|
||||||
"timers": m{
|
|
||||||
"pending_deletion_interval": 2,
|
|
||||||
"connection_alive_interval": 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if overrides != nil {
|
|
||||||
final := m{}
|
|
||||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
mc = final
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
|
||||||
cStr := string(cb)
|
|
||||||
c.LoadString(cStr)
|
|
||||||
|
|
||||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return control, vpnNetworks, udpAddr, c
|
|
||||||
}
|
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
|
|
||||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||||
@@ -308,10 +163,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
|
|||||||
// Get both host infos
|
// Get both host infos
|
||||||
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
||||||
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
||||||
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||||
|
|
||||||
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
||||||
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||||
|
|
||||||
// Check that both vpn and real addr are correct
|
// Check that both vpn and real addr are correct
|
||||||
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
||||||
|
|||||||
@@ -700,7 +700,6 @@ func (r *R) FlushAll() {
|
|||||||
r.Unlock()
|
r.Unlock()
|
||||||
panic("Can't FlushAll for host: " + p.To.String())
|
panic("Can't FlushAll for host: " + p.To.String())
|
||||||
}
|
}
|
||||||
receiver.InjectUDPPacket(p)
|
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,367 +0,0 @@
|
|||||||
//go:build e2e_testing
|
|
||||||
// +build e2e_testing
|
|
||||||
|
|
||||||
package e2e
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"github.com/slackhq/nebula/cert_test"
|
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDropInactiveTunnels(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
|
|
||||||
r.Log("Go inactive and wait for the tunnels to get dropped")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
myIndexes := len(myControl.GetHostmap().Indexes)
|
|
||||||
theirIndexes := len(theirControl.GetHostmap().Indexes)
|
|
||||||
if myIndexes == 0 && theirIndexes == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
|
|
||||||
if since > time.Second*30 {
|
|
||||||
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertUpgrade(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
caB, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
ca2B, err := ca2.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": string(myCert2Pem),
|
|
||||||
"key": string(myPrivKey),
|
|
||||||
},
|
|
||||||
//"tun": m{"disabled": true},
|
|
||||||
"firewall": myC.Settings["firewall"],
|
|
||||||
//"handshakes": m{
|
|
||||||
// "try_interval": "1s",
|
|
||||||
//},
|
|
||||||
"listen": myC.Settings["listen"],
|
|
||||||
"logging": myC.Settings["logging"],
|
|
||||||
"timers": myC.Settings["timers"],
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("reload new v2-only config")
|
|
||||||
err = myC.ReloadConfigString(string(cb))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r.Log("yay, spin until their sees it")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
r.Logf("version %d", version)
|
|
||||||
if version == cert.Version2 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*10 {
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertDowngrade(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
caB, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
ca2B, err := ca2.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
//r.Log("yay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": string(myCertPem),
|
|
||||||
"key": string(myPrivKey),
|
|
||||||
},
|
|
||||||
"firewall": myC.Settings["firewall"],
|
|
||||||
"listen": myC.Settings["listen"],
|
|
||||||
"logging": myC.Settings["logging"],
|
|
||||||
"timers": myC.Settings["timers"],
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("reload new v1-only config")
|
|
||||||
err = myC.ReloadConfigString(string(cb))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r.Log("yay, spin until their sees it")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil || c2 == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
theirVersion := c2.Cert.Version()
|
|
||||||
r.Logf("version %d,%d", version, theirVersion)
|
|
||||||
if version == cert.Version1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*5 {
|
|
||||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
|
||||||
}
|
|
||||||
if since > time.Second*10 {
|
|
||||||
r.Log("wtf")
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertMismatchCorrection(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
//r.Log("yay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil || c2 == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
theirVersion := c2.Cert.Version()
|
|
||||||
r.Logf("version %d,%d", version, theirVersion)
|
|
||||||
if version == theirVersion {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*5 {
|
|
||||||
r.Log("wtf")
|
|
||||||
}
|
|
||||||
if since > time.Second*10 {
|
|
||||||
r.Log("wtf")
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCrossStackRelaysWork(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
|
||||||
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
//myVpnV4 := myVpnIpNet[0]
|
|
||||||
myVpnV6 := myVpnIpNet[1]
|
|
||||||
relayVpnV4 := relayVpnIpNet[0]
|
|
||||||
relayVpnV6 := relayVpnIpNet[1]
|
|
||||||
theirVpnV6 := theirVpnIpNet[0]
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
r.Log("Assert the tunnel works")
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
|
||||||
|
|
||||||
t.Log("reply?")
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
|
||||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
|
||||||
//t.Log("finish up")
|
|
||||||
//myControl.Stop()
|
|
||||||
//theirControl.Stop()
|
|
||||||
//relayControl.Stop()
|
|
||||||
}
|
|
||||||
@@ -338,18 +338,6 @@ logging:
|
|||||||
# after receiving the response for lighthouse queries
|
# after receiving the response for lighthouse queries
|
||||||
#trigger_buffer: 64
|
#trigger_buffer: 64
|
||||||
|
|
||||||
# Tunnel manager settings
|
|
||||||
#tunnels:
|
|
||||||
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
|
|
||||||
# elapsed.
|
|
||||||
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
|
|
||||||
# This setting is reloadable
|
|
||||||
#drop_inactive: false
|
|
||||||
|
|
||||||
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
|
|
||||||
# inactive and eligible to be dropped.
|
|
||||||
# This setting is reloadable
|
|
||||||
#inactivity_timeout: 10m
|
|
||||||
|
|
||||||
# Nebula security group configuration
|
# Nebula security group configuration
|
||||||
firewall:
|
firewall:
|
||||||
|
|||||||
62
firewall.go
62
firewall.go
@@ -10,7 +10,6 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
@@ -19,6 +18,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FirewallInterface interface {
|
type FirewallInterface interface {
|
||||||
@@ -76,7 +76,7 @@ type firewallMetrics struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type FirewallConntrack struct {
|
type FirewallConntrack struct {
|
||||||
sync.Mutex
|
synctrace.Mutex
|
||||||
|
|
||||||
Conns map[firewall.Packet]*conn
|
Conns map[firewall.Packet]*conn
|
||||||
TimerWheel *TimerWheel[firewall.Packet]
|
TimerWheel *TimerWheel[firewall.Packet]
|
||||||
@@ -164,6 +164,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||||||
|
|
||||||
return &Firewall{
|
return &Firewall{
|
||||||
Conntrack: &FirewallConntrack{
|
Conntrack: &FirewallConntrack{
|
||||||
|
Mutex: synctrace.NewMutex("firewall-conntrack"),
|
||||||
Conns: make(map[firewall.Packet]*conn),
|
Conns: make(map[firewall.Packet]*conn),
|
||||||
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
||||||
},
|
},
|
||||||
@@ -417,45 +418,30 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrUnknownNetworkType = errors.New("unknown network type")
|
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||||
var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
|
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||||
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
|
|
||||||
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
|
||||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
|
|
||||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
// returns nil if the packet should not be dropped.
|
// returns nil if the packet should not be dropped.
|
||||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error {
|
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// Check if we spoke to this tuple, if we did then allow this packet
|
||||||
if f.inConns(fp, h, caPool, localCache, now) {
|
if f.inConns(fp, h, caPool, localCache) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
// Make sure remote address matches nebula certificate
|
||||||
if h.networks == nil {
|
if h.networks != nil {
|
||||||
// Simple case: Certificate has one address and no unsafe networks
|
if !h.networks.Contains(fp.RemoteAddr) {
|
||||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
// Simple case: Certificate has one address and no unsafe networks
|
||||||
if !ok {
|
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
switch nwType {
|
|
||||||
case NetworkTypeVPN:
|
|
||||||
break // nothing special
|
|
||||||
case NetworkTypeVPNPeer:
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
|
||||||
case NetworkTypeUnsafe:
|
|
||||||
break // nothing special, one day this may have different FW rules
|
|
||||||
default:
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrUnknownNetworkType //should never happen
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
@@ -476,7 +462,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We always want to conntrack since it is a faster operation
|
// We always want to conntrack since it is a faster operation
|
||||||
f.addConn(fp, incoming, now)
|
f.addConn(fp, incoming)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -505,7 +491,7 @@ func (f *Firewall) EmitStats() {
|
|||||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool {
|
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||||
if localCache != nil {
|
if localCache != nil {
|
||||||
if _, ok := localCache[fp]; ok {
|
if _, ok := localCache[fp]; ok {
|
||||||
return true
|
return true
|
||||||
@@ -517,7 +503,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
// Purge every time we test
|
// Purge every time we test
|
||||||
ep, has := conntrack.TimerWheel.Purge()
|
ep, has := conntrack.TimerWheel.Purge()
|
||||||
if has {
|
if has {
|
||||||
f.evict(ep, now)
|
f.evict(ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
c, ok := conntrack.Conns[fp]
|
c, ok := conntrack.Conns[fp]
|
||||||
@@ -564,11 +550,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
|
|
||||||
switch fp.Protocol {
|
switch fp.Protocol {
|
||||||
case firewall.ProtoTCP:
|
case firewall.ProtoTCP:
|
||||||
c.Expires = now.Add(f.TCPTimeout)
|
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||||
case firewall.ProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
c.Expires = now.Add(f.UDPTimeout)
|
c.Expires = time.Now().Add(f.UDPTimeout)
|
||||||
default:
|
default:
|
||||||
c.Expires = now.Add(f.DefaultTimeout)
|
c.Expires = time.Now().Add(f.DefaultTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
@@ -580,7 +566,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
|
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
c := &conn{}
|
c := &conn{}
|
||||||
|
|
||||||
@@ -596,7 +582,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
|
|||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
conntrack.Lock()
|
conntrack.Lock()
|
||||||
if _, ok := conntrack.Conns[fp]; !ok {
|
if _, ok := conntrack.Conns[fp]; !ok {
|
||||||
conntrack.TimerWheel.Advance(now)
|
conntrack.TimerWheel.Advance(time.Now())
|
||||||
conntrack.TimerWheel.Add(fp, timeout)
|
conntrack.TimerWheel.Add(fp, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,14 +590,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
|
|||||||
// firewall reload
|
// firewall reload
|
||||||
c.incoming = incoming
|
c.incoming = incoming
|
||||||
c.rulesVersion = f.rulesVersion
|
c.rulesVersion = f.rulesVersion
|
||||||
c.Expires = now.Add(timeout)
|
c.Expires = time.Now().Add(timeout)
|
||||||
conntrack.Conns[fp] = c
|
conntrack.Conns[fp] = c
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||||
// Caller must own the connMutex lock!
|
// Caller must own the connMutex lock!
|
||||||
func (f *Firewall) evict(p firewall.Packet, now time.Time) {
|
func (f *Firewall) evict(p firewall.Packet) {
|
||||||
// Are we still tracking this conn?
|
// Are we still tracking this conn?
|
||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
t, ok := conntrack.Conns[p]
|
t, ok := conntrack.Conns[p]
|
||||||
@@ -619,11 +605,11 @@ func (f *Firewall) evict(p firewall.Packet, now time.Time) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newT := t.Expires.Sub(now)
|
newT := t.Expires.Sub(time.Now())
|
||||||
|
|
||||||
// Timeout is in the future, re-add the timer
|
// Timeout is in the future, re-add the timer
|
||||||
if newT > 0 {
|
if newT > 0 {
|
||||||
conntrack.TimerWheel.Advance(now)
|
conntrack.TimerWheel.Advance(time.Now())
|
||||||
conntrack.TimerWheel.Add(p, newT)
|
conntrack.TimerWheel.Add(p, newT)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
438
firewall_test.go
438
firewall_test.go
@@ -8,8 +8,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -70,9 +68,6 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ti6, err := netip.ParsePrefix("fd12::34/128")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
// An empty rule is any
|
// An empty rule is any
|
||||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||||
@@ -97,24 +92,12 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
|
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
|
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||||
@@ -134,13 +117,6 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
anyIp6, err := netip.ParsePrefix("::/0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -151,8 +127,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -177,7 +152,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, &c)
|
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -224,85 +199,6 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropV6(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
|
||||||
LocalPort: 10,
|
|
||||||
RemotePort: 90,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
c := dummyCert{
|
|
||||||
name: "host1",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
h := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &cert.CachedCertificate{
|
|
||||||
Certificate: &c,
|
|
||||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
|
||||||
}
|
|
||||||
h.buildNetworks(myVpnNetworksTable, &c)
|
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
|
|
||||||
// Drop outbound
|
|
||||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
|
||||||
// Allow inbound
|
|
||||||
resetConntrack(fw)
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
|
||||||
// Allow outbound because conntrack
|
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
|
||||||
|
|
||||||
// test remote mismatch
|
|
||||||
oldRemote := p.RemoteAddr
|
|
||||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
|
||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
|
||||||
p.RemoteAddr = oldRemote
|
|
||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
f := &Firewall{}
|
f := &Firewall{}
|
||||||
ft := FirewallTable{
|
ft := FirewallTable{
|
||||||
@@ -312,10 +208,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
||||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
||||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
||||||
|
|
||||||
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
|
||||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
|
|
||||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
|
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
b.Run("fail on proto", func(b *testing.B) {
|
b.Run("fail on proto", func(b *testing.B) {
|
||||||
@@ -347,15 +239,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
|
|
||||||
c := &cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{},
|
|
||||||
}
|
|
||||||
ip := netip.MustParsePrefix("fd99::99/128")
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||||
c := &cert.CachedCertificate{
|
c := &cert.CachedCertificate{
|
||||||
@@ -369,18 +252,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
|
||||||
c := &cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "nope",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
|
|
||||||
},
|
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
|
||||||
}
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||||
c := &cert.CachedCertificate{
|
c := &cert.CachedCertificate{
|
||||||
@@ -394,18 +265,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
|
||||||
c := &cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "nope",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
|
|
||||||
},
|
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
|
||||||
}
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
||||||
c := &cert.CachedCertificate{
|
c := &cert.CachedCertificate{
|
||||||
@@ -430,17 +289,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
|
|
||||||
c := &cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "nope",
|
|
||||||
},
|
|
||||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
|
||||||
}
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("pass on name", func(b *testing.B) {
|
b.Run("pass on name", func(b *testing.B) {
|
||||||
c := &cert.CachedCertificate{
|
c := &cert.CachedCertificate{
|
||||||
@@ -459,8 +307,6 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -486,7 +332,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
c1 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -501,7 +347,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -518,8 +364,6 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -551,7 +395,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c2 := cert.CachedCertificate{
|
c2 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -566,7 +410,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
|
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c3 := cert.CachedCertificate{
|
c3 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -581,7 +425,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -603,50 +447,10 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3V6(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
|
||||||
LocalPort: 1,
|
|
||||||
RemotePort: 1,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
network := netip.MustParsePrefix("fd12::34/120")
|
|
||||||
c := cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "host-owner",
|
|
||||||
networks: []netip.Prefix{network},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
h := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &c,
|
|
||||||
},
|
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
|
||||||
}
|
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
|
||||||
|
|
||||||
// Test a remote address match
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -673,7 +477,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -706,52 +510,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
|
||||||
|
|
||||||
c := cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "host-owner",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
|
||||||
Certificate: &dummyCert{
|
|
||||||
name: "host",
|
|
||||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
|
|
||||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
h1 := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &c1,
|
|
||||||
},
|
|
||||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
|
||||||
}
|
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
|
|
||||||
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("192.0.2.1"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
|
|
||||||
LocalPort: 1,
|
|
||||||
RemotePort: 1,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLookup(b *testing.B) {
|
func BenchmarkLookup(b *testing.B) {
|
||||||
ml := func(m map[string]struct{}, a [][]string) {
|
ml := func(m map[string]struct{}, a [][]string) {
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
@@ -969,21 +727,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with cidr ipv6
|
|
||||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
|
|
||||||
|
|
||||||
// Test adding rule with local_cidr ipv6
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
|
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
@@ -1063,171 +806,6 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
assert.Equal(t, "group1", r.Group)
|
assert.Equal(t, "group1", r.Group)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testcase struct {
|
|
||||||
h *HostInfo
|
|
||||||
p firewall.Packet
|
|
||||||
c cert.Certificate
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
|
||||||
t.Helper()
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
resetConntrack(fw)
|
|
||||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
|
||||||
if c.err == nil {
|
|
||||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
|
||||||
} else {
|
|
||||||
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
|
||||||
c1 := dummyCert{
|
|
||||||
name: "host1",
|
|
||||||
networks: theirPrefixes,
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
h := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &cert.CachedCertificate{
|
|
||||||
Certificate: &c1,
|
|
||||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
|
|
||||||
}
|
|
||||||
for i := range theirPrefixes {
|
|
||||||
h.vpnAddrs[i] = theirPrefixes[i].Addr()
|
|
||||||
}
|
|
||||||
h.buildNetworks(setup.myVpnNetworksTable, &c1)
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
|
||||||
RemoteAddr: theirPrefixes[0].Addr(),
|
|
||||||
LocalPort: 10,
|
|
||||||
RemotePort: 90,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
return testcase{
|
|
||||||
h: &h,
|
|
||||||
p: p,
|
|
||||||
c: &c1,
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type testsetup struct {
|
|
||||||
c dummyCert
|
|
||||||
myVpnNetworksTable *bart.Lite
|
|
||||||
fw *Firewall
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
|
||||||
c := dummyCert{
|
|
||||||
name: "me",
|
|
||||||
networks: myPrefixes,
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
|
|
||||||
return newSetupFromCert(t, l, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
for _, prefix := range c.Networks() {
|
|
||||||
myVpnNetworksTable.Insert(prefix)
|
|
||||||
}
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
||||||
|
|
||||||
return testsetup{
|
|
||||||
c: c,
|
|
||||||
fw: fw,
|
|
||||||
myVpnNetworksTable: myVpnNetworksTable,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
|
||||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
|
||||||
t.Run("allow inbound all matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound local matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("block inbound remote mismatched", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("Block a vpn peer packet", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
twoPrefixes := []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
|
|
||||||
}
|
|
||||||
t.Run("allow inbound one matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, nil, twoPrefixes...)
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("block inbound multimismatch", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
|
|
||||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
|
|
||||||
tc := buildTestCase(setup2, nil, twoPrefixes...)
|
|
||||||
tc.p.RemoteAddr = twoPrefixes[1].Addr()
|
|
||||||
tc.Test(t, setup2.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound unsafe route", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
|
|
||||||
c := dummyCert{
|
|
||||||
name: "me",
|
|
||||||
networks: []netip.Prefix{myPrefix},
|
|
||||||
unsafeNetworks: []netip.Prefix{unsafePrefix},
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
unsafeSetup := newSetupFromCert(t, l, c)
|
|
||||||
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
|
|
||||||
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
|
||||||
tc.err = ErrNoMatchingRule
|
|
||||||
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
|
||||||
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
|
|
||||||
tc.err = nil
|
|
||||||
tc.Test(t, unsafeSetup.fw) //should pass
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type addRuleCall struct {
|
type addRuleCall struct {
|
||||||
incoming bool
|
incoming bool
|
||||||
proto uint8
|
proto uint8
|
||||||
|
|||||||
51
go.mod
51
go.mod
@@ -1,38 +1,40 @@
|
|||||||
module github.com/slackhq/nebula
|
module github.com/slackhq/nebula
|
||||||
|
|
||||||
go 1.25
|
go 1.23.0
|
||||||
|
|
||||||
|
toolchain go1.24.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
dario.cat/mergo v1.0.2
|
dario.cat/mergo v1.0.1
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
||||||
github.com/armon/go-radix v1.0.0
|
github.com/armon/go-radix v1.0.0
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||||
github.com/flynn/noise v1.1.0
|
github.com/flynn/noise v1.1.0
|
||||||
github.com/gaissmai/bart v0.25.0
|
github.com/gaissmai/bart v0.20.4
|
||||||
github.com/gogo/protobuf v1.3.2
|
github.com/gogo/protobuf v1.3.2
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/kardianos/service v1.2.4
|
github.com/kardianos/service v1.2.2
|
||||||
github.com/miekg/dns v1.1.68
|
github.com/miekg/dns v1.1.65
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
||||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.22.0
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.10.0
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.0
|
||||||
go.yaml.in/yaml/v3 v3.0.4
|
github.com/wadey/synctrace v0.0.0-20250612192159-94547ef50dfe
|
||||||
golang.org/x/crypto v0.44.0
|
golang.org/x/crypto v0.37.0
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.46.0
|
golang.org/x/net v0.39.0
|
||||||
golang.org/x/sync v0.18.0
|
golang.org/x/sync v0.13.0
|
||||||
golang.org/x/sys v0.38.0
|
golang.org/x/sys v0.32.0
|
||||||
golang.org/x/term v0.37.0
|
golang.org/x/term v0.31.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/protobuf v1.36.10
|
google.golang.org/protobuf v1.36.6
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||||
)
|
)
|
||||||
@@ -41,15 +43,18 @@ require (
|
|||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/emirpasic/gods v1.18.1 // indirect
|
||||||
github.com/google/btree v1.1.2 // indirect
|
github.com/google/btree v1.1.2 // indirect
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
|
github.com/heimdalr/dag v1.4.0 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/prometheus/client_model v0.6.2 // indirect
|
github.com/prometheus/client_model v0.6.1 // indirect
|
||||||
github.com/prometheus/common v0.66.1 // indirect
|
github.com/prometheus/common v0.62.0 // indirect
|
||||||
github.com/prometheus/procfs v0.16.1 // indirect
|
github.com/prometheus/procfs v0.15.1 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/timandy/routine v1.1.5 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
golang.org/x/mod v0.24.0 // indirect
|
golang.org/x/mod v0.23.0 // indirect
|
||||||
golang.org/x/time v0.7.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.33.0 // indirect
|
golang.org/x/tools v0.30.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
99
go.sum
99
go.sum
@@ -1,6 +1,6 @@
|
|||||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||||
@@ -22,10 +22,12 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||||
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||||
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
|
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
|
||||||
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
|
||||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||||
@@ -33,6 +35,8 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9
|
|||||||
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
|
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
|
||||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
|
github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg=
|
||||||
|
github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
@@ -58,14 +62,18 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/heimdalr/dag v1.4.0 h1:zG3JA4RDVLc55k3AXAgfwa+EgBNZ0TkfOO3C29Ucpmg=
|
||||||
|
github.com/heimdalr/dag v1.4.0/go.mod h1:OCh6ghKmU0hPjtwMqWBoNxPmtRioKd1xSu7Zs4sbIqM=
|
||||||
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
||||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||||
github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
|
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
|
||||||
github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
|
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||||
@@ -83,8 +91,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
|
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
|
||||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
|
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -106,24 +114,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
|
|||||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
|
||||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
|
||||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||||
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||||
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
||||||
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
||||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||||
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
@@ -143,35 +151,33 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
|||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
github.com/timandy/routine v1.1.5 h1:LSpm7Iijwb9imIPlucl4krpr2EeCeAUvifiQ9Uf5X+M=
|
||||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
github.com/timandy/routine v1.1.5/go.mod h1:kXslgIosdY8LW0byTyPnenDgn4/azt2euufAq9rK51w=
|
||||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||||
|
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||||
|
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
|
github.com/wadey/synctrace v0.0.0-20250612192159-94547ef50dfe h1:dc8Q42VsX+ABr0drJw27f3smvGfcz7eB8rJx+IkVMAo=
|
||||||
|
github.com/wadey/synctrace v0.0.0-20250612192159-94547ef50dfe/go.mod h1:F2VCml4UxGPgAAqqm9T0ZfnVRWITrQS1EMZM+KCAm/Q=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
|
||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
|
||||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
|
||||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
|
||||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -182,8 +188,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -191,8 +197,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -203,29 +209,30 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
|
||||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@@ -244,8 +251,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
|||||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|||||||
171
handshake_ix.go
171
handshake_ix.go
@@ -2,12 +2,14 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NOISE IX Handshakes
|
// NOISE IX Handshakes
|
||||||
@@ -22,17 +24,13 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we're connecting to a v6 address we must use a v2 cert
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
v := cs.initiatingVersion
|
v := cs.initiatingVersion
|
||||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
v = hh.initiatingVersionOverride
|
if a.Is6() {
|
||||||
} else if v < cert.Version2 {
|
v = cert.Version2
|
||||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
break
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
|
||||||
if a.Is6() {
|
|
||||||
v = cert.Version2
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,7 +49,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", v).
|
WithField("certVersion", v).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||||
@@ -107,7 +104,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
WithField("certVersion", cs.initiatingVersion).
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||||
@@ -148,8 +144,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fp, fperr := rc.Fingerprint()
|
fp, err := rc.Fingerprint()
|
||||||
if fperr != nil {
|
if err != nil {
|
||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,19 +164,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if myCertOtherVersion == nil {
|
if rc == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
f.l.WithError(err).WithFields(m{
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||||
"udpAddr": addr,
|
Info("Unable to handshake with host due to missing certificate version")
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
return
|
||||||
"cert": remoteCert,
|
|
||||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Record the certificate we are actually using
|
|
||||||
ci.myCert = myCertOtherVersion
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record the certificate we are actually using
|
||||||
|
ci.myCert = rc
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
@@ -191,17 +184,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var vpnAddrs []netip.Addr
|
||||||
|
var filteredNetworks []netip.Prefix
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
certVersion := remoteCert.Certificate.Version()
|
certVersion := remoteCert.Certificate.Version()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
vpnNetworks := remoteCert.Certificate.Networks()
|
|
||||||
|
|
||||||
anyVpnAddrsInCommon := false
|
for _, network := range remoteCert.Certificate.Networks() {
|
||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
vpnAddr := network.Addr()
|
||||||
for i, network := range vpnNetworks {
|
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -209,10 +202,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrs[i] = network.Addr()
|
|
||||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||||
anyVpnAddrsInCommon = true
|
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filteredNetworks = append(filteredNetworks, network)
|
||||||
|
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnAddrs) == 0 {
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
@@ -243,36 +250,33 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
lastHandshakeTime: hs.Details.Time,
|
lastHandshakeTime: hs.Details.Time,
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: nil,
|
RWMutex: synctrace.NewRWMutex("relay-state"),
|
||||||
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgRxL := f.l.WithFields(m{
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
"vpnAddrs": vpnAddrs,
|
WithField("certName", certName).
|
||||||
"udpAddr": addr,
|
WithField("certVersion", certVersion).
|
||||||
"certName": certName,
|
WithField("fingerprint", fingerprint).
|
||||||
"certVersion": certVersion,
|
WithField("issuer", issuer).
|
||||||
"fingerprint": fingerprint,
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
"issuer": issuer,
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
Info("Handshake message received")
|
||||||
"responderIndex": hs.Details.ResponderIndex,
|
|
||||||
"remoteIndex": h.RemoteIndex,
|
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
|
||||||
})
|
|
||||||
|
|
||||||
if anyVpnAddrsInCommon {
|
|
||||||
msgRxL.Info("Handshake message received")
|
|
||||||
} else {
|
|
||||||
//todo warn if not lighthouse or relay?
|
|
||||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
WithField("certVersion", ci.myCert.Version()).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -330,7 +334,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -455,9 +459,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
||||||
|
|
||||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
hostinfo.remotes.ResetBlockedRemotes()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -571,22 +575,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
correctHostResponded := false
|
var vpnAddrs []netip.Addr
|
||||||
anyVpnAddrsInCommon := false
|
var filteredNetworks []netip.Prefix
|
||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
for _, network := range vpnNetworks {
|
||||||
for i, network := range vpnNetworks {
|
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||||
vpnAddrs[i] = network.Addr()
|
vpnAddr := network.Addr()
|
||||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
anyVpnAddrsInCommon = true
|
continue
|
||||||
}
|
|
||||||
if hostinfo.vpnAddrs[0] == network.Addr() {
|
|
||||||
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
|
||||||
correctHostResponded = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filteredNetworks = append(filteredNetworks, network)
|
||||||
|
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnAddrs) == 0 {
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if !correctHostResponded {
|
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
@@ -598,7 +611,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
@@ -625,7 +637,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -633,21 +645,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("durationNs", duration).
|
WithField("durationNs", duration).
|
||||||
WithField("sentCachedPackets", len(hh.packetStore))
|
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||||
if anyVpnAddrsInCommon {
|
Info("Handshake message received")
|
||||||
msgRxL.Info("Handshake message received")
|
|
||||||
} else {
|
|
||||||
//todo warn if not lighthouse or relay?
|
|
||||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
// Build up the radix for the firewall if we have subnets in the cert
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
hostinfo.vpnAddrs = vpnAddrs
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
||||||
@@ -662,7 +669,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
hostinfo.remotes.ResetBlockedRemotes()
|
||||||
f.metricHandshakes.Update(duration)
|
f.metricHandshakes.Update(duration)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
@@ -16,6 +15,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -45,7 +45,7 @@ type HandshakeConfig struct {
|
|||||||
|
|
||||||
type HandshakeManager struct {
|
type HandshakeManager struct {
|
||||||
// Mutex for interacting with the vpnIps and indexes maps
|
// Mutex for interacting with the vpnIps and indexes maps
|
||||||
sync.RWMutex
|
synctrace.RWMutex
|
||||||
|
|
||||||
vpnIps map[netip.Addr]*HandshakeHostInfo
|
vpnIps map[netip.Addr]*HandshakeHostInfo
|
||||||
indexes map[uint32]*HandshakeHostInfo
|
indexes map[uint32]*HandshakeHostInfo
|
||||||
@@ -66,14 +66,13 @@ type HandshakeManager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HandshakeHostInfo struct {
|
type HandshakeHostInfo struct {
|
||||||
sync.Mutex
|
synctrace.Mutex
|
||||||
|
|
||||||
startTime time.Time // Time that we first started trying with this handshake
|
startTime time.Time // Time that we first started trying with this handshake
|
||||||
ready bool // Is the handshake ready
|
ready bool // Is the handshake ready
|
||||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
counter int64 // How many attempts have we made so far
|
||||||
counter int64 // How many attempts have we made so far
|
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
|
||||||
|
|
||||||
hostinfo *HostInfo
|
hostinfo *HostInfo
|
||||||
}
|
}
|
||||||
@@ -105,6 +104,7 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
|
|||||||
|
|
||||||
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
||||||
return &HandshakeManager{
|
return &HandshakeManager{
|
||||||
|
RWMutex: synctrace.NewRWMutex("handshake-manager"),
|
||||||
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
|
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
|
||||||
indexes: map[uint32]*HandshakeHostInfo{},
|
indexes: map[uint32]*HandshakeHostInfo{},
|
||||||
mainHostMap: mainHostMap,
|
mainHostMap: mainHostMap,
|
||||||
@@ -112,7 +112,7 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
|
|||||||
outside: outside,
|
outside: outside,
|
||||||
config: config,
|
config: config,
|
||||||
trigger: make(chan netip.Addr, config.triggerBuffer),
|
trigger: make(chan netip.Addr, config.triggerBuffer),
|
||||||
OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
|
OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr]("outbound-handshake-timer", config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
|
||||||
messageMetrics: config.messageMetrics,
|
messageMetrics: config.messageMetrics,
|
||||||
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
|
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
|
||||||
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
|
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
|
||||||
@@ -269,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||||
// Send a RelayRequest to all known Relay IP's
|
// Send a RelayRequest to all known Relay IP's
|
||||||
for _, relay := range hostinfo.remotes.relays {
|
for _, relay := range hostinfo.remotes.relays {
|
||||||
// Don't relay through the host I'm trying to connect to
|
// Don't relay to myself
|
||||||
if relay == vpnIp {
|
if relay == vpnIp {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't relay to myself
|
// Don't relay through the host I'm trying to connect to
|
||||||
if hm.f.myVpnAddrsTable.Contains(relay) {
|
if hm.f.myVpnAddrsTable.Contains(relay) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -451,13 +451,15 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
|
|||||||
vpnAddrs: []netip.Addr{vpnAddr},
|
vpnAddrs: []netip.Addr{vpnAddr},
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: nil,
|
RWMutex: synctrace.NewRWMutex("relay-state"),
|
||||||
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
hh := &HandshakeHostInfo{
|
hh := &HandshakeHostInfo{
|
||||||
|
Mutex: synctrace.NewMutex("handshake-hostinfo"),
|
||||||
hostinfo: hostinfo,
|
hostinfo: hostinfo,
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
}
|
}
|
||||||
|
|||||||
100
hostmap.go
100
hostmap.go
@@ -4,8 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,12 +13,15 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// const ProbeLen = 100
|
||||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||||
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
||||||
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
||||||
const MaxRemotes = 10
|
const MaxRemotes = 10
|
||||||
|
const maxRecvError = 4
|
||||||
|
|
||||||
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
||||||
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
||||||
@@ -52,22 +53,22 @@ type Relay struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HostMap struct {
|
type HostMap struct {
|
||||||
sync.RWMutex //Because we concurrently read and write to our maps
|
synctrace.RWMutex //Because we concurrently read and write to our maps
|
||||||
Indexes map[uint32]*HostInfo
|
Indexes map[uint32]*HostInfo
|
||||||
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
|
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
|
||||||
RemoteIndexes map[uint32]*HostInfo
|
RemoteIndexes map[uint32]*HostInfo
|
||||||
Hosts map[netip.Addr]*HostInfo
|
Hosts map[netip.Addr]*HostInfo
|
||||||
preferredRanges atomic.Pointer[[]netip.Prefix]
|
preferredRanges atomic.Pointer[[]netip.Prefix]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
||||||
// struct, make a copy of an existing value, edit the fileds in the copy, and
|
// struct, make a copy of an existing value, edit the fileds in the copy, and
|
||||||
// then store a pointer to the new copy in both realyForBy* maps.
|
// then store a pointer to the new copy in both realyForBy* maps.
|
||||||
type RelayState struct {
|
type RelayState struct {
|
||||||
sync.RWMutex
|
synctrace.RWMutex
|
||||||
|
|
||||||
relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
|
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
|
||||||
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
|
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
|
||||||
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
|
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
|
||||||
// the RelayState Lock held)
|
// the RelayState Lock held)
|
||||||
@@ -78,12 +79,7 @@ type RelayState struct {
|
|||||||
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
||||||
rs.Lock()
|
rs.Lock()
|
||||||
defer rs.Unlock()
|
defer rs.Unlock()
|
||||||
for idx, val := range rs.relays {
|
delete(rs.relays, ip)
|
||||||
if val == ip {
|
|
||||||
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
|
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
|
||||||
@@ -128,16 +124,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
|
|||||||
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
|
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
|
||||||
rs.Lock()
|
rs.Lock()
|
||||||
defer rs.Unlock()
|
defer rs.Unlock()
|
||||||
if !slices.Contains(rs.relays, ip) {
|
rs.relays[ip] = struct{}{}
|
||||||
rs.relays = append(rs.relays, ip)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
||||||
ret := make([]netip.Addr, len(rs.relays))
|
|
||||||
rs.RLock()
|
rs.RLock()
|
||||||
defer rs.RUnlock()
|
defer rs.RUnlock()
|
||||||
copy(ret, rs.relays)
|
ret := make([]netip.Addr, 0, len(rs.relays))
|
||||||
|
for ip := range rs.relays {
|
||||||
|
ret = append(ret, ip)
|
||||||
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,18 +208,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
|||||||
rs.relayForByIdx[idx] = r
|
rs.relayForByIdx[idx] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetworkType uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
NetworkTypeUnknown NetworkType = iota
|
|
||||||
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
|
||||||
NetworkTypeVPN
|
|
||||||
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
|
||||||
NetworkTypeVPNPeer
|
|
||||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
|
||||||
NetworkTypeUnsafe
|
|
||||||
)
|
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
remote netip.AddrPort
|
remote netip.AddrPort
|
||||||
remotes *RemoteList
|
remotes *RemoteList
|
||||||
@@ -235,10 +219,11 @@ type HostInfo struct {
|
|||||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
||||||
// The host may have other vpn addresses that are outside our
|
// The host may have other vpn addresses that are outside our
|
||||||
// vpn networks but were removed because they are not usable
|
// vpn networks but were removed because they are not usable
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
recvError atomic.Uint32
|
||||||
|
|
||||||
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
// networks are both all vpn and unsafe networks assigned to this host
|
||||||
networks *bart.Table[NetworkType]
|
networks *bart.Lite
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@@ -265,14 +250,6 @@ type HostInfo struct {
|
|||||||
// Used to track other hostinfos for this vpn ip since only 1 can be primary
|
// Used to track other hostinfos for this vpn ip since only 1 can be primary
|
||||||
// Synchronised via hostmap lock and not the hostinfo lock.
|
// Synchronised via hostmap lock and not the hostinfo lock.
|
||||||
next, prev *HostInfo
|
next, prev *HostInfo
|
||||||
|
|
||||||
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
|
|
||||||
in, out, pendingDeletion atomic.Bool
|
|
||||||
|
|
||||||
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
|
|
||||||
// This value will be behind against actual tunnel utilization in the hot path.
|
|
||||||
// This should only be used by the ConnectionManagers ticker routine.
|
|
||||||
lastUsed time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ViaSender struct {
|
type ViaSender struct {
|
||||||
@@ -311,6 +288,7 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
|||||||
|
|
||||||
func newHostMap(l *logrus.Logger) *HostMap {
|
func newHostMap(l *logrus.Logger) *HostMap {
|
||||||
return &HostMap{
|
return &HostMap{
|
||||||
|
RWMutex: synctrace.NewRWMutex("hostmap"),
|
||||||
Indexes: map[uint32]*HostInfo{},
|
Indexes: map[uint32]*HostInfo{},
|
||||||
Relays: map[uint32]*HostInfo{},
|
Relays: map[uint32]*HostInfo{},
|
||||||
RemoteIndexes: map[uint32]*HostInfo{},
|
RemoteIndexes: map[uint32]*HostInfo{},
|
||||||
@@ -742,26 +720,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
|
func (i *HostInfo) RecvErrorExceeded() bool {
|
||||||
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
|
if i.recvError.Add(1) >= maxRecvError {
|
||||||
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
return true
|
||||||
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
|
}
|
||||||
return // Simple case, no BART needed
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||||
|
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||||
|
// Simple case, no CIDRTree needed
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
i.networks = new(bart.Table[NetworkType])
|
i.networks = new(bart.Lite)
|
||||||
for _, network := range c.Networks() {
|
for _, network := range networks {
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
i.networks.Insert(network)
|
||||||
if myVpnNetworksTable.Contains(network.Addr()) {
|
|
||||||
i.networks.Insert(nprefix, NetworkTypeVPN)
|
|
||||||
} else {
|
|
||||||
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range c.UnsafeNetworks() {
|
for _, network := range unsafeNetworks {
|
||||||
i.networks.Insert(network, NetworkTypeUnsafe)
|
i.networks.Insert(network)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHostMap_MakePrimary(t *testing.T) {
|
func TestHostMap_MakePrimary(t *testing.T) {
|
||||||
@@ -216,31 +215,3 @@ func TestHostMap_reload(t *testing.T) {
|
|||||||
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
||||||
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHostMap_RelayState(t *testing.T) {
|
|
||||||
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
|
|
||||||
a1 := netip.MustParseAddr("::1")
|
|
||||||
a2 := netip.MustParseAddr("2001::1")
|
|
||||||
|
|
||||||
h1.relayState.InsertRelayTo(a1)
|
|
||||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
|
||||||
h1.relayState.InsertRelayTo(a2)
|
|
||||||
assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
|
|
||||||
// Ensure that the first relay added is the first one returned in the copy
|
|
||||||
currentRelays := h1.relayState.CopyRelayIps()
|
|
||||||
require.Len(t, currentRelays, 2)
|
|
||||||
assert.Equal(t, a1, currentRelays[0])
|
|
||||||
|
|
||||||
// Deleting the last one in the list works ok
|
|
||||||
h1.relayState.DeleteRelay(a2)
|
|
||||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
|
||||||
|
|
||||||
// Deleting an element not in the list works ok
|
|
||||||
h1.relayState.DeleteRelay(a2)
|
|
||||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
|
||||||
|
|
||||||
// Deleting the only element in the list works ok
|
|
||||||
h1.relayState.DeleteRelay(a1)
|
|
||||||
assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|||||||
108
inside.go
108
inside.go
@@ -2,18 +2,16 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -55,7 +53,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
})
|
})
|
||||||
|
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.rejectInside(packet, out.Payload, q) //todo vector?
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
@@ -68,11 +66,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now)
|
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason == nil {
|
if dropReason == nil {
|
||||||
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
f.rejectInside(packet, out.Payload, q) //todo vector?
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).
|
hostinfo.logger(f.l).
|
||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
@@ -121,10 +120,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
|
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||||
// it does not check if it is within our vpn networks!
|
|
||||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||||
f.handshakeManager.GetOrHandshake(vpnAddr, nil)
|
f.getOrHandshakeNoRouting(vpnAddr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
||||||
@@ -140,6 +138,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
|
|||||||
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
||||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
||||||
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||||
|
|
||||||
destinationAddr := fwPacket.RemoteAddr
|
destinationAddr := fwPacket.RemoteAddr
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
||||||
@@ -219,7 +218,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now())
|
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("fwPacket", fp).
|
f.l.WithField("fwPacket", fp).
|
||||||
@@ -232,10 +231,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
|
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||||
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
|
|
||||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||||
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -290,7 +288,7 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
c := via.ConnectionState.messageCounter.Add(1)
|
c := via.ConnectionState.messageCounter.Add(1)
|
||||||
|
|
||||||
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
|
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
|
||||||
f.connectionManager.Out(via)
|
f.connectionManager.Out(via.localIndexId)
|
||||||
|
|
||||||
// Authenticate the header and payload, but do not encrypt for this message type.
|
// Authenticate the header and payload, but do not encrypt for this message type.
|
||||||
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
|
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
|
||||||
@@ -358,7 +356,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
|
|
||||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
||||||
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
||||||
f.connectionManager.Out(hostinfo)
|
f.connectionManager.Out(hostinfo.localIndexId)
|
||||||
|
|
||||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||||
// all our addrs and enable a faster roaming.
|
// all our addrs and enable a faster roaming.
|
||||||
@@ -411,81 +409,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
|
|
||||||
if ci.eKey == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
|
||||||
fullOut := out.Payload
|
|
||||||
|
|
||||||
if useRelay {
|
|
||||||
if len(out.Payload) < header.Len {
|
|
||||||
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
|
||||||
// Grow it to make sure the next operation works.
|
|
||||||
out.Payload = out.Payload[:header.Len]
|
|
||||||
}
|
|
||||||
// Save a header's worth of data at the front of the 'out' buffer.
|
|
||||||
out.Payload = out.Payload[header.Len:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if noiseutil.EncryptLockNeeded {
|
|
||||||
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
|
|
||||||
ci.writeLock.Lock()
|
|
||||||
}
|
|
||||||
c := ci.messageCounter.Add(1)
|
|
||||||
|
|
||||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
|
||||||
out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c)
|
|
||||||
f.connectionManager.Out(hostinfo)
|
|
||||||
|
|
||||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
|
||||||
// all our addrs and enable a faster roaming.
|
|
||||||
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
|
||||||
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
|
||||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
|
||||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb)
|
|
||||||
if noiseutil.EncryptLockNeeded {
|
|
||||||
ci.writeLock.Unlock()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).
|
|
||||||
WithField("udpAddr", remote).WithField("counter", c).
|
|
||||||
WithField("attemptedCounter", c).
|
|
||||||
Error("Failed to encrypt outgoing packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if remote.IsValid() {
|
|
||||||
err = f.writers[q].Prep(out, remote)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
|
||||||
}
|
|
||||||
} else if hostinfo.remote.IsValid() {
|
|
||||||
err = f.writers[q].Prep(out, hostinfo.remote)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try to send via a relay
|
|
||||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
|
||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.relayState.DeleteRelay(relayIP)
|
|
||||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
//todo vector!!
|
|
||||||
f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
131
interface.go
131
interface.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -17,31 +18,29 @@ import (
|
|||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const mtu = 9001
|
||||||
const batch = 1024 //todo config!
|
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
Outside udp.Conn
|
Outside udp.Conn
|
||||||
Inside overlay.Device
|
Inside overlay.Device
|
||||||
pki *PKI
|
pki *PKI
|
||||||
Cipher string
|
Firewall *Firewall
|
||||||
Firewall *Firewall
|
ServeDns bool
|
||||||
ServeDns bool
|
HandshakeManager *HandshakeManager
|
||||||
HandshakeManager *HandshakeManager
|
lightHouse *LightHouse
|
||||||
lightHouse *LightHouse
|
checkInterval time.Duration
|
||||||
connectionManager *connectionManager
|
pendingDeletionInterval time.Duration
|
||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
routines int
|
routines int
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
version string
|
version string
|
||||||
relayManager *relayManager
|
relayManager *relayManager
|
||||||
punchy *Punchy
|
punchy *Punchy
|
||||||
|
|
||||||
tryPromoteEvery uint32
|
tryPromoteEvery uint32
|
||||||
reQueryEvery uint32
|
reQueryEvery uint32
|
||||||
@@ -87,18 +86,12 @@ type Interface struct {
|
|||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []overlay.TunDev
|
readers []io.ReadWriteCloser
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
listenInN int
|
|
||||||
listenOutN int
|
|
||||||
|
|
||||||
listenInMetric metrics.Histogram
|
|
||||||
listenOutMetric metrics.Histogram
|
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,9 +157,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
if c.Firewall == nil {
|
if c.Firewall == nil {
|
||||||
return nil, errors.New("no firewall rules")
|
return nil, errors.New("no firewall rules")
|
||||||
}
|
}
|
||||||
if c.connectionManager == nil {
|
|
||||||
return nil, errors.New("no connection manager")
|
|
||||||
}
|
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
@@ -184,14 +174,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
routines: c.routines,
|
routines: c.routines,
|
||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]udp.Conn, c.routines),
|
writers: make([]udp.Conn, c.routines),
|
||||||
readers: make([]overlay.TunDev, c.routines),
|
readers: make([]io.ReadWriteCloser, c.routines),
|
||||||
myVpnNetworks: cs.myVpnNetworks,
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
myVpnAddrs: cs.myVpnAddrs,
|
myVpnAddrs: cs.myVpnAddrs,
|
||||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||||
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
connectionManager: c.connectionManager,
|
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
@@ -203,14 +193,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
|
|
||||||
l: c.l,
|
l: c.l,
|
||||||
}
|
}
|
||||||
ifce.listenInMetric = metrics.GetOrRegisterHistogram("vhost.listenIn.n", nil, metrics.NewExpDecaySample(1028, 0.015))
|
|
||||||
ifce.listenOutMetric = metrics.GetOrRegisterHistogram("vhost.listenOut.n", nil, metrics.NewExpDecaySample(1028, 0.015))
|
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
|
|
||||||
ifce.connectionManager.intf = ifce
|
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
|
||||||
|
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
}
|
}
|
||||||
@@ -234,7 +222,7 @@ func (f *Interface) activate() {
|
|||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
var reader overlay.TunDev = f.inside
|
var reader io.ReadWriteCloser = f.inside
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
@@ -263,71 +251,40 @@ func (f *Interface) run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(q int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if q > 0 {
|
if i > 0 {
|
||||||
li = f.writers[q]
|
li = f.writers[i]
|
||||||
} else {
|
} else {
|
||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
|
plaintext := make([]byte, udp.MTU)
|
||||||
outPackets := make([]*packet.OutPacket, batch)
|
|
||||||
for i := 0; i < batch; i++ {
|
|
||||||
outPackets[i] = packet.NewOut()
|
|
||||||
}
|
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
toSend := make([][]byte, batch)
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
li.ListenOut(func(pkts []*packet.Packet) {
|
|
||||||
toSend = toSend[:0]
|
|
||||||
for i := range outPackets {
|
|
||||||
outPackets[i].Valid = false
|
|
||||||
outPackets[i].SegCounter = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
|
|
||||||
//we opportunistically tx, but try to also send stragglers
|
|
||||||
if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to send packets")
|
|
||||||
}
|
|
||||||
//todo I broke this
|
|
||||||
//n := len(toSend)
|
|
||||||
//if f.l.Level == logrus.DebugLevel {
|
|
||||||
// f.listenOutMetric.Update(int64(n))
|
|
||||||
//}
|
|
||||||
//f.listenOutN = n
|
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
packet := make([]byte, mtu)
|
||||||
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
packets := make([]*packet.VirtIOPacket, batch)
|
|
||||||
outPackets := make([]*packet.Packet, batch)
|
|
||||||
for i := 0; i < batch; i++ {
|
|
||||||
packets[i] = packet.NewVIO()
|
|
||||||
outPackets[i] = packet.New(false) //todo?
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.ReadMany(packets, queueNum)
|
n, err := reader.Read(packet)
|
||||||
|
|
||||||
//todo!!
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
return
|
return
|
||||||
@@ -338,22 +295,7 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
|||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level == logrus.DebugLevel {
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
f.listenInMetric.Update(int64(n))
|
|
||||||
}
|
|
||||||
f.listenInN = n
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for i, pkt := range packets[:n] {
|
|
||||||
outPackets[i].OutLen = -1
|
|
||||||
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
|
|
||||||
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
|
|
||||||
pkt.Reset()
|
|
||||||
}
|
|
||||||
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Error while writing outbound packets")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -491,11 +433,6 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
|||||||
} else {
|
} else {
|
||||||
certMaxVersion.Update(int64(certState.v1Cert.Version()))
|
certMaxVersion.Update(int64(certState.v1Cert.Version()))
|
||||||
}
|
}
|
||||||
if f.l.Level != logrus.DebugLevel {
|
|
||||||
f.listenInMetric.Update(int64(f.listenInN))
|
|
||||||
f.listenOutMetric.Update(int64(f.listenOutN))
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
260
lighthouse.go
260
lighthouse.go
@@ -9,7 +9,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,16 +20,16 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrHostNotKnown = errors.New("host not known")
|
var ErrHostNotKnown = errors.New("host not known")
|
||||||
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
|
|
||||||
|
|
||||||
type LightHouse struct {
|
type LightHouse struct {
|
||||||
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
||||||
sync.RWMutex //Because we concurrently read and write to our maps
|
synctrace.RWMutex //Because we concurrently read and write to our maps
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
amLighthouse bool
|
amLighthouse bool
|
||||||
|
|
||||||
myVpnNetworks []netip.Prefix
|
myVpnNetworks []netip.Prefix
|
||||||
myVpnNetworksTable *bart.Lite
|
myVpnNetworksTable *bart.Lite
|
||||||
@@ -57,7 +56,7 @@ type LightHouse struct {
|
|||||||
// staticList exists to avoid having a bool in each addrMap entry
|
// staticList exists to avoid having a bool in each addrMap entry
|
||||||
// since static should be rare
|
// since static should be rare
|
||||||
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
||||||
lighthouses atomic.Pointer[[]netip.Addr]
|
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
|
||||||
|
|
||||||
interval atomic.Int64
|
interval atomic.Int64
|
||||||
updateCancel context.CancelFunc
|
updateCancel context.CancelFunc
|
||||||
@@ -97,6 +96,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
}
|
}
|
||||||
|
|
||||||
h := LightHouse{
|
h := LightHouse{
|
||||||
|
RWMutex: synctrace.NewRWMutex("lighthouse"),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
amLighthouse: amLighthouse,
|
amLighthouse: amLighthouse,
|
||||||
myVpnNetworks: cs.myVpnNetworks,
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
@@ -108,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
lighthouses := make([]netip.Addr, 0)
|
lighthouses := make(map[netip.Addr]struct{})
|
||||||
h.lighthouses.Store(&lighthouses)
|
h.lighthouses.Store(&lighthouses)
|
||||||
staticList := make(map[netip.Addr]struct{})
|
staticList := make(map[netip.Addr]struct{})
|
||||||
h.staticList.Store(&staticList)
|
h.staticList.Store(&staticList)
|
||||||
@@ -144,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
|
|||||||
return *lh.staticList.Load()
|
return *lh.staticList.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) GetLighthouses() []netip.Addr {
|
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
|
||||||
return *lh.lighthouses.Load()
|
return *lh.lighthouses.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,12 +307,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("lighthouse.hosts") {
|
if initial || c.HasChanged("lighthouse.hosts") {
|
||||||
lhList, err := lh.parseLighthouses(c)
|
lhMap := make(map[netip.Addr]struct{})
|
||||||
|
err := lh.parseLighthouses(c, lhMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
lh.lighthouses.Store(&lhList)
|
lh.lighthouses.Store(&lhMap)
|
||||||
if !initial {
|
if !initial {
|
||||||
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
||||||
lh.l.Info("lighthouse.hosts has changed")
|
lh.l.Info("lighthouse.hosts has changed")
|
||||||
@@ -346,38 +347,36 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
|
||||||
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||||
if lh.amLighthouse && len(lhs) != 0 {
|
if lh.amLighthouse && len(lhs) != 0 {
|
||||||
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||||
}
|
}
|
||||||
out := make([]netip.Addr, len(lhs))
|
|
||||||
|
|
||||||
for i, host := range lhs {
|
for i, host := range lhs {
|
||||||
addr, err := netip.ParseAddr(host)
|
addr, err := netip.ParseAddr(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
|
||||||
}
|
}
|
||||||
out[i] = addr
|
lhMap[addr] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !lh.amLighthouse && len(out) == 0 {
|
if !lh.amLighthouse && len(lhMap) == 0 {
|
||||||
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
||||||
}
|
}
|
||||||
|
|
||||||
staticList := lh.GetStaticHostList()
|
staticList := lh.GetStaticHostList()
|
||||||
for i := range out {
|
for lhAddr, _ := range lhMap {
|
||||||
if _, ok := staticList[out[i]]; !ok {
|
if _, ok := staticList[lhAddr]; !ok {
|
||||||
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
|
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
||||||
@@ -432,8 +431,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
||||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vals, ok := v.([]any)
|
vals, ok := v.([]any)
|
||||||
@@ -475,6 +473,7 @@ func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
synctrace.ChanDebugSend("lighthouse-querychan")
|
||||||
lh.queryChan <- vpnAddr
|
lh.queryChan <- vpnAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -489,7 +488,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
|
|||||||
lh.Lock()
|
lh.Lock()
|
||||||
defer lh.Unlock()
|
defer lh.Unlock()
|
||||||
// Add an entry if we don't already have one
|
// Add an entry if we don't already have one
|
||||||
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
|
return lh.unlockedGetRemoteList(vpnAddrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
||||||
@@ -522,15 +521,11 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
||||||
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
|
// First we check the static mapping
|
||||||
staticList := lh.GetStaticHostList()
|
// and do nothing if it is there
|
||||||
for _, addr := range allVpnAddrs {
|
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
|
||||||
if _, ok := staticList[addr]; ok {
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// None of the VpnAddrs were present. Now we can do the deletes.
|
|
||||||
lh.Lock()
|
lh.Lock()
|
||||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -572,7 +567,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
|
|||||||
am.unlockedSetHostnamesResults(hr)
|
am.unlockedSetHostnamesResults(hr)
|
||||||
|
|
||||||
for _, addrPort := range hr.GetAddrs() {
|
for _, addrPort := range hr.GetAddrs() {
|
||||||
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
|
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
@@ -634,30 +629,23 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
|
|||||||
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedGetRemoteList assumes you have the lh lock
|
// unlockedGetRemoteList
|
||||||
|
// assumes you have the lh lock
|
||||||
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
||||||
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
|
am, ok := lh.addrMap[allAddrs[0]]
|
||||||
for i, addr := range allAddrs {
|
if !ok {
|
||||||
am, ok := lh.addrMap[addr]
|
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
|
||||||
if ok {
|
for _, addr := range allAddrs {
|
||||||
if i != 0 {
|
lh.addrMap[addr] = am
|
||||||
lh.addrMap[allAddrs[0]] = am
|
|
||||||
}
|
|
||||||
return am
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
am := NewRemoteList(allAddrs, lh.shouldAdd)
|
|
||||||
for _, addr := range allAddrs {
|
|
||||||
lh.addrMap[addr] = am
|
|
||||||
}
|
|
||||||
return am
|
return am
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
|
||||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Level >= logrus.TraceLevel {
|
||||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
|
||||||
Trace("remoteAllowList.Allow")
|
Trace("remoteAllowList.Allow")
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -712,22 +700,19 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
||||||
l := lh.GetLighthouses()
|
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
|
||||||
for i := range l {
|
return true
|
||||||
if l[i] == vpnAddr {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
|
||||||
|
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
|
||||||
|
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
|
||||||
l := lh.GetLighthouses()
|
l := lh.GetLighthouses()
|
||||||
for i := range vpnAddrs {
|
for _, a := range vpnAddr {
|
||||||
for j := range l {
|
if _, ok := l[a]; ok {
|
||||||
if l[j] == vpnAddrs[i] {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -742,9 +727,11 @@ func (lh *LightHouse) startQueryWorker() {
|
|||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
|
|
||||||
|
synctrace.ChanDebugRecvLock("lighthouse-querychan")
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-lh.ctx.Done():
|
case <-lh.ctx.Done():
|
||||||
|
synctrace.ChanDebugRecvUnlock("lighthouse-querychan")
|
||||||
return
|
return
|
||||||
case addr := <-lh.queryChan:
|
case addr := <-lh.queryChan:
|
||||||
lh.innerQueryServer(addr, nb, out)
|
lh.innerQueryServer(addr, nb, out)
|
||||||
@@ -769,7 +756,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
queried := 0
|
queried := 0
|
||||||
lighthouses := lh.GetLighthouses()
|
lighthouses := lh.GetLighthouses()
|
||||||
|
|
||||||
for _, lhVpnAddr := range lighthouses {
|
for lhVpnAddr := range lighthouses {
|
||||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||||
if hi != nil {
|
if hi != nil {
|
||||||
v = hi.ConnectionState.myCert.Version()
|
v = hi.ConnectionState.myCert.Version()
|
||||||
@@ -887,7 +874,7 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
updated := 0
|
updated := 0
|
||||||
lighthouses := lh.GetLighthouses()
|
lighthouses := lh.GetLighthouses()
|
||||||
|
|
||||||
for _, lhVpnAddr := range lighthouses {
|
for lhVpnAddr := range lighthouses {
|
||||||
var v cert.Version
|
var v cert.Version
|
||||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||||
if hi != nil {
|
if hi != nil {
|
||||||
@@ -945,6 +932,7 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
V4AddrPorts: v4,
|
V4AddrPorts: v4,
|
||||||
V6AddrPorts: v6,
|
V6AddrPorts: v6,
|
||||||
RelayVpnAddrs: relays,
|
RelayVpnAddrs: relays,
|
||||||
|
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1064,19 +1052,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
useVersion := cert.Version1
|
||||||
if err != nil {
|
var queryVpnAddr netip.Addr
|
||||||
|
if n.Details.OldVpnAddr != 0 {
|
||||||
|
b := [4]byte{}
|
||||||
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
|
queryVpnAddr = netip.AddrFrom4(b)
|
||||||
|
useVersion = 1
|
||||||
|
} else if n.Details.VpnAddr != nil {
|
||||||
|
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
|
useVersion = 2
|
||||||
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
|
||||||
Debugln("Dropping malformed HostQuery")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
|
||||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
|
||||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1085,6 +1073,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostQueryReply
|
n.Type = NebulaMeta_HostQueryReply
|
||||||
if useVersion == cert.Version1 {
|
if useVersion == cert.Version1 {
|
||||||
|
if !queryVpnAddr.Is4() {
|
||||||
|
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
|
||||||
|
}
|
||||||
b := queryVpnAddr.As4()
|
b := queryVpnAddr.As4()
|
||||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
||||||
} else {
|
} else {
|
||||||
@@ -1129,9 +1120,8 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newDest
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
|
||||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
//choosing to do nothing for now, but maybe we return an error?
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1190,17 +1180,19 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
|||||||
if !r.Is4() {
|
if !r.Is4() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
b = r.As4()
|
b = r.As4()
|
||||||
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if v == cert.Version2 {
|
} else if v == cert.Version2 {
|
||||||
for _, r := range c.relay.relay {
|
for _, r := range c.relay.relay {
|
||||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
//TODO: CERT-V2 don't panic
|
||||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
panic("unsupported version")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1210,16 +1202,18 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
lhh.lh.Lock()
|
||||||
if err != nil {
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
var certVpnAddr netip.Addr
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
if n.Details.OldVpnAddr != 0 {
|
||||||
}
|
b := [4]byte{}
|
||||||
return
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
|
certVpnAddr = netip.AddrFrom4(b)
|
||||||
|
} else if n.Details.VpnAddr != nil {
|
||||||
|
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
}
|
}
|
||||||
relays := n.Details.GetRelays()
|
relays := n.Details.GetRelays()
|
||||||
|
|
||||||
lhh.lh.Lock()
|
|
||||||
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
||||||
am.Lock()
|
am.Lock()
|
||||||
lhh.lh.Unlock()
|
lhh.lh.Unlock()
|
||||||
@@ -1244,24 +1238,27 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
|
|
||||||
var detailsVpnAddr netip.Addr
|
var detailsVpnAddr netip.Addr
|
||||||
var useVersion cert.Version
|
useVersion := cert.Version1
|
||||||
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
|
if n.Details.OldVpnAddr != 0 {
|
||||||
b := [4]byte{}
|
b := [4]byte{}
|
||||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
detailsVpnAddr = netip.AddrFrom4(b)
|
detailsVpnAddr = netip.AddrFrom4(b)
|
||||||
useVersion = cert.Version1
|
useVersion = cert.Version1
|
||||||
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
|
} else if n.Details.VpnAddr != nil {
|
||||||
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
useVersion = cert.Version2
|
useVersion = cert.Version2
|
||||||
} else {
|
} else {
|
||||||
detailsVpnAddr = netip.Addr{}
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
useVersion = cert.Version2
|
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
|
||||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
|
||||||
|
//Simple check that the host sent this not someone else
|
||||||
|
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
||||||
}
|
}
|
||||||
@@ -1275,24 +1272,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
am.Lock()
|
am.Lock()
|
||||||
lhh.lh.Unlock()
|
lhh.lh.Unlock()
|
||||||
|
|
||||||
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||||
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||||
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
||||||
am.Unlock()
|
am.Unlock()
|
||||||
|
|
||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostUpdateNotificationAck
|
n.Type = NebulaMeta_HostUpdateNotificationAck
|
||||||
switch useVersion {
|
|
||||||
case cert.Version1:
|
if useVersion == cert.Version1 {
|
||||||
if !fromVpnAddrs[0].Is4() {
|
if !fromVpnAddrs[0].Is4() {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrB := fromVpnAddrs[0].As4()
|
vpnAddrB := fromVpnAddrs[0].As4()
|
||||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
||||||
case cert.Version2:
|
} else if useVersion == cert.Version2 {
|
||||||
// do nothing, we want to send a blank message
|
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
|
||||||
default:
|
} else {
|
||||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1310,20 +1307,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
||||||
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
||||||
|
//maybe one day we'll have a better idea, if it matters.
|
||||||
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
|
||||||
if err != nil {
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
|
||||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
empty := []byte{0}
|
empty := []byte{0}
|
||||||
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
|
punch := func(vpnPeer netip.AddrPort) {
|
||||||
if !vpnPeer.IsValid() {
|
if !vpnPeer.IsValid() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1335,38 +1325,48 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
|
var logVpnAddr netip.Addr
|
||||||
|
if n.Details.OldVpnAddr != 0 {
|
||||||
|
b := [4]byte{}
|
||||||
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
|
logVpnAddr = netip.AddrFrom4(b)
|
||||||
|
} else if n.Details.VpnAddr != nil {
|
||||||
|
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
|
}
|
||||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
|
||||||
for _, a := range n.Details.V4AddrPorts {
|
for _, a := range n.Details.V4AddrPorts {
|
||||||
b := protoV4AddrPortToNetAddrPort(a)
|
punch(protoV4AddrPortToNetAddrPort(a))
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
|
||||||
punch(b, detailsVpnAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V6AddrPorts {
|
for _, a := range n.Details.V6AddrPorts {
|
||||||
b := protoV6AddrPortToNetAddrPort(a)
|
punch(protoV6AddrPortToNetAddrPort(a))
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
|
||||||
punch(b, detailsVpnAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||||
// of a double nat or other difficult scenario, this may help establish
|
// of a double nat or other difficult scenario, this may help establish
|
||||||
// a tunnel.
|
// a tunnel.
|
||||||
if lhh.lh.punchy.GetRespond() {
|
if lhh.lh.punchy.GetRespond() {
|
||||||
|
var queryVpnAddr netip.Addr
|
||||||
|
if n.Details.OldVpnAddr != 0 {
|
||||||
|
b := [4]byte{}
|
||||||
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
|
queryVpnAddr = netip.AddrFrom4(b)
|
||||||
|
} else if n.Details.VpnAddr != nil {
|
||||||
|
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
|
||||||
}
|
}
|
||||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||||
// managed by a channel.
|
// managed by a channel.
|
||||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1445,17 +1445,3 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
|
|||||||
}
|
}
|
||||||
return netip.Addr{}, false
|
return netip.Addr{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
|
|
||||||
if d.OldVpnAddr != 0 {
|
|
||||||
b := [4]byte{}
|
|
||||||
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
|
|
||||||
detailsVpnAddr := netip.AddrFrom4(b)
|
|
||||||
return detailsVpnAddr, cert.Version1, nil
|
|
||||||
} else if d.VpnAddr != nil {
|
|
||||||
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
|
|
||||||
return detailsVpnAddr, cert.Version2, nil
|
|
||||||
} else {
|
|
||||||
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOldIPv4Only(t *testing.T) {
|
func TestOldIPv4Only(t *testing.T) {
|
||||||
@@ -493,123 +493,3 @@ func Test_findNetworkUnion(t *testing.T) {
|
|||||||
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
|
|
||||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
|
||||||
|
|
||||||
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
|
|
||||||
testStaticHost := netip.MustParseAddr("10.128.0.42")
|
|
||||||
//myVpnIp := netip.MustParseAddr("10.128.0.2")
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
|
||||||
lh1 := "10.128.0.2"
|
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
|
||||||
"hosts": []any{lh1},
|
|
||||||
"interval": "1s",
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
|
||||||
c.Settings["static_host_map"] = map[string]any{
|
|
||||||
lh1: []any{"1.1.1.1:4242"},
|
|
||||||
"10.128.0.42": []any{"1.2.3.4:4242"},
|
|
||||||
}
|
|
||||||
|
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
|
||||||
nt := new(bart.Lite)
|
|
||||||
nt.Insert(myVpnNet)
|
|
||||||
cs := &CertState{
|
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
|
||||||
myVpnNetworksTable: nt,
|
|
||||||
}
|
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lh.ifce = &mockEncWriter{}
|
|
||||||
|
|
||||||
//test that we actually have the static entry:
|
|
||||||
out := lh.Query(testStaticHost)
|
|
||||||
assert.NotNil(t, out)
|
|
||||||
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
|
|
||||||
out.Rebuild([]netip.Prefix{}) //why tho
|
|
||||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
|
||||||
|
|
||||||
//bolt on a lower numbered primary IP
|
|
||||||
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
|
|
||||||
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
|
|
||||||
lh.addrMap[testSameHostNotStatic] = am
|
|
||||||
out.Rebuild([]netip.Prefix{}) //???
|
|
||||||
|
|
||||||
//test that we actually have the static entry:
|
|
||||||
out = lh.Query(testStaticHost)
|
|
||||||
assert.NotNil(t, out)
|
|
||||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
|
||||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
|
||||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
|
||||||
|
|
||||||
//test that we actually have the static entry for BOTH:
|
|
||||||
out2 := lh.Query(testSameHostNotStatic)
|
|
||||||
assert.Same(t, out2, out)
|
|
||||||
|
|
||||||
//now do the delete
|
|
||||||
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
|
|
||||||
//verify
|
|
||||||
out = lh.Query(testSameHostNotStatic)
|
|
||||||
assert.NotNil(t, out)
|
|
||||||
if out == nil {
|
|
||||||
t.Fatal("expected non-nil query for the static host")
|
|
||||||
}
|
|
||||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
|
||||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
|
||||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLighthouse_DeletesWork(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
|
|
||||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
|
||||||
testHost := netip.MustParseAddr("10.128.0.42")
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
|
||||||
lh1 := "10.128.0.2"
|
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
|
||||||
"hosts": []any{lh1},
|
|
||||||
"interval": "1s",
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
|
||||||
c.Settings["static_host_map"] = map[string]any{
|
|
||||||
lh1: []any{"1.1.1.1:4242"},
|
|
||||||
}
|
|
||||||
|
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
|
||||||
nt := new(bart.Lite)
|
|
||||||
nt.Insert(myVpnNet)
|
|
||||||
cs := &CertState{
|
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
|
||||||
myVpnNetworksTable: nt,
|
|
||||||
}
|
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lh.ifce = &mockEncWriter{}
|
|
||||||
|
|
||||||
//insert the host
|
|
||||||
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
|
|
||||||
am.vpnAddrs = []netip.Addr{testHost}
|
|
||||||
am.addrs = []netip.AddrPort{myUdpAddr2}
|
|
||||||
lh.addrMap[testHost] = am
|
|
||||||
am.Rebuild([]netip.Prefix{}) //???
|
|
||||||
|
|
||||||
//test that we actually have the entry:
|
|
||||||
out := lh.Query(testHost)
|
|
||||||
assert.NotNil(t, out)
|
|
||||||
assert.Equal(t, out.vpnAddrs[0], testHost)
|
|
||||||
out.Rebuild([]netip.Prefix{}) //why tho
|
|
||||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
|
||||||
|
|
||||||
//now do the delete
|
|
||||||
lh.DeleteVpnAddrs([]netip.Addr{testHost})
|
|
||||||
//verify
|
|
||||||
out = lh.Query(testHost)
|
|
||||||
assert.Nil(t, out)
|
|
||||||
}
|
|
||||||
|
|||||||
50
main.go
50
main.go
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
@@ -75,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
||||||
sshStart = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,7 +185,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
hostMap := NewHostMapFromConfig(l, c)
|
hostMap := NewHostMapFromConfig(l, c)
|
||||||
punchy := NewPunchyFromConfig(l, c)
|
punchy := NewPunchyFromConfig(l, c)
|
||||||
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
|
|
||||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
||||||
@@ -222,26 +220,31 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
|
||||||
|
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
HostMap: hostMap,
|
HostMap: hostMap,
|
||||||
Inside: tun,
|
Inside: tun,
|
||||||
Outside: udpConns[0],
|
Outside: udpConns[0],
|
||||||
pki: pki,
|
pki: pki,
|
||||||
Firewall: fw,
|
Firewall: fw,
|
||||||
ServeDns: serveDns,
|
ServeDns: serveDns,
|
||||||
HandshakeManager: handshakeManager,
|
HandshakeManager: handshakeManager,
|
||||||
connectionManager: connManager,
|
lightHouse: lightHouse,
|
||||||
lightHouse: lightHouse,
|
checkInterval: time.Second * time.Duration(checkInterval),
|
||||||
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
|
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
|
||||||
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
|
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
|
||||||
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
|
||||||
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
||||||
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
||||||
routines: routines,
|
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
||||||
MessageMetrics: messageMetrics,
|
routines: routines,
|
||||||
version: buildVersion,
|
MessageMetrics: messageMetrics,
|
||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
version: buildVersion,
|
||||||
punchy: punchy,
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
|
punchy: punchy,
|
||||||
|
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@@ -293,6 +296,5 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
statsStart,
|
statsStart,
|
||||||
dnsStart,
|
dnsStart,
|
||||||
lightHouse.StartUpdateWorker,
|
lightHouse.StartUpdateWorker,
|
||||||
connManager.Start,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
294
outside.go
294
outside.go
@@ -7,7 +7,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -20,7 +19,7 @@ const (
|
|||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
@@ -62,7 +61,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
switch h.Subtype {
|
switch h.Subtype {
|
||||||
case header.MessageNone:
|
case header.MessageNone:
|
||||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
|
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case header.MessageRelay:
|
case header.MessageRelay:
|
||||||
@@ -82,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
// Pull the Roaming parts up here, and return in all call paths.
|
// Pull the Roaming parts up here, and return in all call paths.
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo.localIndexId)
|
||||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||||
|
|
||||||
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
||||||
@@ -97,7 +96,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
case TerminalType:
|
case TerminalType:
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
// If I am the target of this relay, process the unwrapped packet
|
||||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
|
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
return
|
return
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
@@ -214,218 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo.localIndexId)
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
|
||||||
for i, pkt := range packets {
|
|
||||||
out[i].Scratch = out[i].Scratch[:0]
|
|
||||||
ip := pkt.AddrPort()
|
|
||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
|
||||||
if ip.IsValid() {
|
|
||||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//todo per-segment!
|
|
||||||
for segment := range pkt.Segments() {
|
|
||||||
|
|
||||||
err := h.Parse(segment)
|
|
||||||
if err != nil {
|
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
|
||||||
if len(segment) > 1 {
|
|
||||||
f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var hostinfo *HostInfo
|
|
||||||
// verify if we've seen this index before, otherwise respond to the handshake initiation
|
|
||||||
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
|
||||||
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
|
||||||
} else {
|
|
||||||
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ci *ConnectionState
|
|
||||||
if hostinfo != nil {
|
|
||||||
ci = hostinfo.ConnectionState
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h.Type {
|
|
||||||
case header.Message:
|
|
||||||
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h.Subtype {
|
|
||||||
case header.MessageNone:
|
|
||||||
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case header.MessageRelay:
|
|
||||||
// The entire body is sent as AD, not encrypted.
|
|
||||||
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
|
||||||
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
|
||||||
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
|
|
||||||
// which will gracefully fail in the DecryptDanger call.
|
|
||||||
signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
|
|
||||||
signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
|
|
||||||
out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Successfully validated the thing. Get rid of the Relay header.
|
|
||||||
signedPayload = signedPayload[header.Len:]
|
|
||||||
// Pull the Roaming parts up here, and return in all call paths.
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
|
||||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
|
||||||
|
|
||||||
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
|
||||||
if !ok {
|
|
||||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
|
||||||
// its internal mapping. This should never happen.
|
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch relay.Type {
|
|
||||||
case TerminalType:
|
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
|
||||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
|
||||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
|
|
||||||
return
|
|
||||||
case ForwardingType:
|
|
||||||
// Find the target HostInfo relay object
|
|
||||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If that relay is Established, forward the payload through it
|
|
||||||
if targetRelay.State == Established {
|
|
||||||
switch targetRelay.Type {
|
|
||||||
case ForwardingType:
|
|
||||||
// Forward this packet through the relay tunnel
|
|
||||||
// Find the target HostInfo
|
|
||||||
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
|
|
||||||
return
|
|
||||||
case TerminalType:
|
|
||||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case header.LightHouse:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", segment).
|
|
||||||
Error("Failed to decrypt lighthouse packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
case header.Test:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", segment).
|
|
||||||
Error("Failed to decrypt test packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.Subtype == header.TestRequest {
|
|
||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
|
||||||
// to the new IP address before responding
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
|
||||||
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
|
||||||
// are unauthenticated
|
|
||||||
|
|
||||||
case header.Handshake:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
f.handshakeManager.HandleIncoming(ip, nil, segment, h)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.RecvError:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
f.handleRecvError(ip, h)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.CloseTunnel:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", ip).
|
|
||||||
Info("Close tunnel received, tearing down.")
|
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.Control:
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", segment).
|
|
||||||
Error("Failed to decrypt Control packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
|
||||||
|
|
||||||
default:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
|
|
||||||
}
|
|
||||||
_, err := f.readers[q].WriteOne(out[i], false, q)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to write packet")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
||||||
@@ -466,18 +254,16 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEncrypted returns true if a packet should be processed, false otherwise
|
|
||||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
||||||
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
// If connectionstate exists and the replay protector allows, process packet
|
||||||
if ci == nil {
|
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||||
|
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
f.maybeSendRecvError(addr, h.RemoteIndex)
|
f.maybeSendRecvError(addr, h.RemoteIndex)
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
|
||||||
// If the window check fails, refuse to process the packet, but don't send a recv error
|
|
||||||
if !ci.window.Check(f.l, h.MessageCounter) {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -677,55 +463,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||||
var err error
|
|
||||||
|
|
||||||
seg, err := f.readers[q].AllocSeg(out, q)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0]
|
|
||||||
out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
err = newPacket(out.SegmentPayloads[seg], true, fwPacket)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
|
||||||
Warnf("Error while validating inbound packet")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
|
||||||
Debugln("dropping out of window packet")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
|
|
||||||
if dropReason != nil {
|
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
|
||||||
// This gives us a buffer to build the reject packet in
|
|
||||||
f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
|
||||||
WithField("reason", dropReason).
|
|
||||||
Debugln("dropping inbound packet")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
pkt.OutLen += len(inSegment)
|
|
||||||
out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])]
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
@@ -747,7 +485,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
|
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
@@ -760,7 +498,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo.localIndexId)
|
||||||
_, err = f.readers[q].Write(out)
|
_, err = f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
@@ -799,6 +537,10 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hostinfo.RecvErrorExceeded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
TunDev
|
io.ReadWriteCloser
|
||||||
Activate() error
|
Activate() error
|
||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
NewMultiQueueReader() (TunDev, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,91 +0,0 @@
|
|||||||
package eventfd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
type EventFD struct {
|
|
||||||
fd int
|
|
||||||
buf [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() (EventFD, error) {
|
|
||||||
fd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
|
|
||||||
if err != nil {
|
|
||||||
return EventFD{}, err
|
|
||||||
}
|
|
||||||
return EventFD{
|
|
||||||
fd: fd,
|
|
||||||
buf: [8]byte{},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *EventFD) Kick() error {
|
|
||||||
binary.LittleEndian.PutUint64(e.buf[:], 1) //is this right???
|
|
||||||
_, err := syscall.Write(int(e.fd), e.buf[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *EventFD) Close() error {
|
|
||||||
if e.fd != 0 {
|
|
||||||
return unix.Close(e.fd)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *EventFD) FD() int {
|
|
||||||
return e.fd
|
|
||||||
}
|
|
||||||
|
|
||||||
type Epoll struct {
|
|
||||||
fd int
|
|
||||||
buf [8]byte
|
|
||||||
events []syscall.EpollEvent
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewEpoll() (Epoll, error) {
|
|
||||||
fd, err := unix.EpollCreate1(0)
|
|
||||||
if err != nil {
|
|
||||||
return Epoll{}, err
|
|
||||||
}
|
|
||||||
return Epoll{
|
|
||||||
fd: fd,
|
|
||||||
buf: [8]byte{},
|
|
||||||
events: make([]syscall.EpollEvent, 1),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ep *Epoll) AddEvent(fdToAdd int) error {
|
|
||||||
event := syscall.EpollEvent{
|
|
||||||
Events: syscall.EPOLLIN,
|
|
||||||
Fd: int32(fdToAdd),
|
|
||||||
}
|
|
||||||
return syscall.EpollCtl(ep.fd, syscall.EPOLL_CTL_ADD, fdToAdd, &event)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ep *Epoll) Block() (int, error) {
|
|
||||||
n, err := syscall.EpollWait(ep.fd, ep.events, -1)
|
|
||||||
if err != nil {
|
|
||||||
//goland:noinspection GoDirectComparisonOfErrors
|
|
||||||
if err == syscall.EINTR {
|
|
||||||
return 0, nil //??
|
|
||||||
}
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ep *Epoll) Clear() error {
|
|
||||||
_, err := syscall.Read(int(ep.events[0].Fd), ep.buf[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ep *Epoll) Close() error {
|
|
||||||
if ep.fd != 0 {
|
|
||||||
return unix.Close(ep.fd)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,30 +1,15 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMTU = 1300
|
const DefaultMTU = 1300
|
||||||
|
|
||||||
type TunDev interface {
|
|
||||||
io.WriteCloser
|
|
||||||
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
|
|
||||||
|
|
||||||
//todo this interface sux
|
|
||||||
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
|
|
||||||
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
|
|
||||||
WriteMany(x []*packet.OutPacket, q int) (int, error)
|
|
||||||
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|
||||||
@@ -39,11 +24,11 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||||
// return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
// return newTunFromFd(c, l, *fd, vpnNetworks)
|
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||||
// }
|
}
|
||||||
//}
|
}
|
||||||
|
|
||||||
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
|
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
|
||||||
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
||||||
@@ -85,51 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
|||||||
|
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
|
||||||
pLen := 128
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
pLen = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func flipBytes(b []byte) []byte {
|
|
||||||
for i := 0; i < len(b); i++ {
|
|
||||||
b[i] ^= 0xFF
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
func orBytes(a []byte, b []byte) []byte {
|
|
||||||
ret := make([]byte, len(a))
|
|
||||||
for i := 0; i < len(a); i++ {
|
|
||||||
ret[i] = a[i] | b[i]
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
|
||||||
broadcast, _ := netip.AddrFromSlice(
|
|
||||||
orBytes(
|
|
||||||
cidr.Addr().AsSlice(),
|
|
||||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return broadcast
|
|
||||||
}
|
|
||||||
|
|
||||||
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
|
||||||
for _, gateway := range gateways {
|
|
||||||
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
|
||||||
return gateway, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
|
||||||
return gateway, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -294,6 +295,7 @@ func (t *tun) activate6(network netip.Prefix) error {
|
|||||||
Vltime: 0xffffffff,
|
Vltime: 0xffffffff,
|
||||||
Pltime: 0xffffffff,
|
Pltime: 0xffffffff,
|
||||||
},
|
},
|
||||||
|
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
|
||||||
Flags: _IN6_IFF_NODAD,
|
Flags: _IN6_IFF_NODAD,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -552,3 +554,13 @@ func (t *tun) Name() string {
|
|||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||||
|
pLen := 128
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
pLen = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,10 +22,6 @@ type disabledTun struct {
|
|||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||||
tun := &disabledTun{
|
tun := &disabledTun{
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
@@ -46,10 +40,6 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
|
|||||||
return tun
|
return tun
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*disabledTun) Activate() error {
|
func (*disabledTun) Activate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -115,23 +105,7 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
|
||||||
return t.Read(b[0].Payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
@@ -20,18 +22,12 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||||
FIODGNAME = 0x80106678
|
FIODGNAME = 0x80106678
|
||||||
TUNSIFMODE = 0x8004745e
|
|
||||||
TUNSIFHEAD = 0x80047460
|
|
||||||
OSIOCAIFADDR_IN6 = 0x8088691b
|
|
||||||
IN6_IFF_NODAD = 0x0020
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type fiodgnameArg struct {
|
type fiodgnameArg struct {
|
||||||
@@ -41,159 +37,43 @@ type fiodgnameArg struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ifreqRename struct {
|
type ifreqRename struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [16]byte
|
||||||
Data uintptr
|
Data uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifreqDestroy struct {
|
type ifreqDestroy struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [16]byte
|
||||||
pad [16]byte
|
pad [16]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Flags uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqMTU struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
MTU int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type addrLifetime struct {
|
|
||||||
Expire uint64
|
|
||||||
Preferred uint64
|
|
||||||
Vltime uint32
|
|
||||||
Pltime uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
VHid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet6
|
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime addrLifetime
|
|
||||||
VHid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
linkAddr *netroute.LinkAddr
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
devFd int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
io.ReadWriteCloser
|
||||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
head := make([]byte, 4)
|
|
||||||
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&to[0], uint64(len(to))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if errno != 0 {
|
|
||||||
err = syscall.Errno(errno)
|
|
||||||
} else {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
// fix bytes read number to exclude header
|
|
||||||
bytesRead := int(n)
|
|
||||||
if bytesRead < 0 {
|
|
||||||
return bytesRead, err
|
|
||||||
} else if bytesRead < 4 {
|
|
||||||
return 0, err
|
|
||||||
} else {
|
|
||||||
return bytesRead - 4, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(from) <= 1 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
var head []byte
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
if ipVer == 4 {
|
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&from[0], uint64(len(from))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if errno != 0 {
|
|
||||||
err = syscall.Errno(errno)
|
|
||||||
} else {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(n) - 4, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
if t.devFd >= 0 {
|
if t.ReadWriteCloser != nil {
|
||||||
err := syscall.Close(t.devFd)
|
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("Error closing device")
|
return err
|
||||||
}
|
}
|
||||||
t.devFd = -1
|
defer syscall.Close(s)
|
||||||
|
|
||||||
c := make(chan struct{})
|
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||||
go func() {
|
|
||||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
|
||||||
defer close(c)
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err == nil {
|
|
||||||
defer syscall.Close(s)
|
|
||||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Error destroying tunnel")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// wait up to 1 second so we start blocking at the ioctl
|
// Destroy the interface
|
||||||
select {
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
case <-c:
|
return err
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -205,37 +85,32 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
|
|||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open existing tun device
|
// Try to open existing tun device
|
||||||
var fd int
|
var file *os.File
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName != "" {
|
if deviceName != "" {
|
||||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||||
// If the device doesn't already exist, request a new one and rename it
|
// If the device doesn't already exist, request a new one and rename it
|
||||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the name of the interface
|
rawConn, err := file.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SyscallConn: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var name [16]byte
|
var name [16]byte
|
||||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
var ctrlErr error
|
||||||
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
rawConn.Control(func(fd uintptr) {
|
||||||
|
// Read the name of the interface
|
||||||
if ctrlErr == nil {
|
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||||
// set broadcast mode and multicast
|
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||||
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
})
|
||||||
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctrlErr == nil {
|
|
||||||
// turn on link-layer mode, to support ipv6
|
|
||||||
ifhead := uint32(1)
|
|
||||||
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctrlErr != nil {
|
if ctrlErr != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -247,7 +122,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
|
|
||||||
// If the name doesn't match the desired interface name, rename it now
|
// If the name doesn't match the desired interface name, rename it now
|
||||||
if ifName != deviceName {
|
if ifName != deviceName {
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
s, err := syscall.Socket(
|
||||||
|
syscall.AF_INET,
|
||||||
|
syscall.SOCK_DGRAM,
|
||||||
|
syscall.IPPROTO_IP,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -270,11 +149,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
Device: deviceName,
|
ReadWriteCloser: file,
|
||||||
vpnNetworks: vpnNetworks,
|
Device: deviceName,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
vpnNetworks: vpnNetworks,
|
||||||
l: l,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
devFd: fd,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -293,111 +172,38 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
if cidr.Addr().Is4() {
|
var err error
|
||||||
ifr := ifreqAlias4{
|
// TODO use syscalls instead of exec.Command
|
||||||
Name: t.deviceBytes(),
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
Addr: unix.RawSockaddrInet4{
|
t.l.Debug("command: ", cmd.String())
|
||||||
Len: unix.SizeofSockaddrInet4,
|
if err = cmd.Run(); err != nil {
|
||||||
Family: unix.AF_INET,
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
},
|
|
||||||
DstAddr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: getBroadcast(cidr).As4(),
|
|
||||||
},
|
|
||||||
MaskAddr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
},
|
|
||||||
VHid: 0,
|
|
||||||
}
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
||||||
ifr := ifreqAlias6{
|
t.l.Debug("command: ", cmd.String())
|
||||||
Name: t.deviceBytes(),
|
if err = cmd.Run(); err != nil {
|
||||||
Addr: unix.RawSockaddrInet6{
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
},
|
|
||||||
PrefixMask: unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
},
|
|
||||||
Lifetime: addrLifetime{
|
|
||||||
Expire: 0,
|
|
||||||
Preferred: 0,
|
|
||||||
Vltime: 0xffffffff,
|
|
||||||
Pltime: 0xffffffff,
|
|
||||||
},
|
|
||||||
Flags: IN6_IFF_NODAD,
|
|
||||||
}
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
// Setup our default MTU
|
|
||||||
err := t.setMTU()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
linkAddr, err := getLinkAddr(t.Device)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if linkAddr == nil {
|
|
||||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
|
||||||
}
|
|
||||||
t.linkAddr = linkAddr
|
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
err := t.addIp(t.vpnNetworks[i])
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) setMTU() error {
|
|
||||||
// Set the MTU on the device
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
|
||||||
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
@@ -462,16 +268,15 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := addRoute(r.Cidr, t.linkAddr)
|
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
if err := cmd.Run(); err != nil {
|
||||||
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -484,8 +289,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -500,120 +306,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_ADD,
|
|
||||||
Flags: unix.RTF_UP,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
fmt.Println("DOING CHANGE")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getLinkAddr Gets the link address for the interface of the given name
|
|
||||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
|
||||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range msgs {
|
|
||||||
switch m := m.(type) {
|
|
||||||
case *netroute.InterfaceMessage:
|
|
||||||
if m.Name == name {
|
|
||||||
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
|
||||||
if ok {
|
|
||||||
return sa, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -16,19 +17,15 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay/vhostnet"
|
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
file *os.File
|
io.ReadWriteCloser
|
||||||
fd int
|
fd int
|
||||||
vdev []*vhostnet.Device
|
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MaxMTU int
|
MaxMTU int
|
||||||
@@ -43,8 +40,7 @@ type tun struct {
|
|||||||
useSystemRoutes bool
|
useSystemRoutes bool
|
||||||
useSystemRoutesBufferSize int
|
useSystemRoutesBufferSize int
|
||||||
|
|
||||||
isV6 bool
|
l *logrus.Logger
|
||||||
l *logrus.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
@@ -106,7 +102,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
|
|
||||||
var req ifReq
|
var req ifReq
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI)
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||||
if multiqueue {
|
if multiqueue {
|
||||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||||
}
|
}
|
||||||
@@ -116,47 +112,20 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
if err = unix.SetNonblock(fd, true); err != nil {
|
|
||||||
_ = unix.Close(fd)
|
|
||||||
return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
|
|
||||||
err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("set vnethdr size: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
flags := 0
|
|
||||||
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
|
|
||||||
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("set offloads: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t.fd = fd
|
|
||||||
t.Device = name
|
|
||||||
|
|
||||||
vdev, err := vhostnet.NewDevice(
|
t.Device = name
|
||||||
vhostnet.WithBackendFD(fd),
|
|
||||||
vhostnet.WithQueueSize(8192), //todo config
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t.vdev = []*vhostnet.Device{vdev}
|
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
t := &tun{
|
t := &tun{
|
||||||
file: file,
|
ReadWriteCloser: file,
|
||||||
fd: int(file.Fd()),
|
fd: int(file.Fd()),
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
@@ -164,9 +133,6 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
|||||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
if len(vpnNetworks) != 0 {
|
|
||||||
t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP?
|
|
||||||
}
|
|
||||||
|
|
||||||
err := t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -250,7 +216,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (TunDev, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -263,17 +229,9 @@ func (t *tun) NewMultiQueueReader() (TunDev, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
vdev, err := vhostnet.NewDevice(
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
vhostnet.WithBackendFD(fd),
|
|
||||||
vhostnet.WithQueueSize(8192), //todo config
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.vdev = append(t.vdev, vdev)
|
return file, nil
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
@@ -281,6 +239,29 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
|
var nn int
|
||||||
|
maximum := len(b)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := unix.Write(t.fd, b[nn:maximum])
|
||||||
|
if n > 0 {
|
||||||
|
nn += n
|
||||||
|
}
|
||||||
|
if nn == len(b) {
|
||||||
|
return nn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return nn, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
for i, c := range t.Device {
|
for i, c := range t.Device {
|
||||||
o[i] = byte(c)
|
o[i] = byte(c)
|
||||||
@@ -312,6 +293,7 @@ func (t *tun) addIPs(link netlink.Link) error {
|
|||||||
|
|
||||||
//add all new addresses
|
//add all new addresses
|
||||||
for i := range newAddrs {
|
for i := range newAddrs {
|
||||||
|
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
|
||||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -379,11 +361,6 @@ func (t *tun) Activate() error {
|
|||||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||||
}
|
}
|
||||||
|
|
||||||
const modeNone = 1
|
|
||||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
|
||||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = t.addIPs(link); err != nil {
|
if err = t.addIPs(link); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -661,11 +638,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Dst == nil {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
||||||
@@ -693,14 +665,8 @@ func (t *tun) Close() error {
|
|||||||
close(t.routeChan)
|
close(t.routeChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range t.vdev {
|
if t.ReadWriteCloser != nil {
|
||||||
if v != nil {
|
_ = t.ReadWriteCloser.Close()
|
||||||
_ = v.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.file != nil {
|
|
||||||
_ = t.file.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.ioctlFd > 0 {
|
if t.ioctlFd > 0 {
|
||||||
@@ -709,65 +675,3 @@ func (t *tun) Close() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
|
|
||||||
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
|
||||||
maximum := len(b) //we are RXing
|
|
||||||
|
|
||||||
//todo garbagey
|
|
||||||
out := packet.NewOut()
|
|
||||||
x, err := t.AllocSeg(out, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
copy(out.SegmentPayloads[x], b)
|
|
||||||
err = t.vdev[0].TransmitPacket(out, true)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Transmitting packet")
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return maximum, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
|
||||||
idx, buf, err := t.vdev[q].GetPacketForTx()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
x := pkt.UseSegment(idx, buf, t.isV6)
|
|
||||||
return x, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
|
||||||
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
|
|
||||||
t.l.WithError(err).Error("Transmitting packet")
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|
||||||
maximum := len(x) //we are RXing
|
|
||||||
if maximum == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := t.vdev[q].TransmitPackets(x)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Transmitting packet")
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return maximum, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
|
||||||
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,12 +4,13 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -19,42 +20,11 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type ifreqDestroy struct {
|
||||||
SIOCAIFADDR_IN6 = 0x8080696b
|
Name [16]byte
|
||||||
TUNSIFHEAD = 0x80047442
|
pad [16]byte
|
||||||
TUNSIFMODE = 0x80047458
|
|
||||||
)
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet6
|
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime addrLifetime
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreq struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
data int
|
|
||||||
}
|
|
||||||
|
|
||||||
type addrLifetime struct {
|
|
||||||
Expire uint64
|
|
||||||
Preferred uint64
|
|
||||||
Vltime uint32
|
|
||||||
Pltime uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
@@ -64,18 +34,40 @@ type tun struct {
|
|||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
f *os.File
|
|
||||||
fd int
|
io.ReadWriteCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
func (t *tun) Close() error {
|
||||||
|
if t.ReadWriteCloser != nil {
|
||||||
|
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||||
|
|
||||||
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
// Try to open tun device
|
||||||
|
var file *os.File
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName == "" {
|
if deviceName == "" {
|
||||||
@@ -85,23 +77,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
|
||||||
if err != nil {
|
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
|
||||||
}
|
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
f: os.NewFile(uintptr(fd), ""),
|
ReadWriteCloser: file,
|
||||||
fd: fd,
|
Device: deviceName,
|
||||||
Device: deviceName,
|
vpnNetworks: vpnNetworks,
|
||||||
vpnNetworks: vpnNetworks,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
l: l,
|
||||||
l: l,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -119,225 +105,40 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
|
||||||
if t.f != nil {
|
|
||||||
if err := t.f.Close(); err != nil {
|
|
||||||
return fmt.Errorf("error closing tun file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// t.f.Close should have handled it for us but let's be extra sure
|
|
||||||
_ = unix.Close(t.fd)
|
|
||||||
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ifr := ifreq{Name: t.deviceBytes()}
|
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
|
||||||
rc, err := t.f.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errno syscall.Errno
|
|
||||||
var n uintptr
|
|
||||||
err = rc.Read(func(fd uintptr) bool {
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
head := [4]byte{}
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&to[0], uint64(len(to))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
if errno.Temporary() {
|
|
||||||
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
if err == syscall.EBADF || err.Error() == "use of closed file" {
|
|
||||||
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
|
|
||||||
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
|
|
||||||
return 0, os.ErrClosed
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
|
|
||||||
}
|
|
||||||
|
|
||||||
// fix bytes read number to exclude header
|
|
||||||
bytesRead := int(n)
|
|
||||||
if bytesRead < 0 {
|
|
||||||
return bytesRead, nil
|
|
||||||
} else if bytesRead < 4 {
|
|
||||||
return 0, nil
|
|
||||||
} else {
|
|
||||||
return bytesRead - 4, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
if len(from) <= 1 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
var head [4]byte
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
if ipVer == 4 {
|
|
||||||
head[3] = syscall.AF_INET
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
head[3] = syscall.AF_INET6
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, err := t.f.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var errno syscall.Errno
|
|
||||||
var n uintptr
|
|
||||||
err = rc.Write(func(fd uintptr) bool {
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&from[0], uint64(len(from))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
// According to NetBSD documentation for TUN, writes will only return errors in which
|
|
||||||
// this packet will never be delivered so just go on living life.
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return 0, errno
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(n) - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
if cidr.Addr().Is4() {
|
var err error
|
||||||
var req ifreqAlias4
|
|
||||||
req.Name = t.deviceBytes()
|
|
||||||
req.Addr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.DstAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.MaskAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
// TODO use syscalls instead of exec.Command
|
||||||
if err != nil {
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
return err
|
t.l.Debug("command: ", cmd.String())
|
||||||
}
|
if err = cmd.Run(); err != nil {
|
||||||
defer syscall.Close(s)
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||||
var req ifreqAlias6
|
t.l.Debug("command: ", cmd.String())
|
||||||
req.Name = t.deviceBytes()
|
if err = cmd.Run(); err != nil {
|
||||||
req.Addr = unix.RawSockaddrInet6{
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
}
|
|
||||||
req.PrefixMask = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
}
|
|
||||||
req.Lifetime = addrLifetime{
|
|
||||||
Vltime: 0xffffffff,
|
|
||||||
Pltime: 0xffffffff,
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||||
}
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
func (t *tun) Activate() error {
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
mode := int32(unix.IFF_BROADCAST)
|
|
||||||
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun device mode: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
v := 1
|
|
||||||
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun device head: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
err = t.addIp(t.vpnNetworks[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
return t.addRoutes(false)
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
func (t *tun) Activate() error {
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
for i := range t.vpnNetworks {
|
||||||
if err != nil {
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
defer syscall.Close(s)
|
return nil
|
||||||
|
|
||||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
|
||||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
@@ -396,23 +197,21 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
if err := cmd.Run(); err != nil {
|
||||||
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,8 +224,10 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
//TODO: CERT-V2 is this right?
|
||||||
if err != nil {
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -441,109 +242,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_ADD,
|
|
||||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,50 +4,23 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
SIOCAIFADDR_IN6 = 0x8080691a
|
|
||||||
)
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet6
|
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime [2]uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreq struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
data int
|
|
||||||
}
|
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
@@ -55,46 +28,48 @@ type tun struct {
|
|||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
f *os.File
|
|
||||||
fd int
|
io.ReadWriteCloser
|
||||||
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
out []byte
|
out []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
func (t *tun) Close() error {
|
||||||
|
if t.ReadWriteCloser != nil {
|
||||||
|
return t.ReadWriteCloser.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
return nil
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
|
||||||
var err error
|
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName == "" {
|
if deviceName == "" {
|
||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||||
}
|
|
||||||
if !deviceNameRE.MatchString(deviceName) {
|
|
||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
if !deviceNameRE.MatchString(deviceName) {
|
||||||
|
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
|
||||||
if err != nil {
|
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
|
||||||
}
|
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
f: os.NewFile(uintptr(fd), ""),
|
ReadWriteCloser: file,
|
||||||
fd: fd,
|
Device: deviceName,
|
||||||
Device: deviceName,
|
vpnNetworks: vpnNetworks,
|
||||||
vpnNetworks: vpnNetworks,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
l: l,
|
||||||
l: l,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -112,154 +87,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
|
||||||
if t.f != nil {
|
|
||||||
if err := t.f.Close(); err != nil {
|
|
||||||
return fmt.Errorf("error closing tun file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// t.f.Close should have handled it for us but let's be extra sure
|
|
||||||
_ = unix.Close(t.fd)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
|
||||||
buf := make([]byte, len(to)+4)
|
|
||||||
|
|
||||||
n, err := t.f.Read(buf)
|
|
||||||
|
|
||||||
copy(to, buf[4:])
|
|
||||||
return n - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
buf := t.out
|
|
||||||
if cap(buf) < len(from)+4 {
|
|
||||||
buf = make([]byte, len(from)+4)
|
|
||||||
t.out = buf
|
|
||||||
}
|
|
||||||
buf = buf[:len(from)+4]
|
|
||||||
|
|
||||||
if len(from) == 0 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the IP Family for the NULL L2 Header
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
if ipVer == 4 {
|
|
||||||
buf[3] = syscall.AF_INET
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
buf[3] = syscall.AF_INET6
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(buf[4:], from)
|
|
||||||
|
|
||||||
n, err := t.f.Write(buf)
|
|
||||||
return n - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
|
||||||
if cidr.Addr().Is4() {
|
|
||||||
var req ifreqAlias4
|
|
||||||
req.Name = t.deviceBytes()
|
|
||||||
req.Addr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.DstAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.MaskAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = addRoute(cidr, t.vpnNetworks)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
|
||||||
var req ifreqAlias6
|
|
||||||
req.Name = t.deviceBytes()
|
|
||||||
req.Addr = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
}
|
|
||||||
req.PrefixMask = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
}
|
|
||||||
req.Lifetime[0] = 0xffffffff
|
|
||||||
req.Lifetime[1] = 0xffffffff
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
|
||||||
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
err = t.addIp(t.vpnNetworks[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
|
||||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -297,42 +124,63 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
|
var err error
|
||||||
|
// TODO use syscalls instead of exec.Command
|
||||||
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
|
return t.addRoutes(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
|
||||||
return t.vpnNetworks
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
|
||||||
return t.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
//TODO: CERT-V2 is this right?
|
||||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
if err := cmd.Run(); err != nil {
|
||||||
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -344,9 +192,10 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
//TODO: CERT-V2 is this right?
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
if err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -355,115 +204,52 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
for i, c := range t.Device {
|
return t.vpnNetworks
|
||||||
o[i] = byte(c)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
func (t *tun) Name() string {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
return t.Device
|
||||||
if err != nil {
|
}
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
Version: unix.RTM_VERSION,
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
Type: unix.RTM_ADD,
|
}
|
||||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
|
||||||
Seq: 1,
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
|
buf := make([]byte, len(to)+4)
|
||||||
|
|
||||||
|
n, err := t.ReadWriteCloser.Read(buf)
|
||||||
|
|
||||||
|
copy(to, buf[4:])
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is only valid for single threaded use
|
||||||
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
|
buf := t.out
|
||||||
|
if cap(buf) < len(from)+4 {
|
||||||
|
buf = make([]byte, len(from)+4)
|
||||||
|
t.out = buf
|
||||||
|
}
|
||||||
|
buf = buf[:len(from)+4]
|
||||||
|
|
||||||
|
if len(from) == 0 {
|
||||||
|
return 0, syscall.EIO
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
// Determine the IP Family for the NULL L2 Header
|
||||||
gw, err := selectGateway(prefix, gateways)
|
ipVer := from[0] >> 4
|
||||||
if err != nil {
|
if ipVer == 4 {
|
||||||
return err
|
buf[3] = syscall.AF_INET
|
||||||
}
|
} else if ipVer == 6 {
|
||||||
route.Addrs = []netroute.Addr{
|
buf[3] = syscall.AF_INET6
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
gw, err := selectGateway(prefix, gateways)
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := route.Marshal()
|
copy(buf[4:], from)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
n, err := t.ReadWriteCloser.Write(buf)
|
||||||
if err != nil {
|
return n - 4, err
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,10 +36,6 @@ type UserDevice struct {
|
|||||||
inboundWriter *io.PipeWriter
|
inboundWriter *io.PipeWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) Activate() error {
|
func (d *UserDevice) Activate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -52,7 +46,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
|
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,19 +65,3 @@ func (d *UserDevice) Close() error {
|
|||||||
d.outboundWriter.Close()
|
d.outboundWriter.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
|
||||||
return d.Read(b[0].Payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("user: AllocSeg not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("user: WriteOne not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("user: WriteMany not implemented")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
// Package vhost implements the basic ioctl requests needed to interact with the
|
|
||||||
// kernel-level virtio server that provides accelerated virtio devices for
|
|
||||||
// networking and more.
|
|
||||||
package vhost
|
|
||||||
@@ -1,218 +0,0 @@
|
|||||||
package vhost
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// vhostIoctlGetFeatures can be used to retrieve the features supported by
|
|
||||||
// the vhost implementation in the kernel.
|
|
||||||
//
|
|
||||||
// Response payload: [virtio.Feature]
|
|
||||||
// Kernel name: VHOST_GET_FEATURES
|
|
||||||
vhostIoctlGetFeatures = 0x8008af00
|
|
||||||
|
|
||||||
// vhostIoctlSetFeatures can be used to communicate the features supported
|
|
||||||
// by this virtio implementation to the kernel.
|
|
||||||
//
|
|
||||||
// Request payload: [virtio.Feature]
|
|
||||||
// Kernel name: VHOST_SET_FEATURES
|
|
||||||
vhostIoctlSetFeatures = 0x4008af00
|
|
||||||
|
|
||||||
// vhostIoctlSetOwner can be used to set the current process as the
|
|
||||||
// exclusive owner of a control file descriptor.
|
|
||||||
//
|
|
||||||
// Request payload: none
|
|
||||||
// Kernel name: VHOST_SET_OWNER
|
|
||||||
vhostIoctlSetOwner = 0x0000af01
|
|
||||||
|
|
||||||
// vhostIoctlSetMemoryLayout can be used to set up or modify the memory
|
|
||||||
// layout which describes the IOTLB mappings in the kernel.
|
|
||||||
//
|
|
||||||
// Request payload: [MemoryLayout] with custom serialization
|
|
||||||
// Kernel name: VHOST_SET_MEM_TABLE
|
|
||||||
vhostIoctlSetMemoryLayout = 0x4008af03
|
|
||||||
|
|
||||||
// vhostIoctlSetQueueSize can be used to set the size of the virtqueue.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueState]
|
|
||||||
// Kernel name: VHOST_SET_VRING_NUM
|
|
||||||
vhostIoctlSetQueueSize = 0x4008af10
|
|
||||||
|
|
||||||
// vhostIoctlSetQueueAddress can be used to set the addresses of the
|
|
||||||
// different parts of the virtqueue.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueAddresses]
|
|
||||||
// Kernel name: VHOST_SET_VRING_ADDR
|
|
||||||
vhostIoctlSetQueueAddress = 0x4028af11
|
|
||||||
|
|
||||||
// vhostIoctlSetAvailableRingBase can be used to set the index of the next
|
|
||||||
// available ring entry the device will process.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueState]
|
|
||||||
// Kernel name: VHOST_SET_VRING_BASE
|
|
||||||
vhostIoctlSetAvailableRingBase = 0x4008af12
|
|
||||||
|
|
||||||
// vhostIoctlSetQueueKickEventFD can be used to set the event file
|
|
||||||
// descriptor to signal the device when descriptor chains were added to the
|
|
||||||
// available ring.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueFile]
|
|
||||||
// Kernel name: VHOST_SET_VRING_KICK
|
|
||||||
vhostIoctlSetQueueKickEventFD = 0x4008af20
|
|
||||||
|
|
||||||
// vhostIoctlSetQueueCallEventFD can be used to set the event file
|
|
||||||
// descriptor that gets signaled by the device when descriptor chains have
|
|
||||||
// been used by it.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueFile]
|
|
||||||
// Kernel name: VHOST_SET_VRING_CALL
|
|
||||||
vhostIoctlSetQueueCallEventFD = 0x4008af21
|
|
||||||
)
|
|
||||||
|
|
||||||
// QueueState is an ioctl request payload that can hold a queue index and any
|
|
||||||
// 32-bit number.
|
|
||||||
//
|
|
||||||
// Kernel name: vhost_vring_state
|
|
||||||
type QueueState struct {
|
|
||||||
// QueueIndex is the index of the virtqueue.
|
|
||||||
QueueIndex uint32
|
|
||||||
// Num is any 32-bit number, depending on the request.
|
|
||||||
Num uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueueAddresses is an ioctl request payload that can hold the addresses of the
|
|
||||||
// different parts of a virtqueue.
|
|
||||||
//
|
|
||||||
// Kernel name: vhost_vring_addr
|
|
||||||
type QueueAddresses struct {
|
|
||||||
// QueueIndex is the index of the virtqueue.
|
|
||||||
QueueIndex uint32
|
|
||||||
// Flags that are not used in this implementation.
|
|
||||||
Flags uint32
|
|
||||||
// DescriptorTableAddress is the address of the descriptor table in user
|
|
||||||
// space memory. It must be 16-byte aligned.
|
|
||||||
DescriptorTableAddress uintptr
|
|
||||||
// UsedRingAddress is the address of the used ring in user space memory. It
|
|
||||||
// must be 4-byte aligned.
|
|
||||||
UsedRingAddress uintptr
|
|
||||||
// AvailableRingAddress is the address of the available ring in user space
|
|
||||||
// memory. It must be 2-byte aligned.
|
|
||||||
AvailableRingAddress uintptr
|
|
||||||
// LogAddress is used for an optional logging support, not supported by this
|
|
||||||
// implementation.
|
|
||||||
LogAddress uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueueFile is an ioctl request payload that can hold a queue index and a file
|
|
||||||
// descriptor.
|
|
||||||
//
|
|
||||||
// Kernel name: vhost_vring_file
|
|
||||||
type QueueFile struct {
|
|
||||||
// QueueIndex is the index of the virtqueue.
|
|
||||||
QueueIndex uint32
|
|
||||||
// FD is the file descriptor of the file. Pass -1 to unbind from a file.
|
|
||||||
FD int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// IoctlPtr is a copy of the similarly named unexported function from the Go
|
|
||||||
// unix package. This is needed to do custom ioctl requests not supported by the
|
|
||||||
// standard library.
|
|
||||||
func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error {
|
|
||||||
_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
|
|
||||||
if err != 0 {
|
|
||||||
return fmt.Errorf("ioctl request %d: %w", req, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFeatures requests the supported feature bits from the virtio device
|
|
||||||
// associated with the given control file descriptor.
|
|
||||||
func GetFeatures(controlFD int) (virtio.Feature, error) {
|
|
||||||
var features virtio.Feature
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil {
|
|
||||||
return 0, fmt.Errorf("get features: %w", err)
|
|
||||||
}
|
|
||||||
return features, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFeatures communicates the feature bits supported by this implementation
|
|
||||||
// to the virtio device associated with the given control file descriptor.
|
|
||||||
func SetFeatures(controlFD int, features virtio.Feature) error {
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil {
|
|
||||||
return fmt.Errorf("set features: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// OwnControlFD sets the current process as the exclusive owner for the
|
|
||||||
// given control file descriptor. This must be called before interacting with
|
|
||||||
// the control file descriptor in any other way.
|
|
||||||
func OwnControlFD(controlFD int) error {
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil {
|
|
||||||
return fmt.Errorf("set control file descriptor owner: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMemoryLayout sets up or modifies the memory layout for the kernel-level
|
|
||||||
// virtio device associated with the given control file descriptor.
|
|
||||||
func SetMemoryLayout(controlFD int, layout MemoryLayout) error {
|
|
||||||
payload := layout.serializePayload()
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil {
|
|
||||||
return fmt.Errorf("set memory layout: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterQueue registers a virtio queue with the kernel-level virtio server.
|
|
||||||
// The virtqueue will be linked to the given control file descriptor and will
|
|
||||||
// have the given index. The kernel will use this queue until the control file
|
|
||||||
// descriptor is closed.
|
|
||||||
func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error {
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
Num: uint32(queue.Size()),
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set queue size: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
Flags: 0,
|
|
||||||
DescriptorTableAddress: queue.DescriptorTable().Address(),
|
|
||||||
UsedRingAddress: queue.UsedRing().Address(),
|
|
||||||
AvailableRingAddress: queue.AvailableRing().Address(),
|
|
||||||
LogAddress: 0,
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set queue addresses: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
Num: 0,
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set available ring base: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
FD: int32(queue.KickEventFD()),
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set kick event file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
FD: int32(queue.CallEventFD()),
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set call event file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package vhost_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/vhost"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestQueueState_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueueAddresses_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueueFile_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{}))
|
|
||||||
}
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
package vhost
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MemoryRegion describes a region of userspace memory which is being made
|
|
||||||
// accessible to a vhost device.
|
|
||||||
//
|
|
||||||
// Kernel name: vhost_memory_region
|
|
||||||
type MemoryRegion struct {
|
|
||||||
// GuestPhysicalAddress is the physical address of the memory region within
|
|
||||||
// the guest, when virtualization is used. When no virtualization is used,
|
|
||||||
// this should be the same as UserspaceAddress.
|
|
||||||
GuestPhysicalAddress uintptr
|
|
||||||
// Size is the size of the memory region.
|
|
||||||
Size uint64
|
|
||||||
// UserspaceAddress is the virtual address in the userspace of the host
|
|
||||||
// where the memory region can be found.
|
|
||||||
UserspaceAddress uintptr
|
|
||||||
// Padding and room for flags. Currently unused.
|
|
||||||
_ uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// MemoryLayout is a list of [MemoryRegion]s.
|
|
||||||
type MemoryLayout []MemoryRegion
|
|
||||||
|
|
||||||
// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
|
|
||||||
// memory pages used by the descriptor tables of the given queues.
|
|
||||||
func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
|
|
||||||
regions := make([]MemoryRegion, 0)
|
|
||||||
for _, queue := range queues {
|
|
||||||
for address, size := range queue.DescriptorTable().BufferAddresses() {
|
|
||||||
regions = append(regions, MemoryRegion{
|
|
||||||
// There is no virtualization in play here, so the guest address
|
|
||||||
// is the same as in the host's userspace.
|
|
||||||
GuestPhysicalAddress: address,
|
|
||||||
Size: uint64(size),
|
|
||||||
UserspaceAddress: address,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return regions
|
|
||||||
}
|
|
||||||
|
|
||||||
// serializePayload serializes the list of memory regions into a format that is
|
|
||||||
// compatible to the vhost_memory kernel struct. The returned byte slice can be
|
|
||||||
// used as a payload for the vhostIoctlSetMemoryLayout ioctl.
|
|
||||||
func (regions MemoryLayout) serializePayload() []byte {
|
|
||||||
regionCount := len(regions)
|
|
||||||
regionSize := int(unsafe.Sizeof(MemoryRegion{}))
|
|
||||||
payload := make([]byte, 8+regionCount*regionSize)
|
|
||||||
|
|
||||||
// The first 32 bits contain the number of memory regions. The following 32
|
|
||||||
// bits are padding.
|
|
||||||
binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount))
|
|
||||||
|
|
||||||
if regionCount > 0 {
|
|
||||||
// The underlying byte array of the slice should already have the correct
|
|
||||||
// format, so just copy that.
|
|
||||||
copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(®ions[0])), regionCount*regionSize))
|
|
||||||
if copied != regionCount*regionSize {
|
|
||||||
panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d",
|
|
||||||
copied, regionCount*regionSize))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package vhost
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMemoryRegion_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemoryLayout_SerializePayload(t *testing.T) {
|
|
||||||
layout := MemoryLayout([]MemoryRegion{
|
|
||||||
{
|
|
||||||
GuestPhysicalAddress: 42,
|
|
||||||
Size: 100,
|
|
||||||
UserspaceAddress: 142,
|
|
||||||
}, {
|
|
||||||
GuestPhysicalAddress: 99,
|
|
||||||
Size: 100,
|
|
||||||
UserspaceAddress: 99,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
payload := layout.serializePayload()
|
|
||||||
|
|
||||||
assert.Equal(t, []byte{
|
|
||||||
0x02, 0x00, 0x00, 0x00, // nregions
|
|
||||||
0x00, 0x00, 0x00, 0x00, // padding
|
|
||||||
// region 0
|
|
||||||
0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
|
|
||||||
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
|
|
||||||
0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
|
|
||||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
|
|
||||||
// region 1
|
|
||||||
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
|
|
||||||
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
|
|
||||||
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
|
|
||||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
|
|
||||||
}, payload)
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
@@ -1,427 +0,0 @@
|
|||||||
package vhostnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/vhost"
|
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrDeviceClosed is returned when the [Device] is closed while operations are
|
|
||||||
// still running.
|
|
||||||
var ErrDeviceClosed = errors.New("device was closed")
|
|
||||||
|
|
||||||
// The indexes for the receive and transmit queues.
|
|
||||||
const (
|
|
||||||
receiveQueueIndex = 0
|
|
||||||
transmitQueueIndex = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// Device represents a vhost networking device within the kernel-level virtio
|
|
||||||
// implementation and provides methods to interact with it.
|
|
||||||
type Device struct {
|
|
||||||
initialized bool
|
|
||||||
controlFD int
|
|
||||||
|
|
||||||
fullTable bool
|
|
||||||
ReceiveQueue *virtqueue.SplitQueue
|
|
||||||
TransmitQueue *virtqueue.SplitQueue
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDevice initializes a new vhost networking device within the
|
|
||||||
// kernel-level virtio implementation, sets up the virtqueues and returns a
|
|
||||||
// [Device] instance that can be used to communicate with that vhost device.
|
|
||||||
//
|
|
||||||
// There are multiple options that can be passed to this constructor to
|
|
||||||
// influence device creation:
|
|
||||||
// - [WithQueueSize]
|
|
||||||
// - [WithBackendFD]
|
|
||||||
// - [WithBackendDevice]
|
|
||||||
//
|
|
||||||
// Remember to call [Device.Close] after use to free up resources.
|
|
||||||
func NewDevice(options ...Option) (*Device, error) {
|
|
||||||
var err error
|
|
||||||
opts := optionDefaults
|
|
||||||
opts.apply(options)
|
|
||||||
if err = opts.validate(); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid options: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dev := Device{
|
|
||||||
controlFD: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up a partially initialized device when something fails.
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
_ = dev.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Retrieve a new control file descriptor. This will be used to configure
|
|
||||||
// the vhost networking device in the kernel.
|
|
||||||
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get control file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
|
|
||||||
return nil, fmt.Errorf("own control file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Advertise the supported features. This isn't much for now.
|
|
||||||
// TODO: Add feature options and implement proper feature negotiation.
|
|
||||||
getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get features: %w", err)
|
|
||||||
}
|
|
||||||
if getFeatures == 0 {
|
|
||||||
|
|
||||||
}
|
|
||||||
//const funky = virtio.Feature(1 << 27)
|
|
||||||
//features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers
|
|
||||||
features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers
|
|
||||||
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
|
|
||||||
return nil, fmt.Errorf("set features: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
itemSize := os.Getpagesize() * 4 //todo config
|
|
||||||
|
|
||||||
// Initialize and register the queues needed for the networking device.
|
|
||||||
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
|
|
||||||
return nil, fmt.Errorf("create receive queue: %w", err)
|
|
||||||
}
|
|
||||||
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
|
|
||||||
return nil, fmt.Errorf("create transmit queue: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up memory mappings for all buffers used by the queues. This has to
|
|
||||||
// happen before a backend for the queues can be registered.
|
|
||||||
memoryLayout := vhost.NewMemoryLayoutForQueues(
|
|
||||||
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
|
|
||||||
)
|
|
||||||
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
|
|
||||||
return nil, fmt.Errorf("setup memory layout: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the queue backends. This activates the queues within the kernel.
|
|
||||||
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
|
|
||||||
return nil, fmt.Errorf("set receive queue backend: %w", err)
|
|
||||||
}
|
|
||||||
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
|
|
||||||
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fully populate the receive queue with available buffers which the device
|
|
||||||
// can write new packets into.
|
|
||||||
if err = dev.refillReceiveQueue(); err != nil {
|
|
||||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
|
||||||
}
|
|
||||||
if err = dev.refillTransmitQueue(); err != nil {
|
|
||||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dev.initialized = true
|
|
||||||
|
|
||||||
// Make sure to clean up even when the device gets garbage collected without
|
|
||||||
// Close being called first.
|
|
||||||
devPtr := &dev
|
|
||||||
runtime.SetFinalizer(devPtr, (*Device).Close)
|
|
||||||
|
|
||||||
return devPtr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// refillReceiveQueue offers as many new device-writable buffers to the device
|
|
||||||
// as the queue can fit. The device will then use these to write received
|
|
||||||
// packets.
|
|
||||||
func (dev *Device) refillReceiveQueue() error {
|
|
||||||
for {
|
|
||||||
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
|
||||||
// Queue is full, job is done.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) refillTransmitQueue() error {
|
|
||||||
//for {
|
|
||||||
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
|
||||||
// if err != nil {
|
|
||||||
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
|
||||||
// // Queue is full, job is done.
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
// return fmt.Errorf("offer descriptor chain: %w", err)
|
|
||||||
// } else {
|
|
||||||
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close cleans up the vhost networking device within the kernel and releases
|
|
||||||
// all resources used for it.
|
|
||||||
// The implementation will try to release as many resources as possible and
|
|
||||||
// collect potential errors before returning them.
|
|
||||||
func (dev *Device) Close() error {
|
|
||||||
dev.initialized = false
|
|
||||||
|
|
||||||
// Closing the control file descriptor will unregister all queues from the
|
|
||||||
// kernel.
|
|
||||||
if dev.controlFD >= 0 {
|
|
||||||
if err := unix.Close(dev.controlFD); err != nil {
|
|
||||||
// Return an error and do not continue, because the memory used for
|
|
||||||
// the queues should not be released before they were unregistered
|
|
||||||
// from the kernel.
|
|
||||||
return fmt.Errorf("close control file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
dev.controlFD = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
var errs []error
|
|
||||||
|
|
||||||
if dev.ReceiveQueue != nil {
|
|
||||||
if err := dev.ReceiveQueue.Close(); err == nil {
|
|
||||||
dev.ReceiveQueue = nil
|
|
||||||
} else {
|
|
||||||
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if dev.TransmitQueue != nil {
|
|
||||||
if err := dev.TransmitQueue.Close(); err == nil {
|
|
||||||
dev.TransmitQueue = nil
|
|
||||||
} else {
|
|
||||||
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errs) == 0 {
|
|
||||||
// Everything was cleaned up. No need to run the finalizer anymore.
|
|
||||||
runtime.SetFinalizer(dev, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureInitialized is used as a guard to prevent methods to be called on an
|
|
||||||
// uninitialized instance.
|
|
||||||
func (dev *Device) ensureInitialized() {
|
|
||||||
if !dev.initialized {
|
|
||||||
panic("device is not initialized")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// createQueue creates a new virtqueue and registers it with the vhost device
|
|
||||||
// using the given index.
|
|
||||||
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
|
||||||
var (
|
|
||||||
queue *virtqueue.SplitQueue
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
|
|
||||||
return nil, fmt.Errorf("create virtqueue: %w", err)
|
|
||||||
}
|
|
||||||
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
|
||||||
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
|
|
||||||
}
|
|
||||||
return queue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncateBuffers returns a new list of buffers whose combined length matches
|
|
||||||
// exactly the specified length. When the specified length exceeds the length of
|
|
||||||
// the buffers, this is an error. When it is smaller, the buffer list will be
|
|
||||||
// truncated accordingly.
|
|
||||||
func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
|
|
||||||
for _, buffer := range buffers {
|
|
||||||
if length < len(buffer) {
|
|
||||||
out = append(out, buffer[:length])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out = append(out, buffer)
|
|
||||||
length -= len(buffer)
|
|
||||||
}
|
|
||||||
if length > 0 {
|
|
||||||
panic("length exceeds the combined length of all buffers")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
|
||||||
var err error
|
|
||||||
var idx uint16
|
|
||||||
if !dev.fullTable {
|
|
||||||
|
|
||||||
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
|
||||||
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
|
||||||
dev.fullTable = true
|
|
||||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
|
||||||
}
|
|
||||||
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
return idx, buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
|
|
||||||
if len(pkt.SegmentIDs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
for idx := range pkt.SegmentIDs {
|
|
||||||
segmentID := pkt.SegmentIDs[idx]
|
|
||||||
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
|
|
||||||
}
|
|
||||||
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("offer descriptor chains: %w", err)
|
|
||||||
}
|
|
||||||
pkt.Reset()
|
|
||||||
if kick {
|
|
||||||
if err := dev.TransmitQueue.Kick(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
|
|
||||||
if len(pkts) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range pkts {
|
|
||||||
if err := dev.TransmitPacket(pkts[i], false); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := dev.TransmitQueue.Kick(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
|
||||||
// TODO: Implement zero-copy variants to transmit and receive packets?
|
|
||||||
|
|
||||||
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
|
|
||||||
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
|
|
||||||
//read first element to see how many descriptors we need:
|
|
||||||
pkt.Reset()
|
|
||||||
|
|
||||||
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
if len(pkt.ChainRefs) == 0 {
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The specification requires that the first descriptor chain starts
|
|
||||||
// with a virtio-net header. It is not clear, whether it is also
|
|
||||||
// required to be fully contained in the first buffer of that
|
|
||||||
// descriptor chain, but it is reasonable to assume that this is
|
|
||||||
// always the case.
|
|
||||||
// The decode method already does the buffer length check.
|
|
||||||
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
|
|
||||||
// The device misbehaved. There is no way we can gracefully
|
|
||||||
// recover from this, because we don't know how many of the
|
|
||||||
// following descriptor chains belong to this packet.
|
|
||||||
return 0, fmt.Errorf("decode vnethdr: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//we have the header now: what do we need to do?
|
|
||||||
if int(pkt.Header.NumBuffers) > len(chains) {
|
|
||||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
|
|
||||||
}
|
|
||||||
if int(pkt.Header.NumBuffers) != 1 {
|
|
||||||
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
|
|
||||||
}
|
|
||||||
if chains[0].Length > 16000 {
|
|
||||||
//todo!
|
|
||||||
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
|
|
||||||
}
|
|
||||||
|
|
||||||
//shift the buffer out of out:
|
|
||||||
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
|
||||||
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
|
||||||
return 1, nil
|
|
||||||
|
|
||||||
//cursor := n - virtio.NetHdrSize
|
|
||||||
//
|
|
||||||
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
|
||||||
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
|
||||||
// return 1, nil
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//i := 1
|
|
||||||
//// we used chain 0 already
|
|
||||||
//for i = 1; i < len(chains); i++ {
|
|
||||||
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
|
|
||||||
// if err != nil {
|
|
||||||
// // When this fails we may miss to free some descriptor chains. We
|
|
||||||
// // could try to mitigate this by deferring the freeing somehow, but
|
|
||||||
// // it's not worth the hassle. When this method fails, the queue will
|
|
||||||
// // be in a broken state anyway.
|
|
||||||
// return i, fmt.Errorf("get descriptor chain: %w", err)
|
|
||||||
// }
|
|
||||||
// cursor += n
|
|
||||||
//}
|
|
||||||
////todo this has to be wrong
|
|
||||||
//pkt.Payload = pkt.Payload[:cursor]
|
|
||||||
//return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
|
||||||
//todo optimize?
|
|
||||||
var chains []virtqueue.UsedElement
|
|
||||||
var err error
|
|
||||||
//if len(dev.extraRx) == 0 {
|
|
||||||
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if len(chains) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
//} else {
|
|
||||||
// chains = dev.extraRx
|
|
||||||
//}
|
|
||||||
|
|
||||||
numPackets := 0
|
|
||||||
chainsIdx := 0
|
|
||||||
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
|
|
||||||
if numPackets >= len(out) {
|
|
||||||
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
|
|
||||||
}
|
|
||||||
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
chainsIdx += numChains
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we have copied all buffers, we can recycle the used descriptor chains
|
|
||||||
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
|
|
||||||
// return 0, err
|
|
||||||
//}
|
|
||||||
|
|
||||||
return numPackets, nil
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
// Package vhostnet implements methods to initialize vhost networking devices
|
|
||||||
// within the kernel-level virtio implementation and communicate with them.
|
|
||||||
package vhostnet
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package vhostnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/vhost"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
|
|
||||||
// or TAP device.
|
|
||||||
//
|
|
||||||
// Request payload: [vhost.QueueFile]
|
|
||||||
// Kernel name: VHOST_NET_SET_BACKEND
|
|
||||||
vhostNetIoctlSetBackend = 0x4008af30
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetQueueBackend attaches a virtqueue of the vhost networking device
|
|
||||||
// described by controlFD to the given backend file descriptor.
|
|
||||||
// The backend file descriptor can either be a RAW socket or a TAP device. When
|
|
||||||
// it is -1, the queue will be detached.
|
|
||||||
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
|
|
||||||
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
FD: int32(backendFD),
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set queue backend file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
package vhostnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
)
|
|
||||||
|
|
||||||
type optionValues struct {
|
|
||||||
queueSize int
|
|
||||||
backendFD int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *optionValues) apply(options []Option) {
|
|
||||||
for _, option := range options {
|
|
||||||
option(o)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *optionValues) validate() error {
|
|
||||||
if o.queueSize == -1 {
|
|
||||||
return errors.New("queue size is required")
|
|
||||||
}
|
|
||||||
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if o.backendFD == -1 {
|
|
||||||
return errors.New("backend file descriptor is required")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var optionDefaults = optionValues{
|
|
||||||
// Required.
|
|
||||||
queueSize: -1,
|
|
||||||
// Required.
|
|
||||||
backendFD: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Option can be passed to [NewDevice] to influence device creation.
|
|
||||||
type Option func(*optionValues)
|
|
||||||
|
|
||||||
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
|
|
||||||
// that are to be created for the device. It specifies the number of
|
|
||||||
// entries/buffers each queue can hold. This also affects the memory
|
|
||||||
// consumption.
|
|
||||||
// This is required and must be an integer from 1 to 32768 that is also a power
|
|
||||||
// of 2.
|
|
||||||
func WithQueueSize(queueSize int) Option {
|
|
||||||
return func(o *optionValues) { o.queueSize = queueSize }
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBackendFD returns an [Option] that sets the file descriptor of the
|
|
||||||
// backend that will be used for the queues of the device. The device will write
|
|
||||||
// and read packets to/from that backend. The file descriptor can either be of a
|
|
||||||
// RAW socket or TUN/TAP device.
|
|
||||||
// Either this or [WithBackendDevice] is required.
|
|
||||||
func WithBackendFD(backendFD int) Option {
|
|
||||||
return func(o *optionValues) { o.backendFD = backendFD }
|
|
||||||
}
|
|
||||||
|
|
||||||
//// WithBackendDevice returns an [Option] that sets the given TAP device as the
|
|
||||||
//// backend that will be used for the queues of the device. The device will
|
|
||||||
//// write and read packets to/from that backend. The TAP device should have been
|
|
||||||
//// created with the [tuntap.WithVirtioNetHdr] option enabled.
|
|
||||||
//// Either this or [WithBackendFD] is required.
|
|
||||||
//func WithBackendDevice(dev *tuntap.Device) Option {
|
|
||||||
// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
|
|
||||||
//}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
// availableRingFlag is a flag that describes an [AvailableRing].
|
|
||||||
type availableRingFlag uint16
|
|
||||||
|
|
||||||
const (
|
|
||||||
// availableRingFlagNoInterrupt is used by the guest to advise the host to
|
|
||||||
// not interrupt it when consuming a buffer. It's unreliable, so it's simply
|
|
||||||
// an optimization.
|
|
||||||
availableRingFlagNoInterrupt availableRingFlag = 1 << iota
|
|
||||||
)
|
|
||||||
|
|
||||||
// availableRingSize is the number of bytes needed to store an [AvailableRing]
|
|
||||||
// with the given queue size in memory.
|
|
||||||
func availableRingSize(queueSize int) int {
|
|
||||||
return 6 + 2*queueSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// availableRingAlignment is the minimum alignment of an [AvailableRing]
|
|
||||||
// in memory, as required by the virtio spec.
|
|
||||||
const availableRingAlignment = 2
|
|
||||||
|
|
||||||
// AvailableRing is used by the driver to offer descriptor chains to the device.
|
|
||||||
// Each ring entry refers to the head of a descriptor chain. It is only written
|
|
||||||
// to by the driver and read by the device.
|
|
||||||
//
|
|
||||||
// Because the size of the ring depends on the queue size, we cannot define a
|
|
||||||
// Go struct with a static size that maps to the memory of the ring. Instead,
|
|
||||||
// this struct only contains pointers to the corresponding memory areas.
|
|
||||||
type AvailableRing struct {
|
|
||||||
initialized bool
|
|
||||||
|
|
||||||
// flags that describe this ring.
|
|
||||||
flags *availableRingFlag
|
|
||||||
// ringIndex indicates where the driver would put the next entry into the
|
|
||||||
// ring (modulo the queue size).
|
|
||||||
ringIndex *uint16
|
|
||||||
// ring references buffers using the index of the head of the descriptor
|
|
||||||
// chain in the [DescriptorTable]. It wraps around at queue size.
|
|
||||||
ring []uint16
|
|
||||||
// usedEvent is not used by this implementation, but we reserve it anyway to
|
|
||||||
// avoid issues in case a device may try to access it, contrary to the
|
|
||||||
// virtio specification.
|
|
||||||
usedEvent *uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// newAvailableRing creates an available ring that uses the given underlying
|
|
||||||
// memory. The length of the memory slice must match the size needed for the
|
|
||||||
// ring (see [availableRingSize]) for the given queue size.
|
|
||||||
func newAvailableRing(queueSize int, mem []byte) *AvailableRing {
|
|
||||||
ringSize := availableRingSize(queueSize)
|
|
||||||
if len(mem) != ringSize {
|
|
||||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
|
||||||
"for available ring: %v", len(mem), ringSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &AvailableRing{
|
|
||||||
initialized: true,
|
|
||||||
flags: (*availableRingFlag)(unsafe.Pointer(&mem[0])),
|
|
||||||
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
|
|
||||||
ring: unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize),
|
|
||||||
usedEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address returns the pointer to the beginning of the ring in memory.
|
|
||||||
// Do not modify the memory directly to not interfere with this implementation.
|
|
||||||
func (r *AvailableRing) Address() uintptr {
|
|
||||||
if !r.initialized {
|
|
||||||
panic("available ring is not initialized")
|
|
||||||
}
|
|
||||||
return uintptr(unsafe.Pointer(r.flags))
|
|
||||||
}
|
|
||||||
|
|
||||||
// offer adds the given descriptor chain heads to the available ring and
|
|
||||||
// advances the ring index accordingly to make the device process the new
|
|
||||||
// descriptor chains.
|
|
||||||
func (r *AvailableRing) offerElements(chains []UsedElement) {
|
|
||||||
//always called under lock
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
// Add descriptor chain heads to the ring.
|
|
||||||
for offset, x := range chains {
|
|
||||||
// The 16-bit ring index may overflow. This is expected and is not an
|
|
||||||
// issue because the size of the ring array (which equals the queue
|
|
||||||
// size) is always a power of 2 and smaller than the highest possible
|
|
||||||
// 16-bit value.
|
|
||||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
|
||||||
r.ring[insertIndex] = x.GetHead()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increase the ring index by the number of descriptor chains added to the
|
|
||||||
// ring.
|
|
||||||
*r.ringIndex += uint16(len(chains))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AvailableRing) offer(chains []uint16) {
|
|
||||||
//always called under lock
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
// Add descriptor chain heads to the ring.
|
|
||||||
for offset, x := range chains {
|
|
||||||
// The 16-bit ring index may overflow. This is expected and is not an
|
|
||||||
// issue because the size of the ring array (which equals the queue
|
|
||||||
// size) is always a power of 2 and smaller than the highest possible
|
|
||||||
// 16-bit value.
|
|
||||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
|
||||||
r.ring[insertIndex] = x
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increase the ring index by the number of descriptor chains added to the
|
|
||||||
// ring.
|
|
||||||
*r.ringIndex += uint16(len(chains))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AvailableRing) offerSingle(x uint16) {
|
|
||||||
//always called under lock
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
offset := 0
|
|
||||||
// Add descriptor chain heads to the ring.
|
|
||||||
|
|
||||||
// The 16-bit ring index may overflow. This is expected and is not an
|
|
||||||
// issue because the size of the ring array (which equals the queue
|
|
||||||
// size) is always a power of 2 and smaller than the highest possible
|
|
||||||
// 16-bit value.
|
|
||||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
|
||||||
r.ring[insertIndex] = x
|
|
||||||
|
|
||||||
// Increase the ring index by the number of descriptor chains added to the ring.
|
|
||||||
*r.ringIndex += 1
|
|
||||||
}
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAvailableRing_MemoryLayout(t *testing.T) {
|
|
||||||
const queueSize = 2
|
|
||||||
|
|
||||||
memory := make([]byte, availableRingSize(queueSize))
|
|
||||||
r := newAvailableRing(queueSize, memory)
|
|
||||||
|
|
||||||
*r.flags = 0x01ff
|
|
||||||
*r.ringIndex = 1
|
|
||||||
r.ring[0] = 0x1234
|
|
||||||
r.ring[1] = 0x5678
|
|
||||||
|
|
||||||
assert.Equal(t, []byte{
|
|
||||||
0xff, 0x01,
|
|
||||||
0x01, 0x00,
|
|
||||||
0x34, 0x12,
|
|
||||||
0x78, 0x56,
|
|
||||||
0x00, 0x00,
|
|
||||||
}, memory)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAvailableRing_Offer(t *testing.T) {
|
|
||||||
const queueSize = 8
|
|
||||||
|
|
||||||
chainHeads := []uint16{42, 33, 69}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
startRingIndex uint16
|
|
||||||
expectedRingIndex uint16
|
|
||||||
expectedRing []uint16
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no overflow",
|
|
||||||
startRingIndex: 0,
|
|
||||||
expectedRingIndex: 3,
|
|
||||||
expectedRing: []uint16{42, 33, 69, 0, 0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ring overflow",
|
|
||||||
startRingIndex: 6,
|
|
||||||
expectedRingIndex: 9,
|
|
||||||
expectedRing: []uint16{69, 0, 0, 0, 0, 0, 42, 33},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "index overflow",
|
|
||||||
startRingIndex: 65535,
|
|
||||||
expectedRingIndex: 2,
|
|
||||||
expectedRing: []uint16{33, 69, 0, 0, 0, 0, 0, 42},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
memory := make([]byte, availableRingSize(queueSize))
|
|
||||||
r := newAvailableRing(queueSize, memory)
|
|
||||||
*r.ringIndex = tt.startRingIndex
|
|
||||||
|
|
||||||
r.offer(chainHeads)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedRingIndex, *r.ringIndex)
|
|
||||||
assert.Equal(t, tt.expectedRing, r.ring)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
// descriptorFlag is a flag that describes a [Descriptor].
|
|
||||||
type descriptorFlag uint16
|
|
||||||
|
|
||||||
const (
|
|
||||||
// descriptorFlagHasNext marks a descriptor chain as continuing via the next
|
|
||||||
// field.
|
|
||||||
descriptorFlagHasNext descriptorFlag = 1 << iota
|
|
||||||
// descriptorFlagWritable marks a buffer as device write-only (otherwise
|
|
||||||
// device read-only).
|
|
||||||
descriptorFlagWritable
|
|
||||||
// descriptorFlagIndirect means the buffer contains a list of buffer
|
|
||||||
// descriptors to provide an additional layer of indirection.
|
|
||||||
// Only allowed when the [virtio.FeatureIndirectDescriptors] feature was
|
|
||||||
// negotiated.
|
|
||||||
descriptorFlagIndirect
|
|
||||||
)
|
|
||||||
|
|
||||||
// descriptorSize is the number of bytes needed to store a [Descriptor] in
|
|
||||||
// memory.
|
|
||||||
const descriptorSize = 16
|
|
||||||
|
|
||||||
// Descriptor describes (a part of) a buffer which is either read-only for the
|
|
||||||
// device or write-only for the device (depending on [descriptorFlagWritable]).
|
|
||||||
// Multiple descriptors can be chained to produce a "descriptor chain" that can
|
|
||||||
// contain both device-readable and device-writable buffers. Device-readable
|
|
||||||
// descriptors always come first in a chain. A single, large buffer may be
|
|
||||||
// split up by chaining multiple similar descriptors that reference different
|
|
||||||
// memory pages. This is required, because buffers may exceed a single page size
|
|
||||||
// and the memory accessed by the device is expected to be continuous.
|
|
||||||
type Descriptor struct {
|
|
||||||
// address is the address to the continuous memory holding the data for this
|
|
||||||
// descriptor.
|
|
||||||
address uintptr
|
|
||||||
// length is the amount of bytes stored at address.
|
|
||||||
length uint32
|
|
||||||
// flags that describe this descriptor.
|
|
||||||
flags descriptorFlag
|
|
||||||
// next contains the index of the next descriptor continuing this descriptor
|
|
||||||
// chain when the [descriptorFlagHasNext] flag is set.
|
|
||||||
next uint16
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDescriptor_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{}))
|
|
||||||
}
|
|
||||||
@@ -1,641 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
|
|
||||||
// no buffers, which is not allowed.
|
|
||||||
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
|
|
||||||
|
|
||||||
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
|
|
||||||
// exhausted, meaning that the queue is full.
|
|
||||||
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
|
|
||||||
|
|
||||||
// ErrInvalidDescriptorChain is returned when a descriptor chain is not
|
|
||||||
// valid for a given operation.
|
|
||||||
ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
|
|
||||||
)
|
|
||||||
|
|
||||||
// noFreeHead is used to mark when all descriptors are in use and we have no
|
|
||||||
// free chain. This value is impossible to occur as an index naturally, because
|
|
||||||
// it exceeds the maximum queue size.
|
|
||||||
const noFreeHead = uint16(math.MaxUint16)
|
|
||||||
|
|
||||||
// descriptorTableSize is the number of bytes needed to store a
|
|
||||||
// [DescriptorTable] with the given queue size in memory.
|
|
||||||
func descriptorTableSize(queueSize int) int {
|
|
||||||
return descriptorSize * queueSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
|
|
||||||
// in memory, as required by the virtio spec.
|
|
||||||
const descriptorTableAlignment = 16
|
|
||||||
|
|
||||||
// DescriptorTable is a table that holds [Descriptor]s, addressed via their
|
|
||||||
// index in the slice.
|
|
||||||
type DescriptorTable struct {
|
|
||||||
descriptors []Descriptor
|
|
||||||
|
|
||||||
// freeHeadIndex is the index of the head of the descriptor chain which
|
|
||||||
// contains all currently unused descriptors. When all descriptors are in
|
|
||||||
// use, this has the special value of noFreeHead.
|
|
||||||
freeHeadIndex uint16
|
|
||||||
// freeNum tracks the number of descriptors which are currently not in use.
|
|
||||||
freeNum uint16
|
|
||||||
|
|
||||||
bufferBase uintptr
|
|
||||||
bufferSize int
|
|
||||||
itemSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDescriptorTable creates a descriptor table that uses the given underlying
|
|
||||||
// memory. The Length of the memory slice must match the size needed for the
|
|
||||||
// descriptor table (see [descriptorTableSize]) for the given queue size.
|
|
||||||
//
|
|
||||||
// Before this descriptor table can be used, [initialize] must be called.
|
|
||||||
func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
|
|
||||||
dtSize := descriptorTableSize(queueSize)
|
|
||||||
if len(mem) != dtSize {
|
|
||||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
|
||||||
"for descriptor table: %v", len(mem), dtSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &DescriptorTable{
|
|
||||||
descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
|
|
||||||
// We have no free descriptors until they were initialized.
|
|
||||||
freeHeadIndex: noFreeHead,
|
|
||||||
freeNum: 0,
|
|
||||||
itemSize: itemSize, //todo configurable? needs to be page-aligned
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address returns the pointer to the beginning of the descriptor table in
|
|
||||||
// memory. Do not modify the memory directly to not interfere with this
|
|
||||||
// implementation.
|
|
||||||
func (dt *DescriptorTable) Address() uintptr {
|
|
||||||
if dt.descriptors == nil {
|
|
||||||
panic("descriptor table is not initialized")
|
|
||||||
}
|
|
||||||
//should be same as dt.bufferBase
|
|
||||||
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) Size() uintptr {
|
|
||||||
if dt.descriptors == nil {
|
|
||||||
panic("descriptor table is not initialized")
|
|
||||||
}
|
|
||||||
return uintptr(dt.bufferSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BufferAddresses returns a map of pointer->size for all allocations used by the table
|
|
||||||
func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
|
|
||||||
if dt.descriptors == nil {
|
|
||||||
panic("descriptor table is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[uintptr]int{dt.bufferBase: dt.bufferSize}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initializeDescriptors allocates buffers with the size of a full memory page
|
|
||||||
// for each descriptor in the table. While this may be a bit wasteful, it makes
|
|
||||||
// dealing with descriptors way easier. Without this preallocation, we would
|
|
||||||
// have to allocate and free memory on demand, increasing complexity.
|
|
||||||
//
|
|
||||||
// All descriptors will be marked as free and will form a free chain. The
|
|
||||||
// addresses of all descriptors will be populated while their length remains
|
|
||||||
// zero.
|
|
||||||
func (dt *DescriptorTable) initializeDescriptors() error {
|
|
||||||
numDescriptors := len(dt.descriptors)
|
|
||||||
|
|
||||||
// Allocate ONE large region for all buffers
|
|
||||||
totalSize := dt.itemSize * numDescriptors
|
|
||||||
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
|
|
||||||
unix.PROT_READ|unix.PROT_WRITE,
|
|
||||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the base for cleanup later
|
|
||||||
dt.bufferBase = uintptr(basePtr)
|
|
||||||
dt.bufferSize = totalSize
|
|
||||||
|
|
||||||
for i := range dt.descriptors {
|
|
||||||
dt.descriptors[i] = Descriptor{
|
|
||||||
address: dt.bufferBase + uintptr(i*dt.itemSize),
|
|
||||||
length: 0,
|
|
||||||
// All descriptors should form a free chain that loops around.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: uint16((i + 1) % len(dt.descriptors)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// All descriptors are free to use now.
|
|
||||||
dt.freeHeadIndex = 0
|
|
||||||
dt.freeNum = uint16(len(dt.descriptors))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// releaseBuffers releases all allocated buffers for this descriptor table.
|
|
||||||
// The implementation will try to release as many buffers as possible and
|
|
||||||
// collect potential errors before returning them.
|
|
||||||
// The descriptor table should no longer be used after calling this.
|
|
||||||
func (dt *DescriptorTable) releaseBuffers() error {
|
|
||||||
for i := range dt.descriptors {
|
|
||||||
descriptor := &dt.descriptors[i]
|
|
||||||
descriptor.address = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// As a safety measure, make sure no descriptors can be used anymore.
|
|
||||||
dt.freeHeadIndex = noFreeHead
|
|
||||||
dt.freeNum = 0
|
|
||||||
|
|
||||||
if dt.bufferBase != 0 {
|
|
||||||
// The pointer points to memory not managed by Go, so this conversion
|
|
||||||
// is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
dt.bufferBase = 0
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("release buffer memory: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createDescriptorChain creates a new descriptor chain within the descriptor
|
|
||||||
// table which contains a number of device-readable buffers (out buffers) and
|
|
||||||
// device-writable buffers (in buffers).
|
|
||||||
//
|
|
||||||
// All buffers in the outBuffers slice will be concatenated by chaining
|
|
||||||
// descriptors, one for each buffer in the slice. The size of the single buffers
|
|
||||||
// must not exceed the size of a memory page (see [os.Getpagesize]).
|
|
||||||
// When numInBuffers is greater than zero, the given number of device-writable
|
|
||||||
// descriptors will be appended to the end of the chain, each referencing a
|
|
||||||
// whole memory page.
|
|
||||||
//
|
|
||||||
// The index of the head of the new descriptor chain will be returned. Callers
|
|
||||||
// should make sure to free the descriptor chain using [freeDescriptorChain]
|
|
||||||
// after it was used by the device.
|
|
||||||
//
|
|
||||||
// When there are not enough free descriptors to hold the given number of
|
|
||||||
// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the
|
|
||||||
// caller should try again after some descriptor chains were used by the device
|
|
||||||
// and returned back into the free chain.
|
|
||||||
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
|
|
||||||
// Calculate the number of descriptors needed to build the chain.
|
|
||||||
numDesc := uint16(len(outBuffers) + numInBuffers)
|
|
||||||
|
|
||||||
// Descriptor chains must always contain at least one descriptor.
|
|
||||||
if numDesc < 1 {
|
|
||||||
return 0, ErrDescriptorChainEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do we still have enough free descriptors?
|
|
||||||
if numDesc > dt.freeNum {
|
|
||||||
return 0, ErrNotEnoughFreeDescriptors
|
|
||||||
}
|
|
||||||
|
|
||||||
// Above validation ensured that there is at least one free descriptor, so
|
|
||||||
// the free descriptor chain head should be valid.
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
|
||||||
}
|
|
||||||
|
|
||||||
// To avoid having to iterate over the whole table to find the descriptor
|
|
||||||
// pointing to the head just to replace the free head, we instead always
|
|
||||||
// create descriptor chains from the descriptors coming after the head.
|
|
||||||
// This way we only have to touch the head as a last resort, when all other
|
|
||||||
// descriptors are already used.
|
|
||||||
head := dt.descriptors[dt.freeHeadIndex].next
|
|
||||||
next := head
|
|
||||||
tail := head
|
|
||||||
for i, buffer := range outBuffers {
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
checkUnusedDescriptorLength(next, desc)
|
|
||||||
|
|
||||||
if len(buffer) > dt.itemSize {
|
|
||||||
// The caller should already prevent that from happening.
|
|
||||||
panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy the buffer to the memory referenced by the descriptor.
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
|
|
||||||
desc.length = uint32(len(buffer))
|
|
||||||
|
|
||||||
// Clear the flags in case there were any others set.
|
|
||||||
desc.flags = descriptorFlagHasNext
|
|
||||||
|
|
||||||
tail = next
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
for range numInBuffers {
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
checkUnusedDescriptorLength(next, desc)
|
|
||||||
|
|
||||||
// Give the device the maximum available number of bytes to write into.
|
|
||||||
desc.length = uint32(dt.itemSize)
|
|
||||||
|
|
||||||
// Mark the descriptor as device-writable.
|
|
||||||
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
|
|
||||||
|
|
||||||
tail = next
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
|
|
||||||
// The last descriptor should end the chain.
|
|
||||||
tailDesc := &dt.descriptors[tail]
|
|
||||||
tailDesc.flags &= ^descriptorFlagHasNext
|
|
||||||
tailDesc.next = 0 // Not necessary to clear this, it's just for looks.
|
|
||||||
|
|
||||||
dt.freeNum -= numDesc
|
|
||||||
|
|
||||||
if dt.freeNum == 0 {
|
|
||||||
// The last descriptor in the chain should be the free chain head
|
|
||||||
// itself.
|
|
||||||
if tail != dt.freeHeadIndex {
|
|
||||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
|
||||||
}
|
|
||||||
|
|
||||||
// When this new chain takes up all remaining descriptors, we no longer
|
|
||||||
// have a free chain.
|
|
||||||
dt.freeHeadIndex = noFreeHead
|
|
||||||
} else {
|
|
||||||
// We took some descriptors out of the free chain, so make sure to close
|
|
||||||
// the circle again.
|
|
||||||
dt.descriptors[dt.freeHeadIndex].next = next
|
|
||||||
}
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
|
|
||||||
//todo just fill the damn table
|
|
||||||
// Do we still have enough free descriptors?
|
|
||||||
|
|
||||||
if 1 > dt.freeNum {
|
|
||||||
return 0, ErrNotEnoughFreeDescriptors
|
|
||||||
}
|
|
||||||
|
|
||||||
// Above validation ensured that there is at least one free descriptor, so
|
|
||||||
// the free descriptor chain head should be valid.
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
|
||||||
}
|
|
||||||
|
|
||||||
// To avoid having to iterate over the whole table to find the descriptor
|
|
||||||
// pointing to the head just to replace the free head, we instead always
|
|
||||||
// create descriptor chains from the descriptors coming after the head.
|
|
||||||
// This way we only have to touch the head as a last resort, when all other
|
|
||||||
// descriptors are already used.
|
|
||||||
head := dt.descriptors[dt.freeHeadIndex].next
|
|
||||||
desc := &dt.descriptors[head]
|
|
||||||
next := desc.next
|
|
||||||
|
|
||||||
checkUnusedDescriptorLength(head, desc)
|
|
||||||
|
|
||||||
// Give the device the maximum available number of bytes to write into.
|
|
||||||
desc.length = uint32(dt.itemSize)
|
|
||||||
desc.flags = 0 // descriptorFlagWritable
|
|
||||||
desc.next = 0 // Not necessary to clear this, it's just for looks.
|
|
||||||
|
|
||||||
dt.freeNum -= 1
|
|
||||||
|
|
||||||
if dt.freeNum == 0 {
|
|
||||||
// The last descriptor in the chain should be the free chain head
|
|
||||||
// itself.
|
|
||||||
if next != dt.freeHeadIndex {
|
|
||||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
|
||||||
}
|
|
||||||
|
|
||||||
// When this new chain takes up all remaining descriptors, we no longer
|
|
||||||
// have a free chain.
|
|
||||||
dt.freeHeadIndex = noFreeHead
|
|
||||||
} else {
|
|
||||||
// We took some descriptors out of the free chain, so make sure to close
|
|
||||||
// the circle again.
|
|
||||||
dt.descriptors[dt.freeHeadIndex].next = next
|
|
||||||
}
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
|
|
||||||
// Do we still have enough free descriptors?
|
|
||||||
if 1 > dt.freeNum {
|
|
||||||
return 0, ErrNotEnoughFreeDescriptors
|
|
||||||
}
|
|
||||||
|
|
||||||
// Above validation ensured that there is at least one free descriptor, so
|
|
||||||
// the free descriptor chain head should be valid.
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
|
||||||
}
|
|
||||||
|
|
||||||
// To avoid having to iterate over the whole table to find the descriptor
|
|
||||||
// pointing to the head just to replace the free head, we instead always
|
|
||||||
// create descriptor chains from the descriptors coming after the head.
|
|
||||||
// This way we only have to touch the head as a last resort, when all other
|
|
||||||
// descriptors are already used.
|
|
||||||
head := dt.descriptors[dt.freeHeadIndex].next
|
|
||||||
desc := &dt.descriptors[head]
|
|
||||||
next := desc.next
|
|
||||||
|
|
||||||
checkUnusedDescriptorLength(head, desc)
|
|
||||||
|
|
||||||
// Give the device the maximum available number of bytes to write into.
|
|
||||||
desc.length = uint32(dt.itemSize)
|
|
||||||
desc.flags = descriptorFlagWritable
|
|
||||||
desc.next = 0 // Not necessary to clear this, it's just for looks.
|
|
||||||
|
|
||||||
dt.freeNum -= 1
|
|
||||||
|
|
||||||
if dt.freeNum == 0 {
|
|
||||||
// The last descriptor in the chain should be the free chain head
|
|
||||||
// itself.
|
|
||||||
if next != dt.freeHeadIndex {
|
|
||||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
|
||||||
}
|
|
||||||
|
|
||||||
// When this new chain takes up all remaining descriptors, we no longer
|
|
||||||
// have a free chain.
|
|
||||||
dt.freeHeadIndex = noFreeHead
|
|
||||||
} else {
|
|
||||||
// We took some descriptors out of the free chain, so make sure to close
|
|
||||||
// the circle again.
|
|
||||||
dt.descriptors[dt.freeHeadIndex].next = next
|
|
||||||
}
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
|
||||||
|
|
||||||
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
|
||||||
// device-writable buffers (in buffers) of the descriptor chain that starts with
|
|
||||||
// the given head index. The descriptor chain must have been created using
|
|
||||||
// [createDescriptorChain] and must not have been freed yet (meaning that the
|
|
||||||
// head index must not be contained in the free chain).
|
|
||||||
//
|
|
||||||
// Be careful to only access the returned buffer slices when the device has not
|
|
||||||
// yet or is no longer using them. They must not be accessed after
|
|
||||||
// [freeDescriptorChain] has been called.
|
|
||||||
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
||||||
|
|
||||||
if desc.flags&descriptorFlagWritable == 0 {
|
|
||||||
outBuffers = append(outBuffers, bs)
|
|
||||||
} else {
|
|
||||||
inBuffers = append(inBuffers, bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
||||||
return bs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
||||||
|
|
||||||
if desc.flags&descriptorFlagWritable == 0 {
|
|
||||||
return fmt.Errorf("there should not be an outbuffer in %d", head)
|
|
||||||
} else {
|
|
||||||
*inBuffers = append(*inBuffers, bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
|
|
||||||
length := 0
|
|
||||||
//find length
|
|
||||||
next := head
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
if desc.flags&descriptorFlagWritable == 0 {
|
|
||||||
return 0, fmt.Errorf("receive queue contains device-readable buffer")
|
|
||||||
}
|
|
||||||
length += int(desc.length)
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
if maxLen > 0 {
|
|
||||||
//todo length = min(maxLen, length)
|
|
||||||
}
|
|
||||||
//set out to length:
|
|
||||||
out = out[:length]
|
|
||||||
|
|
||||||
//now do the copying
|
|
||||||
copied := 0
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length))
|
|
||||||
copied += copy(out[copied:], bs)
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// we did this already, no need to detect loops.
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
if copied != length {
|
|
||||||
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
|
|
||||||
}
|
|
||||||
|
|
||||||
return length, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
|
||||||
// longer in use. The descriptor chain that starts with the given index will be
|
|
||||||
// put back into the free chain, so the descriptors can be used for later calls
|
|
||||||
// of [createDescriptorChain].
|
|
||||||
// The descriptor chain must have been created using [createDescriptorChain] and
|
|
||||||
// must not have been freed yet (meaning that the head index must not be
|
|
||||||
// contained in the free chain).
|
|
||||||
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
var tailDesc *Descriptor
|
|
||||||
var chainLen uint16
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
chainLen++
|
|
||||||
|
|
||||||
// Set the length of all unused descriptors back to zero.
|
|
||||||
desc.length = 0
|
|
||||||
|
|
||||||
// Unset all flags except the next flag.
|
|
||||||
desc.flags &= descriptorFlagHasNext
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
tailDesc = desc
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
if tailDesc == nil {
|
|
||||||
// A descriptor chain longer than the queue size but without loops
|
|
||||||
// should be impossible.
|
|
||||||
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The tail descriptor does not have the next flag set, but when it comes
|
|
||||||
// back into the free chain, it should have.
|
|
||||||
tailDesc.flags = descriptorFlagHasNext
|
|
||||||
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
// The whole free chain was used up, so we turn this returned descriptor
|
|
||||||
// chain into the new free chain by completing the circle and using its
|
|
||||||
// head.
|
|
||||||
tailDesc.next = head
|
|
||||||
dt.freeHeadIndex = head
|
|
||||||
} else {
|
|
||||||
// Attach the returned chain at the beginning of the free chain but
|
|
||||||
// right after the free chain head.
|
|
||||||
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
|
|
||||||
tailDesc.next = freeHeadDesc.next
|
|
||||||
freeHeadDesc.next = head
|
|
||||||
}
|
|
||||||
|
|
||||||
dt.freeNum += chainLen
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
|
||||||
// is zero, as it should be.
|
|
||||||
// This is not a requirement by the virtio spec but rather a thing we do to
|
|
||||||
// notice when our algorithm goes sideways.
|
|
||||||
func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
|
|
||||||
if desc.length != 0 {
|
|
||||||
panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,407 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDescriptorTable_InitializeDescriptors(t *testing.T) {
|
|
||||||
const queueSize = 32
|
|
||||||
|
|
||||||
dt := DescriptorTable{
|
|
||||||
descriptors: make([]Descriptor, queueSize),
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.NoError(t, dt.initializeDescriptors())
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, dt.releaseBuffers())
|
|
||||||
})
|
|
||||||
|
|
||||||
for i, descriptor := range dt.descriptors {
|
|
||||||
assert.NotZero(t, descriptor.address)
|
|
||||||
assert.Zero(t, descriptor.length)
|
|
||||||
assert.EqualValues(t, descriptorFlagHasNext, descriptor.flags)
|
|
||||||
assert.EqualValues(t, (i+1)%queueSize, descriptor.next)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDescriptorTable_DescriptorChains(t *testing.T) {
|
|
||||||
// Use a very short queue size to not make this test overly verbose.
|
|
||||||
const queueSize = 8
|
|
||||||
|
|
||||||
pageSize := os.Getpagesize() * 2
|
|
||||||
|
|
||||||
// Initialize descriptor table.
|
|
||||||
dt := DescriptorTable{
|
|
||||||
descriptors: make([]Descriptor, queueSize),
|
|
||||||
}
|
|
||||||
assert.NoError(t, dt.initializeDescriptors())
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, dt.releaseBuffers())
|
|
||||||
})
|
|
||||||
|
|
||||||
// Some utilities for easier checking if the descriptor table looks as
|
|
||||||
// expected.
|
|
||||||
type desc struct {
|
|
||||||
buffer []byte
|
|
||||||
flags descriptorFlag
|
|
||||||
next uint16
|
|
||||||
}
|
|
||||||
assertDescriptorTable := func(expected [queueSize]desc) {
|
|
||||||
for i := 0; i < queueSize; i++ {
|
|
||||||
actualDesc := &dt.descriptors[i]
|
|
||||||
expectedDesc := &expected[i]
|
|
||||||
assert.Equal(t, uint32(len(expectedDesc.buffer)), actualDesc.length)
|
|
||||||
if len(expectedDesc.buffer) > 0 {
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
assert.EqualValues(t,
|
|
||||||
unsafe.Slice((*byte)(unsafe.Pointer(actualDesc.address)), actualDesc.length),
|
|
||||||
expectedDesc.buffer)
|
|
||||||
}
|
|
||||||
assert.Equal(t, expectedDesc.flags, actualDesc.flags)
|
|
||||||
if expectedDesc.flags&descriptorFlagHasNext != 0 {
|
|
||||||
assert.Equal(t, expectedDesc.next, actualDesc.next)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initial state: All descriptors are in the free chain.
|
|
||||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(8), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create the first chain.
|
|
||||||
firstChain, err := dt.createDescriptorChain([][]byte{
|
|
||||||
makeTestBuffer(t, 26),
|
|
||||||
makeTestBuffer(t, 256),
|
|
||||||
}, 1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, uint16(1), firstChain)
|
|
||||||
|
|
||||||
// Now there should be a new chain next to the free chain.
|
|
||||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(5), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of first chain.
|
|
||||||
buffer: makeTestBuffer(t, 26),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 256),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Tail of first chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a second chain with only a single in buffer.
|
|
||||||
secondChain, err := dt.createDescriptorChain(nil, 1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, uint16(4), secondChain)
|
|
||||||
|
|
||||||
// Now there should be two chains next to the free chain.
|
|
||||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(4), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of the first chain.
|
|
||||||
buffer: makeTestBuffer(t, 26),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 256),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Tail of the first chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head and tail of the second chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a third chain taking up all remaining descriptors.
|
|
||||||
thirdChain, err := dt.createDescriptorChain([][]byte{
|
|
||||||
makeTestBuffer(t, 42),
|
|
||||||
makeTestBuffer(t, 96),
|
|
||||||
makeTestBuffer(t, 33),
|
|
||||||
makeTestBuffer(t, 222),
|
|
||||||
}, 0)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, uint16(5), thirdChain)
|
|
||||||
|
|
||||||
// Now there should be three chains and no free chain.
|
|
||||||
assert.Equal(t, noFreeHead, dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(0), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
// Tail of the third chain.
|
|
||||||
buffer: makeTestBuffer(t, 222),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of the first chain.
|
|
||||||
buffer: makeTestBuffer(t, 26),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 256),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Tail of the first chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head and tail of the second chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of the third chain.
|
|
||||||
buffer: makeTestBuffer(t, 42),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 96),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 33),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Free the third chain.
|
|
||||||
assert.NoError(t, dt.freeDescriptorChain(thirdChain))
|
|
||||||
|
|
||||||
// Now there should be two chains and a free chain again.
|
|
||||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(4), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of the first chain.
|
|
||||||
buffer: makeTestBuffer(t, 26),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 256),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Tail of the first chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head and tail of the second chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Free the first chain.
|
|
||||||
assert.NoError(t, dt.freeDescriptorChain(firstChain))
|
|
||||||
|
|
||||||
// Now there should be only a single chain next to the free chain.
|
|
||||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(7), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head and tail of the second chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Free the second chain.
|
|
||||||
assert.NoError(t, dt.freeDescriptorChain(secondChain))
|
|
||||||
|
|
||||||
// Now all descriptors should be in the free chain again.
|
|
||||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(8), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestBuffer(t *testing.T, length int) []byte {
|
|
||||||
t.Helper()
|
|
||||||
buf := make([]byte, length)
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
buf[i] = byte(length - i)
|
|
||||||
}
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
// Package virtqueue implements the driver-side for a virtio queue as described
|
|
||||||
// in the specification:
|
|
||||||
// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006
|
|
||||||
// This package does not make assumptions about the device that consumes the
|
|
||||||
// queue. It rather just allocates the queue structures in memory and provides
|
|
||||||
// methods to interact with it.
|
|
||||||
package virtqueue
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"gvisor.dev/gvisor/pkg/eventfd"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Tests how an eventfd and a waiting goroutine can be gracefully closed.
|
|
||||||
// Extends the eventfd test suite:
|
|
||||||
// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go
|
|
||||||
func TestEventFD_CancelWait(t *testing.T) {
|
|
||||||
efd, err := eventfd.Create()
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, efd.Close())
|
|
||||||
})
|
|
||||||
|
|
||||||
var stop bool
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
for !stop {
|
|
||||||
_ = efd.Wait()
|
|
||||||
}
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
t.Fatalf("goroutine ended early")
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
}
|
|
||||||
|
|
||||||
stop = true
|
|
||||||
assert.NoError(t, efd.Notify())
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
break
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Error("goroutine did not end")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrQueueSizeInvalid is returned when a queue size is invalid.
|
|
||||||
var ErrQueueSizeInvalid = errors.New("queue size is invalid")
|
|
||||||
|
|
||||||
// CheckQueueSize checks if the given value would be a valid size for a
|
|
||||||
// virtqueue and returns an [ErrQueueSizeInvalid], if not.
|
|
||||||
func CheckQueueSize(queueSize int) error {
|
|
||||||
if queueSize <= 0 {
|
|
||||||
return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The queue size must always be a power of 2.
|
|
||||||
// This ensures that ring indexes wrap correctly when the 16-bit integers
|
|
||||||
// overflow.
|
|
||||||
if queueSize&(queueSize-1) != 0 {
|
|
||||||
return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The largest power of 2 that fits into a 16-bit integer is 32768.
|
|
||||||
// 2 * 32768 would be 65536 which no longer fits.
|
|
||||||
if queueSize > 32768 {
|
|
||||||
return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768",
|
|
||||||
ErrQueueSizeInvalid, queueSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCheckQueueSize(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
queueSize int
|
|
||||||
containsErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "negative",
|
|
||||||
queueSize: -1,
|
|
||||||
containsErr: "too small",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero",
|
|
||||||
queueSize: 0,
|
|
||||||
containsErr: "too small",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not a power of 2",
|
|
||||||
queueSize: 24,
|
|
||||||
containsErr: "not a power of 2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "too large",
|
|
||||||
queueSize: 65536,
|
|
||||||
containsErr: "larger than the maximum",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid 1",
|
|
||||||
queueSize: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid 256",
|
|
||||||
queueSize: 256,
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "valid 32768",
|
|
||||||
queueSize: 32768,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := CheckQueueSize(tt.queueSize)
|
|
||||||
if tt.containsErr != "" {
|
|
||||||
assert.ErrorContains(t, err, tt.containsErr)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,530 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/eventfd"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SplitQueue is a virtqueue that consists of several parts, where each part is
|
|
||||||
// writeable by either the driver or the device, but not both.
|
|
||||||
type SplitQueue struct {
|
|
||||||
// size is the size of the queue.
|
|
||||||
size int
|
|
||||||
// buf is the underlying memory used for the queue.
|
|
||||||
buf []byte
|
|
||||||
|
|
||||||
descriptorTable *DescriptorTable
|
|
||||||
availableRing *AvailableRing
|
|
||||||
usedRing *UsedRing
|
|
||||||
|
|
||||||
// kickEventFD is used to signal the device when descriptor chains were
|
|
||||||
// added to the available ring.
|
|
||||||
kickEventFD eventfd.EventFD
|
|
||||||
// callEventFD is used by the device to signal when it has used descriptor
|
|
||||||
// chains and put them in the used ring.
|
|
||||||
callEventFD eventfd.EventFD
|
|
||||||
|
|
||||||
// stop is used by [SplitQueue.Close] to cancel the goroutine that handles
|
|
||||||
// used buffer notifications. It blocks until the goroutine ended.
|
|
||||||
stop func() error
|
|
||||||
|
|
||||||
itemSize int
|
|
||||||
|
|
||||||
epoll eventfd.Epoll
|
|
||||||
more int
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
|
|
||||||
// specifies the number of entries/buffers the queue can hold. This also affects
|
|
||||||
// the memory consumption.
|
|
||||||
func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
|
|
||||||
if err = CheckQueueSize(queueSize); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if itemSize%os.Getpagesize() != 0 {
|
|
||||||
return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
|
|
||||||
}
|
|
||||||
|
|
||||||
sq := SplitQueue{
|
|
||||||
size: queueSize,
|
|
||||||
itemSize: itemSize,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up a partially initialized queue when something fails.
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
_ = sq.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// There are multiple ways for how the memory for the virtqueue could be
|
|
||||||
// allocated. We could use Go native structs with arrays inside them, but
|
|
||||||
// this wouldn't allow us to make the queue size configurable. And including
|
|
||||||
// a slice in the Go structs wouldn't work, because this would just put the
|
|
||||||
// Go slice descriptor into the memory region which the virtio device will
|
|
||||||
// not understand.
|
|
||||||
// Additionally, Go does not allow us to ensure a correct alignment of the
|
|
||||||
// parts of the virtqueue, as it is required by the virtio specification.
|
|
||||||
//
|
|
||||||
// To resolve this, let's just allocate the memory manually by allocating
|
|
||||||
// one or more memory pages, depending on the queue size. Making the
|
|
||||||
// virtqueue start at the beginning of a page is not strictly necessary, as
|
|
||||||
// the virtio specification does not require it to be continuous in the
|
|
||||||
// physical memory of the host (e.g. the vhost implementation in the kernel
|
|
||||||
// always uses copy_from_user to access it), but this makes it very easy to
|
|
||||||
// guarantee the alignment. Also, it is not required for the virtqueue parts
|
|
||||||
// to be in the same memory region, as we pass separate pointers to them to
|
|
||||||
// the device, but this design just makes things easier to implement.
|
|
||||||
//
|
|
||||||
// One added benefit of allocating the memory manually is, that we have full
|
|
||||||
// control over its lifetime and don't risk the garbage collector to collect
|
|
||||||
// our valuable structures while the device still works with them.
|
|
||||||
|
|
||||||
// The descriptor table is at the start of the page, so alignment is not an
|
|
||||||
// issue here.
|
|
||||||
descriptorTableStart := 0
|
|
||||||
descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
|
|
||||||
availableRingStart := align(descriptorTableEnd, availableRingAlignment)
|
|
||||||
availableRingEnd := availableRingStart + availableRingSize(queueSize)
|
|
||||||
usedRingStart := align(availableRingEnd, usedRingAlignment)
|
|
||||||
usedRingEnd := usedRingStart + usedRingSize(queueSize)
|
|
||||||
|
|
||||||
sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
|
|
||||||
unix.PROT_READ|unix.PROT_WRITE,
|
|
||||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
|
|
||||||
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
|
|
||||||
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
|
|
||||||
|
|
||||||
sq.kickEventFD, err = eventfd.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create kick event file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
sq.callEventFD, err = eventfd.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create call event file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = sq.descriptorTable.initializeDescriptors(); err != nil {
|
|
||||||
return nil, fmt.Errorf("initialize descriptors: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sq.epoll, err = eventfd.NewEpoll()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = sq.epoll.AddEvent(sq.callEventFD.FD())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consume used buffer notifications in the background.
|
|
||||||
sq.stop = sq.startConsumeUsedRing()
|
|
||||||
|
|
||||||
return &sq, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns the size of this queue, which is the number of entries/buffers
|
|
||||||
// this queue can hold.
|
|
||||||
func (sq *SplitQueue) Size() int {
|
|
||||||
return sq.size
|
|
||||||
}
|
|
||||||
|
|
||||||
// DescriptorTable returns the [DescriptorTable] behind this queue.
|
|
||||||
func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
|
|
||||||
return sq.descriptorTable
|
|
||||||
}
|
|
||||||
|
|
||||||
// AvailableRing returns the [AvailableRing] behind this queue.
|
|
||||||
func (sq *SplitQueue) AvailableRing() *AvailableRing {
|
|
||||||
return sq.availableRing
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsedRing returns the [UsedRing] behind this queue.
|
|
||||||
func (sq *SplitQueue) UsedRing() *UsedRing {
|
|
||||||
return sq.usedRing
|
|
||||||
}
|
|
||||||
|
|
||||||
// KickEventFD returns the kick event file descriptor behind this queue.
|
|
||||||
// The returned file descriptor should be used with great care to not interfere
|
|
||||||
// with this implementation.
|
|
||||||
func (sq *SplitQueue) KickEventFD() int {
|
|
||||||
return sq.kickEventFD.FD()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CallEventFD returns the call event file descriptor behind this queue.
|
|
||||||
// The returned file descriptor should be used with great care to not interfere
|
|
||||||
// with this implementation.
|
|
||||||
func (sq *SplitQueue) CallEventFD() int {
|
|
||||||
return sq.callEventFD.FD()
|
|
||||||
}
|
|
||||||
|
|
||||||
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
|
|
||||||
// A function is returned that can be used to gracefully cancel it. todo rename
|
|
||||||
func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
|
||||||
return func() error {
|
|
||||||
|
|
||||||
// The goroutine blocks until it receives a signal on the event file
|
|
||||||
// descriptor, so it will never notice the context being canceled.
|
|
||||||
// To resolve this, we can just produce a fake-signal ourselves to wake
|
|
||||||
// it up.
|
|
||||||
if err := sq.callEventFD.Kick(); err != nil {
|
|
||||||
return fmt.Errorf("wake up goroutine: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s
|
|
||||||
func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) {
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
|
|
||||||
// Wait for a signal from the device.
|
|
||||||
if n, err = sq.epoll.Block(); err != nil {
|
|
||||||
return nil, fmt.Errorf("wait: %w", err)
|
|
||||||
}
|
|
||||||
if n > 0 {
|
|
||||||
stillNeedToTake, out := sq.usedRing.take(-1)
|
|
||||||
sq.more = stillNeedToTake
|
|
||||||
if stillNeedToTake == 0 {
|
|
||||||
_ = sq.epoll.Clear() //???
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
out, ok := sq.usedRing.takeOne()
|
|
||||||
if ok {
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
// Wait for a signal from the device.
|
|
||||||
if n, err = sq.epoll.Block(); err != nil {
|
|
||||||
return 0, fmt.Errorf("wait: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if n > 0 {
|
|
||||||
out, ok = sq.usedRing.takeOne()
|
|
||||||
if ok {
|
|
||||||
_ = sq.epoll.Clear() //???
|
|
||||||
return out, nil
|
|
||||||
} else {
|
|
||||||
continue //???
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
|
|
||||||
//we have leftovers in the fridge
|
|
||||||
if sq.more > 0 {
|
|
||||||
stillNeedToTake, out := sq.usedRing.take(maxToTake)
|
|
||||||
sq.more = stillNeedToTake
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
//look inside the fridge
|
|
||||||
stillNeedToTake, out := sq.usedRing.take(maxToTake)
|
|
||||||
if len(out) > 0 {
|
|
||||||
sq.more = stillNeedToTake
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
//fridge is empty I guess
|
|
||||||
|
|
||||||
// Wait for a signal from the device.
|
|
||||||
if n, err = sq.epoll.Block(); err != nil {
|
|
||||||
return nil, fmt.Errorf("wait: %w", err)
|
|
||||||
}
|
|
||||||
if n > 0 {
|
|
||||||
_ = sq.epoll.Clear() //???
|
|
||||||
stillNeedToTake, out = sq.usedRing.take(maxToTake)
|
|
||||||
sq.more = stillNeedToTake
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// OfferDescriptorChain offers a descriptor chain to the device which contains a
|
|
||||||
// number of device-readable buffers (out buffers) and device-writable buffers
|
|
||||||
// (in buffers).
|
|
||||||
//
|
|
||||||
// All buffers in the outBuffers slice will be concatenated by chaining
|
|
||||||
// descriptors, one for each buffer in the slice. When a buffer is too large to
|
|
||||||
// fit into a single descriptor (limited by the system's page size), it will be
|
|
||||||
// split up into multiple descriptors within the chain.
|
|
||||||
// When numInBuffers is greater than zero, the given number of device-writable
|
|
||||||
// descriptors will be appended to the end of the chain, each referencing a
|
|
||||||
// whole memory page (see [os.Getpagesize]).
|
|
||||||
//
|
|
||||||
// When the queue is full and no more descriptor chains can be added, a wrapped
|
|
||||||
// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
|
|
||||||
// this method will handle this error and will block instead until there are
|
|
||||||
// enough free descriptors again.
|
|
||||||
//
|
|
||||||
// After defining the descriptor chain in the [DescriptorTable], the index of
|
|
||||||
// the head of the chain will be made available to the device using the
|
|
||||||
// [AvailableRing] and will be returned by this method.
|
|
||||||
// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
|
|
||||||
// notified when the descriptor chain was used by the device and should free the
|
|
||||||
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
|
|
||||||
// they're done with them. When this does not happen, the queue will run full
|
|
||||||
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
|
||||||
|
|
||||||
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
|
||||||
// Create a descriptor chain for the given buffers.
|
|
||||||
var (
|
|
||||||
head uint16
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
for {
|
|
||||||
head, err = sq.descriptorTable.createDescriptorForInputs()
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// I don't wanna use errors.Is, it's slow
|
|
||||||
//goland:noinspection GoDirectComparisonOfErrors
|
|
||||||
if err == ErrNotEnoughFreeDescriptors {
|
|
||||||
return 0, err
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the descriptor chain available to the device.
|
|
||||||
sq.availableRing.offerSingle(head)
|
|
||||||
|
|
||||||
// Notify the device to make it process the updated available ring.
|
|
||||||
if err := sq.kickEventFD.Kick(); err != nil {
|
|
||||||
return head, fmt.Errorf("notify device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) {
|
|
||||||
// TODO change this
|
|
||||||
// Each descriptor can only hold a whole memory page, so split large out
|
|
||||||
// buffers into multiple smaller ones.
|
|
||||||
outBuffers = splitBuffers(outBuffers, sq.itemSize)
|
|
||||||
|
|
||||||
chains := make([]uint16, len(outBuffers))
|
|
||||||
|
|
||||||
// Create a descriptor chain for the given buffers.
|
|
||||||
var (
|
|
||||||
head uint16
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
for i := range outBuffers {
|
|
||||||
for {
|
|
||||||
bufs := [][]byte{prepend, outBuffers[i]}
|
|
||||||
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// I don't wanna use errors.Is, it's slow
|
|
||||||
//goland:noinspection GoDirectComparisonOfErrors
|
|
||||||
if err == ErrNotEnoughFreeDescriptors {
|
|
||||||
// Wait for more free descriptors to be put back into the queue.
|
|
||||||
// If the number of free descriptors is still not sufficient, we'll
|
|
||||||
// land here again.
|
|
||||||
//todo should never happen
|
|
||||||
syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("create descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
chains[i] = head
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the descriptor chain available to the device.
|
|
||||||
sq.availableRing.offer(chains)
|
|
||||||
|
|
||||||
// Notify the device to make it process the updated available ring.
|
|
||||||
if err := sq.kickEventFD.Kick(); err != nil {
|
|
||||||
return chains, fmt.Errorf("notify device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return chains, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
|
||||||
// device-writable buffers (in buffers) of the descriptor chain with the given
|
|
||||||
// head index.
|
|
||||||
// The head index must be one that was returned by a previous call to
|
|
||||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
|
||||||
// freed yet.
|
|
||||||
//
|
|
||||||
// Be careful to only access the returned buffer slices when the device is no
|
|
||||||
// longer using them. They must not be accessed after
|
|
||||||
// [SplitQueue.FreeDescriptorChain] has been called.
|
|
||||||
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
|
||||||
return sq.descriptorTable.getDescriptorChain(head)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
|
||||||
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
|
|
||||||
return sq.descriptorTable.getDescriptorItem(head)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
|
||||||
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
|
||||||
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FreeDescriptorChain frees the descriptor chain with the given head index.
|
|
||||||
// The head index must be one that was returned by a previous call to
|
|
||||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
|
||||||
// freed yet.
|
|
||||||
//
|
|
||||||
// This creates new room in the queue which can be used by following
|
|
||||||
// [SplitQueue.OfferDescriptorChain] calls.
|
|
||||||
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
|
|
||||||
// are waiting for free room in the queue, they may become unblocked by this.
|
|
||||||
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
|
||||||
//not called under lock
|
|
||||||
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
|
||||||
return fmt.Errorf("free: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
|
|
||||||
//not called under lock
|
|
||||||
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
|
|
||||||
//todo not doing this may break eventually?
|
|
||||||
//not called under lock
|
|
||||||
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
|
||||||
// return fmt.Errorf("free: %w", err)
|
|
||||||
//}
|
|
||||||
|
|
||||||
// Make the descriptor chain available to the device.
|
|
||||||
sq.availableRing.offer(chains)
|
|
||||||
|
|
||||||
// Notify the device to make it process the updated available ring.
|
|
||||||
if kick {
|
|
||||||
return sq.Kick()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) Kick() error {
|
|
||||||
if err := sq.kickEventFD.Kick(); err != nil {
|
|
||||||
return fmt.Errorf("notify device: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close releases all resources used for this queue.
|
|
||||||
// The implementation will try to release as many resources as possible and
|
|
||||||
// collect potential errors before returning them.
|
|
||||||
func (sq *SplitQueue) Close() error {
|
|
||||||
var errs []error
|
|
||||||
|
|
||||||
if sq.stop != nil {
|
|
||||||
// This has to happen before the event file descriptors may be closed.
|
|
||||||
if err := sq.stop(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure that this code block is executed only once.
|
|
||||||
sq.stop = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sq.kickEventFD.Close(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
|
|
||||||
}
|
|
||||||
if err := sq.callEventFD.Close(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sq.descriptorTable.releaseBuffers(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if sq.buf != nil {
|
|
||||||
if err := unix.Munmap(sq.buf); err == nil {
|
|
||||||
sq.buf = nil
|
|
||||||
} else {
|
|
||||||
errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureInitialized is used as a guard to prevent methods to be called on an
|
|
||||||
// uninitialized instance.
|
|
||||||
func (sq *SplitQueue) ensureInitialized() {
|
|
||||||
if sq.buf == nil {
|
|
||||||
panic("used ring is not initialized")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func align(index, alignment int) int {
|
|
||||||
remainder := index % alignment
|
|
||||||
if remainder == 0 {
|
|
||||||
return index
|
|
||||||
}
|
|
||||||
return index + alignment - remainder
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitBuffers processes a list of buffers and splits each buffer that is
|
|
||||||
// larger than the size limit into multiple smaller buffers.
|
|
||||||
// If none of the buffers are too big though, do nothing, to avoid allocation for now
|
|
||||||
func splitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
|
|
||||||
for i := range buffers {
|
|
||||||
if len(buffers[i]) > sizeLimit {
|
|
||||||
return reallySplitBuffers(buffers, sizeLimit)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return buffers
|
|
||||||
}
|
|
||||||
|
|
||||||
func reallySplitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
|
|
||||||
result := make([][]byte, 0, len(buffers))
|
|
||||||
for _, buffer := range buffers {
|
|
||||||
for added := 0; added < len(buffer); added += sizeLimit {
|
|
||||||
if len(buffer)-added <= sizeLimit {
|
|
||||||
result = append(result, buffer[added:])
|
|
||||||
break
|
|
||||||
}
|
|
||||||
result = append(result, buffer[added:added+sizeLimit])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSplitQueue_MemoryAlignment(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
queueSize int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "minimal queue size",
|
|
||||||
queueSize: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "small queue size",
|
|
||||||
queueSize: 8,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large queue size",
|
|
||||||
queueSize: 256,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
sq, err := NewSplitQueue(tt.queueSize)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Zero(t, sq.descriptorTable.Address()%descriptorTableAlignment)
|
|
||||||
assert.Zero(t, sq.availableRing.Address()%availableRingAlignment)
|
|
||||||
assert.Zero(t, sq.usedRing.Address()%usedRingAlignment)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSplitBuffers(t *testing.T) {
|
|
||||||
const sizeLimit = 16
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
buffers [][]byte
|
|
||||||
expected [][]byte
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no buffers",
|
|
||||||
buffers: make([][]byte, 0),
|
|
||||||
expected: make([][]byte, 0),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "small",
|
|
||||||
buffers: [][]byte{
|
|
||||||
make([]byte, 11),
|
|
||||||
},
|
|
||||||
expected: [][]byte{
|
|
||||||
make([]byte, 11),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "exact size",
|
|
||||||
buffers: [][]byte{
|
|
||||||
make([]byte, sizeLimit),
|
|
||||||
},
|
|
||||||
expected: [][]byte{
|
|
||||||
make([]byte, sizeLimit),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large",
|
|
||||||
buffers: [][]byte{
|
|
||||||
make([]byte, 42),
|
|
||||||
},
|
|
||||||
expected: [][]byte{
|
|
||||||
make([]byte, 16),
|
|
||||||
make([]byte, 16),
|
|
||||||
make([]byte, 10),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed",
|
|
||||||
buffers: [][]byte{
|
|
||||||
make([]byte, 7),
|
|
||||||
make([]byte, 30),
|
|
||||||
make([]byte, 15),
|
|
||||||
make([]byte, 32),
|
|
||||||
},
|
|
||||||
expected: [][]byte{
|
|
||||||
make([]byte, 7),
|
|
||||||
make([]byte, 16),
|
|
||||||
make([]byte, 14),
|
|
||||||
make([]byte, 15),
|
|
||||||
make([]byte, 16),
|
|
||||||
make([]byte, 16),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
actual := splitBuffers(tt.buffers, sizeLimit)
|
|
||||||
assert.Equal(t, tt.expected, actual)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
// usedElementSize is the number of bytes needed to store a [UsedElement] in
|
|
||||||
// memory.
|
|
||||||
const usedElementSize = 8
|
|
||||||
|
|
||||||
// UsedElement is an element of the [UsedRing] and describes a descriptor chain
|
|
||||||
// that was used by the device.
|
|
||||||
type UsedElement struct {
|
|
||||||
// DescriptorIndex is the index of the head of the used descriptor chain in
|
|
||||||
// the [DescriptorTable].
|
|
||||||
// The index is 32-bit here for padding reasons.
|
|
||||||
DescriptorIndex uint32
|
|
||||||
// Length is the number of bytes written into the device writable portion of
|
|
||||||
// the buffer described by the descriptor chain.
|
|
||||||
Length uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *UsedElement) GetHead() uint16 {
|
|
||||||
return uint16(u.DescriptorIndex)
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUsedElement_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{}))
|
|
||||||
}
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
// usedRingFlag is a flag that describes a [UsedRing].
|
|
||||||
type usedRingFlag uint16
|
|
||||||
|
|
||||||
const (
|
|
||||||
// usedRingFlagNoNotify is used by the host to advise the guest to not
|
|
||||||
// kick it when adding a buffer. It's unreliable, so it's simply an
|
|
||||||
// optimization. Guest will still kick when it's out of buffers.
|
|
||||||
usedRingFlagNoNotify usedRingFlag = 1 << iota
|
|
||||||
)
|
|
||||||
|
|
||||||
// usedRingSize is the number of bytes needed to store a [UsedRing] with the
|
|
||||||
// given queue size in memory.
|
|
||||||
func usedRingSize(queueSize int) int {
|
|
||||||
return 6 + usedElementSize*queueSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as
|
|
||||||
// required by the virtio spec.
|
|
||||||
const usedRingAlignment = 4
|
|
||||||
|
|
||||||
// UsedRing is where the device returns descriptor chains once it is done with
|
|
||||||
// them. Each ring entry is a [UsedElement]. It is only written to by the device
|
|
||||||
// and read by the driver.
|
|
||||||
//
|
|
||||||
// Because the size of the ring depends on the queue size, we cannot define a
|
|
||||||
// Go struct with a static size that maps to the memory of the ring. Instead,
|
|
||||||
// this struct only contains pointers to the corresponding memory areas.
|
|
||||||
type UsedRing struct {
|
|
||||||
initialized bool
|
|
||||||
|
|
||||||
// flags that describe this ring.
|
|
||||||
flags *usedRingFlag
|
|
||||||
// ringIndex indicates where the device would put the next entry into the
|
|
||||||
// ring (modulo the queue size).
|
|
||||||
ringIndex *uint16
|
|
||||||
// ring contains the [UsedElement]s. It wraps around at queue size.
|
|
||||||
ring []UsedElement
|
|
||||||
// availableEvent is not used by this implementation, but we reserve it
|
|
||||||
// anyway to avoid issues in case a device may try to write to it, contrary
|
|
||||||
// to the virtio specification.
|
|
||||||
availableEvent *uint16
|
|
||||||
|
|
||||||
// lastIndex is the internal ringIndex up to which all [UsedElement]s were
|
|
||||||
// processed.
|
|
||||||
lastIndex uint16
|
|
||||||
|
|
||||||
//mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// newUsedRing creates a used ring that uses the given underlying memory. The
|
|
||||||
// length of the memory slice must match the size needed for the ring (see
|
|
||||||
// [usedRingSize]) for the given queue size.
|
|
||||||
func newUsedRing(queueSize int, mem []byte) *UsedRing {
|
|
||||||
ringSize := usedRingSize(queueSize)
|
|
||||||
if len(mem) != ringSize {
|
|
||||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
|
||||||
"for used ring: %v", len(mem), ringSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
r := UsedRing{
|
|
||||||
initialized: true,
|
|
||||||
flags: (*usedRingFlag)(unsafe.Pointer(&mem[0])),
|
|
||||||
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
|
|
||||||
ring: unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize),
|
|
||||||
availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
|
|
||||||
}
|
|
||||||
r.lastIndex = *r.ringIndex
|
|
||||||
return &r
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address returns the pointer to the beginning of the ring in memory.
|
|
||||||
// Do not modify the memory directly to not interfere with this implementation.
|
|
||||||
func (r *UsedRing) Address() uintptr {
|
|
||||||
if !r.initialized {
|
|
||||||
panic("used ring is not initialized")
|
|
||||||
}
|
|
||||||
return uintptr(unsafe.Pointer(r.flags))
|
|
||||||
}
|
|
||||||
|
|
||||||
// take returns all new [UsedElement]s that the device put into the ring and
|
|
||||||
// that weren't already returned by a previous call to this method.
|
|
||||||
// had a lock, I removed it
|
|
||||||
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
ringIndex := *r.ringIndex
|
|
||||||
if ringIndex == r.lastIndex {
|
|
||||||
// Nothing new.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the number new used elements that we can read from the ring.
|
|
||||||
// The ring index may wrap, so special handling for that case is needed.
|
|
||||||
count := int(ringIndex - r.lastIndex)
|
|
||||||
if count < 0 {
|
|
||||||
count += 0xffff
|
|
||||||
}
|
|
||||||
|
|
||||||
stillNeedToTake := 0
|
|
||||||
|
|
||||||
if maxToTake > 0 {
|
|
||||||
stillNeedToTake = count - maxToTake
|
|
||||||
if stillNeedToTake < 0 {
|
|
||||||
stillNeedToTake = 0
|
|
||||||
}
|
|
||||||
count = min(count, maxToTake)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number of new elements can never exceed the queue size.
|
|
||||||
if count > len(r.ring) {
|
|
||||||
panic("used ring contains more new elements than the ring is long")
|
|
||||||
}
|
|
||||||
|
|
||||||
elems := make([]UsedElement, count)
|
|
||||||
for i := range count {
|
|
||||||
elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))]
|
|
||||||
r.lastIndex++
|
|
||||||
}
|
|
||||||
|
|
||||||
return stillNeedToTake, elems
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *UsedRing) takeOne() (uint16, bool) {
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
ringIndex := *r.ringIndex
|
|
||||||
if ringIndex == r.lastIndex {
|
|
||||||
// Nothing new.
|
|
||||||
return 0xffff, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the number new used elements that we can read from the ring.
|
|
||||||
// The ring index may wrap, so special handling for that case is needed.
|
|
||||||
count := int(ringIndex - r.lastIndex)
|
|
||||||
if count < 0 {
|
|
||||||
count += 0xffff
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number of new elements can never exceed the queue size.
|
|
||||||
if count > len(r.ring) {
|
|
||||||
panic("used ring contains more new elements than the ring is long")
|
|
||||||
}
|
|
||||||
|
|
||||||
if count == 0 {
|
|
||||||
return 0xffff, false
|
|
||||||
}
|
|
||||||
|
|
||||||
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
|
|
||||||
r.lastIndex++
|
|
||||||
|
|
||||||
return out, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
|
|
||||||
func (r *UsedRing) InitOfferSingle(x uint16, size int) {
|
|
||||||
//always called under lock
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
offset := 0
|
|
||||||
// Add descriptor chain heads to the ring.
|
|
||||||
|
|
||||||
// The 16-bit ring index may overflow. This is expected and is not an
|
|
||||||
// issue because the size of the ring array (which equals the queue
|
|
||||||
// size) is always a power of 2 and smaller than the highest possible
|
|
||||||
// 16-bit value.
|
|
||||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
|
||||||
r.ring[insertIndex] = UsedElement{
|
|
||||||
DescriptorIndex: uint32(x),
|
|
||||||
Length: uint32(size),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increase the ring index by the number of descriptor chains added to the ring.
|
|
||||||
*r.ringIndex += 1
|
|
||||||
}
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUsedRing_MemoryLayout(t *testing.T) {
|
|
||||||
const queueSize = 2
|
|
||||||
|
|
||||||
memory := make([]byte, usedRingSize(queueSize))
|
|
||||||
r := newUsedRing(queueSize, memory)
|
|
||||||
|
|
||||||
*r.flags = 0x01ff
|
|
||||||
*r.ringIndex = 1
|
|
||||||
r.ring[0] = UsedElement{
|
|
||||||
DescriptorIndex: 0x0123,
|
|
||||||
Length: 0x4567,
|
|
||||||
}
|
|
||||||
r.ring[1] = UsedElement{
|
|
||||||
DescriptorIndex: 0x89ab,
|
|
||||||
Length: 0xcdef,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, []byte{
|
|
||||||
0xff, 0x01,
|
|
||||||
0x01, 0x00,
|
|
||||||
0x23, 0x01, 0x00, 0x00,
|
|
||||||
0x67, 0x45, 0x00, 0x00,
|
|
||||||
0xab, 0x89, 0x00, 0x00,
|
|
||||||
0xef, 0xcd, 0x00, 0x00,
|
|
||||||
0x00, 0x00,
|
|
||||||
}, memory)
|
|
||||||
}
|
|
||||||
|
|
||||||
//func TestUsedRing_Take(t *testing.T) {
|
|
||||||
// const queueSize = 8
|
|
||||||
//
|
|
||||||
// tests := []struct {
|
|
||||||
// name string
|
|
||||||
// ring []UsedElement
|
|
||||||
// ringIndex uint16
|
|
||||||
// lastIndex uint16
|
|
||||||
// expected []UsedElement
|
|
||||||
// }{
|
|
||||||
// {
|
|
||||||
// name: "nothing new",
|
|
||||||
// ring: []UsedElement{
|
|
||||||
// {DescriptorIndex: 1},
|
|
||||||
// {DescriptorIndex: 2},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// },
|
|
||||||
// ringIndex: 4,
|
|
||||||
// lastIndex: 4,
|
|
||||||
// expected: nil,
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "no overflow",
|
|
||||||
// ring: []UsedElement{
|
|
||||||
// {DescriptorIndex: 1},
|
|
||||||
// {DescriptorIndex: 2},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// },
|
|
||||||
// ringIndex: 4,
|
|
||||||
// lastIndex: 1,
|
|
||||||
// expected: []UsedElement{
|
|
||||||
// {DescriptorIndex: 2},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "ring overflow",
|
|
||||||
// ring: []UsedElement{
|
|
||||||
// {DescriptorIndex: 9},
|
|
||||||
// {DescriptorIndex: 10},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// {DescriptorIndex: 5},
|
|
||||||
// {DescriptorIndex: 6},
|
|
||||||
// {DescriptorIndex: 7},
|
|
||||||
// {DescriptorIndex: 8},
|
|
||||||
// },
|
|
||||||
// ringIndex: 10,
|
|
||||||
// lastIndex: 7,
|
|
||||||
// expected: []UsedElement{
|
|
||||||
// {DescriptorIndex: 8},
|
|
||||||
// {DescriptorIndex: 9},
|
|
||||||
// {DescriptorIndex: 10},
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "index overflow",
|
|
||||||
// ring: []UsedElement{
|
|
||||||
// {DescriptorIndex: 9},
|
|
||||||
// {DescriptorIndex: 10},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// {DescriptorIndex: 5},
|
|
||||||
// {DescriptorIndex: 6},
|
|
||||||
// {DescriptorIndex: 7},
|
|
||||||
// {DescriptorIndex: 8},
|
|
||||||
// },
|
|
||||||
// ringIndex: 2,
|
|
||||||
// lastIndex: 65535,
|
|
||||||
// expected: []UsedElement{
|
|
||||||
// {DescriptorIndex: 8},
|
|
||||||
// {DescriptorIndex: 9},
|
|
||||||
// {DescriptorIndex: 10},
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
// for _, tt := range tests {
|
|
||||||
// t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// memory := make([]byte, usedRingSize(queueSize))
|
|
||||||
// r := newUsedRing(queueSize, memory)
|
|
||||||
//
|
|
||||||
// copy(r.ring, tt.ring)
|
|
||||||
// *r.ringIndex = tt.ringIndex
|
|
||||||
// r.lastIndex = tt.lastIndex
|
|
||||||
//
|
|
||||||
// assert.Equal(t, tt.expected, r.take())
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
type OutPacket struct {
|
|
||||||
Segments [][]byte
|
|
||||||
SegmentPayloads [][]byte
|
|
||||||
SegmentHeaders [][]byte
|
|
||||||
SegmentIDs []uint16
|
|
||||||
//todo virtio header?
|
|
||||||
SegSize int
|
|
||||||
SegCounter int
|
|
||||||
Valid bool
|
|
||||||
wasSegmented bool
|
|
||||||
|
|
||||||
Scratch []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOut() *OutPacket {
|
|
||||||
out := new(OutPacket)
|
|
||||||
out.Segments = make([][]byte, 0, 64)
|
|
||||||
out.SegmentHeaders = make([][]byte, 0, 64)
|
|
||||||
out.SegmentPayloads = make([][]byte, 0, 64)
|
|
||||||
out.SegmentIDs = make([]uint16, 0, 64)
|
|
||||||
out.Scratch = make([]byte, Size)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pkt *OutPacket) Reset() {
|
|
||||||
pkt.Segments = pkt.Segments[:0]
|
|
||||||
pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
|
|
||||||
pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
|
|
||||||
pkt.SegmentIDs = pkt.SegmentIDs[:0]
|
|
||||||
pkt.SegSize = 0
|
|
||||||
pkt.Valid = false
|
|
||||||
pkt.wasSegmented = false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {
|
|
||||||
pkt.Valid = true
|
|
||||||
pkt.SegmentIDs = append(pkt.SegmentIDs, segID)
|
|
||||||
pkt.Segments = append(pkt.Segments, seg) //todo do we need this?
|
|
||||||
|
|
||||||
vhdr := virtio.NetHdr{ //todo
|
|
||||||
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
|
||||||
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
|
||||||
HdrLen: 0,
|
|
||||||
GSOSize: 0,
|
|
||||||
CsumStart: 0,
|
|
||||||
CsumOffset: 0,
|
|
||||||
NumBuffers: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
hdr := seg[0 : virtio.NetHdrSize+14]
|
|
||||||
_ = vhdr.Encode(hdr)
|
|
||||||
if isV6 {
|
|
||||||
hdr[virtio.NetHdrSize+14-2] = 0x86
|
|
||||||
hdr[virtio.NetHdrSize+14-1] = 0xdd
|
|
||||||
} else {
|
|
||||||
hdr[virtio.NetHdrSize+14-2] = 0x08
|
|
||||||
hdr[virtio.NetHdrSize+14-1] = 0x00
|
|
||||||
}
|
|
||||||
|
|
||||||
pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr)
|
|
||||||
pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:])
|
|
||||||
return len(pkt.SegmentIDs) - 1
|
|
||||||
}
|
|
||||||
119
packet/packet.go
119
packet/packet.go
@@ -1,119 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"iter"
|
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const Size = 0xffff
|
|
||||||
|
|
||||||
type Packet struct {
|
|
||||||
Payload []byte
|
|
||||||
Control []byte
|
|
||||||
Name []byte
|
|
||||||
SegSize int
|
|
||||||
|
|
||||||
//todo should this hold out as well?
|
|
||||||
OutLen int
|
|
||||||
|
|
||||||
wasSegmented bool
|
|
||||||
isV4 bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(isV4 bool) *Packet {
|
|
||||||
return &Packet{
|
|
||||||
Payload: make([]byte, Size),
|
|
||||||
Control: make([]byte, unix.CmsgSpace(2)),
|
|
||||||
Name: make([]byte, unix.SizeofSockaddrInet6),
|
|
||||||
isV4: isV4,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) AddrPort() netip.AddrPort {
|
|
||||||
var ip netip.Addr
|
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
|
||||||
if p.isV4 {
|
|
||||||
ip, _ = netip.AddrFromSlice(p.Name[4:8])
|
|
||||||
} else {
|
|
||||||
ip, _ = netip.AddrFromSlice(p.Name[8:24])
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) updateCtrl(ctrlLen int) {
|
|
||||||
p.SegSize = len(p.Payload)
|
|
||||||
p.wasSegmented = false
|
|
||||||
if ctrlLen == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(p.Control) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cmsgs, err := unix.ParseSocketControlMessage(p.Control)
|
|
||||||
if err != nil {
|
|
||||||
return // oh well
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cmsgs {
|
|
||||||
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
|
||||||
p.wasSegmented = true
|
|
||||||
p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2]))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update sets a Packet into "just received, not processed" state
|
|
||||||
func (p *Packet) Update(ctrlLen int) {
|
|
||||||
p.OutLen = -1
|
|
||||||
p.updateCtrl(ctrlLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) SetSegSizeForTX() {
|
|
||||||
p.SegSize = len(p.Payload)
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
|
|
||||||
hdr.Level = unix.SOL_UDP
|
|
||||||
hdr.Type = unix.UDP_SEGMENT
|
|
||||||
hdr.SetLen(syscall.CmsgLen(2))
|
|
||||||
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
|
|
||||||
//same dest
|
|
||||||
if !slices.Equal(p.Name, otherP.Name) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
//don't get too big
|
|
||||||
if len(p.Payload)+currentTotalSize >= 0xffff {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
//same body len
|
|
||||||
//todo allow single different size at end
|
|
||||||
if len(p.Payload) != len(otherP.Payload) {
|
|
||||||
return false //todo technically you can cram one extra in
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) Segments() iter.Seq[[]byte] {
|
|
||||||
return func(yield func([]byte) bool) {
|
|
||||||
//cursor := 0
|
|
||||||
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
|
|
||||||
end := offset + p.SegSize
|
|
||||||
if end > len(p.Payload) {
|
|
||||||
end = len(p.Payload)
|
|
||||||
}
|
|
||||||
if !yield(p.Payload[offset:end]) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type VirtIOPacket struct {
|
|
||||||
Payload []byte
|
|
||||||
Header virtio.NetHdr
|
|
||||||
Chains []uint16
|
|
||||||
ChainRefs [][]byte
|
|
||||||
// OfferDescriptorChains(chains []uint16, kick bool) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVIO() *VirtIOPacket {
|
|
||||||
out := new(VirtIOPacket)
|
|
||||||
out.Payload = nil
|
|
||||||
out.ChainRefs = make([][]byte, 0, 4)
|
|
||||||
out.Chains = make([]uint16, 0, 8)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *VirtIOPacket) Reset() {
|
|
||||||
v.Payload = nil
|
|
||||||
v.ChainRefs = v.ChainRefs[:0]
|
|
||||||
v.Chains = v.Chains[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
type VirtIOTXPacket struct {
|
|
||||||
VirtIOPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVIOTX(isV4 bool) *VirtIOTXPacket {
|
|
||||||
out := new(VirtIOTXPacket)
|
|
||||||
out.VirtIOPacket = *NewVIO()
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
@@ -180,7 +180,6 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
|
|||||||
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
||||||
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
||||||
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
||||||
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up the parameters which include the peer's public key
|
// Set up the parameters which include the peer's public key
|
||||||
|
|||||||
98
pki.go
98
pki.go
@@ -100,62 +100,55 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
currentState := p.cs.Load()
|
currentState := p.cs.Load()
|
||||||
if newState.v1Cert != nil {
|
if newState.v1Cert != nil {
|
||||||
if currentState.v1Cert == nil {
|
if currentState.v1Cert == nil {
|
||||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||||
} else {
|
|
||||||
// did IP in cert change? if so, don't set
|
|
||||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Networks in new cert was different from old",
|
|
||||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Curve in new v1 cert was different from old",
|
|
||||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// did IP in cert change? if so, don't set
|
||||||
|
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Networks in new cert was different from old",
|
||||||
|
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Curve in new cert was different from old",
|
||||||
|
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if currentState.v1Cert != nil {
|
||||||
|
//TODO: CERT-V2 we should be able to tear this down
|
||||||
|
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if newState.v2Cert != nil {
|
if newState.v2Cert != nil {
|
||||||
if currentState.v2Cert == nil {
|
if currentState.v2Cert == nil {
|
||||||
//adding certs is fine, actually
|
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||||
} else {
|
|
||||||
// did IP in cert change? if so, don't set
|
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Networks in new cert was different from old",
|
|
||||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Curve in new cert was different from old",
|
|
||||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if currentState.v2Cert != nil {
|
// did IP in cert change? if so, don't set
|
||||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||||
if newState.v1Cert == nil {
|
|
||||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
|
||||||
}
|
|
||||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
"Networks in new cert was different from old",
|
||||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Curve in new cert was different from old",
|
||||||
|
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if currentState.v2Cert != nil {
|
||||||
|
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cipher cant be hot swapped so just leave it at what it was before
|
// Cipher cant be hot swapped so just leave it at what it was before
|
||||||
@@ -180,6 +173,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
|
|
||||||
p.cs.Store(newState)
|
p.cs.Store(newState)
|
||||||
|
|
||||||
|
//TODO: CERT-V2 newState needs a stringer that does json
|
||||||
if initial {
|
if initial {
|
||||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||||
} else {
|
} else {
|
||||||
@@ -365,9 +359,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if v1.Networks()[0] != v2.Networks()[0] {
|
//TODO: CERT-V2 make sure v2 has v1s address
|
||||||
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
cs.initiatingVersion = dv
|
cs.initiatingVersion = dv
|
||||||
}
|
}
|
||||||
@@ -523,13 +515,9 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
|||||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bl := c.GetStringSlice("pki.blocklist", []string{})
|
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||||
if len(bl) > 0 {
|
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
||||||
for _, fp := range bl {
|
caPool.BlocklistFingerprint(fp)
|
||||||
caPool.BlocklistFingerprint(fp)
|
|
||||||
}
|
|
||||||
|
|
||||||
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return caPool, nil
|
return caPool, nil
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// forEachFunc is used to benefit folks that want to do work inside the lock
|
// forEachFunc is used to benefit folks that want to do work inside the lock
|
||||||
@@ -185,12 +185,12 @@ func (hr *hostnamesResults) GetAddrs() []netip.AddrPort {
|
|||||||
// It serves as a local cache of query replies, host update notifications, and locally learned addresses
|
// It serves as a local cache of query replies, host update notifications, and locally learned addresses
|
||||||
type RemoteList struct {
|
type RemoteList struct {
|
||||||
// Every interaction with internals requires a lock!
|
// Every interaction with internals requires a lock!
|
||||||
sync.RWMutex
|
synctrace.RWMutex
|
||||||
|
|
||||||
// The full list of vpn addresses assigned to this host
|
// The full list of vpn addresses assigned to this host
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
|
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
||||||
addrs []netip.AddrPort
|
addrs []netip.AddrPort
|
||||||
|
|
||||||
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
||||||
@@ -201,10 +201,8 @@ type RemoteList struct {
|
|||||||
// For learned addresses, this is the vpnIp that sent the packet
|
// For learned addresses, this is the vpnIp that sent the packet
|
||||||
cache map[netip.Addr]*cache
|
cache map[netip.Addr]*cache
|
||||||
|
|
||||||
hr *hostnamesResults
|
hr *hostnamesResults
|
||||||
|
shouldAdd func(netip.Addr) bool
|
||||||
// shouldAdd is a nillable function that decides if x should be added to addrs.
|
|
||||||
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
|
|
||||||
|
|
||||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||||
// They should not be tried again during a handshake
|
// They should not be tried again during a handshake
|
||||||
@@ -215,8 +213,9 @@ type RemoteList struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewRemoteList creates a new empty RemoteList
|
// NewRemoteList creates a new empty RemoteList
|
||||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
|
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
|
||||||
r := &RemoteList{
|
r := &RemoteList{
|
||||||
|
RWMutex: synctrace.NewRWMutex("remote-list"),
|
||||||
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
||||||
addrs: make([]netip.AddrPort, 0),
|
addrs: make([]netip.AddrPort, 0),
|
||||||
relays: make([]netip.Addr, 0),
|
relays: make([]netip.Addr, 0),
|
||||||
@@ -370,15 +369,6 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
|
|
||||||
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
|
|
||||||
r.Lock()
|
|
||||||
r.badRemotes = nil
|
|
||||||
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
|
|
||||||
copy(r.vpnAddrs, vpnAddrs)
|
|
||||||
r.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetBlockedRemotes locks and clears the blocked remotes list
|
// ResetBlockedRemotes locks and clears the blocked remotes list
|
||||||
func (r *RemoteList) ResetBlockedRemotes() {
|
func (r *RemoteList) ResetBlockedRemotes() {
|
||||||
r.Lock()
|
r.Lock()
|
||||||
@@ -588,7 +578,7 @@ func (r *RemoteList) unlockedCollect() {
|
|||||||
|
|
||||||
dnsAddrs := r.hr.GetAddrs()
|
dnsAddrs := r.hr.GetAddrs()
|
||||||
for _, addr := range dnsAddrs {
|
for _, addr := range dnsAddrs {
|
||||||
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
|
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
||||||
if !r.unlockedIsBad(addr) {
|
if !r.unlockedIsBad(addr) {
|
||||||
addrs = append(addrs, addr)
|
addrs = append(addrs, addr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"go.yaml.in/yaml/v3"
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/armon/go-radix"
|
"github.com/armon/go-radix"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ type SSHServer struct {
|
|||||||
listener net.Listener
|
listener net.Listener
|
||||||
|
|
||||||
// Locks the conns/counter to avoid concurrent map access
|
// Locks the conns/counter to avoid concurrent map access
|
||||||
connsLock sync.Mutex
|
connsLock synctrace.Mutex
|
||||||
conns map[int]*session
|
conns map[int]*session
|
||||||
counter int
|
counter int
|
||||||
}
|
}
|
||||||
@@ -41,6 +41,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
|
|||||||
l: l,
|
l: l,
|
||||||
commands: radix.New(),
|
commands: radix.New(),
|
||||||
conns: make(map[int]*session),
|
conns: make(map[int]*session),
|
||||||
|
connsLock: synctrace.NewMutex("ssh-server-conns"),
|
||||||
}
|
}
|
||||||
|
|
||||||
cc := ssh.CertChecker{
|
cc := ssh.CertChecker{
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/wadey/synctrace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// How many timer objects should be cached
|
// How many timer objects should be cached
|
||||||
@@ -34,7 +35,7 @@ type TimerWheel[T any] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type LockingTimerWheel[T any] struct {
|
type LockingTimerWheel[T any] struct {
|
||||||
m sync.Mutex
|
m synctrace.Mutex
|
||||||
t *TimerWheel[T]
|
t *TimerWheel[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,8 +82,9 @@ func NewTimerWheel[T any](min, max time.Duration) *TimerWheel[T] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty
|
// NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty
|
||||||
func NewLockingTimerWheel[T any](min, max time.Duration) *LockingTimerWheel[T] {
|
func NewLockingTimerWheel[T any](name string, min, max time.Duration) *LockingTimerWheel[T] {
|
||||||
return &LockingTimerWheel[T]{
|
return &LockingTimerWheel[T]{
|
||||||
|
m: synctrace.NewMutex(name),
|
||||||
t: NewTimerWheel[T](min, max),
|
t: NewTimerWheel[T](min, max),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const MTU = 9001
|
const MTU = 9001
|
||||||
|
|
||||||
type EncReader func(
|
type EncReader func(
|
||||||
[]*packet.Packet,
|
addr netip.AddrPort,
|
||||||
|
payload []byte,
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
@@ -19,8 +19,6 @@ type Conn interface {
|
|||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Prep(pkt *packet.Packet, addr netip.AddrPort) error
|
|
||||||
WriteBatch(pkt []*packet.Packet) (int, error)
|
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
package udp
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")
|
|
||||||
@@ -3,62 +3,20 @@
|
|||||||
|
|
||||||
package udp
|
package udp
|
||||||
|
|
||||||
|
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StdConn struct {
|
|
||||||
*net.UDPConn
|
|
||||||
isV4 bool
|
|
||||||
sysFd uintptr
|
|
||||||
l *logrus.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Conn = &StdConn{}
|
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
lc := NewListenConfig(multi)
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if uc, ok := pc.(*net.UDPConn); ok {
|
|
||||||
c := &StdConn{UDPConn: uc, l: l}
|
|
||||||
|
|
||||||
rc, err := uc.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to open udp socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rc.Control(func(fd uintptr) {
|
|
||||||
c.sysFd = fd
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get udp fd: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
la, err := c.LocalAddr()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.isV4 = la.Addr().Is4()
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListenConfig(multi bool) net.ListenConfig {
|
func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
@@ -85,116 +43,16 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:linkname sendto golang.org/x/sys/unix.sendto
|
func (u *GenericConn) Rebind() error {
|
||||||
//go:noescape
|
rc, err := u.UDPConn.SyscallConn()
|
||||||
func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
|
|
||||||
|
|
||||||
func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
|
||||||
var sa unsafe.Pointer
|
|
||||||
var addrLen int32
|
|
||||||
|
|
||||||
if u.isV4 {
|
|
||||||
if ap.Addr().Is6() {
|
|
||||||
return ErrInvalidIPv6RemoteForSocket
|
|
||||||
}
|
|
||||||
|
|
||||||
var rsa unix.RawSockaddrInet6
|
|
||||||
rsa.Family = unix.AF_INET6
|
|
||||||
rsa.Addr = ap.Addr().As16()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
|
||||||
sa = unsafe.Pointer(&rsa)
|
|
||||||
addrLen = syscall.SizeofSockaddrInet4
|
|
||||||
} else {
|
|
||||||
var rsa unix.RawSockaddrInet6
|
|
||||||
rsa.Family = unix.AF_INET6
|
|
||||||
rsa.Addr = ap.Addr().As16()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
|
||||||
sa = unsafe.Pointer(&rsa)
|
|
||||||
addrLen = syscall.SizeofSockaddrInet6
|
|
||||||
}
|
|
||||||
|
|
||||||
// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
|
|
||||||
// See https://github.com/golang/go/issues/73919
|
|
||||||
for {
|
|
||||||
//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
|
|
||||||
err := sendto(int(u.sysFd), b, 0, sa, addrLen)
|
|
||||||
if err == nil {
|
|
||||||
// Written, get out before the error handling
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.Is(err, syscall.EINTR) {
|
|
||||||
// Write was interrupted, retry
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.Is(err, syscall.EAGAIN) {
|
|
||||||
return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.Is(err, syscall.EBADF) {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
return &net.OpError{Op: "sendto", Err: err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|
||||||
a := u.UDPConn.LocalAddr()
|
|
||||||
|
|
||||||
switch v := a.(type) {
|
|
||||||
case *net.UDPAddr:
|
|
||||||
addr, ok := netip.AddrFromSlice(v.IP)
|
|
||||||
if !ok {
|
|
||||||
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
|
||||||
// TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|
||||||
// No UDP stats for non-linux
|
|
||||||
return func() {}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
|
||||||
buffer := make([]byte, MTU)
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Just read one packet at a time
|
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
|
||||||
}
|
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
|
||||||
var err error
|
|
||||||
if u.isV4 {
|
|
||||||
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
|
|
||||||
} else {
|
|
||||||
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Error("Failed to rebind udp socket")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return rc.Control(func(fd uintptr) {
|
||||||
|
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
|
||||||
|
if err != nil {
|
||||||
|
u.l.WithError(err).Error("Failed to rebind udp socket")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
//go:build (!linux || android) && !e2e_testing && !darwin
|
//go:build (!linux || android) && !e2e_testing
|
||||||
// +build !linux android
|
// +build !linux android
|
||||||
// +build !e2e_testing
|
// +build !e2e_testing
|
||||||
// +build !darwin
|
|
||||||
|
|
||||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||||
// means it can be used on platforms like Darwin and Windows.
|
// means it can be used on platforms like Darwin and Windows.
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user