Compare commits

..

18 Commits

Author SHA1 Message Date
Wade Simmons
b5b9d33ee7 v1.5.2 (#612)
Update CHANGELOG for Nebula v1.5.2
2021-12-14 16:48:56 -05:00
Wade Simmons
e434ba6523 fix unsafe routes darwin (#610)
With Nebula 1.4.0, if you create an unsafe_route that has a collision with an existing route on the system, the unsafe_route will be silently dropped (and the existing system route remains).

With Nebula 1.5.0, this same situation will cause Nebula to fail to start with an error (EEXIST).

This change restores the Nebula 1.4.0 behavior (but with a WARN log as well).
2021-12-14 11:52:49 -05:00
Wade Simmons
068a93d1f4 fix makeRouteTree allowMTU (#611)
With the previous implementation, we check if route.MTU is greater than zero,
but it will always be because we set it to the default MTU in
parseUnsafeRoutes. This change leaves it as zero in parseUnsafeRoutes so
it can be examined later.
2021-12-14 11:52:28 -05:00
Nate Brown
15fdabc3ab v1.5.1 (#606)
Update CHANGELOG for Nebula v1.5.1
2021-12-13 20:43:25 -05:00
forfuncsake
1110756f0f Allow setup of a CA pool from bytes that contain expired certs (#599)
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2021-12-09 21:24:56 -06:00
Nate Brown
e31006d546 Be more clear about ipv4 in nebula-cert (#604) 2021-12-07 21:40:30 -06:00
Wade Simmons
949ec78653 don't set ConnectionState to nil (#590)
* don't set ConnectionState to nil

We might have packets processing in another thread, so we can't safely
just set this to nil. Since we removed it from the hostmaps, the next
packets to process should start the handshake over again.

I believe this comment is outdated or incorrect, since the next
handshake will start over with a new HostInfo, I don't think there is
any way a counter reuse could happen:

> We must null the connectionstate or a counter reuse may happen

Here is a panic we saw that I think is related:

    panic: runtime error: invalid memory address or nil pointer dereference
    [signal SIGSEGV: segmentation violation code=0x1 addr=0x20 pc=0x93a037]
    goroutine 59 [running, locked to thread]:
    github.com/slackhq/nebula.(*Firewall).Drop(...)
            github.com/slackhq/nebula/firewall.go:380
    github.com/slackhq/nebula.(*Interface).consumeInsidePacket(...)
            github.com/slackhq/nebula/inside.go:59
    github.com/slackhq/nebula.(*Interface).listenIn(...)
            github.com/slackhq/nebula/interface.go:233
    created by github.com/slackhq/nebula.(*Interface).run
            github.com/slackhq/nebula/interface.go:191

* use closeTunnel
2021-12-06 14:09:05 -05:00
Wade Simmons
127a116bfd update golang.org/x/crypto (#603)
> Version v0.0.0-20211202192323-5770296d904e of golang.org/x/crypto fixes a vulnerability in the golang.org/x/crypto/ssh package which allowed unauthenticated clients to cause a panic in SSH servers.
>
> This issue was discovered and reported by Rod Hynes, Psiphon Inc., and is tracked as CVE-2021-43565 and Issue golang/go#49932.

    Updated  golang.org/x/crypto  089bfa5675...5770296d90
    Updated  golang.org/x/net     4a448f8816...69e39bad7d
2021-12-06 14:07:05 -05:00
Wade Simmons
befce3f990 fix crash with -test (#602)
When running in `-test` mode, `tun` is set to nil. So we should move the
defer into the `!configTest` if block.

    panic: runtime error: invalid memory address or nil pointer dereference
    [signal SIGSEGV: segmentation violation code=0x1 addr=0x28 pc=0x54855c]

    goroutine 1 [running]:
    github.com/slackhq/nebula.Main.func3(0x4000135e80, {0x0, 0x0})
            github.com/slackhq/nebula/main.go:176 +0x2c
    github.com/slackhq/nebula.Main(0x400022e060, 0x1, {0x76faa0, 0x5}, 0x4000230000, 0x0)
            github.com/slackhq/nebula/main.go:316 +0x2414
    main.main()
            github.com/slackhq/nebula/cmd/nebula/main.go:54 +0x540
2021-12-06 14:06:16 -05:00
Wade Simmons
f60ed2b36d overlay: fix tun.RouteFor getting *net.IP (#595)
tun.RouteFor expects the routeTree to have an iputil.VpnIp inside of it
instead of a *net.IP.
2021-12-06 09:35:31 -05:00
Nate Brown
48c47f5841 Warn if no lighthouses were configured on a non lighthouse node (#587) 2021-11-30 10:31:33 -06:00
Wade Simmons
75306487c5 fix wintun package to have // +build comments (#598)
Without these comments, gofmt 1.16.9 will complain. Since we otherwise
still support building with go1.16, lets add the comments to make it
easier to compile and gofmt.

Related: #588
2021-11-30 11:14:15 -05:00
Nate Brown
78d0d46bae Remove WriteRaw, cidrTree -> routeTree to better describe its purpose, remove redundancy from field names (#582) 2021-11-12 12:47:09 -06:00
Nate Brown
467e605d5e Push route handling into overlay, a few more nits fixed (#581) 2021-11-12 11:19:28 -06:00
Nate Brown
2f1f0d602f Cleanup most of the remaining nits (#578) 2021-11-12 10:47:36 -06:00
Nate Brown
e07524a654 Move all of tun into overlay (#577) 2021-11-11 16:37:29 -06:00
Nate Brown
88ce0edf76 Start the overlay package with the old Inside interface (#576) 2021-11-10 21:52:26 -06:00
Nate Brown
4453964e34 Move util to test, contextual errors to util (#575) 2021-11-10 21:47:38 -06:00
57 changed files with 1363 additions and 1102 deletions

View File

@@ -7,6 +7,45 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [1.5.2] - 2021-12-14
### Added
- Warn when a non lighthouse node does not have lighthouse hosts configured. (#587)
### Changed
- No longer fatals if expired CA certificates are present in `pki.ca`, as long as 1 valid CA is present. (#599)
- `nebula-cert` will now enforce ipv4 addresses. (#604)
- Warn on macOS if an unsafe route cannot be created due to a collision with an
existing route. (#610)
- Warn if you set a route MTU on platforms where we don't support it. (#611)
### Fixed
- Rare race condition when tearing down a tunnel due to `recv_error` and sending packets on another thread. (#590)
- Bug in `routes` and `unsafe_routes` handling that was introduced in 1.5.0. (#595)
- `-test` mode no longer results in a crash. (#602)
### Removed
- `x509.ca` config alias for `pki.ca`. (#604)
### Security
- Upgraded `golang.org/x/crypto` to address an issue which allowed unauthenticated clients to cause a panic in SSH
servers. (#603)
## 1.5.1 - 2021-12-13
(This release was skipped due to discovering #610 and #611 after the tag was
created.)
## [1.5.0] - 2021-11-11
### Added
@@ -306,7 +345,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.5.0...HEAD
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.5.2...HEAD
[1.5.2]: https://github.com/slackhq/nebula/releases/tag/v1.5.2
[1.5.0]: https://github.com/slackhq/nebula/releases/tag/v1.5.0
[1.4.0]: https://github.com/slackhq/nebula/releases/tag/v1.4.0
[1.3.0]: https://github.com/slackhq/nebula/releases/tag/v1.3.0

View File

@@ -7,12 +7,12 @@ import (
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func TestNewAllowListFromConfig(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
c := config.NewC(l)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,

View File

@@ -3,12 +3,12 @@ package nebula
import (
"testing"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func TestBits(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
b := NewBits(10)
// make sure it is the right size
@@ -76,7 +76,7 @@ func TestBits(t *testing.T) {
}
func TestBitsDupeCounter(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
@@ -101,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
}
func TestBitsOutOfWindowCounter(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
@@ -131,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
}
func TestBitsLostCounter(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()

26
cert.go
View File

@@ -124,19 +124,13 @@ func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error)
var err error
caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" {
// Support backwards compat with the old x509
//TODO: remove after this is rolled out everywhere - NB 2018/02/23
caPathOrPEM = c.GetString("x509.ca", "")
}
if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided")
}
if strings.Contains(caPathOrPEM, "-----BEGIN") {
rawCA = []byte(caPathOrPEM)
caPathOrPEM = "<inline>"
} else {
rawCA, err = ioutil.ReadFile(caPathOrPEM)
if err != nil {
@@ -145,7 +139,20 @@ func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error)
}
CAs, err := cert.NewCAPoolFromBytes(rawCA)
if err != nil {
if errors.Is(err, cert.ErrExpired) {
var expired int
for _, cert := range CAs.CAs {
if cert.Expired(time.Now()) {
expired++
l.WithField("cert", cert).Warn("expired certificate present in CA pool")
}
}
if expired >= len(CAs.CAs) {
return nil, errors.New("no valid CA certificates present")
}
} else if err != nil {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
}
@@ -154,7 +161,8 @@ func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error)
CAs.BlocklistFingerprint(fp)
}
// Support deprecated config for at leaast one minor release to allow for migrations
// Support deprecated config for at least one minor release to allow for migrations
//TODO: remove in 2022 or later
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")

View File

@@ -1,6 +1,7 @@
package cert
import (
"errors"
"fmt"
"strings"
"time"
@@ -21,19 +22,32 @@ func NewCAPool() *NebulaCAPool {
return &ca
}
// NewCAPoolFromBytes will create a new CA pool from the provided
// input bytes, which must be a PEM-encoded set of nebula certificates.
// If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
pool := NewCAPool()
var err error
var expired bool
for {
caPEMs, err = pool.AddCACertificate(caPEMs)
if errors.Is(err, ErrExpired) {
expired = true
err = nil
}
if err != nil {
return nil, err
}
if caPEMs == nil || len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
}
}
if expired {
return pool, ErrExpired
}
return pool, nil
}
@@ -47,15 +61,11 @@ func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
}
if !c.Details.IsCA {
return pemBytes, fmt.Errorf("provided certificate was not a CA; %s", c.Details.Name)
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA)
}
if !c.CheckSignature(c.Details.PublicKey) {
return pemBytes, fmt.Errorf("provided certificate was not self signed; %s", c.Details.Name)
}
if c.Expired(time.Now()) {
return pemBytes, fmt.Errorf("provided CA certificate is expired; %s", c.Details.Name)
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned)
}
sum, err := c.Sha256Sum()
@@ -64,6 +74,10 @@ func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
}
ncp.CAs[sum] = c
if c.Expired(time.Now()) {
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired)
}
return pemBytes, nil
}

View File

@@ -9,7 +9,7 @@ import (
"time"
"github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
@@ -429,6 +429,15 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
-----END NEBULA CERTIFICATE-----
`
expired := `
# expired certificate
-----BEGIN NEBULA CERTIFICATE-----
CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4
vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie
WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
-----END NEBULA CERTIFICATE-----
`
rootCA := NebulaCertificate{
@@ -452,6 +461,19 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
assert.Nil(t, err)
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
// expired cert, no valid certs
ppp, err := NewCAPoolFromBytes([]byte(expired))
assert.Equal(t, ErrExpired, err)
assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
// expired cert, with valid certs
pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...))
assert.Equal(t, ErrExpired, err)
assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
assert.Equal(t, len(pppp.CAs), 3)
}
func appendByteSlices(b ...[]byte) []byte {
@@ -752,7 +774,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
assert.Nil(t, err)
cc := c.Copy()
util.AssertDeepCopyEqual(t, c, cc)
test.AssertDeepCopyEqual(t, c, cc)
}
func TestUnmarshalNebulaCertificate(t *testing.T) {

9
cert/errors.go Normal file
View File

@@ -0,0 +1,9 @@
package cert
import "errors"
var (
ErrExpired = errors.New("certificate is expired")
ErrNotCA = errors.New("certificate is not a CA")
ErrNotSelfSigned = errors.New("certificate is not self-signed")
)

View File

@@ -37,8 +37,8 @@ func newCaFlags() *caFlags {
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ip and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use")
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ip and network in CIDR notation. This will limit which subnet addresses and networks subordinate certs can use")
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
return &cf
}
@@ -82,6 +82,9 @@ func ca(args []string, out io.Writer, errOut io.Writer) error {
if err != nil {
return newHelpErrorf("invalid ip definition: %s", err)
}
if ip.To4() == nil {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
}
ipNet.IP = ip
ips = append(ips, ipNet)
@@ -98,6 +101,9 @@ func ca(args []string, out io.Writer, errOut io.Writer) error {
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", err)
}
if s.IP.To4() == nil {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
}
subnets = append(subnets, s)
}
}

View File

@@ -31,7 +31,7 @@ func Test_caHelp(t *testing.T) {
" -groups string\n"+
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
" -ips string\n"+
" \tOptional: comma separated list of ip and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
" -name string\n"+
" \tRequired: name of the certificate authority\n"+
" -out-crt string\n"+
@@ -41,7 +41,7 @@ func Test_caHelp(t *testing.T) {
" -out-qr string\n"+
" \tOptional: output a qr code image (png) of the certificate\n"+
" -subnets string\n"+
" \tOptional: comma separated list of ip and network in CIDR notation. This will limit which subnet addresses and networks subordinate certs can use\n",
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
ob.String(),
)
}
@@ -55,6 +55,16 @@ func Test_ca(t *testing.T) {
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// ipv4 only ips
assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// ipv4 only subnets
assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// failed key write
ob.Reset()
eb.Reset()

View File

@@ -37,14 +37,14 @@ func newSignFlags() *signFlags {
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
sf.ip = sf.set.String("ip", "", "Required: ip and network in CIDR notation to assign the cert")
sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of subnet this cert can serve for")
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
return &sf
}
@@ -114,6 +114,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error {
if err != nil {
return newHelpErrorf("invalid ip definition: %s", err)
}
if ip.To4() == nil {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
}
ipNet.IP = ip
groups := []string{}
@@ -135,6 +138,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error {
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", err)
}
if s.IP.To4() == nil {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
}
subnets = append(subnets, s)
}
}

View File

@@ -39,7 +39,7 @@ func Test_signHelp(t *testing.T) {
" -in-pub string\n"+
" \tOptional (if out-key not set): path to read a previously generated public key\n"+
" -ip string\n"+
" \tRequired: ip and network in CIDR notation to assign the cert\n"+
" \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
" -name string\n"+
" \tRequired: name of the cert, usually a hostname\n"+
" -out-crt string\n"+
@@ -49,7 +49,7 @@ func Test_signHelp(t *testing.T) {
" -out-qr string\n"+
" \tOptional: output a qr code image (png) of the certificate\n"+
" -subnets string\n"+
" \tOptional: comma separated list of subnet this cert can serve for\n",
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
ob.String(),
)
}
@@ -59,7 +59,6 @@ func Test_signCert(t *testing.T) {
eb := &bytes.Buffer{}
// required args
assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-name is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@@ -160,6 +159,13 @@ func Test_signCert(t *testing.T) {
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// bad subnet cidr
ob.Reset()
eb.Reset()
@@ -168,6 +174,13 @@ func Test_signCert(t *testing.T) {
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// mismatched ca key
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
caKeyF2, err := ioutil.TempFile("", "sign-cert-2.key")

View File

@@ -8,6 +8,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
)
// A version string that can be set with
@@ -60,7 +61,7 @@ func main() {
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
case util.ContextualError:
v.Log(l)
os.Exit(1)
case error:

View File

@@ -8,6 +8,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
)
// A version string that can be set with
@@ -54,7 +55,7 @@ func main() {
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
case util.ContextualError:
v.Log(l)
os.Exit(1)
case error:

View File

@@ -7,12 +7,12 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func TestConfig_Load(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
dir, err := ioutil.TempDir("", "config-test")
// invalid yaml
c := NewC(l)
@@ -42,7 +42,7 @@ func TestConfig_Load(t *testing.T) {
}
func TestConfig_Get(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
// test simple type
c := NewC(l)
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
@@ -58,14 +58,14 @@ func TestConfig_Get(t *testing.T) {
}
func TestConfig_GetStringSlice(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
c := NewC(l)
c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
}
func TestConfig_GetBool(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
c := NewC(l)
c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false))
@@ -93,7 +93,7 @@ func TestConfig_GetBool(t *testing.T) {
}
func TestConfig_HasChanged(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
// No reload has occurred, return false
c := NewC(l)
c.Settings["test"] = "hi"
@@ -115,7 +115,7 @@ func TestConfig_HasChanged(t *testing.T) {
}
func TestConfig_ReloadConfig(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
done := make(chan bool, 1)
dir, err := ioutil.TempDir("", "config-test")
assert.Nil(t, err)

View File

@@ -11,15 +11,15 @@ import (
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
var vpnIp iputil.VpnIp
func Test_NewConnectionManagerTest(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -38,7 +38,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
@@ -89,7 +89,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
}
func Test_NewConnectionManagerTest2(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -107,7 +107,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
@@ -164,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Disconnect only if disconnectInvalid: true is set.
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
now := time.Now()
l := util.NewTestLogger()
l := test.NewLogger()
ipNet := net.IPNet{
IP: net.IPv4(172, 1, 1, 2),
Mask: net.IPMask{255, 255, 255, 0},
@@ -216,7 +216,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},

View File

@@ -9,13 +9,13 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
@@ -94,7 +94,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {

View File

@@ -10,6 +10,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
)
@@ -64,7 +65,7 @@ func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
// GetFromTun will pull a packet off the tun side of nebula
func (c *Control) GetFromTun(block bool) []byte {
return c.f.inside.(*Tun).Get(block)
return c.f.inside.(*overlay.TestTun).Get(block)
}
// GetFromUDP will pull a udp packet off the udp side of nebula
@@ -77,7 +78,7 @@ func (c *Control) GetUDPTxChan() <-chan *udp.Packet {
}
func (c *Control) GetTunTxChan() <-chan []byte {
return c.f.inside.(*Tun).txPackets
return c.f.inside.(*overlay.TestTun).TxPackets
}
// InjectUDPPacket will inject a packet into the udp side of nebula
@@ -91,7 +92,7 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: c.f.inside.CidrNet().IP,
SrcIP: c.f.inside.Cidr().IP,
DstIP: toIp,
}
@@ -114,7 +115,7 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
panic(err)
}
c.f.inside.(*Tun).Send(buffer.Bytes())
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
}
func (c *Control) GetUDPAddr() string {

View File

@@ -14,12 +14,12 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func TestNewFirewall(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
c := &cert.NebulaCertificate{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
conntrack := fw.Conntrack
@@ -58,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
}
func TestFirewall_AddRule(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
@@ -133,7 +133,7 @@ func TestFirewall_AddRule(t *testing.T) {
}
func TestFirewall_Drop(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
@@ -308,7 +308,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
func TestFirewall_Drop2(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
@@ -367,7 +367,7 @@ func TestFirewall_Drop2(t *testing.T) {
}
func TestFirewall_Drop3(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
@@ -453,7 +453,7 @@ func TestFirewall_Drop3(t *testing.T) {
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
@@ -635,7 +635,7 @@ func Test_parsePort(t *testing.T) {
}
func TestNewFirewallFromConfig(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
// Test a bad rule definition
c := &cert.NebulaCertificate{}
conf := config.NewC(l)
@@ -685,7 +685,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
}
func TestAddFirewallRulesFromConfig(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
// Test adding tcp rule
conf := config.NewC(l)
mf := &mockFirewall{}
@@ -849,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
}
func TestFirewall_convertRule(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)

4
go.mod
View File

@@ -25,8 +25,8 @@ require (
github.com/stretchr/testify v1.7.0
github.com/vishvananda/netlink v1.1.0
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
golang.zx2c4.com/wireguard/windows v0.5.1

4
go.sum
View File

@@ -243,6 +243,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e h1:MUP6MR3rJ7Gk9LEia0LP2ytiH6MuCfs7qYz+47jGdD8=
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -309,6 +311,8 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20211020060615-d418f374d309/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU=
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 h1:CIJ76btIcR3eFI5EgSo6k1qKw9KJexJuRLI9G7Hp5wE=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=

View File

@@ -7,13 +7,13 @@ import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
}
func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")

View File

@@ -35,7 +35,6 @@ type HostMap struct {
Hosts map[iputil.VpnIp]*HostInfo
preferredRanges []*net.IPNet
vpnCIDR *net.IPNet
unsafeRoutes *cidr.Tree4
metricsEnabled bool
l *logrus.Logger
}
@@ -98,7 +97,6 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
unsafeRoutes: cidr.NewTree4(),
l: l,
}
return &m
@@ -332,15 +330,6 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
return nil, errors.New("unable to find host")
}
func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
r := hm.unsafeRoutes.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
} else {
return 0
}
}
// We already have the hm Lock when this is called, so make sure to not call
// any other methods that might try to grab it again
func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
@@ -408,17 +397,6 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
}
}
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes {
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via))
}
}
func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
i.ConnectionState = cs
}
// TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
@@ -559,10 +537,6 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
return false
}
func (i *HostInfo) ClearConnectionState() {
i.ConnectionState = nil
}
func (i *HostInfo) RecvErrorExceeded() bool {
if i.recvError < 3 {
i.recvError += 1

View File

@@ -72,7 +72,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
//TODO: we can find contains without converting back to bytes
if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false {
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
vpnIp = f.inside.RouteFor(vpnIp)
if vpnIp == 0 {
return nil
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"io"
"net"
"os"
"runtime"
"sync/atomic"
@@ -16,24 +15,16 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
)
const mtu = 9001
type Inside interface {
io.ReadWriteCloser
Activate() error
CidrNet() *net.IPNet
DeviceName() string
WriteRaw([]byte) error
NewMultiQueueReader() (io.ReadWriteCloser, error)
}
type InterfaceConfig struct {
HostMap *HostMap
Outside *udp.Conn
Inside Inside
Inside overlay.Device
certState *CertState
Cipher string
Firewall *Firewall
@@ -57,7 +48,7 @@ type InterfaceConfig struct {
type Interface struct {
hostMap *HostMap
outside *udp.Conn
inside Inside
inside overlay.Device
certState *CertState
cipher string
firewall *Firewall
@@ -156,7 +147,7 @@ func (f *Interface) activate() {
f.l.WithError(err).Error("Failed to get udp listen address")
}
f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active")

View File

@@ -8,8 +8,8 @@ import (
"github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
@@ -46,7 +46,7 @@ func TestNewLhQuery(t *testing.T) {
}
func Test_lhStaticMapping(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
@@ -67,7 +67,7 @@ func Test_lhStaticMapping(t *testing.T) {
}
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := util.NewTestLogger()
l := test.NewLogger()
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
@@ -137,7 +137,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
}
func TestLighthouse_Memory(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
@@ -266,7 +266,7 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
//TODO: this is a RemoteList test
//func Test_lhRemoteAllowList(t *testing.T) {
// l := NewTestLogger()
// l := NewLogger()
// c := NewConfig(l)
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
// "10.20.0.0/12": false,

View File

@@ -1,7 +1,6 @@
package nebula
import (
"errors"
"fmt"
"strings"
"time"
@@ -10,38 +9,6 @@ import (
"github.com/slackhq/nebula/config"
)
type ContextualError struct {
RealError error
Fields map[string]interface{}
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}
func configLogger(l *logrus.Logger, c *config.C) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))

85
main.go
View File

@@ -10,8 +10,10 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"gopkg.in/yaml.v2"
)
@@ -44,7 +46,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
err := configLogger(l, c)
if err != nil {
return nil, NewContextualError("Failed to configure the logger", nil, err)
return nil, util.NewContextualError("Failed to configure the logger", nil, err)
}
c.RegisterReloadCallback(func(c *config.C) {
@@ -57,33 +59,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
caPool, err := loadCAFromConfig(l, c)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err)
return nil, util.NewContextualError("Failed to load ca from config", nil, err)
}
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(c)
if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted
return nil, NewContextualError("Failed to load certificate from config", nil, err)
return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
}
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
if err != nil {
return nil, NewContextualError("Error while loading firewall rules", nil, err)
return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
}
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
// TODO: make sure mask is 4 bytes
tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(c, tunCidr)
if err != nil {
return nil, NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil {
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(l, ssh, c)
@@ -91,7 +85,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err)
return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
}
}
@@ -136,39 +130,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
}
var tun Inside
var tun overlay.Device
if !configTest {
c.CatchHUP(ctx)
switch {
case c.GetBool("tun.disabled", false):
tun = newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
case tunFd != nil:
tun, err = newTunFromFd(
l,
*tunFd,
tunCidr,
c.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
c.GetInt("tun.tx_queue", 500),
)
default:
tun, err = newTun(
l,
c.GetString("tun.dev", ""),
tunCidr,
c.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
c.GetInt("tun.tx_queue", 500),
routines > 1,
)
}
tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
if err != nil {
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
}
return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
}
defer func() {
@@ -176,6 +144,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
tun.Close()
}
}()
}
// set up our UDP listener
udpConns := make([]*udp.Conn, routines)
@@ -185,7 +154,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i := 0; i < routines; i++ {
udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil {
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
udpServer.ReloadConfig(c)
udpConns[i] = udpServer
@@ -194,7 +163,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if port == 0 {
uPort, err := udpServer.LocalAddr()
if err != nil {
return nil, NewContextualError("Failed to get listening port", nil, err)
return nil, util.NewContextualError("Failed to get listening port", nil, err)
}
port = int(uPort.Port)
}
@@ -209,7 +178,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
@@ -222,7 +191,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
return nil, NewContextualError("Failed to parse local_range", nil, err)
return nil, util.NewContextualError("Failed to parse local_range", nil, err)
}
// Check if the entry for local_range was already specified in
@@ -240,8 +209,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
@@ -261,7 +228,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
// fatal if am_lighthouse is enabled but we are using an ephemeral port
if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
}
// warn if am_lighthouse is enabled but upstream lighthouses exists
@@ -274,14 +241,18 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host)
if ip == nil {
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
}
if !tunCidr.Contains(ip) {
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
}
lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
}
if !amLighthouse && len(lighthouseHosts) == 0 {
l.Warn("No lighthouses.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
}
lightHouse := NewLightHouse(
l,
amLighthouse,
@@ -298,13 +269,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
if err != nil {
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
}
lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
if err != nil {
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
}
lightHouse.SetLocalAllowList(localAllowList)
@@ -313,21 +284,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
ip := net.ParseIP(fmt.Sprintf("%v", k))
vpnIp := iputil.Ip2VpnIp(ip)
if !tunCidr.Contains(ip) {
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
}
vals, ok := v.([]interface{})
if ok {
for _, v := range vals {
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
}
} else {
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
}
@@ -426,7 +397,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
statsStart, err := startStats(l, c, buildVersion, configTest)
if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err)
return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
}
if configTest {

View File

@@ -349,12 +349,9 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
return
}
// We delete this host from the main hostmap
f.hostMap.DeleteHostInfo(hostinfo)
// We also delete it from pending to allow for
// fast reconnect. We must null the connectionstate
// or a counter reuse may happen
hostinfo.ConnectionState = nil
f.closeTunnel(hostinfo, false)
// We also delete it from pending hostmap to allow for
// fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo)
}

17
overlay/device.go Normal file
View File

@@ -0,0 +1,17 @@
package overlay
import (
"io"
"net"
"github.com/slackhq/nebula/iputil"
)
type Device interface {
io.ReadWriteCloser
Activate() error
Cidr() *net.IPNet
Name() string
RouteFor(iputil.VpnIp) iputil.VpnIp
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

View File

@@ -1,29 +1,45 @@
package nebula
package overlay
import (
"fmt"
"math"
"net"
"runtime"
"strconv"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
)
const DEFAULT_MTU = 1300
type route struct {
mtu int
metric int
route *net.IPNet
via *net.IP
type Route struct {
MTU int
Metric int
Cidr *net.IPNet
Via *iputil.VpnIp
}
func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
routeTree := cidr.NewTree4()
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
}
if r.Via != nil {
routeTree.AddCIDR(r.Cidr, *r.Via)
}
}
return routeTree, nil
}
func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
var err error
r := c.Get("tun.routes")
if r == nil {
return []route{}, nil
return []Route{}, nil
}
rawRoutes, ok := r.([]interface{})
@@ -32,10 +48,10 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
}
if len(rawRoutes) < 1 {
return []route{}, nil
return []Route{}, nil
}
routes := make([]route, len(rawRoutes))
routes := make([]Route, len(rawRoutes))
for i, r := range rawRoutes {
m, ok := r.(map[interface{}]interface{})
if !ok {
@@ -64,20 +80,20 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
return nil, fmt.Errorf("entry %v.route in tun.routes is not present", i+1)
}
r := route{
mtu: mtu,
r := Route{
MTU: mtu,
}
_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
}
if !ipWithin(network, r.route) {
if !ipWithin(network, r.Cidr) {
return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
i+1,
r.route.String(),
r.Cidr.String(),
network.String(),
)
}
@@ -88,12 +104,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
return routes, nil
}
func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
var err error
r := c.Get("tun.unsafe_routes")
if r == nil {
return []route{}, nil
return []Route{}, nil
}
rawRoutes, ok := r.([]interface{})
@@ -102,22 +118,19 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
}
if len(rawRoutes) < 1 {
return []route{}, nil
return []Route{}, nil
}
routes := make([]route, len(rawRoutes))
routes := make([]Route, len(rawRoutes))
for i, r := range rawRoutes {
m, ok := r.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
}
rMtu, ok := m["mtu"]
if !ok {
rMtu = c.GetInt("tun.mtu", DEFAULT_MTU)
}
mtu, ok := rMtu.(int)
var mtu int
if rMtu, ok := m["mtu"]; ok {
mtu, ok = rMtu.(int)
if !ok {
mtu, err = strconv.Atoi(rMtu.(string))
if err != nil {
@@ -125,9 +138,10 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
}
}
if mtu < 500 {
if mtu != 0 && mtu < 500 {
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is below 500: %v", i+1, mtu)
}
}
rMetric, ok := m["metric"]
if !ok {
@@ -166,22 +180,24 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
}
r := route{
via: &nVia,
mtu: mtu,
metric: metric,
viaVpnIp := iputil.Ip2VpnIp(nVia)
r := Route{
Via: &viaVpnIp,
MTU: mtu,
Metric: metric,
}
_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
}
if ipWithin(network, r.route) {
if ipWithin(network, r.Cidr) {
return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
i+1,
r.route.String(),
r.Cidr.String(),
network.String(),
)
}

View File

@@ -1,4 +1,4 @@
package nebula
package overlay
import (
"fmt"
@@ -6,12 +6,13 @@ import (
"testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func Test_parseRoutes(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
@@ -91,12 +92,12 @@ func Test_parseRoutes(t *testing.T) {
tested := 0
for _, r := range routes {
if r.mtu == 8000 {
assert.Equal(t, "10.0.0.1/32", r.route.String())
if r.MTU == 8000 {
assert.Equal(t, "10.0.0.1/32", r.Cidr.String())
tested++
} else {
assert.Equal(t, 9000, r.mtu)
assert.Equal(t, "10.0.0.0/29", r.route.String())
assert.Equal(t, 9000, r.MTU)
assert.Equal(t, "10.0.0.0/29", r.Cidr.String())
tested++
}
}
@@ -107,7 +108,7 @@ func Test_parseRoutes(t *testing.T) {
}
func Test_parseUnsafeRoutes(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
@@ -190,7 +191,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Len(t, routes, 1)
assert.Equal(t, DEFAULT_MTU, routes[0].mtu)
assert.Equal(t, 0, routes[0].MTU)
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
@@ -216,17 +217,17 @@ func Test_parseUnsafeRoutes(t *testing.T) {
tested := 0
for _, r := range routes {
if r.mtu == 8000 {
assert.Equal(t, "1.0.0.1/32", r.route.String())
if r.MTU == 8000 {
assert.Equal(t, "1.0.0.1/32", r.Cidr.String())
tested++
} else if r.mtu == 9000 {
assert.Equal(t, 9000, r.mtu)
assert.Equal(t, "1.0.0.0/29", r.route.String())
} else if r.MTU == 9000 {
assert.Equal(t, 9000, r.MTU)
assert.Equal(t, "1.0.0.0/29", r.Cidr.String())
tested++
} else {
assert.Equal(t, 1500, r.mtu)
assert.Equal(t, 1234, r.metric)
assert.Equal(t, "1.0.0.2/32", r.route.String())
assert.Equal(t, 1500, r.MTU)
assert.Equal(t, 1234, r.Metric)
assert.Equal(t, "1.0.0.2/32", r.Cidr.String())
tested++
}
}
@@ -235,3 +236,35 @@ func Test_parseUnsafeRoutes(t *testing.T) {
t.Fatal("Did not see both unsafe_routes")
}
}
func Test_makeRouteTree(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
}}
routes, err := parseUnsafeRoutes(c, n)
assert.NoError(t, err)
assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true)
assert.NoError(t, err)
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
r := routeTree.MostSpecificContains(ip)
assert.NotNil(t, r)
assert.IsType(t, iputil.VpnIp(0), r)
assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
r = routeTree.MostSpecificContains(ip)
assert.NotNil(t, r)
assert.IsType(t, iputil.VpnIp(0), r)
assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
r = routeTree.MostSpecificContains(ip)
assert.Nil(t, r)
}

51
overlay/tun.go Normal file
View File

@@ -0,0 +1,51 @@
package overlay
import (
"net"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
)
const DefaultMTU = 1300
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) {
routes, err := parseRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
routes = append(routes, unsafeRoutes...)
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
case fd != nil:
return newTunFromFd(
l,
*fd,
tunCidr,
c.GetInt("tun.mtu", DefaultMTU),
routes,
c.GetInt("tun.tx_queue", 500),
)
default:
return newTun(
l,
c.GetString("tun.dev", ""),
tunCidr,
c.GetInt("tun.mtu", DefaultMTU),
routes,
c.GetInt("tun.tx_queue", 500),
routines > 1,
)
}
}

61
overlay/tun_android.go Normal file
View File

@@ -0,0 +1,61 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"fmt"
"io"
"net"
"os"
"runtime"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
)
type tun struct {
io.ReadWriteCloser
fd int
cidr *net.IPNet
l *logrus.Logger
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
if len(routes) > 0 {
return nil, fmt.Errorf("routes are not supported in %s", runtime.GOOS)
}
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
return &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
cidr: cidr,
l: l,
}, nil
}
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
func (t *tun) RouteFor(iputil.VpnIp) iputil.VpnIp {
return 0
}
func (t tun) Activate() error {
return nil
}
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}
func (t *tun) Name() string {
return "android"
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}

View File

@@ -1,9 +1,10 @@
//go:build !ios && !e2e_testing
// +build !ios,!e2e_testing
package nebula
package overlay
import (
"errors"
"fmt"
"io"
"net"
@@ -12,17 +13,19 @@ import (
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
type Tun struct {
type tun struct {
io.ReadWriteCloser
Device string
Cidr *net.IPNet
cidr *net.IPNet
DefaultMTU int
TXQueueLen int
UnsafeRoutes []route
Routes []Route
routeTree *cidr.Tree4
l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -74,15 +77,10 @@ type ifreqMTU struct {
pad [8]byte
}
type ifreqQLEN struct {
Name [16]byte
Value int32
pad [8]byte
}
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
ifIndex := -1
@@ -106,7 +104,7 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
ctlName [96]byte
}{}
copy(ctlInfo.ctlName[:], []byte(utunControlName))
copy(ctlInfo.ctlName[:], utunControlName)
err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
if err != nil {
@@ -125,7 +123,7 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
unix.SYS_CONNECT,
uintptr(fd),
uintptr(unsafe.Pointer(&sc)),
uintptr(sockaddrCtlSize),
sockaddrCtlSize,
)
if errno != 0 {
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
@@ -152,44 +150,44 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
file := os.NewFile(uintptr(fd), "")
tun := &Tun{
tun := &tun{
ReadWriteCloser: file,
Device: name,
Cidr: cidr,
cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
UnsafeRoutes: unsafeRoutes,
Routes: routes,
routeTree: routeTree,
l: l,
}
return tun, nil
}
func (t *Tun) deviceBytes() (o [16]byte) {
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
func (c *Tun) Close() error {
if c.ReadWriteCloser != nil {
return c.ReadWriteCloser.Close()
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func (t *Tun) Activate() error {
func (t *tun) Activate() error {
devName := t.deviceBytes()
var addr, mask [4]byte
copy(addr[:], t.Cidr.IP.To4())
copy(mask[:], t.Cidr.Mask)
copy(addr[:], t.cidr.IP.To4())
copy(mask[:], t.cidr.Mask)
s, err := unix.Socket(
unix.AF_INET,
@@ -231,7 +229,7 @@ func (t *Tun) Activate() error {
// Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
return fmt.Errorf("Failed to set tun mtu: %v", err)
return fmt.Errorf("failed to set tun mtu: %v", err)
}
/*
@@ -275,6 +273,9 @@ func (t *Tun) Activate() error {
copy(maskAddr.IP[:], mask[:])
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
}
return err
}
@@ -285,14 +286,24 @@ func (t *Tun) Activate() error {
}
// Unsafe path routes
for _, r := range t.UnsafeRoutes {
copy(routeAddr.IP[:], r.route.IP.To4())
copy(maskAddr.IP[:], net.IP(r.route.Mask).To4())
for _, r := range t.Routes {
if r.Via == nil {
// We don't allow route MTUs so only install routes with a via
continue
}
copy(routeAddr.IP[:], r.Cidr.IP.To4())
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
return err
}
}
// TODO how to set metric
}
@@ -300,6 +311,15 @@ func (t *Tun) Activate() error {
return nil
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
}
// Get the LinkAddr for the interface of the given name
// TODO: Is there an easier way to fetch this when we create the interface?
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
@@ -343,19 +363,17 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
data, err := r.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %v", err)
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %v", err)
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
var _ io.ReadWriteCloser = (*Tun)(nil)
func (t *Tun) Read(to []byte) (int, error) {
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
@@ -366,7 +384,7 @@ func (t *Tun) Read(to []byte) (int, error) {
}
// Write is only valid for single threaded use
func (t *Tun) Write(from []byte) (int, error) {
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
@@ -385,7 +403,7 @@ func (t *Tun) Write(from []byte) (int, error) {
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("Unable to determine IP version from packet")
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
@@ -394,19 +412,14 @@ func (t *Tun) Write(from []byte) (int, error) {
return n - 4, err
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}
func (c *Tun) DeviceName() string {
return c.Device
func (t *tun) Name() string {
return t.Device
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}

View File

@@ -1,4 +1,4 @@
package nebula
package overlay
import (
"encoding/binary"
@@ -9,6 +9,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
)
type disabledTun struct {
@@ -43,11 +44,15 @@ func (*disabledTun) Activate() error {
return nil
}
func (t *disabledTun) CidrNet() *net.IPNet {
func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
return 0
}
func (t *disabledTun) Cidr() *net.IPNet {
return t.cidr
}
func (*disabledTun) DeviceName() string {
func (*disabledTun) Name() string {
return "disabled"
}
@@ -71,7 +76,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
// Return early if this is not a simple ICMP Echo Request
if !(len(b) >= 28 && len(b) <= mtu && b[0] == 0x45 && b[9] == 0x01 && b[20] == 0x08) {
//TODO: make constants out of these
if !(len(b) >= 28 && len(b) <= 9001 && b[0] == 0x45 && b[9] == 0x01 && b[20] == 0x08) {
return false
}
@@ -122,11 +128,6 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil
}

122
overlay/tun_freebsd.go Normal file
View File

@@ -0,0 +1,122 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
)
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
type tun struct {
Device string
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4
l *logrus.Logger
io.ReadWriteCloser
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
if strings.HasPrefix(deviceName, "/dev/") {
deviceName = strings.TrimPrefix(deviceName, "/dev/")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
}
return &tun{
Device: deviceName,
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
l: l,
}, nil
}
func (t *tun) Activate() error {
var err error
t.ReadWriteCloser, err = os.OpenFile("/dev/"+t.Device, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("activate failed: %v", err)
}
// TODO use syscalls instead of exec.Command
t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
t.l.Debug("command: route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
if err = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range t.Routes {
if r.Via == nil {
// We don't allow route MTUs so only install routes with a via
continue
}
t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
}
}
return nil
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
}
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}
func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}

117
overlay/tun_ios.go Normal file
View File

@@ -0,0 +1,117 @@
//go:build ios && !e2e_testing
// +build ios,!e2e_testing
package overlay
import (
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"sync"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
)
type tun struct {
io.ReadWriteCloser
cidr *net.IPNet
}
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(_ *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
if len(routes) > 0 {
return nil, fmt.Errorf("routes are not supported in %s", runtime.GOOS)
}
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
return &tun{
cidr: cidr,
ReadWriteCloser: &tunReadCloser{f: file},
}, nil
}
func (t *tun) Activate() error {
return nil
}
func (t *tun) RouteFor(iputil.VpnIp) iputil.VpnIp {
return 0
}
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
type tunReadCloser struct {
f io.ReadWriteCloser
rMu sync.Mutex
rBuf []byte
wMu sync.Mutex
wBuf []byte
}
func (tr *tunReadCloser) Read(to []byte) (int, error) {
tr.rMu.Lock()
defer tr.rMu.Unlock()
if cap(tr.rBuf) < len(to)+4 {
tr.rBuf = make([]byte, len(to)+4)
}
tr.rBuf = tr.rBuf[:len(to)+4]
n, err := tr.f.Read(tr.rBuf)
copy(to, tr.rBuf[4:])
return n - 4, err
}
func (tr *tunReadCloser) Write(from []byte) (int, error) {
if len(from) == 0 {
return 0, syscall.EIO
}
tr.wMu.Lock()
defer tr.wMu.Unlock()
if cap(tr.wBuf) < len(from)+4 {
tr.wBuf = make([]byte, len(from)+4)
}
tr.wBuf = tr.wBuf[:len(from)+4]
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
tr.wBuf[3] = syscall.AF_INET
} else if ipVer == 6 {
tr.wBuf[3] = syscall.AF_INET6
} else {
return 0, errors.New("unable to determine IP version from packet")
}
copy(tr.wBuf[4:], from)
n, err := tr.f.Write(tr.wBuf)
return n - 4, err
}
func (tr *tunReadCloser) Close() error {
return tr.f.Close()
}
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}
func (t *tun) Name() string {
return "iOS"
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}

View File

@@ -1,7 +1,7 @@
//go:build !android && !e2e_testing
// +build !android,!e2e_testing
package nebula
package overlay
import (
"fmt"
@@ -12,20 +12,22 @@ import (
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
type Tun struct {
type tun struct {
io.ReadWriteCloser
fd int
Device string
Cidr *net.IPNet
cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []route
UnsafeRoutes []route
Routes []Route
routeTree *cidr.Tree4
l *logrus.Logger
}
@@ -43,26 +45,6 @@ func ioctl(a1, a2, a3 uintptr) error {
return nil
}
/*
func ipv4(addr string) (o [4]byte, err error) {
ip := net.ParseIP(addr).To4()
if ip == nil {
err = fmt.Errorf("failed to parse addr %s", addr)
return
}
for i, b := range ip {
o[i] = b
}
return
}
*/
const (
cIFF_TUN = 0x0001
cIFF_NO_PI = 0x1000
cIFF_MULTI_QUEUE = 0x0100
)
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
@@ -81,34 +63,37 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, true)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
return &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "tun0",
Cidr: cidr,
cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
routeTree: routeTree,
l: l,
}
return
}, nil
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(cIFF_TUN | cIFF_NO_PI)
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
req.Flags |= cIFF_MULTI_QUEUE
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], deviceName)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
@@ -120,35 +105,43 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
maxMTU := defaultMTU
for _, r := range routes {
if r.mtu > maxMTU {
maxMTU = r.mtu
if r.MTU == 0 {
r.MTU = defaultMTU
}
if r.MTU > maxMTU {
maxMTU = r.MTU
}
}
ifce = &Tun{
routeTree, err := makeRouteTree(l, routes, true)
if err != nil {
return nil, err
}
return &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: name,
Cidr: cidr,
cidr: cidr,
MaxMTU: maxMTU,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
routeTree: routeTree,
l: l,
}
return
}, nil
}
func (c *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(cIFF_TUN | cIFF_NO_PI | cIFF_MULTI_QUEUE)
copy(req.Name[:], c.Device)
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
@@ -158,46 +151,52 @@ func (c *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return file, nil
}
func (c *Tun) WriteRaw(b []byte) error {
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
}
func (t *tun) Write(b []byte) (int, error) {
var nn int
for {
max := len(b)
n, err := unix.Write(c.fd, b[nn:max])
for {
n, err := unix.Write(t.fd, b[nn:max])
if n > 0 {
nn += n
}
if nn == len(b) {
return err
return nn, err
}
if err != nil {
return err
return nn, err
}
if n == 0 {
return io.ErrUnexpectedEOF
return nn, io.ErrUnexpectedEOF
}
}
}
func (c *Tun) Write(b []byte) (int, error) {
return len(b), c.WriteRaw(b)
}
func (c Tun) deviceBytes() (o [16]byte) {
for i, c := range c.Device {
func (t tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}
func (c Tun) Activate() error {
devName := c.deviceBytes()
func (t tun) Activate() error {
devName := t.deviceBytes()
var addr, mask [4]byte
copy(addr[:], c.Cidr.IP.To4())
copy(mask[:], c.Cidr.Mask)
copy(addr[:], t.cidr.IP.To4())
copy(mask[:], t.cidr.Mask)
s, err := unix.Socket(
unix.AF_INET,
@@ -235,17 +234,17 @@ func (c Tun) Activate() error {
}
// Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)}
ifm := ifreqMTU{Name: devName, MTU: int32(t.MaxMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
c.l.WithError(err).Error("Failed to set tun mtu")
t.l.WithError(err).Error("Failed to set tun mtu")
}
// Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)}
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
c.l.WithError(err).Error("Failed to set tun tx queue length")
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
// Bring up the interface
@@ -255,59 +254,46 @@ func (c Tun) Activate() error {
}
// Set the routes
link, err := netlink.LinkByName(c.Device)
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
// Default route
dr := &net.IPNet{IP: c.Cidr.IP.Mask(c.Cidr.Mask), Mask: c.Cidr.Mask}
dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
nr := netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: dr,
MTU: c.DefaultMTU,
AdvMSS: c.advMSS(route{}),
MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
Src: c.Cidr.IP,
Src: t.cidr.IP,
Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
err = netlink.RouteReplace(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", c.DefaultMTU, dr, err)
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
}
// Path routes
for _, r := range c.Routes {
for _, r := range t.Routes {
nr := netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: r.route,
MTU: r.mtu,
AdvMSS: c.advMSS(r),
Dst: r.Cidr,
MTU: r.MTU,
AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
}
if r.Metric > 0 {
nr.Priority = r.Metric
}
err = netlink.RouteAdd(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.mtu, r.route, err)
}
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
nr := netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: r.route,
MTU: r.mtu,
Priority: r.metric,
AdvMSS: c.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
}
err = netlink.RouteAdd(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.mtu, r.route, err)
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err)
}
}
@@ -320,22 +306,22 @@ func (c Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}
func (c *Tun) DeviceName() string {
return c.Device
func (t *tun) Name() string {
return t.Device
}
func (c Tun) advMSS(r route) int {
mtu := r.mtu
if r.mtu == 0 {
mtu = c.DefaultMTU
func (t tun) advMSS(r Route) int {
mtu := r.MTU
if r.MTU == 0 {
mtu = t.DefaultMTU
}
// We only need to set advmss if the route MTU does not match the device MTU
if mtu != c.MaxMTU {
if mtu != t.MaxMTU {
return mtu - 40
}
return 0

View File

@@ -1,25 +1,25 @@
//go:build !e2e_testing
// +build !e2e_testing
package nebula
package overlay
import "testing"
var runAdvMSSTests = []struct {
name string
tun Tun
r route
tun tun
r Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{}, 0},
{"default-min", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1440}, 0},
{"default-low", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1200}, 1160},
{"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
{"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
{"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{}, 1400},
{"route-min", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 1440}, 1400},
{"route-high", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 8941}, 0},
{"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
{"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
{"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {

117
overlay/tun_tester.go Normal file
View File

@@ -0,0 +1,117 @@
//go:build e2e_testing
// +build e2e_testing
package overlay
import (
"fmt"
"io"
"net"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
)
type TestTun struct {
Device string
cidr *net.IPNet
Routes []Route
routeTree *cidr.Tree4
l *logrus.Logger
rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
return &TestTun{
Device: deviceName,
cidr: cidr,
Routes: routes,
routeTree: routeTree,
l: l,
rxPackets: make(chan []byte, 1),
TxPackets: make(chan []byte, 1),
}, nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
// Send will place a byte array onto the receive queue for nebula to consume
// These are unencrypted ip layer frames destined for another nebula node.
// packets should exit the udp side, capture them with udpConn.Get
func (t *TestTun) Send(packet []byte) {
t.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
t.rxPackets <- packet
}
// Get will pull an unencrypted ip layer frame from the transmit queue
// nebula meant to send this message to some application on the local system
// packets were ingested from the udp side, you can send them with udpConn.Send
func (t *TestTun) Get(block bool) []byte {
if block {
return <-t.TxPackets
}
select {
case p := <-t.TxPackets:
return p
default:
return nil
}
}
//********************************************************************************************************************//
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
}
func (t *TestTun) Activate() error {
return nil
}
func (t *TestTun) Cidr() *net.IPNet {
return t.cidr
}
func (t *TestTun) Name() string {
return t.Device
}
func (t *TestTun) Write(b []byte) (n int, err error) {
packet := make([]byte, len(b), len(b))
copy(packet, b)
t.TxPackets <- packet
return len(b), nil
}
func (t *TestTun) Close() error {
close(t.rxPackets)
return nil
}
func (t *TestTun) Read(b []byte) (int, error) {
p := <-t.rxPackets
copy(b, p)
return len(p), nil
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}

View File

@@ -0,0 +1,126 @@
package overlay
import (
"fmt"
"io"
"net"
"os/exec"
"strconv"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
"github.com/songgao/water"
)
type waterTun struct {
Device string
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4
*water.Interface
}
func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
return &waterTun{
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
}, nil
}
func (t *waterTun) Activate() error {
var err error
t.Interface, err = water.New(water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901",
Network: t.cidr.String(),
},
})
if err != nil {
return fmt.Errorf("activate failed: %v", err)
}
t.Device = t.Interface.Name()
// TODO use syscalls instead of exec.Command
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", t.Device),
"source=static",
fmt.Sprintf("addr=%s", t.cidr.IP),
fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)),
"gateway=none",
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
}
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
t.Device,
fmt.Sprintf("mtu=%d", t.MTU),
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
iface, err := net.InterfaceByName(t.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
}
for _, r := range t.Routes {
if r.Via == nil {
// We don't allow route MTUs so only install routes with a via
continue
}
err = exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
return fmt.Errorf("failed to add the unsafe_route %s: %v", r.Cidr.String(), err)
}
}
return nil
}
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
}
func (t *waterTun) Cidr() *net.IPNet {
return t.cidr
}
func (t *waterTun) Name() string {
return t.Device
}
func (t *waterTun) Close() error {
if t.Interface == nil {
return nil
}
return t.Interface.Close()
}
func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

62
overlay/tun_windows.go Normal file
View File

@@ -0,0 +1,62 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"syscall"
"github.com/sirupsen/logrus"
)
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Windows")
}
useWintun := true
if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
useWintun = false
}
if useWintun {
device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes)
if err != nil {
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
}
return device, nil
}
device, err := newWaterTun(l, cidr, defaultMTU, routes)
if err != nil {
return nil, fmt.Errorf("create wintap driver failed, %w", err)
}
return device, nil
}
func checkWinTunExists() error {
myPath, err := os.Executable()
if err != nil {
return err
}
arch := runtime.GOARCH
switch arch {
case "386":
//NOTE: wintun bundles 386 as x86
arch = "x86"
}
_, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
return err
}

View File

@@ -1,4 +1,4 @@
package nebula
package overlay
import (
"crypto"
@@ -7,6 +7,9 @@ import (
"net"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -14,11 +17,12 @@ import (
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type WinTun struct {
type winTun struct {
Device string
Cidr *net.IPNet
cidr *net.IPNet
MTU int
UnsafeRoutes []route
Routes []Route
routeTree *cidr.Tree4
tun *wintun.NativeTun
}
@@ -42,51 +46,60 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, unsafeRoutes []route, txQueueLen int) (ifce *WinTun, err error) {
func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("Generate GUID failed: %w", err)
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
tunDevice, err := wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
if err != nil {
return nil, fmt.Errorf("Create TUN device failed: %w", err)
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
ifce = &WinTun{
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
return &winTun{
Device: deviceName,
Cidr: cidr,
cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
Routes: routes,
routeTree: routeTree,
tun: tunDevice.(*wintun.NativeTun),
}, nil
}
return ifce, nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
func (c *WinTun) Activate() error {
luid := winipcfg.LUID(c.tun.LUID())
if err := luid.SetIPAddresses([]net.IPNet{*c.Cidr}); err != nil {
if err := luid.SetIPAddresses([]net.IPNet{*t.cidr}); err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
foundDefault4 := false
routes := make([]*winipcfg.RouteData, 0, len(c.UnsafeRoutes)+1)
routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1)
for _, r := range t.Routes {
if r.Via == nil {
// We don't allow route MTUs so only install routes with a via
continue
}
for _, r := range c.UnsafeRoutes {
if !foundDefault4 {
if cidr, bits := r.route.Mask.Size(); cidr == 0 && bits != 0 {
if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
foundDefault4 = true
}
}
// Add our unsafe route
routes = append(routes, &winipcfg.RouteData{
Destination: *r.route,
NextHop: *r.via,
Metric: uint32(r.metric),
Destination: *r.Cidr,
NextHop: r.Via.ToIP(),
Metric: uint32(r.Metric),
})
}
@@ -99,7 +112,7 @@ func (c *WinTun) Activate() error {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(c.MTU)
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
@@ -112,35 +125,39 @@ func (c *WinTun) Activate() error {
return nil
}
func (c *WinTun) CidrNet() *net.IPNet {
return c.Cidr
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
func (c *WinTun) DeviceName() string {
return c.Device
return 0
}
func (c *WinTun) Read(b []byte) (int, error) {
return c.tun.Read(b, 0)
func (t *winTun) Cidr() *net.IPNet {
return t.cidr
}
func (c *WinTun) Write(b []byte) (int, error) {
return c.tun.Write(b, 0)
func (t *winTun) Name() string {
return t.Device
}
func (c *WinTun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (c *WinTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (c *WinTun) Close() error {
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(c.tun.LUID())
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
/* We don't support IPV6 yet
@@ -149,5 +166,5 @@ func (c *WinTun) Close() error {
*/
_ = luid.FlushDNS(windows.AF_INET)
return c.tun.Close()
return t.tun.Close()
}

View File

@@ -5,12 +5,12 @@ import (
"time"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func TestNewPunchyFromConfig(t *testing.T) {
l := util.NewTestLogger()
l := test.NewLogger()
c := config.NewC(l)
// Test defaults

View File

@@ -1,4 +1,4 @@
package util
package test
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package util
package test
import (
"io/ioutil"
@@ -7,7 +7,7 @@ import (
"github.com/sirupsen/logrus"
)
func NewTestLogger() *logrus.Logger {
func NewLogger() *logrus.Logger {
l := logrus.New()
v := os.Getenv("TEST_LOGS")

43
test/tun.go Normal file
View File

@@ -0,0 +1,43 @@
package test
import (
"errors"
"io"
"net"
"github.com/slackhq/nebula/iputil"
)
type NoopTun struct{}
func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
return 0
}
func (NoopTun) Activate() error {
return nil
}
func (NoopTun) Cidr() *net.IPNet {
return nil
}
func (NoopTun) Name() string {
return "noop"
}
func (NoopTun) Read([]byte) (int, error) {
return 0, nil
}
func (NoopTun) Write([]byte) (int, error) {
return 0, nil
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported")
}
func (NoopTun) Close() error {
return nil
}

View File

@@ -1,86 +0,0 @@
//go:build !e2e_testing
// +build !e2e_testing
package nebula
import (
"fmt"
"io"
"net"
"os"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
type Tun struct {
io.ReadWriteCloser
fd int
Device string
Cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []route
UnsafeRoutes []route
l *logrus.Logger
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "android",
Cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
l: l,
}
return
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
func (c *Tun) WriteRaw(b []byte) error {
var nn int
for {
max := len(b)
n, err := unix.Write(c.fd, b[nn:max])
if n > 0 {
nn += n
}
if nn == len(b) {
return err
}
if err != nil {
return err
}
if n == 0 {
return io.ErrUnexpectedEOF
}
}
}
func (c Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}

View File

@@ -1,107 +0,0 @@
//go:build !e2e_testing
// +build !e2e_testing
package nebula
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
"github.com/sirupsen/logrus"
)
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
io.ReadWriteCloser
}
func (c *Tun) Close() error {
if c.ReadWriteCloser != nil {
return c.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
}
if strings.HasPrefix(deviceName, "/dev/") {
deviceName = strings.TrimPrefix(deviceName, "/dev/")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
}
return &Tun{
Device: deviceName,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
}, nil
}
func (c *Tun) Activate() error {
var err error
c.ReadWriteCloser, err = os.OpenFile("/dev/"+c.Device, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("Activate failed: %v", err)
}
// TODO use syscalls instead of exec.Command
c.l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
c.l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
c.l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
c.l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
}
}
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}

View File

@@ -1,120 +0,0 @@
//go:build ios && !e2e_testing
// +build ios,!e2e_testing
package nebula
import (
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
"github.com/sirupsen/logrus"
)
type Tun struct {
io.ReadWriteCloser
Device string
Cidr *net.IPNet
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
ifce = &Tun{
Cidr: cidr,
Device: "iOS",
ReadWriteCloser: &tunReadCloser{f: file},
}
return
}
func (c *Tun) Activate() error {
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
type tunReadCloser struct {
f io.ReadWriteCloser
rMu sync.Mutex
rBuf []byte
wMu sync.Mutex
wBuf []byte
}
func (t *tunReadCloser) Read(to []byte) (int, error) {
t.rMu.Lock()
defer t.rMu.Unlock()
if cap(t.rBuf) < len(to)+4 {
t.rBuf = make([]byte, len(to)+4)
}
t.rBuf = t.rBuf[:len(to)+4]
n, err := t.f.Read(t.rBuf)
copy(to, t.rBuf[4:])
return n - 4, err
}
func (t *tunReadCloser) Write(from []byte) (int, error) {
if len(from) == 0 {
return 0, syscall.EIO
}
t.wMu.Lock()
defer t.wMu.Unlock()
if cap(t.wBuf) < len(from)+4 {
t.wBuf = make([]byte, len(from)+4)
}
t.wBuf = t.wBuf[:len(from)+4]
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
t.wBuf[3] = syscall.AF_INET
} else if ipVer == 6 {
t.wBuf[3] = syscall.AF_INET6
} else {
return 0, errors.New("unable to determine IP version from packet")
}
copy(t.wBuf[4:], from)
n, err := t.f.Write(t.wBuf)
return n - 4, err
}
func (t *tunReadCloser) Close() error {
return t.f.Close()
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}

View File

@@ -1,105 +0,0 @@
//go:build e2e_testing
// +build e2e_testing
package nebula
import (
"fmt"
"io"
"net"
"github.com/sirupsen/logrus"
)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
rxPackets chan []byte // Packets to receive into nebula
txPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, _ []route, unsafeRoutes []route, _ int, _ bool) (ifce *Tun, err error) {
return &Tun{
Device: deviceName,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
rxPackets: make(chan []byte, 1),
txPackets: make(chan []byte, 1),
}, nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []route, _ []route, _ int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
// Send will place a byte array onto the receive queue for nebula to consume
// These are unencrypted ip layer frames destined for another nebula node.
// packets should exit the udp side, capture them with udpConn.Get
func (c *Tun) Send(packet []byte) {
c.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
c.rxPackets <- packet
}
// Get will pull an unencrypted ip layer frame from the transmit queue
// nebula meant to send this message to some application on the local system
// packets were ingested from the udp side, you can send them with udpConn.Send
func (c *Tun) Get(block bool) []byte {
if block {
return <-c.txPackets
}
select {
case p := <-c.txPackets:
return p
default:
return nil
}
}
//********************************************************************************************************************//
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (c *Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) Write(b []byte) (n int, err error) {
return len(b), c.WriteRaw(b)
}
func (c *Tun) Close() error {
close(c.rxPackets)
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
packet := make([]byte, len(b), len(b))
copy(packet, b)
c.txPackets <- packet
return nil
}
func (c *Tun) Read(b []byte) (int, error) {
p := <-c.rxPackets
copy(b, p)
return len(p), nil
}
func (c *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}

View File

@@ -1,107 +0,0 @@
package nebula
import (
"fmt"
"io"
"net"
"os/exec"
"strconv"
"github.com/songgao/water"
)
type WindowsWaterTun struct {
Device string
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
*water.Interface
}
func newWindowsWaterTun(deviceName string, cidr *net.IPNet, defaultMTU int, unsafeRoutes []route, txQueueLen int) (ifce *WindowsWaterTun, err error) {
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
return &WindowsWaterTun{
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
}, nil
}
func (c *WindowsWaterTun) Activate() error {
var err error
c.Interface, err = water.New(water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901",
Network: c.Cidr.String(),
},
})
if err != nil {
return fmt.Errorf("Activate failed: %v", err)
}
c.Device = c.Interface.Name()
// TODO use syscalls instead of exec.Command
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", c.Device),
"source=static",
fmt.Sprintf("addr=%s", c.Cidr.IP),
fmt.Sprintf("mask=%s", net.IP(c.Cidr.Mask)),
"gateway=none",
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
}
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
c.Device,
fmt.Sprintf("mtu=%d", c.MTU),
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
iface, err := net.InterfaceByName(c.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", c.Device, err)
}
for _, r := range c.UnsafeRoutes {
err = exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.route.String(), r.via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.metric),
).Run()
if err != nil {
return fmt.Errorf("failed to add the unsafe_route %s: %v", r.route.String(), err)
}
}
return nil
}
func (c *WindowsWaterTun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *WindowsWaterTun) DeviceName() string {
return c.Device
}
func (c *WindowsWaterTun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
func (c *WindowsWaterTun) Close() error {
if c.Interface == nil {
return nil
}
return c.Interface.Close()
}
func (t *WindowsWaterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View File

@@ -1,74 +0,0 @@
//go:build !e2e_testing
// +build !e2e_testing
package nebula
import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"runtime"
"syscall"
"github.com/sirupsen/logrus"
)
type Tun struct {
Inside
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Windows")
}
useWintun := true
if err = checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
useWintun = false
}
var inside Inside
if useWintun {
inside, err = newWinTun(deviceName, cidr, defaultMTU, unsafeRoutes, txQueueLen)
if err != nil {
return nil, fmt.Errorf("Create Wintun interface failed, %w", err)
}
} else {
inside, err = newWindowsWaterTun(deviceName, cidr, defaultMTU, unsafeRoutes, txQueueLen)
if err != nil {
return nil, fmt.Errorf("Create wintap driver failed, %w", err)
}
}
return &Tun{
Inside: inside,
}, nil
}
func checkWinTunExists() error {
myPath, err := os.Executable()
if err != nil {
return err
}
arch := runtime.GOARCH
switch arch {
case "386":
//NOTE: wintun bundles 386 as x86
arch = "x86"
}
_, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
return err
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

39
util/error.go Normal file
View File

@@ -0,0 +1,39 @@
package util
import (
"errors"
"github.com/sirupsen/logrus"
)
type ContextualError struct {
RealError error
Fields map[string]interface{}
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}

View File

@@ -1,4 +1,4 @@
package nebula
package util
import (
"errors"
@@ -8,6 +8,8 @@ import (
"github.com/stretchr/testify/assert"
)
type m map[string]interface{}
type TestLogWriter struct {
Logs []string
}

View File

@@ -1,4 +1,5 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*

View File

@@ -1,4 +1,5 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*