Compare commits

..

6 Commits

Author SHA1 Message Date
Nate Brown
ee8e4d2017 Start of the changelog 2025-11-18 23:00:04 -06:00
Nate Brown
8d656fb890 Pull in v1.9.5-v1.9.7 CHANGELOG 2025-11-18 21:58:26 -06:00
Wade Simmons
27ea667aee add more tests around bits counters (#1441)
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-11-18 16:42:21 -06:00
Hal Martin
4df8bcb1f5 nebula-cert: support reading CA passphrase from env (#1421)
* nebula-cert: support reading CA passphrase from env

This patch extends the `nebula-cert` command to support reading
the CA passphrase from the environment variable `CA_PASSPHRASE`.

Currently `nebula-cert` depends in an interactive session to obtain
the CA passphrase. This presents a challenge for automation tools like
ansible. With this change, ansible can store the CA passphrase in a
vault and supply it to `nebula-cert` via the `CA_PASSPHRASE`
environment variable for non-interactive signing.

Signed-off-by: Hal Martin <1230969+halmartin@users.noreply.github.com>

* name the variable NEBULA_CA_PASSPHRASE

---------

Signed-off-by: Hal Martin <1230969+halmartin@users.noreply.github.com>
Co-authored-by: JackDoan <me@jackdoan.com>
2025-11-17 14:41:08 -06:00
Wade Simmons
36c890eaad populate default Build version if missing (#1386)
* populate default Build version if missing

Use the Go module information built into the binary if the Build var
wasn't set during the build.

This means if you install via a specific tag, you get:

    go install github.com/slackhq/nebula/cmd/nebula@v1.9.5

    $ nebula -version
    Version: 1.9.5

And if you install master, you get:

    go install github.com/slackhq/nebula/cmd/nebula@master

    $ nebula -version
    Version: 1.9.5-0.20250408154034-18279ed17b10

* also default in the library

* cleanup
2025-11-14 08:58:15 -05:00
dependabot[bot]
44001244f2 Bump github.com/gaissmai/bart from 0.25.0 to 0.26.0 (#1508)
* Bump github.com/gaissmai/bart from 0.25.0 to 0.26.0

Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.25.0 to 0.26.0.
- [Release notes](https://github.com/gaissmai/bart/releases)
- [Commits](https://github.com/gaissmai/bart/compare/v0.25.0...v0.26.0)

---
updated-dependencies:
- dependency-name: github.com/gaissmai/bart
  dependency-version: 0.26.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix tests

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Wade Simmons <wsimmons@slack-corp.com>
2025-11-13 13:16:48 -05:00
63 changed files with 437 additions and 5003 deletions

View File

@@ -7,12 +7,64 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [1.10.0] - ????
### Added
- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153)
- ASN.1 based v2 nebula certificates with support for ipv6 and multiple ip addresses.
Certificates now have a unified interface for external implementations. (#1212, #1216, #1345)
**TODO: External documentation link!**
- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
- Add ECMP support for `unsafe_routes`. (#1332)
### Changed
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
intended to target an `unsafe_routes` entry must explicitly declare it via the
`local_cidr` field. This is almost always the intended behavior. This flag is
deprecated and will be removed in a future release.
deprecated and will be removed in a future release. (#1373)
### Fixed
- Fix moving a udp address from one vpn address to another in the `static_host_map`
which could cause rapid re-handshaking with an incorrect remote. (#1259)
- Improve smoke tests in environments where the docker network is not the default. (#1347)
## [1.9.7] - 2025-10-10
### Security
- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
### Changed
- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1543)
## [1.9.6] - 2025-7-15
### Added
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
### Fixed
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
- Fix Windows freeze due to ICMP error handling (#1412)
- Fix relay migration panic (#1403)
## [1.9.5] - 2024-12-05
### Added
- Gracefully ignore v2 certificates. (#1282)
### Fixed
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
## [1.9.4] - 2024-09-09
@@ -671,7 +723,11 @@ created.)
- Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2

109
bits.go
View File

@@ -9,14 +9,13 @@ type Bits struct {
length uint64
current uint64
bits []bool
firstSeen bool
lostCounter metrics.Counter
dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter
}
func NewBits(bits uint64) *Bits {
return &Bits{
b := &Bits{
length: bits,
bits: make([]bool, bits, bits),
current: 0,
@@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits {
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
}
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
return b
}
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
if i > b.current {
return true
}
// If i is within the window, check if it's been set already. The first window will fail this check
if i > b.current-b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
// If i is within the window, check if it's been set already.
if i > b.current-b.length || i < b.length && b.current < b.length {
return !b.bits[i%b.length]
}
// Not within the window
l.Debugf("rejected a packet (top) %d %d delta %d\n", b.current, i, b.current-i)
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
}
return false
}
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through
if i > b.length && b.bits[i%b.length] == false {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
// The very first window can only be tracked as lost once we are on the 2nd window or greater
if b.bits[i%b.length] == false && i > b.length {
b.lostCounter.Inc(1)
}
b.bits[i%b.length] = true
@@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return true
}
// If i packet is greater than current but less than the maximum length of our bitmap,
// flip everything in between to false and move ahead.
if i > b.current && i < b.current+b.length {
// In between current and i need to be zero'd to allow those packets to come in later
for n := b.current + 1; n < i; n++ {
// If i is a jump, adjust the window, record lost, update current, and return true
if i > b.current {
lost := int64(0)
// Zero out the bits between the current and the new counter value, limited by the window size,
// since the window is shifting
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
}
b.bits[n%b.length] = false
}
b.bits[i%b.length] = true
b.current = i
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
// Only record any skipped packets as a result of the window moving further than the window length
// Any loss within the new window will be accounted for in future calls
lost += max(0, int64(i-b.current-b.length))
b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true
b.current = i
return true
}
// Allow for the 0 packet to come in within the first window
if i == 0 && b.firstSeen == false && b.current < b.length {
b.firstSeen = true
b.bits[i%b.length] = true
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
// If i is within the current window but below the current counter,
// Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window")
@@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false
}
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true
return true
}
// In all other cases, fail and don't change current.
@@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
}
return false
}
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@@ -15,48 +15,41 @@ func TestBits(t *testing.T) {
assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1))
u := b.Update(l, 1)
assert.True(t, u)
assert.True(t, b.Update(l, 1))
assert.EqualValues(t, 1, b.current)
g := []bool{false, true, false, false, false, false, false, false, false, false}
g := []bool{true, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two
assert.True(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.True(t, u)
assert.True(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
g = []bool{false, true, true, false, false, false, false, false, false, false}
g = []bool{true, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two again - it will fail
assert.False(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.False(t, u)
assert.False(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15))
u = b.Update(l, 15)
assert.True(t, u)
assert.True(t, b.Update(l, 15))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14))
u = b.Update(l, 14)
assert.True(t, u)
assert.True(t, b.Update(l, 14))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(l, 5))
u = b.Update(l, 5)
assert.False(t, u)
assert.False(t, b.Update(l, 5))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
@@ -69,10 +62,29 @@ func TestBits(t *testing.T) {
// Walk through a few windows in order
b = NewBits(10)
for i := uint64(0); i <= 100; i++ {
for i := uint64(1); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i)
}
assert.False(t, b.Check(l, 1), "Out of window check")
}
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
}
func TestBitsDupeCounter(t *testing.T) {
@@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
//tODO: make sure lostcounter doesn't increase in orderly increment
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
}
@@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
@@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) {
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
@@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) {
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 5))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 6))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) {
z := NewBits(10)
for n := 0; n < b.N; n++ {

View File

@@ -173,6 +173,8 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var passphrase []byte
if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
@@ -192,6 +194,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
}
}
}
var curve cert.Curve
var pub, rawPriv []byte

View File

@@ -171,6 +171,17 @@ func Test_ca(t *testing.T) {
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, ca(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// read encrypted key file and verify default params
rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb)

View File

@@ -5,10 +5,28 @@ import (
"fmt"
"io"
"os"
"runtime/debug"
"strings"
)
// A version string that can be set with
//
// -ldflags "-X main.Build=SOMEVERSION"
//
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
type helpError struct {
s string
}

View File

@@ -116,8 +116,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
// ask for a passphrase until we get one
var passphrase []byte
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
// ask for a passphrase until we get one
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
@@ -135,7 +137,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
}
}
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
if err != nil {
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)

View File

@@ -379,6 +379,15 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the wrong password
ob.Reset()
eb.Reset()
@@ -389,6 +398,17 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
// test with the wrong password in environment
ob.Reset()
eb.Reset()
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the user not entering a password
ob.Reset()
eb.Reset()

View File

@@ -4,6 +4,8 @@ import (
"flag"
"fmt"
"os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
@@ -18,6 +20,17 @@ import (
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() {
serviceFlag := flag.String("service", "", "Control the system service.")
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")

View File

@@ -3,10 +3,9 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
@@ -21,6 +20,17 @@ import (
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() {
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
@@ -61,10 +71,6 @@ func main() {
os.Exit(1)
}
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest {
ctrl.Start()
notifyReady(l)

View File

@@ -13,7 +13,7 @@ import (
"github.com/slackhq/nebula/noiseutil"
)
const ReplayWindow = 4096
const ReplayWindow = 1024
type ConnectionState struct {
eKey *NebulaCipherState

View File

@@ -425,9 +425,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error {
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache, now) {
if f.inConns(fp, h, caPool, localCache) {
return nil
}
@@ -476,7 +476,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// We always want to conntrack since it is a faster operation
f.addConn(fp, incoming, now)
f.addConn(fp, incoming)
return nil
}
@@ -505,7 +505,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool {
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@@ -517,7 +517,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// Purge every time we test
ep, has := conntrack.TimerWheel.Purge()
if has {
f.evict(ep, now)
f.evict(ep)
}
c, ok := conntrack.Conns[fp]
@@ -564,11 +564,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
switch fp.Protocol {
case firewall.ProtoTCP:
c.Expires = now.Add(f.TCPTimeout)
c.Expires = time.Now().Add(f.TCPTimeout)
case firewall.ProtoUDP:
c.Expires = now.Add(f.UDPTimeout)
c.Expires = time.Now().Add(f.UDPTimeout)
default:
c.Expires = now.Add(f.DefaultTimeout)
c.Expires = time.Now().Add(f.DefaultTimeout)
}
conntrack.Unlock()
@@ -580,7 +580,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
return true
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
var timeout time.Duration
c := &conn{}
@@ -596,7 +596,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Advance(now)
conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Add(fp, timeout)
}
@@ -604,14 +604,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
// firewall reload
c.incoming = incoming
c.rulesVersion = f.rulesVersion
c.Expires = now.Add(timeout)
c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c
conntrack.Unlock()
}
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet, now time.Time) {
func (f *Firewall) evict(p firewall.Packet) {
// Are we still tracking this conn?
conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
@@ -619,11 +619,11 @@ func (f *Firewall) evict(p firewall.Packet, now time.Time) {
return
}
newT := t.Expires.Sub(now)
newT := t.Expires.Sub(time.Now())
// Timeout is in the future, re-add the timer
if newT > 0 {
conntrack.TimerWheel.Advance(now)
conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Add(p, newT)
return
}

View File

@@ -106,13 +106,13 @@ func TestFirewall_AddRule(t *testing.T) {
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)

4
go.mod
View File

@@ -8,7 +8,7 @@ require (
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.25.0
github.com/gaissmai/bart v0.26.0
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
@@ -50,6 +50,6 @@ require (
github.com/vishvananda/netns v0.0.5 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/time v0.7.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.33.0 // indirect
)

8
go.sum
View File

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -217,8 +217,8 @@ golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=

View File

@@ -2,18 +2,16 @@ package nebula
import (
"net/netip"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
@@ -55,7 +53,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
})
if hostinfo == nil {
f.rejectInside(packet, out.Payload, q) //todo vector?
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
@@ -68,11 +66,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now)
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else {
f.rejectInside(packet, out.Payload, q) //todo vector?
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
@@ -219,7 +218,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now())
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).
@@ -411,81 +410,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
}
}
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
if ci.eKey == nil {
return
}
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out.Payload
if useRelay {
if len(out.Payload) < header.Len {
// out always has a capacity of mtu, but not always a length greater than the header.Len.
// Grow it to make sure the next operation works.
out.Payload = out.Payload[:header.Len]
}
// Save a header's worth of data at the front of the 'out' buffer.
out.Payload = out.Payload[header.Len:]
}
if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our addrs and enable a faster roaming.
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
var err error
out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return
}
if remote.IsValid() {
err = f.writers[q].Prep(out, remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].Prep(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else {
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
continue
}
//todo vector!!
f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true)
break
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"runtime"
@@ -17,12 +18,10 @@ import (
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp"
)
const mtu = 9001
const batch = 1024 //todo config!
type InterfaceConfig struct {
HostMap *HostMap
@@ -87,18 +86,12 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []udp.Conn
readers []overlay.TunDev
readers []io.ReadWriteCloser
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
listenInN int
listenOutN int
listenInMetric metrics.Histogram
listenOutMetric metrics.Histogram
l *logrus.Logger
}
@@ -184,7 +177,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]overlay.TunDev, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
@@ -203,8 +196,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
ifce.listenInMetric = metrics.GetOrRegisterHistogram("vhost.listenIn.n", nil, metrics.NewExpDecaySample(1028, 0.015))
ifce.listenOutMetric = metrics.GetOrRegisterHistogram("vhost.listenOut.n", nil, metrics.NewExpDecaySample(1028, 0.015))
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
@@ -234,7 +225,7 @@ func (f *Interface) activate() {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
var reader overlay.TunDev = f.inside
var reader io.ReadWriteCloser = f.inside
for i := 0; i < f.routines; i++ {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
@@ -263,71 +254,40 @@ func (f *Interface) run() {
}
}
func (f *Interface) listenOut(q int) {
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if q > 0 {
li = f.writers[q]
if i > 0 {
li = f.writers[i]
} else {
li = f.outside
}
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
outPackets := make([]*packet.OutPacket, batch)
for i := 0; i < batch; i++ {
outPackets[i] = packet.NewOut()
}
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
toSend := make([][]byte, batch)
li.ListenOut(func(pkts []*packet.Packet) {
toSend = toSend[:0]
for i := range outPackets {
outPackets[i].Valid = false
outPackets[i].SegCounter = 0
}
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
//we opportunistically tx, but try to also send stragglers
if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
f.l.WithError(err).Error("Failed to send packets")
}
//todo I broke this
//n := len(toSend)
//if f.l.Level == logrus.DebugLevel {
// f.listenOutMetric.Update(int64(n))
//}
//f.listenOutN = n
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}
func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
packets := make([]*packet.VirtIOPacket, batch)
outPackets := make([]*packet.Packet, batch)
for i := 0; i < batch; i++ {
packets[i] = packet.NewVIO()
outPackets[i] = packet.New(false) //todo?
}
for {
n, err := reader.ReadMany(packets, queueNum)
//todo!!
n, err := reader.Read(packet)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
@@ -338,22 +298,7 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
os.Exit(2)
}
if f.l.Level == logrus.DebugLevel {
f.listenInMetric.Update(int64(n))
}
f.listenInN = n
now := time.Now()
for i, pkt := range packets[:n] {
outPackets[i].OutLen = -1
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
pkt.Reset()
}
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
if err != nil {
f.l.WithError(err).Error("Error while writing outbound packets")
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
}
@@ -491,11 +436,6 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
} else {
certMaxVersion.Update(int64(certState.v1Cert.Version()))
}
if f.l.Level != logrus.DebugLevel {
f.listenInMetric.Update(int64(f.listenInN))
f.listenOutMetric.Update(int64(f.listenOutN))
}
}
}
}

21
main.go
View File

@@ -5,6 +5,8 @@ import (
"fmt"
"net"
"net/netip"
"runtime/debug"
"strings"
"time"
"github.com/sirupsen/logrus"
@@ -27,6 +29,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}()
if buildVersion == "" {
buildVersion = moduleVersion()
}
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
@@ -296,3 +302,18 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
connManager.Start,
}, nil
}
func moduleVersion() string {
info, ok := debug.ReadBuildInfo()
if !ok {
return ""
}
for _, dep := range info.Deps {
if dep.Path == "github.com/slackhq/nebula" {
return strings.TrimPrefix(dep.Version, "v")
}
}
return ""
}

View File

@@ -7,7 +7,6 @@ import (
"time"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/packet"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus"
@@ -20,7 +19,7 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -62,7 +61,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
return
}
case header.MessageRelay:
@@ -97,7 +96,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -217,217 +216,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.connectionManager.In(hostinfo)
}
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
for i, pkt := range packets {
out[i].Scratch = out[i].Scratch[:0]
ip := pkt.AddrPort()
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
return
}
}
//todo per-segment!
for segment := range pkt.Segments() {
err := h.Parse(segment)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(segment) > 1 {
f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err)
}
return
}
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip, h) {
return
}
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
return
}
case header.MessageRelay:
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return
}
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return
}
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return
}
}
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt lighthouse packet")
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt test packet")
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(ip, nil, segment, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(ip, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt Control packet")
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return
}
f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo)
}
_, err := f.readers[q].WriteOne(out[i], false, q)
if err != nil {
f.l.WithError(err).Error("Failed to write packet")
}
}
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
final := f.hostMap.DeleteHostInfo(hostInfo)
@@ -677,55 +465,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil
}
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
var err error
seg, err := f.readers[q].AllocSeg(out, q)
if err != nil {
f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment")
return false
}
out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0]
out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
return false
}
err = newPacket(out.SegmentPayloads[seg], true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
}
return false
}
f.connectionManager.In(hostinfo)
pkt.OutLen += len(inSegment)
out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])]
return true
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
@@ -747,7 +487,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in

View File

@@ -1,16 +1,17 @@
package overlay
import (
"io"
"net/netip"
"github.com/slackhq/nebula/routing"
)
type Device interface {
TunDev
io.ReadWriteCloser
Activate() error
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (TunDev, error)
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

View File

@@ -1,91 +0,0 @@
package eventfd
import (
"encoding/binary"
"syscall"
"golang.org/x/sys/unix"
)
type EventFD struct {
fd int
buf [8]byte
}
func New() (EventFD, error) {
fd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
if err != nil {
return EventFD{}, err
}
return EventFD{
fd: fd,
buf: [8]byte{},
}, nil
}
func (e *EventFD) Kick() error {
binary.LittleEndian.PutUint64(e.buf[:], 1) //is this right???
_, err := syscall.Write(int(e.fd), e.buf[:])
return err
}
func (e *EventFD) Close() error {
if e.fd != 0 {
return unix.Close(e.fd)
}
return nil
}
func (e *EventFD) FD() int {
return e.fd
}
type Epoll struct {
fd int
buf [8]byte
events []syscall.EpollEvent
}
func NewEpoll() (Epoll, error) {
fd, err := unix.EpollCreate1(0)
if err != nil {
return Epoll{}, err
}
return Epoll{
fd: fd,
buf: [8]byte{},
events: make([]syscall.EpollEvent, 1),
}, nil
}
func (ep *Epoll) AddEvent(fdToAdd int) error {
event := syscall.EpollEvent{
Events: syscall.EPOLLIN,
Fd: int32(fdToAdd),
}
return syscall.EpollCtl(ep.fd, syscall.EPOLL_CTL_ADD, fdToAdd, &event)
}
func (ep *Epoll) Block() (int, error) {
n, err := syscall.EpollWait(ep.fd, ep.events, -1)
if err != nil {
//goland:noinspection GoDirectComparisonOfErrors
if err == syscall.EINTR {
return 0, nil //??
}
return -1, err
}
return n, nil
}
func (ep *Epoll) Clear() error {
_, err := syscall.Read(int(ep.events[0].Fd), ep.buf[:])
return err
}
func (ep *Epoll) Close() error {
if ep.fd != 0 {
return unix.Close(ep.fd)
}
return nil
}

View File

@@ -2,29 +2,16 @@ package overlay
import (
"fmt"
"io"
"net"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util"
)
const DefaultMTU = 1300
type TunDev interface {
io.WriteCloser
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
//todo this interface sux
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
WriteMany(x []*packet.OutPacket, q int) (int, error)
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
}
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
@@ -39,11 +26,11 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
}
}
//func NewFdDeviceFromConfig(fd *int) DeviceFactory {
// return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
// return newTunFromFd(c, l, *fd, vpnNetworks)
// }
//}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
}
}
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {

View File

@@ -9,8 +9,6 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing"
)
@@ -24,10 +22,6 @@ type disabledTun struct {
l *logrus.Logger
}
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return nil
}
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{
vpnNetworks: vpnNetworks,
@@ -46,10 +40,6 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
return tun
}
func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
return nil
}
func (*disabledTun) Activate() error {
return nil
}
@@ -115,23 +105,7 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
}
func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
}
func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
}
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return t.Read(b[0].Payload)
}
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil
}

View File

@@ -5,6 +5,7 @@ package overlay
import (
"fmt"
"io"
"net"
"net/netip"
"os"
@@ -16,19 +17,15 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/vhostnet"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/util/virtio"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
type tun struct {
file *os.File
io.ReadWriteCloser
fd int
vdev []*vhostnet.Device
Device string
vpnNetworks []netip.Prefix
MaxMTU int
@@ -43,7 +40,6 @@ type tun struct {
useSystemRoutes bool
useSystemRoutesBufferSize int
isV6 bool
l *logrus.Logger
}
@@ -106,7 +102,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI)
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
@@ -116,47 +112,20 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
}
name := strings.Trim(string(req.Name[:]), "\x00")
if err = unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
if err != nil {
return nil, fmt.Errorf("set vnethdr size: %w", err)
}
flags := 0
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
if err != nil {
return nil, fmt.Errorf("set offloads: %w", err)
}
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
t.fd = fd
t.Device = name
vdev, err := vhostnet.NewDevice(
vhostnet.WithBackendFD(fd),
vhostnet.WithQueueSize(8192), //todo config
)
if err != nil {
return nil, err
}
t.vdev = []*vhostnet.Device{vdev}
t.Device = name
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{
file: file,
ReadWriteCloser: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
@@ -164,9 +133,6 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
l: l,
}
if len(vpnNetworks) != 0 {
t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP?
}
err := t.reload(c, true)
if err != nil {
@@ -250,7 +216,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) NewMultiQueueReader() (TunDev, error) {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
@@ -263,17 +229,9 @@ func (t *tun) NewMultiQueueReader() (TunDev, error) {
return nil, err
}
vdev, err := vhostnet.NewDevice(
vhostnet.WithBackendFD(fd),
vhostnet.WithQueueSize(8192), //todo config
)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t.vdev = append(t.vdev, vdev)
return t, nil
return file, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -281,6 +239,29 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r
}
func (t *tun) Write(b []byte) (int, error) {
var nn int
maximum := len(b)
for {
n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 {
nn += n
}
if nn == len(b) {
return nn, err
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, io.ErrUnexpectedEOF
}
}
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
@@ -693,14 +674,8 @@ func (t *tun) Close() error {
close(t.routeChan)
}
for _, v := range t.vdev {
if v != nil {
_ = v.Close()
}
}
if t.file != nil {
_ = t.file.Close()
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
}
if t.ioctlFd > 0 {
@@ -709,65 +684,3 @@ func (t *tun) Close() error {
return nil
}
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
if err != nil {
return 0, err
}
return n, nil
}
func (t *tun) Write(b []byte) (int, error) {
maximum := len(b) //we are RXing
//todo garbagey
out := packet.NewOut()
x, err := t.AllocSeg(out, 0)
if err != nil {
return 0, err
}
copy(out.SegmentPayloads[x], b)
err = t.vdev[0].TransmitPacket(out, true)
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return maximum, nil
}
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
idx, buf, err := t.vdev[q].GetPacketForTx()
if err != nil {
return 0, err
}
x := pkt.UseSegment(idx, buf, t.isV6)
return x, nil
}
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return 1, nil
}
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
maximum := len(x) //we are RXing
if maximum == 0 {
return 0, nil
}
err := t.vdev[q].TransmitPackets(x)
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return maximum, nil
}
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
}

View File

@@ -1,13 +1,11 @@
package overlay
import (
"fmt"
"io"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing"
)
@@ -38,10 +36,6 @@ type UserDevice struct {
inboundWriter *io.PipeWriter
}
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return nil
}
func (d *UserDevice) Activate() error {
return nil
}
@@ -52,7 +46,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}
@@ -71,19 +65,3 @@ func (d *UserDevice) Close() error {
d.outboundWriter.Close()
return nil
}
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return d.Read(b[0].Payload)
}
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("user: AllocSeg not implemented")
}
func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
return 0, fmt.Errorf("user: WriteOne not implemented")
}
func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("user: WriteMany not implemented")
}

View File

@@ -1,23 +0,0 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,4 +0,0 @@
// Package vhost implements the basic ioctl requests needed to interact with the
// kernel-level virtio server that provides accelerated virtio devices for
// networking and more.
package vhost

View File

@@ -1,218 +0,0 @@
package vhost
import (
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
const (
// vhostIoctlGetFeatures can be used to retrieve the features supported by
// the vhost implementation in the kernel.
//
// Response payload: [virtio.Feature]
// Kernel name: VHOST_GET_FEATURES
vhostIoctlGetFeatures = 0x8008af00
// vhostIoctlSetFeatures can be used to communicate the features supported
// by this virtio implementation to the kernel.
//
// Request payload: [virtio.Feature]
// Kernel name: VHOST_SET_FEATURES
vhostIoctlSetFeatures = 0x4008af00
// vhostIoctlSetOwner can be used to set the current process as the
// exclusive owner of a control file descriptor.
//
// Request payload: none
// Kernel name: VHOST_SET_OWNER
vhostIoctlSetOwner = 0x0000af01
// vhostIoctlSetMemoryLayout can be used to set up or modify the memory
// layout which describes the IOTLB mappings in the kernel.
//
// Request payload: [MemoryLayout] with custom serialization
// Kernel name: VHOST_SET_MEM_TABLE
vhostIoctlSetMemoryLayout = 0x4008af03
// vhostIoctlSetQueueSize can be used to set the size of the virtqueue.
//
// Request payload: [QueueState]
// Kernel name: VHOST_SET_VRING_NUM
vhostIoctlSetQueueSize = 0x4008af10
// vhostIoctlSetQueueAddress can be used to set the addresses of the
// different parts of the virtqueue.
//
// Request payload: [QueueAddresses]
// Kernel name: VHOST_SET_VRING_ADDR
vhostIoctlSetQueueAddress = 0x4028af11
// vhostIoctlSetAvailableRingBase can be used to set the index of the next
// available ring entry the device will process.
//
// Request payload: [QueueState]
// Kernel name: VHOST_SET_VRING_BASE
vhostIoctlSetAvailableRingBase = 0x4008af12
// vhostIoctlSetQueueKickEventFD can be used to set the event file
// descriptor to signal the device when descriptor chains were added to the
// available ring.
//
// Request payload: [QueueFile]
// Kernel name: VHOST_SET_VRING_KICK
vhostIoctlSetQueueKickEventFD = 0x4008af20
// vhostIoctlSetQueueCallEventFD can be used to set the event file
// descriptor that gets signaled by the device when descriptor chains have
// been used by it.
//
// Request payload: [QueueFile]
// Kernel name: VHOST_SET_VRING_CALL
vhostIoctlSetQueueCallEventFD = 0x4008af21
)
// QueueState is an ioctl request payload that can hold a queue index and any
// 32-bit number.
//
// Kernel name: vhost_vring_state
type QueueState struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// Num is any 32-bit number, depending on the request.
Num uint32
}
// QueueAddresses is an ioctl request payload that can hold the addresses of the
// different parts of a virtqueue.
//
// Kernel name: vhost_vring_addr
type QueueAddresses struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// Flags that are not used in this implementation.
Flags uint32
// DescriptorTableAddress is the address of the descriptor table in user
// space memory. It must be 16-byte aligned.
DescriptorTableAddress uintptr
// UsedRingAddress is the address of the used ring in user space memory. It
// must be 4-byte aligned.
UsedRingAddress uintptr
// AvailableRingAddress is the address of the available ring in user space
// memory. It must be 2-byte aligned.
AvailableRingAddress uintptr
// LogAddress is used for an optional logging support, not supported by this
// implementation.
LogAddress uintptr
}
// QueueFile is an ioctl request payload that can hold a queue index and a file
// descriptor.
//
// Kernel name: vhost_vring_file
type QueueFile struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// FD is the file descriptor of the file. Pass -1 to unbind from a file.
FD int32
}
// IoctlPtr is a copy of the similarly named unexported function from the Go
// unix package. This is needed to do custom ioctl requests not supported by the
// standard library.
func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error {
_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
if err != 0 {
return fmt.Errorf("ioctl request %d: %w", req, err)
}
return nil
}
// GetFeatures requests the supported feature bits from the virtio device
// associated with the given control file descriptor.
func GetFeatures(controlFD int) (virtio.Feature, error) {
var features virtio.Feature
if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil {
return 0, fmt.Errorf("get features: %w", err)
}
return features, nil
}
// SetFeatures communicates the feature bits supported by this implementation
// to the virtio device associated with the given control file descriptor.
func SetFeatures(controlFD int, features virtio.Feature) error {
if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil {
return fmt.Errorf("set features: %w", err)
}
return nil
}
// OwnControlFD sets the current process as the exclusive owner for the
// given control file descriptor. This must be called before interacting with
// the control file descriptor in any other way.
func OwnControlFD(controlFD int) error {
if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil {
return fmt.Errorf("set control file descriptor owner: %w", err)
}
return nil
}
// SetMemoryLayout sets up or modifies the memory layout for the kernel-level
// virtio device associated with the given control file descriptor.
func SetMemoryLayout(controlFD int, layout MemoryLayout) error {
payload := layout.serializePayload()
if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil {
return fmt.Errorf("set memory layout: %w", err)
}
return nil
}
// RegisterQueue registers a virtio queue with the kernel-level virtio server.
// The virtqueue will be linked to the given control file descriptor and will
// have the given index. The kernel will use this queue until the control file
// descriptor is closed.
func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error {
if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{
QueueIndex: queueIndex,
Num: uint32(queue.Size()),
})); err != nil {
return fmt.Errorf("set queue size: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{
QueueIndex: queueIndex,
Flags: 0,
DescriptorTableAddress: queue.DescriptorTable().Address(),
UsedRingAddress: queue.UsedRing().Address(),
AvailableRingAddress: queue.AvailableRing().Address(),
LogAddress: 0,
})); err != nil {
return fmt.Errorf("set queue addresses: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{
QueueIndex: queueIndex,
Num: 0,
})); err != nil {
return fmt.Errorf("set available ring base: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{
QueueIndex: queueIndex,
FD: int32(queue.KickEventFD()),
})); err != nil {
return fmt.Errorf("set kick event file descriptor: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{
QueueIndex: queueIndex,
FD: int32(queue.CallEventFD()),
})); err != nil {
return fmt.Errorf("set call event file descriptor: %w", err)
}
return nil
}

View File

@@ -1,21 +0,0 @@
package vhost_test
import (
"testing"
"unsafe"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/stretchr/testify/assert"
)
func TestQueueState_Size(t *testing.T) {
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{}))
}
func TestQueueAddresses_Size(t *testing.T) {
assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{}))
}
func TestQueueFile_Size(t *testing.T) {
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{}))
}

View File

@@ -1,73 +0,0 @@
package vhost
import (
"encoding/binary"
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/virtqueue"
)
// MemoryRegion describes a region of userspace memory which is being made
// accessible to a vhost device.
//
// Kernel name: vhost_memory_region
type MemoryRegion struct {
// GuestPhysicalAddress is the physical address of the memory region within
// the guest, when virtualization is used. When no virtualization is used,
// this should be the same as UserspaceAddress.
GuestPhysicalAddress uintptr
// Size is the size of the memory region.
Size uint64
// UserspaceAddress is the virtual address in the userspace of the host
// where the memory region can be found.
UserspaceAddress uintptr
// Padding and room for flags. Currently unused.
_ uint64
}
// MemoryLayout is a list of [MemoryRegion]s.
type MemoryLayout []MemoryRegion
// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
// memory pages used by the descriptor tables of the given queues.
func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
regions := make([]MemoryRegion, 0)
for _, queue := range queues {
for address, size := range queue.DescriptorTable().BufferAddresses() {
regions = append(regions, MemoryRegion{
// There is no virtualization in play here, so the guest address
// is the same as in the host's userspace.
GuestPhysicalAddress: address,
Size: uint64(size),
UserspaceAddress: address,
})
}
}
return regions
}
// serializePayload serializes the list of memory regions into a format that is
// compatible to the vhost_memory kernel struct. The returned byte slice can be
// used as a payload for the vhostIoctlSetMemoryLayout ioctl.
func (regions MemoryLayout) serializePayload() []byte {
regionCount := len(regions)
regionSize := int(unsafe.Sizeof(MemoryRegion{}))
payload := make([]byte, 8+regionCount*regionSize)
// The first 32 bits contain the number of memory regions. The following 32
// bits are padding.
binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount))
if regionCount > 0 {
// The underlying byte array of the slice should already have the correct
// format, so just copy that.
copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(&regions[0])), regionCount*regionSize))
if copied != regionCount*regionSize {
panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d",
copied, regionCount*regionSize))
}
}
return payload
}

View File

@@ -1,42 +0,0 @@
package vhost
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestMemoryRegion_Size(t *testing.T) {
assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{}))
}
func TestMemoryLayout_SerializePayload(t *testing.T) {
layout := MemoryLayout([]MemoryRegion{
{
GuestPhysicalAddress: 42,
Size: 100,
UserspaceAddress: 142,
}, {
GuestPhysicalAddress: 99,
Size: 100,
UserspaceAddress: 99,
},
})
payload := layout.serializePayload()
assert.Equal(t, []byte{
0x02, 0x00, 0x00, 0x00, // nregions
0x00, 0x00, 0x00, 0x00, // padding
// region 0
0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
// region 1
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
}, payload)
}

View File

@@ -1,23 +0,0 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,427 +0,0 @@
package vhostnet
import (
"context"
"errors"
"fmt"
"os"
"runtime"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
// ErrDeviceClosed is returned when the [Device] is closed while operations are
// still running.
var ErrDeviceClosed = errors.New("device was closed")
// The indexes for the receive and transmit queues.
const (
receiveQueueIndex = 0
transmitQueueIndex = 1
)
// Device represents a vhost networking device within the kernel-level virtio
// implementation and provides methods to interact with it.
type Device struct {
initialized bool
controlFD int
fullTable bool
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
}
// NewDevice initializes a new vhost networking device within the
// kernel-level virtio implementation, sets up the virtqueues and returns a
// [Device] instance that can be used to communicate with that vhost device.
//
// There are multiple options that can be passed to this constructor to
// influence device creation:
// - [WithQueueSize]
// - [WithBackendFD]
// - [WithBackendDevice]
//
// Remember to call [Device.Close] after use to free up resources.
func NewDevice(options ...Option) (*Device, error) {
var err error
opts := optionDefaults
opts.apply(options)
if err = opts.validate(); err != nil {
return nil, fmt.Errorf("invalid options: %w", err)
}
dev := Device{
controlFD: -1,
}
// Clean up a partially initialized device when something fails.
defer func() {
if err != nil {
_ = dev.Close()
}
}()
// Retrieve a new control file descriptor. This will be used to configure
// the vhost networking device in the kernel.
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
if err != nil {
return nil, fmt.Errorf("get control file descriptor: %w", err)
}
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
return nil, fmt.Errorf("own control file descriptor: %w", err)
}
// Advertise the supported features. This isn't much for now.
// TODO: Add feature options and implement proper feature negotiation.
getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why
if err != nil {
return nil, fmt.Errorf("get features: %w", err)
}
if getFeatures == 0 {
}
//const funky = virtio.Feature(1 << 27)
//features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers
features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
return nil, fmt.Errorf("set features: %w", err)
}
itemSize := os.Getpagesize() * 4 //todo config
// Initialize and register the queues needed for the networking device.
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create receive queue: %w", err)
}
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create transmit queue: %w", err)
}
// Set up memory mappings for all buffers used by the queues. This has to
// happen before a backend for the queues can be registered.
memoryLayout := vhost.NewMemoryLayoutForQueues(
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
)
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
return nil, fmt.Errorf("setup memory layout: %w", err)
}
// Set the queue backends. This activates the queues within the kernel.
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set receive queue backend: %w", err)
}
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set transmit queue backend: %w", err)
}
// Fully populate the receive queue with available buffers which the device
// can write new packets into.
if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
if err = dev.refillTransmitQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
dev.initialized = true
// Make sure to clean up even when the device gets garbage collected without
// Close being called first.
devPtr := &dev
runtime.SetFinalizer(devPtr, (*Device).Close)
return devPtr, nil
}
// refillReceiveQueue offers as many new device-writable buffers to the device
// as the queue can fit. The device will then use these to write received
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
return nil
}
return fmt.Errorf("offer descriptor chain: %w", err)
}
}
}
func (dev *Device) refillTransmitQueue() error {
//for {
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
// if err != nil {
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// // Queue is full, job is done.
// return nil
// }
// return fmt.Errorf("offer descriptor chain: %w", err)
// } else {
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
// }
//}
return nil
}
// Close cleans up the vhost networking device within the kernel and releases
// all resources used for it.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.
func (dev *Device) Close() error {
dev.initialized = false
// Closing the control file descriptor will unregister all queues from the
// kernel.
if dev.controlFD >= 0 {
if err := unix.Close(dev.controlFD); err != nil {
// Return an error and do not continue, because the memory used for
// the queues should not be released before they were unregistered
// from the kernel.
return fmt.Errorf("close control file descriptor: %w", err)
}
dev.controlFD = -1
}
var errs []error
if dev.ReceiveQueue != nil {
if err := dev.ReceiveQueue.Close(); err == nil {
dev.ReceiveQueue = nil
} else {
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
}
}
if dev.TransmitQueue != nil {
if err := dev.TransmitQueue.Close(); err == nil {
dev.TransmitQueue = nil
} else {
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
}
}
if len(errs) == 0 {
// Everything was cleaned up. No need to run the finalizer anymore.
runtime.SetFinalizer(dev, nil)
}
return errors.Join(errs...)
}
// ensureInitialized is used as a guard to prevent methods to be called on an
// uninitialized instance.
func (dev *Device) ensureInitialized() {
if !dev.initialized {
panic("device is not initialized")
}
}
// createQueue creates a new virtqueue and registers it with the vhost device
// using the given index.
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
var (
queue *virtqueue.SplitQueue
err error
)
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create virtqueue: %w", err)
}
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
}
return queue, nil
}
// truncateBuffers returns a new list of buffers whose combined length matches
// exactly the specified length. When the specified length exceeds the length of
// the buffers, this is an error. When it is smaller, the buffer list will be
// truncated accordingly.
func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
for _, buffer := range buffers {
if length < len(buffer) {
out = append(out, buffer[:length])
return
}
out = append(out, buffer)
length -= len(buffer)
}
if length > 0 {
panic("length exceeds the combined length of all buffers")
}
return
}
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
var err error
var idx uint16
if !dev.fullTable {
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
if err == virtqueue.ErrNotEnoughFreeDescriptors {
dev.fullTable = true
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
} else {
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
if err != nil {
return 0, nil, fmt.Errorf("transmit queue: %w", err)
}
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
if err != nil {
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
}
return idx, buf, nil
}
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
if len(pkt.SegmentIDs) == 0 {
return nil
}
for idx := range pkt.SegmentIDs {
segmentID := pkt.SegmentIDs[idx]
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
}
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
if err != nil {
return fmt.Errorf("offer descriptor chains: %w", err)
}
pkt.Reset()
if kick {
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
}
return nil
}
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
if len(pkts) == 0 {
return nil
}
for i := range pkts {
if err := dev.TransmitPacket(pkts[i], false); err != nil {
return err
}
}
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
return nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?
// TODO: Implement zero-copy variants to transmit and receive packets?
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need:
pkt.Reset()
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
if err != nil {
return 0, fmt.Errorf("get descriptor chain: %w", err)
}
if len(pkt.ChainRefs) == 0 {
return 1, nil
}
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
// required to be fully contained in the first buffer of that
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return 0, fmt.Errorf("decode vnethdr: %w", err)
}
//we have the header now: what do we need to do?
if int(pkt.Header.NumBuffers) > len(chains) {
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
}
if int(pkt.Header.NumBuffers) != 1 {
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
}
if chains[0].Length > 16000 {
//todo!
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
}
//shift the buffer out of out:
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
return 1, nil
//cursor := n - virtio.NetHdrSize
//
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
// return 1, nil
//}
//
//i := 1
//// we used chain 0 already
//for i = 1; i < len(chains); i++ {
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
// if err != nil {
// // When this fails we may miss to free some descriptor chains. We
// // could try to mitigate this by deferring the freeing somehow, but
// // it's not worth the hassle. When this method fails, the queue will
// // be in a broken state anyway.
// return i, fmt.Errorf("get descriptor chain: %w", err)
// }
// cursor += n
//}
////todo this has to be wrong
//pkt.Payload = pkt.Payload[:cursor]
//return i, nil
}
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
//todo optimize?
var chains []virtqueue.UsedElement
var err error
//if len(dev.extraRx) == 0 {
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
if err != nil {
return 0, err
}
if len(chains) == 0 {
return 0, nil
}
//} else {
// chains = dev.extraRx
//}
numPackets := 0
chainsIdx := 0
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
if numPackets >= len(out) {
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
}
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
if err != nil {
return 0, err
}
chainsIdx += numChains
}
// Now that we have copied all buffers, we can recycle the used descriptor chains
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
// return 0, err
//}
return numPackets, nil
}

View File

@@ -1,3 +0,0 @@
// Package vhostnet implements methods to initialize vhost networking devices
// within the kernel-level virtio implementation and communicate with them.
package vhostnet

View File

@@ -1,31 +0,0 @@
package vhostnet
import (
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/vhost"
)
const (
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
// or TAP device.
//
// Request payload: [vhost.QueueFile]
// Kernel name: VHOST_NET_SET_BACKEND
vhostNetIoctlSetBackend = 0x4008af30
)
// SetQueueBackend attaches a virtqueue of the vhost networking device
// described by controlFD to the given backend file descriptor.
// The backend file descriptor can either be a RAW socket or a TAP device. When
// it is -1, the queue will be detached.
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
QueueIndex: queueIndex,
FD: int32(backendFD),
})); err != nil {
return fmt.Errorf("set queue backend file descriptor: %w", err)
}
return nil
}

View File

@@ -1,69 +0,0 @@
package vhostnet
import (
"errors"
"github.com/slackhq/nebula/overlay/virtqueue"
)
type optionValues struct {
queueSize int
backendFD int
}
func (o *optionValues) apply(options []Option) {
for _, option := range options {
option(o)
}
}
func (o *optionValues) validate() error {
if o.queueSize == -1 {
return errors.New("queue size is required")
}
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
return err
}
if o.backendFD == -1 {
return errors.New("backend file descriptor is required")
}
return nil
}
var optionDefaults = optionValues{
// Required.
queueSize: -1,
// Required.
backendFD: -1,
}
// Option can be passed to [NewDevice] to influence device creation.
type Option func(*optionValues)
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
// that are to be created for the device. It specifies the number of
// entries/buffers each queue can hold. This also affects the memory
// consumption.
// This is required and must be an integer from 1 to 32768 that is also a power
// of 2.
func WithQueueSize(queueSize int) Option {
return func(o *optionValues) { o.queueSize = queueSize }
}
// WithBackendFD returns an [Option] that sets the file descriptor of the
// backend that will be used for the queues of the device. The device will write
// and read packets to/from that backend. The file descriptor can either be of a
// RAW socket or TUN/TAP device.
// Either this or [WithBackendDevice] is required.
func WithBackendFD(backendFD int) Option {
return func(o *optionValues) { o.backendFD = backendFD }
}
//// WithBackendDevice returns an [Option] that sets the given TAP device as the
//// backend that will be used for the queues of the device. The device will
//// write and read packets to/from that backend. The TAP device should have been
//// created with the [tuntap.WithVirtioNetHdr] option enabled.
//// Either this or [WithBackendFD] is required.
//func WithBackendDevice(dev *tuntap.Device) Option {
// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
//}

View File

@@ -1,23 +0,0 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,140 +0,0 @@
package virtqueue
import (
"fmt"
"unsafe"
)
// availableRingFlag is a flag that describes an [AvailableRing].
type availableRingFlag uint16
const (
// availableRingFlagNoInterrupt is used by the guest to advise the host to
// not interrupt it when consuming a buffer. It's unreliable, so it's simply
// an optimization.
availableRingFlagNoInterrupt availableRingFlag = 1 << iota
)
// availableRingSize is the number of bytes needed to store an [AvailableRing]
// with the given queue size in memory.
func availableRingSize(queueSize int) int {
return 6 + 2*queueSize
}
// availableRingAlignment is the minimum alignment of an [AvailableRing]
// in memory, as required by the virtio spec.
const availableRingAlignment = 2
// AvailableRing is used by the driver to offer descriptor chains to the device.
// Each ring entry refers to the head of a descriptor chain. It is only written
// to by the driver and read by the device.
//
// Because the size of the ring depends on the queue size, we cannot define a
// Go struct with a static size that maps to the memory of the ring. Instead,
// this struct only contains pointers to the corresponding memory areas.
type AvailableRing struct {
initialized bool
// flags that describe this ring.
flags *availableRingFlag
// ringIndex indicates where the driver would put the next entry into the
// ring (modulo the queue size).
ringIndex *uint16
// ring references buffers using the index of the head of the descriptor
// chain in the [DescriptorTable]. It wraps around at queue size.
ring []uint16
// usedEvent is not used by this implementation, but we reserve it anyway to
// avoid issues in case a device may try to access it, contrary to the
// virtio specification.
usedEvent *uint16
}
// newAvailableRing creates an available ring that uses the given underlying
// memory. The length of the memory slice must match the size needed for the
// ring (see [availableRingSize]) for the given queue size.
func newAvailableRing(queueSize int, mem []byte) *AvailableRing {
ringSize := availableRingSize(queueSize)
if len(mem) != ringSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for available ring: %v", len(mem), ringSize))
}
return &AvailableRing{
initialized: true,
flags: (*availableRingFlag)(unsafe.Pointer(&mem[0])),
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
ring: unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize),
usedEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
}
}
// Address returns the pointer to the beginning of the ring in memory.
// Do not modify the memory directly to not interfere with this implementation.
func (r *AvailableRing) Address() uintptr {
if !r.initialized {
panic("available ring is not initialized")
}
return uintptr(unsafe.Pointer(r.flags))
}
// offer adds the given descriptor chain heads to the available ring and
// advances the ring index accordingly to make the device process the new
// descriptor chains.
func (r *AvailableRing) offerElements(chains []UsedElement) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
// Add descriptor chain heads to the ring.
for offset, x := range chains {
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x.GetHead()
}
// Increase the ring index by the number of descriptor chains added to the
// ring.
*r.ringIndex += uint16(len(chains))
}
func (r *AvailableRing) offer(chains []uint16) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
// Add descriptor chain heads to the ring.
for offset, x := range chains {
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x
}
// Increase the ring index by the number of descriptor chains added to the
// ring.
*r.ringIndex += uint16(len(chains))
}
func (r *AvailableRing) offerSingle(x uint16) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0
// Add descriptor chain heads to the ring.
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x
// Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1
}

View File

@@ -1,71 +0,0 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAvailableRing_MemoryLayout(t *testing.T) {
const queueSize = 2
memory := make([]byte, availableRingSize(queueSize))
r := newAvailableRing(queueSize, memory)
*r.flags = 0x01ff
*r.ringIndex = 1
r.ring[0] = 0x1234
r.ring[1] = 0x5678
assert.Equal(t, []byte{
0xff, 0x01,
0x01, 0x00,
0x34, 0x12,
0x78, 0x56,
0x00, 0x00,
}, memory)
}
func TestAvailableRing_Offer(t *testing.T) {
const queueSize = 8
chainHeads := []uint16{42, 33, 69}
tests := []struct {
name string
startRingIndex uint16
expectedRingIndex uint16
expectedRing []uint16
}{
{
name: "no overflow",
startRingIndex: 0,
expectedRingIndex: 3,
expectedRing: []uint16{42, 33, 69, 0, 0, 0, 0, 0},
},
{
name: "ring overflow",
startRingIndex: 6,
expectedRingIndex: 9,
expectedRing: []uint16{69, 0, 0, 0, 0, 0, 42, 33},
},
{
name: "index overflow",
startRingIndex: 65535,
expectedRingIndex: 2,
expectedRing: []uint16{33, 69, 0, 0, 0, 0, 0, 42},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
memory := make([]byte, availableRingSize(queueSize))
r := newAvailableRing(queueSize, memory)
*r.ringIndex = tt.startRingIndex
r.offer(chainHeads)
assert.Equal(t, tt.expectedRingIndex, *r.ringIndex)
assert.Equal(t, tt.expectedRing, r.ring)
})
}
}

View File

@@ -1,43 +0,0 @@
package virtqueue
// descriptorFlag is a flag that describes a [Descriptor].
type descriptorFlag uint16
const (
// descriptorFlagHasNext marks a descriptor chain as continuing via the next
// field.
descriptorFlagHasNext descriptorFlag = 1 << iota
// descriptorFlagWritable marks a buffer as device write-only (otherwise
// device read-only).
descriptorFlagWritable
// descriptorFlagIndirect means the buffer contains a list of buffer
// descriptors to provide an additional layer of indirection.
// Only allowed when the [virtio.FeatureIndirectDescriptors] feature was
// negotiated.
descriptorFlagIndirect
)
// descriptorSize is the number of bytes needed to store a [Descriptor] in
// memory.
const descriptorSize = 16
// Descriptor describes (a part of) a buffer which is either read-only for the
// device or write-only for the device (depending on [descriptorFlagWritable]).
// Multiple descriptors can be chained to produce a "descriptor chain" that can
// contain both device-readable and device-writable buffers. Device-readable
// descriptors always come first in a chain. A single, large buffer may be
// split up by chaining multiple similar descriptors that reference different
// memory pages. This is required, because buffers may exceed a single page size
// and the memory accessed by the device is expected to be continuous.
type Descriptor struct {
// address is the address to the continuous memory holding the data for this
// descriptor.
address uintptr
// length is the amount of bytes stored at address.
length uint32
// flags that describe this descriptor.
flags descriptorFlag
// next contains the index of the next descriptor continuing this descriptor
// chain when the [descriptorFlagHasNext] flag is set.
next uint16
}

View File

@@ -1,12 +0,0 @@
package virtqueue
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestDescriptor_Size(t *testing.T) {
assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{}))
}

View File

@@ -1,641 +0,0 @@
package virtqueue
import (
"errors"
"fmt"
"math"
"unsafe"
"golang.org/x/sys/unix"
)
var (
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
// no buffers, which is not allowed.
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
// exhausted, meaning that the queue is full.
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
// ErrInvalidDescriptorChain is returned when a descriptor chain is not
// valid for a given operation.
ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
)
// noFreeHead is used to mark when all descriptors are in use and we have no
// free chain. This value is impossible to occur as an index naturally, because
// it exceeds the maximum queue size.
const noFreeHead = uint16(math.MaxUint16)
// descriptorTableSize is the number of bytes needed to store a
// [DescriptorTable] with the given queue size in memory.
func descriptorTableSize(queueSize int) int {
return descriptorSize * queueSize
}
// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
// in memory, as required by the virtio spec.
const descriptorTableAlignment = 16
// DescriptorTable is a table that holds [Descriptor]s, addressed via their
// index in the slice.
type DescriptorTable struct {
descriptors []Descriptor
// freeHeadIndex is the index of the head of the descriptor chain which
// contains all currently unused descriptors. When all descriptors are in
// use, this has the special value of noFreeHead.
freeHeadIndex uint16
// freeNum tracks the number of descriptors which are currently not in use.
freeNum uint16
bufferBase uintptr
bufferSize int
itemSize int
}
// newDescriptorTable creates a descriptor table that uses the given underlying
// memory. The Length of the memory slice must match the size needed for the
// descriptor table (see [descriptorTableSize]) for the given queue size.
//
// Before this descriptor table can be used, [initialize] must be called.
func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
dtSize := descriptorTableSize(queueSize)
if len(mem) != dtSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for descriptor table: %v", len(mem), dtSize))
}
return &DescriptorTable{
descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
// We have no free descriptors until they were initialized.
freeHeadIndex: noFreeHead,
freeNum: 0,
itemSize: itemSize, //todo configurable? needs to be page-aligned
}
}
// Address returns the pointer to the beginning of the descriptor table in
// memory. Do not modify the memory directly to not interfere with this
// implementation.
func (dt *DescriptorTable) Address() uintptr {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
//should be same as dt.bufferBase
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
}
func (dt *DescriptorTable) Size() uintptr {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
return uintptr(dt.bufferSize)
}
// BufferAddresses returns a map of pointer->size for all allocations used by the table
func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
return map[uintptr]int{dt.bufferBase: dt.bufferSize}
}
// initializeDescriptors allocates buffers with the size of a full memory page
// for each descriptor in the table. While this may be a bit wasteful, it makes
// dealing with descriptors way easier. Without this preallocation, we would
// have to allocate and free memory on demand, increasing complexity.
//
// All descriptors will be marked as free and will form a free chain. The
// addresses of all descriptors will be populated while their length remains
// zero.
func (dt *DescriptorTable) initializeDescriptors() error {
numDescriptors := len(dt.descriptors)
// Allocate ONE large region for all buffers
totalSize := dt.itemSize * numDescriptors
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
unix.PROT_READ|unix.PROT_WRITE,
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
if err != nil {
return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
}
// Store the base for cleanup later
dt.bufferBase = uintptr(basePtr)
dt.bufferSize = totalSize
for i := range dt.descriptors {
dt.descriptors[i] = Descriptor{
address: dt.bufferBase + uintptr(i*dt.itemSize),
length: 0,
// All descriptors should form a free chain that loops around.
flags: descriptorFlagHasNext,
next: uint16((i + 1) % len(dt.descriptors)),
}
}
// All descriptors are free to use now.
dt.freeHeadIndex = 0
dt.freeNum = uint16(len(dt.descriptors))
return nil
}
// releaseBuffers releases all allocated buffers for this descriptor table.
// The implementation will try to release as many buffers as possible and
// collect potential errors before returning them.
// The descriptor table should no longer be used after calling this.
func (dt *DescriptorTable) releaseBuffers() error {
for i := range dt.descriptors {
descriptor := &dt.descriptors[i]
descriptor.address = 0
}
// As a safety measure, make sure no descriptors can be used anymore.
dt.freeHeadIndex = noFreeHead
dt.freeNum = 0
if dt.bufferBase != 0 {
// The pointer points to memory not managed by Go, so this conversion
// is safe. See https://github.com/golang/go/issues/58625
dt.bufferBase = 0
//goland:noinspection GoVetUnsafePointer
err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
if err != nil {
return fmt.Errorf("release buffer memory: %w", err)
}
}
return nil
}
// createDescriptorChain creates a new descriptor chain within the descriptor
// table which contains a number of device-readable buffers (out buffers) and
// device-writable buffers (in buffers).
//
// All buffers in the outBuffers slice will be concatenated by chaining
// descriptors, one for each buffer in the slice. The size of the single buffers
// must not exceed the size of a memory page (see [os.Getpagesize]).
// When numInBuffers is greater than zero, the given number of device-writable
// descriptors will be appended to the end of the chain, each referencing a
// whole memory page.
//
// The index of the head of the new descriptor chain will be returned. Callers
// should make sure to free the descriptor chain using [freeDescriptorChain]
// after it was used by the device.
//
// When there are not enough free descriptors to hold the given number of
// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the
// caller should try again after some descriptor chains were used by the device
// and returned back into the free chain.
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
// Calculate the number of descriptors needed to build the chain.
numDesc := uint16(len(outBuffers) + numInBuffers)
// Descriptor chains must always contain at least one descriptor.
if numDesc < 1 {
return 0, ErrDescriptorChainEmpty
}
// Do we still have enough free descriptors?
if numDesc > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
next := head
tail := head
for i, buffer := range outBuffers {
desc := &dt.descriptors[next]
checkUnusedDescriptorLength(next, desc)
if len(buffer) > dt.itemSize {
// The caller should already prevent that from happening.
panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
}
// Copy the buffer to the memory referenced by the descriptor.
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
desc.length = uint32(len(buffer))
// Clear the flags in case there were any others set.
desc.flags = descriptorFlagHasNext
tail = next
next = desc.next
}
for range numInBuffers {
desc := &dt.descriptors[next]
checkUnusedDescriptorLength(next, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
// Mark the descriptor as device-writable.
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
tail = next
next = desc.next
}
// The last descriptor should end the chain.
tailDesc := &dt.descriptors[tail]
tailDesc.flags &= ^descriptorFlagHasNext
tailDesc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= numDesc
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if tail != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
//todo just fill the damn table
// Do we still have enough free descriptors?
if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
desc := &dt.descriptors[head]
next := desc.next
checkUnusedDescriptorLength(head, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
desc.flags = 0 // descriptorFlagWritable
desc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= 1
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if next != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
// Do we still have enough free descriptors?
if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
desc := &dt.descriptors[head]
next := desc.next
checkUnusedDescriptorLength(head, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
desc.flags = descriptorFlagWritable
desc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= 1
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if next != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
// TODO: Implement a zero-copy variant of createDescriptorChain?
// getDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain that starts with
// the given head index. The descriptor chain must have been created using
// [createDescriptorChain] and must not have been freed yet (meaning that the
// head index must not be contained in the free chain).
//
// Be careful to only access the returned buffer slices when the device has not
// yet or is no longer using them. They must not be accessed after
// [freeDescriptorChain] has been called.
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
if int(head) > len(dt.descriptors) {
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
outBuffers = append(outBuffers, bs)
} else {
inBuffers = append(inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return
}
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
if int(head) > len(dt.descriptors) {
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
return bs, nil
}
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
return fmt.Errorf("there should not be an outbuffer in %d", head)
} else {
*inBuffers = append(*inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return nil
}
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
if int(head) > len(dt.descriptors) {
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
length := 0
//find length
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
if desc.flags&descriptorFlagWritable == 0 {
return 0, fmt.Errorf("receive queue contains device-readable buffer")
}
length += int(desc.length)
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
if maxLen > 0 {
//todo length = min(maxLen, length)
}
//set out to length:
out = out[:length]
//now do the copying
copied := 0
for range len(dt.descriptors) {
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length))
copied += copy(out[copied:], bs)
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// we did this already, no need to detect loops.
next = desc.next
}
if copied != length {
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
}
return length, nil
}
// freeDescriptorChain can be used to free a descriptor chain when it is no
// longer in use. The descriptor chain that starts with the given index will be
// put back into the free chain, so the descriptors can be used for later calls
// of [createDescriptorChain].
// The descriptor chain must have been created using [createDescriptorChain] and
// must not have been freed yet (meaning that the head index must not be
// contained in the free chain).
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
var tailDesc *Descriptor
var chainLen uint16
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
chainLen++
// Set the length of all unused descriptors back to zero.
desc.length = 0
// Unset all flags except the next flag.
desc.flags &= descriptorFlagHasNext
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
tailDesc = desc
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
if tailDesc == nil {
// A descriptor chain longer than the queue size but without loops
// should be impossible.
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
}
// The tail descriptor does not have the next flag set, but when it comes
// back into the free chain, it should have.
tailDesc.flags = descriptorFlagHasNext
if dt.freeHeadIndex == noFreeHead {
// The whole free chain was used up, so we turn this returned descriptor
// chain into the new free chain by completing the circle and using its
// head.
tailDesc.next = head
dt.freeHeadIndex = head
} else {
// Attach the returned chain at the beginning of the free chain but
// right after the free chain head.
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
tailDesc.next = freeHeadDesc.next
freeHeadDesc.next = head
}
dt.freeNum += chainLen
return nil
}
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
// is zero, as it should be.
// This is not a requirement by the virtio spec but rather a thing we do to
// notice when our algorithm goes sideways.
func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
if desc.length != 0 {
panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
}
}

View File

@@ -1,407 +0,0 @@
package virtqueue
import (
"os"
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestDescriptorTable_InitializeDescriptors(t *testing.T) {
const queueSize = 32
dt := DescriptorTable{
descriptors: make([]Descriptor, queueSize),
}
assert.NoError(t, dt.initializeDescriptors())
t.Cleanup(func() {
assert.NoError(t, dt.releaseBuffers())
})
for i, descriptor := range dt.descriptors {
assert.NotZero(t, descriptor.address)
assert.Zero(t, descriptor.length)
assert.EqualValues(t, descriptorFlagHasNext, descriptor.flags)
assert.EqualValues(t, (i+1)%queueSize, descriptor.next)
}
}
func TestDescriptorTable_DescriptorChains(t *testing.T) {
// Use a very short queue size to not make this test overly verbose.
const queueSize = 8
pageSize := os.Getpagesize() * 2
// Initialize descriptor table.
dt := DescriptorTable{
descriptors: make([]Descriptor, queueSize),
}
assert.NoError(t, dt.initializeDescriptors())
t.Cleanup(func() {
assert.NoError(t, dt.releaseBuffers())
})
// Some utilities for easier checking if the descriptor table looks as
// expected.
type desc struct {
buffer []byte
flags descriptorFlag
next uint16
}
assertDescriptorTable := func(expected [queueSize]desc) {
for i := 0; i < queueSize; i++ {
actualDesc := &dt.descriptors[i]
expectedDesc := &expected[i]
assert.Equal(t, uint32(len(expectedDesc.buffer)), actualDesc.length)
if len(expectedDesc.buffer) > 0 {
//goland:noinspection GoVetUnsafePointer
assert.EqualValues(t,
unsafe.Slice((*byte)(unsafe.Pointer(actualDesc.address)), actualDesc.length),
expectedDesc.buffer)
}
assert.Equal(t, expectedDesc.flags, actualDesc.flags)
if expectedDesc.flags&descriptorFlagHasNext != 0 {
assert.Equal(t, expectedDesc.next, actualDesc.next)
}
}
}
// Initial state: All descriptors are in the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(8), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 1,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 4,
},
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create the first chain.
firstChain, err := dt.createDescriptorChain([][]byte{
makeTestBuffer(t, 26),
makeTestBuffer(t, 256),
}, 1)
assert.NoError(t, err)
assert.Equal(t, uint16(1), firstChain)
// Now there should be a new chain next to the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(5), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 4,
},
{
// Head of first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create a second chain with only a single in buffer.
secondChain, err := dt.createDescriptorChain(nil, 1)
assert.NoError(t, err)
assert.Equal(t, uint16(4), secondChain)
// Now there should be two chains next to the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(4), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 5,
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create a third chain taking up all remaining descriptors.
thirdChain, err := dt.createDescriptorChain([][]byte{
makeTestBuffer(t, 42),
makeTestBuffer(t, 96),
makeTestBuffer(t, 33),
makeTestBuffer(t, 222),
}, 0)
assert.NoError(t, err)
assert.Equal(t, uint16(5), thirdChain)
// Now there should be three chains and no free chain.
assert.Equal(t, noFreeHead, dt.freeHeadIndex)
assert.Equal(t, uint16(0), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Tail of the third chain.
buffer: makeTestBuffer(t, 222),
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head of the third chain.
buffer: makeTestBuffer(t, 42),
flags: descriptorFlagHasNext,
next: 6,
},
{
buffer: makeTestBuffer(t, 96),
flags: descriptorFlagHasNext,
next: 7,
},
{
buffer: makeTestBuffer(t, 33),
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the third chain.
assert.NoError(t, dt.freeDescriptorChain(thirdChain))
// Now there should be two chains and a free chain again.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(4), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the first chain.
assert.NoError(t, dt.freeDescriptorChain(firstChain))
// Now there should be only a single chain next to the free chain.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(7), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 1,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the second chain.
assert.NoError(t, dt.freeDescriptorChain(secondChain))
// Now all descriptors should be in the free chain again.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(8), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 1,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 4,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
}
func makeTestBuffer(t *testing.T, length int) []byte {
t.Helper()
buf := make([]byte, length)
for i := 0; i < length; i++ {
buf[i] = byte(length - i)
}
return buf
}

View File

@@ -1,7 +0,0 @@
// Package virtqueue implements the driver-side for a virtio queue as described
// in the specification:
// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006
// This package does not make assumptions about the device that consumes the
// queue. It rather just allocates the queue structures in memory and provides
// methods to interact with it.
package virtqueue

View File

@@ -1,45 +0,0 @@
package virtqueue
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gvisor.dev/gvisor/pkg/eventfd"
)
// Tests how an eventfd and a waiting goroutine can be gracefully closed.
// Extends the eventfd test suite:
// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go
func TestEventFD_CancelWait(t *testing.T) {
efd, err := eventfd.Create()
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, efd.Close())
})
var stop bool
done := make(chan struct{})
go func() {
for !stop {
_ = efd.Wait()
}
close(done)
}()
select {
case <-done:
t.Fatalf("goroutine ended early")
case <-time.After(500 * time.Millisecond):
}
stop = true
assert.NoError(t, efd.Notify())
select {
case <-done:
break
case <-time.After(5 * time.Second):
t.Error("goroutine did not end")
}
}

View File

@@ -1,33 +0,0 @@
package virtqueue
import (
"errors"
"fmt"
)
// ErrQueueSizeInvalid is returned when a queue size is invalid.
var ErrQueueSizeInvalid = errors.New("queue size is invalid")
// CheckQueueSize checks if the given value would be a valid size for a
// virtqueue and returns an [ErrQueueSizeInvalid], if not.
func CheckQueueSize(queueSize int) error {
if queueSize <= 0 {
return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize)
}
// The queue size must always be a power of 2.
// This ensures that ring indexes wrap correctly when the 16-bit integers
// overflow.
if queueSize&(queueSize-1) != 0 {
return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize)
}
// The largest power of 2 that fits into a 16-bit integer is 32768.
// 2 * 32768 would be 65536 which no longer fits.
if queueSize > 32768 {
return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768",
ErrQueueSizeInvalid, queueSize)
}
return nil
}

View File

@@ -1,59 +0,0 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCheckQueueSize(t *testing.T) {
tests := []struct {
name string
queueSize int
containsErr string
}{
{
name: "negative",
queueSize: -1,
containsErr: "too small",
},
{
name: "zero",
queueSize: 0,
containsErr: "too small",
},
{
name: "not a power of 2",
queueSize: 24,
containsErr: "not a power of 2",
},
{
name: "too large",
queueSize: 65536,
containsErr: "larger than the maximum",
},
{
name: "valid 1",
queueSize: 1,
},
{
name: "valid 256",
queueSize: 256,
},
{
name: "valid 32768",
queueSize: 32768,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckQueueSize(tt.queueSize)
if tt.containsErr != "" {
assert.ErrorContains(t, err, tt.containsErr)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,530 +0,0 @@
package virtqueue
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"github.com/slackhq/nebula/overlay/eventfd"
"golang.org/x/sys/unix"
)
// SplitQueue is a virtqueue that consists of several parts, where each part is
// writeable by either the driver or the device, but not both.
type SplitQueue struct {
// size is the size of the queue.
size int
// buf is the underlying memory used for the queue.
buf []byte
descriptorTable *DescriptorTable
availableRing *AvailableRing
usedRing *UsedRing
// kickEventFD is used to signal the device when descriptor chains were
// added to the available ring.
kickEventFD eventfd.EventFD
// callEventFD is used by the device to signal when it has used descriptor
// chains and put them in the used ring.
callEventFD eventfd.EventFD
// stop is used by [SplitQueue.Close] to cancel the goroutine that handles
// used buffer notifications. It blocks until the goroutine ended.
stop func() error
itemSize int
epoll eventfd.Epoll
more int
}
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
// specifies the number of entries/buffers the queue can hold. This also affects
// the memory consumption.
func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
if err = CheckQueueSize(queueSize); err != nil {
return nil, err
}
if itemSize%os.Getpagesize() != 0 {
return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
}
sq := SplitQueue{
size: queueSize,
itemSize: itemSize,
}
// Clean up a partially initialized queue when something fails.
defer func() {
if err != nil {
_ = sq.Close()
}
}()
// There are multiple ways for how the memory for the virtqueue could be
// allocated. We could use Go native structs with arrays inside them, but
// this wouldn't allow us to make the queue size configurable. And including
// a slice in the Go structs wouldn't work, because this would just put the
// Go slice descriptor into the memory region which the virtio device will
// not understand.
// Additionally, Go does not allow us to ensure a correct alignment of the
// parts of the virtqueue, as it is required by the virtio specification.
//
// To resolve this, let's just allocate the memory manually by allocating
// one or more memory pages, depending on the queue size. Making the
// virtqueue start at the beginning of a page is not strictly necessary, as
// the virtio specification does not require it to be continuous in the
// physical memory of the host (e.g. the vhost implementation in the kernel
// always uses copy_from_user to access it), but this makes it very easy to
// guarantee the alignment. Also, it is not required for the virtqueue parts
// to be in the same memory region, as we pass separate pointers to them to
// the device, but this design just makes things easier to implement.
//
// One added benefit of allocating the memory manually is, that we have full
// control over its lifetime and don't risk the garbage collector to collect
// our valuable structures while the device still works with them.
// The descriptor table is at the start of the page, so alignment is not an
// issue here.
descriptorTableStart := 0
descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
availableRingStart := align(descriptorTableEnd, availableRingAlignment)
availableRingEnd := availableRingStart + availableRingSize(queueSize)
usedRingStart := align(availableRingEnd, usedRingAlignment)
usedRingEnd := usedRingStart + usedRingSize(queueSize)
sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
unix.PROT_READ|unix.PROT_WRITE,
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
if err != nil {
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
}
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
sq.kickEventFD, err = eventfd.New()
if err != nil {
return nil, fmt.Errorf("create kick event file descriptor: %w", err)
}
sq.callEventFD, err = eventfd.New()
if err != nil {
return nil, fmt.Errorf("create call event file descriptor: %w", err)
}
if err = sq.descriptorTable.initializeDescriptors(); err != nil {
return nil, fmt.Errorf("initialize descriptors: %w", err)
}
sq.epoll, err = eventfd.NewEpoll()
if err != nil {
return nil, err
}
err = sq.epoll.AddEvent(sq.callEventFD.FD())
if err != nil {
return nil, err
}
// Consume used buffer notifications in the background.
sq.stop = sq.startConsumeUsedRing()
return &sq, nil
}
// Size returns the size of this queue, which is the number of entries/buffers
// this queue can hold.
func (sq *SplitQueue) Size() int {
return sq.size
}
// DescriptorTable returns the [DescriptorTable] behind this queue.
func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
return sq.descriptorTable
}
// AvailableRing returns the [AvailableRing] behind this queue.
func (sq *SplitQueue) AvailableRing() *AvailableRing {
return sq.availableRing
}
// UsedRing returns the [UsedRing] behind this queue.
func (sq *SplitQueue) UsedRing() *UsedRing {
return sq.usedRing
}
// KickEventFD returns the kick event file descriptor behind this queue.
// The returned file descriptor should be used with great care to not interfere
// with this implementation.
func (sq *SplitQueue) KickEventFD() int {
return sq.kickEventFD.FD()
}
// CallEventFD returns the call event file descriptor behind this queue.
// The returned file descriptor should be used with great care to not interfere
// with this implementation.
func (sq *SplitQueue) CallEventFD() int {
return sq.callEventFD.FD()
}
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
// A function is returned that can be used to gracefully cancel it. todo rename
func (sq *SplitQueue) startConsumeUsedRing() func() error {
return func() error {
// The goroutine blocks until it receives a signal on the event file
// descriptor, so it will never notice the context being canceled.
// To resolve this, we can just produce a fake-signal ourselves to wake
// it up.
if err := sq.callEventFD.Kick(); err != nil {
return fmt.Errorf("wake up goroutine: %w", err)
}
return nil
}
}
// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s
func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
stillNeedToTake, out := sq.usedRing.take(-1)
sq.more = stillNeedToTake
if stillNeedToTake == 0 {
_ = sq.epoll.Clear() //???
}
return out, nil
}
}
return nil, ctx.Err()
}
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
var n int
var err error
for ctx.Err() == nil {
out, ok := sq.usedRing.takeOne()
if ok {
return out, nil
}
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return 0, fmt.Errorf("wait: %w", err)
}
if n > 0 {
out, ok = sq.usedRing.takeOne()
if ok {
_ = sq.epoll.Clear() //???
return out, nil
} else {
continue //???
}
}
}
return 0, ctx.Err()
}
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
//we have leftovers in the fridge
if sq.more > 0 {
stillNeedToTake, out := sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
return out, nil
}
//look inside the fridge
stillNeedToTake, out := sq.usedRing.take(maxToTake)
if len(out) > 0 {
sq.more = stillNeedToTake
return out, nil
}
//fridge is empty I guess
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
_ = sq.epoll.Clear() //???
stillNeedToTake, out = sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
return out, nil
}
}
return nil, ctx.Err()
}
// OfferDescriptorChain offers a descriptor chain to the device which contains a
// number of device-readable buffers (out buffers) and device-writable buffers
// (in buffers).
//
// All buffers in the outBuffers slice will be concatenated by chaining
// descriptors, one for each buffer in the slice. When a buffer is too large to
// fit into a single descriptor (limited by the system's page size), it will be
// split up into multiple descriptors within the chain.
// When numInBuffers is greater than zero, the given number of device-writable
// descriptors will be appended to the end of the chain, each referencing a
// whole memory page (see [os.Getpagesize]).
//
// When the queue is full and no more descriptor chains can be added, a wrapped
// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
// this method will handle this error and will block instead until there are
// enough free descriptors again.
//
// After defining the descriptor chain in the [DescriptorTable], the index of
// the head of the chain will be made available to the device using the
// [AvailableRing] and will be returned by this method.
// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
// notified when the descriptor chain was used by the device and should free the
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
// they're done with them. When this does not happen, the queue will run full
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for {
head, err = sq.descriptorTable.createDescriptorForInputs()
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
return 0, err
} else {
return 0, fmt.Errorf("create descriptor chain: %w", err)
}
}
// Make the descriptor chain available to the device.
sq.availableRing.offerSingle(head)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return head, fmt.Errorf("notify device: %w", err)
}
return head, nil
}
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) {
// TODO change this
// Each descriptor can only hold a whole memory page, so split large out
// buffers into multiple smaller ones.
outBuffers = splitBuffers(outBuffers, sq.itemSize)
chains := make([]uint16, len(outBuffers))
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for i := range outBuffers {
for {
bufs := [][]byte{prepend, outBuffers[i]}
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
// Wait for more free descriptors to be put back into the queue.
// If the number of free descriptors is still not sufficient, we'll
// land here again.
//todo should never happen
syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier
continue
}
return nil, fmt.Errorf("create descriptor chain: %w", err)
}
chains[i] = head
}
// Make the descriptor chain available to the device.
sq.availableRing.offer(chains)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return chains, fmt.Errorf("notify device: %w", err)
}
return chains, nil
}
// GetDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain with the given
// head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
//
// Be careful to only access the returned buffer slices when the device is no
// longer using them. They must not be accessed after
// [SplitQueue.FreeDescriptorChain] has been called.
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
return sq.descriptorTable.getDescriptorChain(head)
}
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
return sq.descriptorTable.getDescriptorItem(head)
}
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
}
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
}
// FreeDescriptorChain frees the descriptor chain with the given head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
//
// This creates new room in the queue which can be used by following
// [SplitQueue.OfferDescriptorChain] calls.
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
// are waiting for free room in the queue, they may become unblocked by this.
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
//not called under lock
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
return fmt.Errorf("free: %w", err)
}
return nil
}
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
//not called under lock
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
}
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
//todo not doing this may break eventually?
//not called under lock
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
// return fmt.Errorf("free: %w", err)
//}
// Make the descriptor chain available to the device.
sq.availableRing.offer(chains)
// Notify the device to make it process the updated available ring.
if kick {
return sq.Kick()
}
return nil
}
func (sq *SplitQueue) Kick() error {
if err := sq.kickEventFD.Kick(); err != nil {
return fmt.Errorf("notify device: %w", err)
}
return nil
}
// Close releases all resources used for this queue.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.
func (sq *SplitQueue) Close() error {
var errs []error
if sq.stop != nil {
// This has to happen before the event file descriptors may be closed.
if err := sq.stop(); err != nil {
errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
}
// Make sure that this code block is executed only once.
sq.stop = nil
}
if err := sq.kickEventFD.Close(); err != nil {
errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
}
if err := sq.callEventFD.Close(); err != nil {
errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
}
if err := sq.descriptorTable.releaseBuffers(); err != nil {
errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
}
if sq.buf != nil {
if err := unix.Munmap(sq.buf); err == nil {
sq.buf = nil
} else {
errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
}
}
return errors.Join(errs...)
}
// ensureInitialized is used as a guard to prevent methods to be called on an
// uninitialized instance.
func (sq *SplitQueue) ensureInitialized() {
if sq.buf == nil {
panic("used ring is not initialized")
}
}
func align(index, alignment int) int {
remainder := index % alignment
if remainder == 0 {
return index
}
return index + alignment - remainder
}
// splitBuffers processes a list of buffers and splits each buffer that is
// larger than the size limit into multiple smaller buffers.
// If none of the buffers are too big though, do nothing, to avoid allocation for now
func splitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
for i := range buffers {
if len(buffers[i]) > sizeLimit {
return reallySplitBuffers(buffers, sizeLimit)
}
}
return buffers
}
func reallySplitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
result := make([][]byte, 0, len(buffers))
for _, buffer := range buffers {
for added := 0; added < len(buffer); added += sizeLimit {
if len(buffer)-added <= sizeLimit {
result = append(result, buffer[added:])
break
}
result = append(result, buffer[added:added+sizeLimit])
}
}
return result
}

View File

@@ -1,105 +0,0 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSplitQueue_MemoryAlignment(t *testing.T) {
tests := []struct {
name string
queueSize int
}{
{
name: "minimal queue size",
queueSize: 1,
},
{
name: "small queue size",
queueSize: 8,
},
{
name: "large queue size",
queueSize: 256,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sq, err := NewSplitQueue(tt.queueSize)
require.NoError(t, err)
assert.Zero(t, sq.descriptorTable.Address()%descriptorTableAlignment)
assert.Zero(t, sq.availableRing.Address()%availableRingAlignment)
assert.Zero(t, sq.usedRing.Address()%usedRingAlignment)
})
}
}
func TestSplitBuffers(t *testing.T) {
const sizeLimit = 16
tests := []struct {
name string
buffers [][]byte
expected [][]byte
}{
{
name: "no buffers",
buffers: make([][]byte, 0),
expected: make([][]byte, 0),
},
{
name: "small",
buffers: [][]byte{
make([]byte, 11),
},
expected: [][]byte{
make([]byte, 11),
},
},
{
name: "exact size",
buffers: [][]byte{
make([]byte, sizeLimit),
},
expected: [][]byte{
make([]byte, sizeLimit),
},
},
{
name: "large",
buffers: [][]byte{
make([]byte, 42),
},
expected: [][]byte{
make([]byte, 16),
make([]byte, 16),
make([]byte, 10),
},
},
{
name: "mixed",
buffers: [][]byte{
make([]byte, 7),
make([]byte, 30),
make([]byte, 15),
make([]byte, 32),
},
expected: [][]byte{
make([]byte, 7),
make([]byte, 16),
make([]byte, 14),
make([]byte, 15),
make([]byte, 16),
make([]byte, 16),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := splitBuffers(tt.buffers, sizeLimit)
assert.Equal(t, tt.expected, actual)
})
}
}

View File

@@ -1,21 +0,0 @@
package virtqueue
// usedElementSize is the number of bytes needed to store a [UsedElement] in
// memory.
const usedElementSize = 8
// UsedElement is an element of the [UsedRing] and describes a descriptor chain
// that was used by the device.
type UsedElement struct {
// DescriptorIndex is the index of the head of the used descriptor chain in
// the [DescriptorTable].
// The index is 32-bit here for padding reasons.
DescriptorIndex uint32
// Length is the number of bytes written into the device writable portion of
// the buffer described by the descriptor chain.
Length uint32
}
func (u *UsedElement) GetHead() uint16 {
return uint16(u.DescriptorIndex)
}

View File

@@ -1,12 +0,0 @@
package virtqueue
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestUsedElement_Size(t *testing.T) {
assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{}))
}

View File

@@ -1,184 +0,0 @@
package virtqueue
import (
"fmt"
"unsafe"
)
// usedRingFlag is a flag that describes a [UsedRing].
type usedRingFlag uint16
const (
// usedRingFlagNoNotify is used by the host to advise the guest to not
// kick it when adding a buffer. It's unreliable, so it's simply an
// optimization. Guest will still kick when it's out of buffers.
usedRingFlagNoNotify usedRingFlag = 1 << iota
)
// usedRingSize is the number of bytes needed to store a [UsedRing] with the
// given queue size in memory.
func usedRingSize(queueSize int) int {
return 6 + usedElementSize*queueSize
}
// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as
// required by the virtio spec.
const usedRingAlignment = 4
// UsedRing is where the device returns descriptor chains once it is done with
// them. Each ring entry is a [UsedElement]. It is only written to by the device
// and read by the driver.
//
// Because the size of the ring depends on the queue size, we cannot define a
// Go struct with a static size that maps to the memory of the ring. Instead,
// this struct only contains pointers to the corresponding memory areas.
type UsedRing struct {
initialized bool
// flags that describe this ring.
flags *usedRingFlag
// ringIndex indicates where the device would put the next entry into the
// ring (modulo the queue size).
ringIndex *uint16
// ring contains the [UsedElement]s. It wraps around at queue size.
ring []UsedElement
// availableEvent is not used by this implementation, but we reserve it
// anyway to avoid issues in case a device may try to write to it, contrary
// to the virtio specification.
availableEvent *uint16
// lastIndex is the internal ringIndex up to which all [UsedElement]s were
// processed.
lastIndex uint16
//mu sync.Mutex
}
// newUsedRing creates a used ring that uses the given underlying memory. The
// length of the memory slice must match the size needed for the ring (see
// [usedRingSize]) for the given queue size.
func newUsedRing(queueSize int, mem []byte) *UsedRing {
ringSize := usedRingSize(queueSize)
if len(mem) != ringSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for used ring: %v", len(mem), ringSize))
}
r := UsedRing{
initialized: true,
flags: (*usedRingFlag)(unsafe.Pointer(&mem[0])),
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
ring: unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize),
availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
}
r.lastIndex = *r.ringIndex
return &r
}
// Address returns the pointer to the beginning of the ring in memory.
// Do not modify the memory directly to not interfere with this implementation.
func (r *UsedRing) Address() uintptr {
if !r.initialized {
panic("used ring is not initialized")
}
return uintptr(unsafe.Pointer(r.flags))
}
// take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method.
// had a lock, I removed it
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0, nil
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
}
stillNeedToTake := 0
if maxToTake > 0 {
stillNeedToTake = count - maxToTake
if stillNeedToTake < 0 {
stillNeedToTake = 0
}
count = min(count, maxToTake)
}
// The number of new elements can never exceed the queue size.
if count > len(r.ring) {
panic("used ring contains more new elements than the ring is long")
}
elems := make([]UsedElement, count)
for i := range count {
elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))]
r.lastIndex++
}
return stillNeedToTake, elems
}
func (r *UsedRing) takeOne() (uint16, bool) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0xffff, false
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
}
// The number of new elements can never exceed the queue size.
if count > len(r.ring) {
panic("used ring contains more new elements than the ring is long")
}
if count == 0 {
return 0xffff, false
}
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
r.lastIndex++
return out, true
}
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
func (r *UsedRing) InitOfferSingle(x uint16, size int) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0
// Add descriptor chain heads to the ring.
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = UsedElement{
DescriptorIndex: uint32(x),
Length: uint32(size),
}
// Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1
}

View File

@@ -1,136 +0,0 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestUsedRing_MemoryLayout(t *testing.T) {
const queueSize = 2
memory := make([]byte, usedRingSize(queueSize))
r := newUsedRing(queueSize, memory)
*r.flags = 0x01ff
*r.ringIndex = 1
r.ring[0] = UsedElement{
DescriptorIndex: 0x0123,
Length: 0x4567,
}
r.ring[1] = UsedElement{
DescriptorIndex: 0x89ab,
Length: 0xcdef,
}
assert.Equal(t, []byte{
0xff, 0x01,
0x01, 0x00,
0x23, 0x01, 0x00, 0x00,
0x67, 0x45, 0x00, 0x00,
0xab, 0x89, 0x00, 0x00,
0xef, 0xcd, 0x00, 0x00,
0x00, 0x00,
}, memory)
}
//func TestUsedRing_Take(t *testing.T) {
// const queueSize = 8
//
// tests := []struct {
// name string
// ring []UsedElement
// ringIndex uint16
// lastIndex uint16
// expected []UsedElement
// }{
// {
// name: "nothing new",
// ring: []UsedElement{
// {DescriptorIndex: 1},
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {},
// {},
// {},
// {},
// },
// ringIndex: 4,
// lastIndex: 4,
// expected: nil,
// },
// {
// name: "no overflow",
// ring: []UsedElement{
// {DescriptorIndex: 1},
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {},
// {},
// {},
// {},
// },
// ringIndex: 4,
// lastIndex: 1,
// expected: []UsedElement{
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// },
// },
// {
// name: "ring overflow",
// ring: []UsedElement{
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {DescriptorIndex: 5},
// {DescriptorIndex: 6},
// {DescriptorIndex: 7},
// {DescriptorIndex: 8},
// },
// ringIndex: 10,
// lastIndex: 7,
// expected: []UsedElement{
// {DescriptorIndex: 8},
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// },
// },
// {
// name: "index overflow",
// ring: []UsedElement{
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {DescriptorIndex: 5},
// {DescriptorIndex: 6},
// {DescriptorIndex: 7},
// {DescriptorIndex: 8},
// },
// ringIndex: 2,
// lastIndex: 65535,
// expected: []UsedElement{
// {DescriptorIndex: 8},
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// },
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// memory := make([]byte, usedRingSize(queueSize))
// r := newUsedRing(queueSize, memory)
//
// copy(r.ring, tt.ring)
// *r.ringIndex = tt.ringIndex
// r.lastIndex = tt.lastIndex
//
// assert.Equal(t, tt.expected, r.take())
// })
// }
//}

View File

@@ -1,70 +0,0 @@
package packet
import (
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
type OutPacket struct {
Segments [][]byte
SegmentPayloads [][]byte
SegmentHeaders [][]byte
SegmentIDs []uint16
//todo virtio header?
SegSize int
SegCounter int
Valid bool
wasSegmented bool
Scratch []byte
}
func NewOut() *OutPacket {
out := new(OutPacket)
out.Segments = make([][]byte, 0, 64)
out.SegmentHeaders = make([][]byte, 0, 64)
out.SegmentPayloads = make([][]byte, 0, 64)
out.SegmentIDs = make([]uint16, 0, 64)
out.Scratch = make([]byte, Size)
return out
}
func (pkt *OutPacket) Reset() {
pkt.Segments = pkt.Segments[:0]
pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
pkt.SegmentIDs = pkt.SegmentIDs[:0]
pkt.SegSize = 0
pkt.Valid = false
pkt.wasSegmented = false
}
func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {
pkt.Valid = true
pkt.SegmentIDs = append(pkt.SegmentIDs, segID)
pkt.Segments = append(pkt.Segments, seg) //todo do we need this?
vhdr := virtio.NetHdr{ //todo
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
HdrLen: 0,
GSOSize: 0,
CsumStart: 0,
CsumOffset: 0,
NumBuffers: 0,
}
hdr := seg[0 : virtio.NetHdrSize+14]
_ = vhdr.Encode(hdr)
if isV6 {
hdr[virtio.NetHdrSize+14-2] = 0x86
hdr[virtio.NetHdrSize+14-1] = 0xdd
} else {
hdr[virtio.NetHdrSize+14-2] = 0x08
hdr[virtio.NetHdrSize+14-1] = 0x00
}
pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr)
pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:])
return len(pkt.SegmentIDs) - 1
}

View File

@@ -1,119 +0,0 @@
package packet
import (
"encoding/binary"
"iter"
"net/netip"
"slices"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
const Size = 0xffff
type Packet struct {
Payload []byte
Control []byte
Name []byte
SegSize int
//todo should this hold out as well?
OutLen int
wasSegmented bool
isV4 bool
}
func New(isV4 bool) *Packet {
return &Packet{
Payload: make([]byte, Size),
Control: make([]byte, unix.CmsgSpace(2)),
Name: make([]byte, unix.SizeofSockaddrInet6),
isV4: isV4,
}
}
func (p *Packet) AddrPort() netip.AddrPort {
var ip netip.Addr
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if p.isV4 {
ip, _ = netip.AddrFromSlice(p.Name[4:8])
} else {
ip, _ = netip.AddrFromSlice(p.Name[8:24])
}
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
}
func (p *Packet) updateCtrl(ctrlLen int) {
p.SegSize = len(p.Payload)
p.wasSegmented = false
if ctrlLen == 0 {
return
}
if len(p.Control) == 0 {
return
}
cmsgs, err := unix.ParseSocketControlMessage(p.Control)
if err != nil {
return // oh well
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
p.wasSegmented = true
p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2]))
return
}
}
}
// Update sets a Packet into "just received, not processed" state
func (p *Packet) Update(ctrlLen int) {
p.OutLen = -1
p.updateCtrl(ctrlLen)
}
func (p *Packet) SetSegSizeForTX() {
p.SegSize = len(p.Payload)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(syscall.CmsgLen(2))
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
}
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
//same dest
if !slices.Equal(p.Name, otherP.Name) {
return false
}
//don't get too big
if len(p.Payload)+currentTotalSize >= 0xffff {
return false
}
//same body len
//todo allow single different size at end
if len(p.Payload) != len(otherP.Payload) {
return false //todo technically you can cram one extra in
}
return true
}
func (p *Packet) Segments() iter.Seq[[]byte] {
return func(yield func([]byte) bool) {
//cursor := 0
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
end := offset + p.SegSize
if end > len(p.Payload) {
end = len(p.Payload)
}
if !yield(p.Payload[offset:end]) {
return
}
}
}
}

View File

@@ -1,37 +0,0 @@
package packet
import (
"github.com/slackhq/nebula/util/virtio"
)
type VirtIOPacket struct {
Payload []byte
Header virtio.NetHdr
Chains []uint16
ChainRefs [][]byte
// OfferDescriptorChains(chains []uint16, kick bool) error
}
func NewVIO() *VirtIOPacket {
out := new(VirtIOPacket)
out.Payload = nil
out.ChainRefs = make([][]byte, 0, 4)
out.Chains = make([]uint16, 0, 8)
return out
}
func (v *VirtIOPacket) Reset() {
v.Payload = nil
v.ChainRefs = v.ChainRefs[:0]
v.Chains = v.Chains[:0]
}
type VirtIOTXPacket struct {
VirtIOPacket
}
func NewVIOTX(isV4 bool) *VirtIOTXPacket {
out := new(VirtIOTXPacket)
out.VirtIOPacket = *NewVIO()
return out
}

View File

@@ -4,13 +4,13 @@ import (
"net/netip"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
)
const MTU = 9001
type EncReader func(
[]*packet.Packet,
addr netip.AddrPort,
payload []byte,
)
type Conn interface {
@@ -19,8 +19,6 @@ type Conn interface {
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Prep(pkt *packet.Packet, addr netip.AddrPort) error
WriteBatch(pkt []*packet.Packet) (int, error)
Close() error
}

View File

@@ -14,22 +14,22 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix"
)
const iovMax = 128 //1024 //no unix constant for this? from limits.h
//todo I'd like this to be 1024 but we seem to hit errors around ~130?
type StdConn struct {
sysFd int
isV4 bool
l *logrus.Logger
batch int
enableGRO bool
}
msgs []rawMessage
iovs [][]iovec
func maybeIPV4(ip net.IP) (net.IP, bool) {
ip4 := ip.To4()
if ip4 != nil {
return ip4, true
}
return ip, false
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -69,20 +69,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err)
}
const batchSize = 8192
msgs := make([]rawMessage, 0, batchSize) //todo configure
iovs := make([][]iovec, batchSize)
for i := range iovs {
iovs[i] = make([]iovec, iovMax)
}
return &StdConn{
sysFd: fd,
isV4: ip.Is4(),
l: l,
batch: batch,
msgs: msgs,
iovs: iovs,
}, err
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
}
func (u *StdConn) Rebind() error {
@@ -132,7 +119,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
}
func (u *StdConn) ListenOut(r EncReader) {
msgs, packets := u.PrepareRawMessages(u.batch, u.isV4)
var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti
if u.batch == 1 {
read = u.ReadSingle
@@ -146,12 +135,13 @@ func (u *StdConn) ListenOut(r EncReader) {
}
for i := 0; i < n; i++ {
packets[i].Payload = packets[i].Payload[:msgs[i].Len]
packets[i].Update(getRawMessageControlLen(&msgs[i]))
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
r(packets[:n])
for i := 0; i < n; i++ { //todo reset this in prev loop, but this makes debug ez
msgs[i].Hdr.Controllen = uint64(unix.CmsgSpace(2))
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
}
}
}
@@ -204,147 +194,6 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip)
}
func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
if u.isV4 {
return u.writeTo4(b, ip)
}
return u.writeTo6(b, ip)
}
func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
nl, err := u.encodeSockaddr(pkt.Name, addr)
if err != nil {
return err
}
pkt.Name = pkt.Name[:nl]
pkt.OutLen = len(pkt.Payload)
return nil
}
func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
if len(pkts) == 0 {
return 0, nil
}
u.msgs = u.msgs[:0]
//u.iovs = u.iovs[:0]
sent := 0
var mostRecentPkt *packet.Packet
mostRecentPktSize := 0
//segmenting := false
idx := 0
for _, pkt := range pkts {
if len(pkt.Payload) == 0 || pkt.OutLen == -1 {
sent++
continue
}
lastIdx := idx - 1
if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt, mostRecentPktSize) && u.msgs[lastIdx].Hdr.Iovlen < iovMax {
u.msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control))
u.msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0]
u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Base = &pkt.Payload[0]
u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Len = uint64(len(pkt.Payload))
u.msgs[lastIdx].Hdr.Iovlen++
mostRecentPktSize += len(pkt.Payload)
mostRecentPkt.SetSegSizeForTX()
} else {
u.msgs = append(u.msgs, rawMessage{})
u.iovs[idx][0] = iovec{
Base: &pkt.Payload[0],
Len: uint64(len(pkt.Payload)),
}
msg := &u.msgs[idx]
iov := &u.iovs[idx][0]
idx++
msg.Hdr.Iov = iov
msg.Hdr.Iovlen = 1
setRawMessageControl(msg, nil)
msg.Hdr.Flags = 0
msg.Hdr.Name = &pkt.Name[0]
msg.Hdr.Namelen = uint32(len(pkt.Name))
mostRecentPkt = pkt
mostRecentPktSize = len(pkt.Payload)
}
}
if len(u.msgs) == 0 {
return sent, nil
}
offset := 0
for offset < len(u.msgs) {
n, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&u.msgs[offset])),
uintptr(len(u.msgs)-offset),
0,
0,
0,
)
if errno != 0 {
if errno == unix.EINTR {
continue
}
//for i := 0; i < len(u.msgs); i++ {
// for j := 0; j < int(u.msgs[i].Hdr.Iovlen); j++ {
// u.l.WithFields(logrus.Fields{
// "msg_index": i,
// "iov idx": j,
// "iov": fmt.Sprintf("%+v", u.iovs[i][j]),
// }).Warn("failed to send message")
// }
//
//}
u.l.WithFields(logrus.Fields{
"errno": errno,
"idx": idx,
"len": len(u.msgs),
"deets": fmt.Sprintf("%+v", u.msgs),
"lastIOV": fmt.Sprintf("%+v", u.iovs[len(u.msgs)-1][u.msgs[len(u.msgs)-1].Hdr.Iovlen-1]),
}).Error("failed to send message")
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
}
if n == 0 {
break
}
offset += int(n)
}
return sent + len(u.msgs), nil
}
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
if u.isV4 {
if !addr.Addr().Is4() {
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
@@ -445,27 +294,6 @@ func (u *StdConn) ReloadConfig(c *config.C) {
u.l.WithError(err).Error("Failed to set listen.so_mark")
}
}
u.configureGRO(true)
}
func (u *StdConn) configureGRO(enable bool) {
if enable == u.enableGRO {
return
}
if enable {
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
u.l.WithError(err).Warn("Failed to enable UDP GRO")
return
}
u.enableGRO = true
u.l.Info("UDP GRO enabled")
} else {
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
u.l.WithError(err).Warn("Failed to disable UDP GRO")
}
u.enableGRO = false
}
}
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {

View File

@@ -7,7 +7,6 @@
package udp
import (
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix"
)
@@ -34,59 +33,25 @@ type rawMessage struct {
Pad0 [4]byte
}
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
packets := make([]*packet.Packet, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs {
packets[i] = packet.New(isV4)
buffers[i] = make([]byte, MTU)
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &packets[i].Payload[0], Len: uint64(packet.Size)},
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint64(len(vs))
msgs[i].Hdr.Name = &packets[i].Name[0]
msgs[i].Hdr.Namelen = uint32(len(packets[i].Name))
if u.enableGRO {
msgs[i].Hdr.Control = &packets[i].Control[0]
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = 0
}
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
}
return msgs, packets
}
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint64(len(b))
return msgs, buffers, names
}

View File

@@ -1,3 +0,0 @@
// Package virtio contains some generic types and concepts related to the virtio
// protocol.
package virtio

View File

@@ -1,136 +0,0 @@
package virtio
// Feature contains feature bits that describe a virtio device or driver.
type Feature uint64
// Device-independent feature bits.
//
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-6600006
const (
// FeatureIndirectDescriptors indicates that the driver can use descriptors
// with an additional layer of indirection.
FeatureIndirectDescriptors Feature = 1 << 28
// FeatureVersion1 indicates compliance with version 1.0 of the virtio
// specification.
FeatureVersion1 Feature = 1 << 32
)
// Feature bits for networking devices.
//
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-2200003
const (
// FeatureNetDeviceCsum indicates that the device can handle packets with
// partial checksum (checksum offload).
FeatureNetDeviceCsum Feature = 1 << 0
// FeatureNetDriverCsum indicates that the driver can handle packets with
// partial checksum.
FeatureNetDriverCsum Feature = 1 << 1
// FeatureNetCtrlDriverOffloads indicates support for dynamic offload state
// reconfiguration.
FeatureNetCtrlDriverOffloads Feature = 1 << 2
// FeatureNetMTU indicates that the device reports a maximum MTU value.
FeatureNetMTU Feature = 1 << 3
// FeatureNetMAC indicates that the device provides a MAC address.
FeatureNetMAC Feature = 1 << 5
// FeatureNetDriverTSO4 indicates that the driver supports the TCP
// segmentation offload for received IPv4 packets.
FeatureNetDriverTSO4 Feature = 1 << 7
// FeatureNetDriverTSO6 indicates that the driver supports the TCP
// segmentation offload for received IPv6 packets.
FeatureNetDriverTSO6 Feature = 1 << 8
// FeatureNetDriverECN indicates that the driver supports the TCP
// segmentation offload with ECN for received packets.
FeatureNetDriverECN Feature = 1 << 9
// FeatureNetDriverUFO indicates that the driver supports the UDP
// fragmentation offload for received packets.
FeatureNetDriverUFO Feature = 1 << 10
// FeatureNetDeviceTSO4 indicates that the device supports the TCP
// segmentation offload for received IPv4 packets.
FeatureNetDeviceTSO4 Feature = 1 << 11
// FeatureNetDeviceTSO6 indicates that the device supports the TCP
// segmentation offload for received IPv6 packets.
FeatureNetDeviceTSO6 Feature = 1 << 12
// FeatureNetDeviceECN indicates that the device supports the TCP
// segmentation offload with ECN for received packets.
FeatureNetDeviceECN Feature = 1 << 13
// FeatureNetDeviceUFO indicates that the device supports the UDP
// fragmentation offload for received packets.
FeatureNetDeviceUFO Feature = 1 << 14
// FeatureNetMergeRXBuffers indicates that the driver can handle merged
// receive buffers.
// When this feature is negotiated, devices may merge multiple descriptor
// chains together to transport large received packets. [NetHdr.NumBuffers]
// will then contain the number of merged descriptor chains.
FeatureNetMergeRXBuffers Feature = 1 << 15
// FeatureNetStatus indicates that the device configuration status field is
// available.
FeatureNetStatus Feature = 1 << 16
// FeatureNetCtrlVQ indicates that a control channel virtqueue is
// available.
FeatureNetCtrlVQ Feature = 1 << 17
// FeatureNetCtrlRX indicates support for RX mode control (e.g. promiscuous
// or all-multicast) for packet receive filtering.
FeatureNetCtrlRX Feature = 1 << 18
// FeatureNetCtrlVLAN indicates support for VLAN filtering through the
// control channel.
FeatureNetCtrlVLAN Feature = 1 << 19
// FeatureNetDriverAnnounce indicates that the driver can send gratuitous
// packets.
FeatureNetDriverAnnounce Feature = 1 << 21
// FeatureNetMQ indicates that the device supports multiqueue with automatic
// receive steering.
FeatureNetMQ Feature = 1 << 22
// FeatureNetCtrlMACAddr indicates that the MAC address can be set through
// the control channel.
FeatureNetCtrlMACAddr Feature = 1 << 23
// FeatureNetDeviceUSO indicates that the device supports the UDP
// segmentation offload for received packets.
FeatureNetDeviceUSO Feature = 1 << 56
// FeatureNetHashReport indicates that the device can report a per-packet
// hash value and type.
FeatureNetHashReport Feature = 1 << 57
// FeatureNetDriverHdrLen indicates that the driver can provide the exact
// header length value (see [NetHdr.HdrLen]).
// Devices may benefit from knowing the exact header length.
FeatureNetDriverHdrLen Feature = 1 << 59
// FeatureNetRSS indicates that the device supports RSS (receive-side
// scaling) with configurable hash parameters.
FeatureNetRSS Feature = 1 << 60
// FeatureNetRSCExt indicates that the device can process duplicated ACKs
// and report the number of coalesced segments and duplicated ACKs.
FeatureNetRSCExt Feature = 1 << 61
// FeatureNetStandby indicates that the device may act as a standby for a
// primary device with the same MAC address.
FeatureNetStandby Feature = 1 << 62
// FeatureNetSpeedDuplex indicates that the device can report link speed and
// duplex mode.
FeatureNetSpeedDuplex Feature = 1 << 63
)

View File

@@ -1,77 +0,0 @@
package virtio
import (
"errors"
"unsafe"
"golang.org/x/sys/unix"
)
// Workaround to make Go doc links work.
var _ unix.Errno
// NetHdrSize is the number of bytes needed to store a [NetHdr] in memory.
const NetHdrSize = 12
// ErrNetHdrBufferTooSmall is returned when a buffer is too small to fit a
// virtio_net_hdr.
var ErrNetHdrBufferTooSmall = errors.New("the buffer is too small to fit a virtio_net_hdr")
// NetHdr defines the virtio_net_hdr as described by the virtio specification.
type NetHdr struct {
// Flags that describe the packet.
// Possible values are:
// - [unix.VIRTIO_NET_HDR_F_NEEDS_CSUM]
// - [unix.VIRTIO_NET_HDR_F_DATA_VALID]
// - [unix.VIRTIO_NET_HDR_F_RSC_INFO]
Flags uint8
// GSOType contains the type of segmentation offload that should be used for
// the packet.
// Possible values are:
// - [unix.VIRTIO_NET_HDR_GSO_NONE]
// - [unix.VIRTIO_NET_HDR_GSO_TCPV4]
// - [unix.VIRTIO_NET_HDR_GSO_UDP]
// - [unix.VIRTIO_NET_HDR_GSO_TCPV6]
// - [unix.VIRTIO_NET_HDR_GSO_UDP_L4]
// - [unix.VIRTIO_NET_HDR_GSO_ECN]
GSOType uint8
// HdrLen contains the length of the headers that need to be replicated by
// segmentation offloads. It's the number of bytes from the beginning of the
// packet to the beginning of the transport payload.
// Only used when [FeatureNetDriverHdrLen] is negotiated.
HdrLen uint16
// GSOSize contains the maximum size of each segmented packet beyond the
// header (payload size). In case of TCP, this is the MSS.
GSOSize uint16
// CsumStart contains the offset within the packet from which on the
// checksum should be computed.
CsumStart uint16
// CsumOffset specifies how many bytes after [NetHdr.CsumStart] the computed
// 16-bit checksum should be inserted.
CsumOffset uint16
// NumBuffers contains the number of merged descriptor chains when
// [FeatureNetMergeRXBuffers] is negotiated.
// This field is only used for packets received by the driver and should be
// zero for transmitted packets.
NumBuffers uint16
}
// Decode decodes the [NetHdr] from the given byte slice. The slice must contain
// at least [NetHdrSize] bytes.
func (v *NetHdr) Decode(data []byte) error {
if len(data) < NetHdrSize {
return ErrNetHdrBufferTooSmall
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize), data[:NetHdrSize])
return nil
}
// Encode encodes the [NetHdr] into the given byte slice. The slice must have
// room for at least [NetHdrSize] bytes.
func (v *NetHdr) Encode(data []byte) error {
if len(data) < NetHdrSize {
return ErrNetHdrBufferTooSmall
}
copy(data[:NetHdrSize], unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize))
return nil
}

View File

@@ -1,43 +0,0 @@
package virtio
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func TestNetHdr_Size(t *testing.T) {
assert.EqualValues(t, NetHdrSize, unsafe.Sizeof(NetHdr{}))
}
func TestNetHdr_Encoding(t *testing.T) {
vnethdr := NetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: 42,
GSOSize: 1472,
CsumStart: 34,
CsumOffset: 6,
NumBuffers: 16,
}
buf := make([]byte, NetHdrSize)
require.NoError(t, vnethdr.Encode(buf))
assert.Equal(t, []byte{
0x01, 0x05,
0x2a, 0x00,
0xc0, 0x05,
0x22, 0x00,
0x06, 0x00,
0x10, 0x00,
}, buf)
var decoded NetHdr
require.NoError(t, decoded.Decode(buf))
assert.Equal(t, vnethdr, decoded)
}