mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
this is awful, but also it's about 20% better
This commit is contained in:
45
interface.go
45
interface.go
@@ -86,7 +86,7 @@ type Interface struct {
|
|||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []overlay.TunDev
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
@@ -177,7 +177,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
routines: c.routines,
|
routines: c.routines,
|
||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]udp.Conn, c.routines),
|
writers: make([]udp.Conn, c.routines),
|
||||||
readers: make([]io.ReadWriteCloser, c.routines),
|
readers: make([]overlay.TunDev, c.routines),
|
||||||
myVpnNetworks: cs.myVpnNetworks,
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
myVpnAddrs: cs.myVpnAddrs,
|
myVpnAddrs: cs.myVpnAddrs,
|
||||||
@@ -225,7 +225,7 @@ func (f *Interface) activate() {
|
|||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
var reader io.ReadWriteCloser = f.inside
|
var reader overlay.TunDev = f.inside
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
@@ -254,25 +254,52 @@ func (f *Interface) run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(q int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if q > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[q]
|
||||||
} else {
|
} else {
|
||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const batch = 64 //todo
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
plaintext := make([]byte, udp.MTU)
|
plaintexts := make([][]byte, batch)
|
||||||
|
outNeedsTun := make([]*int, batch)
|
||||||
|
for i := 0; i < batch; i++ {
|
||||||
|
plaintexts[i] = make([]byte, udp.MTU)
|
||||||
|
outNeedsTun[i] = new(int)
|
||||||
|
*outNeedsTun[i] = -1
|
||||||
|
}
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
toSend := make([][]byte, batch)
|
||||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
|
||||||
|
li.ListenOut(func(fromUdpAddrs []netip.AddrPort, payloads [][]byte) {
|
||||||
|
toSend = toSend[:0]
|
||||||
|
for i := range plaintexts {
|
||||||
|
plaintexts[i] = plaintexts[i][:0]
|
||||||
|
}
|
||||||
|
f.readOutsidePacketsMany(fromUdpAddrs, plaintexts, outNeedsTun, payloads, h, fwPacket, lhh, nb, q, ctCache.Get(f.l))
|
||||||
|
for i := range plaintexts {
|
||||||
|
if *outNeedsTun[i] != -1 {
|
||||||
|
toSend = append(toSend, plaintexts[i][:*outNeedsTun[i]])
|
||||||
|
*outNeedsTun[i] = -1
|
||||||
|
//toSendCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//toSend = toSend[:toSendCount]
|
||||||
|
_, err := f.readers[q].WriteMany(toSend)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).Error("Failed to write messages")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
241
outside.go
241
outside.go
@@ -216,6 +216,207 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Interface) readOutsidePacketsMany(ip []netip.AddrPort, out [][]byte, outNeedsTun []*int, packets [][]byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
|
for i, packet := range packets {
|
||||||
|
|
||||||
|
err := h.Parse(packet)
|
||||||
|
if err != nil {
|
||||||
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
|
if len(packet) > 1 {
|
||||||
|
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
|
if ip[i].IsValid() {
|
||||||
|
if f.myVpnNetworksTable.Contains(ip[i].Addr()) {
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var hostinfo *HostInfo
|
||||||
|
// verify if we've seen this index before, otherwise respond to the handshake initiation
|
||||||
|
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
||||||
|
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
||||||
|
} else {
|
||||||
|
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ci *ConnectionState
|
||||||
|
if hostinfo != nil {
|
||||||
|
ci = hostinfo.ConnectionState
|
||||||
|
}
|
||||||
|
|
||||||
|
switch h.Type {
|
||||||
|
case header.Message:
|
||||||
|
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
|
||||||
|
if !f.handleEncrypted(ci, ip[i], h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch h.Subtype {
|
||||||
|
case header.MessageNone:
|
||||||
|
out[i] = out[i][:0]
|
||||||
|
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i][:0], outNeedsTun[i], packet, fwPacket, nb, q, localCache) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case header.MessageRelay:
|
||||||
|
// The entire body is sent as AD, not encrypted.
|
||||||
|
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
||||||
|
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
||||||
|
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
|
||||||
|
// which will gracefully fail in the DecryptDanger call.
|
||||||
|
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
|
||||||
|
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
|
||||||
|
out[i], err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i], signedPayload, signatureValue, h.MessageCounter, nb)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Successfully validated the thing. Get rid of the Relay header.
|
||||||
|
signedPayload = signedPayload[header.Len:]
|
||||||
|
// Pull the Roaming parts up here, and return in all call paths.
|
||||||
|
f.handleHostRoaming(hostinfo, ip[i])
|
||||||
|
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||||
|
f.connectionManager.In(hostinfo)
|
||||||
|
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||||
|
|
||||||
|
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
||||||
|
if !ok {
|
||||||
|
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||||
|
// its internal mapping. This should never happen.
|
||||||
|
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch relay.Type {
|
||||||
|
case TerminalType:
|
||||||
|
// If I am the target of this relay, process the unwrapped packet
|
||||||
|
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||||
|
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i][:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
|
return
|
||||||
|
case ForwardingType:
|
||||||
|
// Find the target HostInfo relay object
|
||||||
|
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If that relay is Established, forward the payload through it
|
||||||
|
if targetRelay.State == Established {
|
||||||
|
switch targetRelay.Type {
|
||||||
|
case ForwardingType:
|
||||||
|
// Forward this packet through the relay tunnel
|
||||||
|
// Find the target HostInfo
|
||||||
|
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i], false)
|
||||||
|
return
|
||||||
|
case TerminalType:
|
||||||
|
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case header.LightHouse:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
if !f.handleEncrypted(ci, ip[i], h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
|
WithField("packet", packet).
|
||||||
|
Error("Failed to decrypt lighthouse packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lhf.HandleRequest(ip[i], hostinfo.vpnAddrs, d, f)
|
||||||
|
|
||||||
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
|
case header.Test:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
if !f.handleEncrypted(ci, ip[i], h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
|
WithField("packet", packet).
|
||||||
|
Error("Failed to decrypt test packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.Subtype == header.TestRequest {
|
||||||
|
// This testRequest might be from TryPromoteBest, so we should roam
|
||||||
|
// to the new IP address before responding
|
||||||
|
f.handleHostRoaming(hostinfo, ip[i])
|
||||||
|
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
|
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
||||||
|
// are unauthenticated
|
||||||
|
|
||||||
|
case header.Handshake:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
f.handshakeManager.HandleIncoming(ip[i], nil, packet, h)
|
||||||
|
return
|
||||||
|
|
||||||
|
case header.RecvError:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
f.handleRecvError(ip[i], h)
|
||||||
|
return
|
||||||
|
|
||||||
|
case header.CloseTunnel:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
if !f.handleEncrypted(ci, ip[i], h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo.logger(f.l).WithField("udpAddr", ip).
|
||||||
|
Info("Close tunnel received, tearing down.")
|
||||||
|
|
||||||
|
f.closeTunnel(hostinfo)
|
||||||
|
return
|
||||||
|
|
||||||
|
case header.Control:
|
||||||
|
if !f.handleEncrypted(ci, ip[i], h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
|
WithField("packet", packet).
|
||||||
|
Error("Failed to decrypt Control packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
||||||
|
|
||||||
|
default:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.handleHostRoaming(hostinfo, ip[i])
|
||||||
|
|
||||||
|
f.connectionManager.In(hostinfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
||||||
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
||||||
final := f.hostMap.DeleteHostInfo(hostInfo)
|
final := f.hostMap.DeleteHostInfo(hostInfo)
|
||||||
@@ -465,6 +666,46 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out []byte, outNeedsTun *int, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
err = newPacket(out, true, fwPacket)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||||
|
Warnf("Error while validating inbound packet")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||||
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
|
Debugln("dropping out of window packet")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
|
if dropReason != nil {
|
||||||
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
|
// This gives us a buffer to build the reject packet in
|
||||||
|
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
|
WithField("reason", dropReason).
|
||||||
|
Debugln("dropping inbound packet")
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
f.connectionManager.In(hostinfo)
|
||||||
|
*outNeedsTun = len(out)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
io.ReadWriteCloser
|
TunDev
|
||||||
Activate() error
|
Activate() error
|
||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() (TunDev, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
@@ -12,6 +13,11 @@ import (
|
|||||||
|
|
||||||
const DefaultMTU = 1300
|
const DefaultMTU = 1300
|
||||||
|
|
||||||
|
type TunDev interface {
|
||||||
|
io.ReadWriteCloser
|
||||||
|
WriteMany([][]byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,19 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *disabledTun) WriteMany(b [][]byte) (int, error) {
|
||||||
|
out := 0
|
||||||
|
for i := range b {
|
||||||
|
x, err := t.Write(b[i])
|
||||||
|
if err != nil {
|
||||||
|
return out, err
|
||||||
|
}
|
||||||
|
out += x
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -257,7 +257,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (TunDev, error) {
|
||||||
//fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
//fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
//if err != nil {
|
//if err != nil {
|
||||||
// return nil, err
|
// return nil, err
|
||||||
@@ -741,3 +741,24 @@ func (t *tun) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
return maximum, nil
|
return maximum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteMany(b [][]byte) (int, error) {
|
||||||
|
maximum := len(b) //we are RXing
|
||||||
|
|
||||||
|
hdr := virtio.NetHdr{ //todo
|
||||||
|
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
||||||
|
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
||||||
|
HdrLen: 0,
|
||||||
|
GSOSize: 0,
|
||||||
|
CsumStart: 0,
|
||||||
|
CsumOffset: 0,
|
||||||
|
NumBuffers: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := t.vdev.TransmitPackets(hdr, b)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Transmitting packet")
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return maximum, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,3 +65,15 @@ func (d *UserDevice) Close() error {
|
|||||||
d.outboundWriter.Close()
|
d.outboundWriter.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *UserDevice) WriteMany(b [][]byte) (int, error) {
|
||||||
|
out := 0
|
||||||
|
for i := range b {
|
||||||
|
x, err := d.Write(b[i])
|
||||||
|
if err != nil {
|
||||||
|
return out, err
|
||||||
|
}
|
||||||
|
out += x
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -311,6 +311,33 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
|
||||||
|
// Prepend the packet with its virtio-net header.
|
||||||
|
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
|
||||||
|
if err := vnethdr.Encode(vnethdrBuf); err != nil {
|
||||||
|
return fmt.Errorf("encode vnethdr: %w", err)
|
||||||
|
}
|
||||||
|
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
|
||||||
|
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
||||||
|
|
||||||
|
chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//todo blocking here suxxxx
|
||||||
|
// Wait for the packet to have been transmitted.
|
||||||
|
for i := range chainIndexes {
|
||||||
|
<-dev.transmitted[chainIndexes[i]]
|
||||||
|
|
||||||
|
if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
|
||||||
|
return fmt.Errorf("free descriptor chain: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ReceivePacket reads the next available packet from the receive queue of this
|
// ReceivePacket reads the next available packet from the receive queue of this
|
||||||
// device and returns its [virtio.NetHdr] and packet data separately.
|
// device and returns its [virtio.NetHdr] and packet data separately.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -345,6 +345,66 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int
|
|||||||
return head, nil
|
return head, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
|
||||||
|
// TODO change this
|
||||||
|
// Each descriptor can only hold a whole memory page, so split large out
|
||||||
|
// buffers into multiple smaller ones.
|
||||||
|
outBuffers = splitBuffers(outBuffers, sq.pageSize)
|
||||||
|
|
||||||
|
// Synchronize the offering of descriptor chains. While the descriptor table
|
||||||
|
// and available ring are synchronized on their own as well, this does not
|
||||||
|
// protect us from interleaved calls which could cause reordering.
|
||||||
|
// By locking here, we can ensure that all descriptor chains are made
|
||||||
|
// available to the device in the same order as this method was called.
|
||||||
|
sq.offerMutex.Lock()
|
||||||
|
defer sq.offerMutex.Unlock()
|
||||||
|
|
||||||
|
chains := make([]uint16, len(outBuffers))
|
||||||
|
|
||||||
|
// Create a descriptor chain for the given buffers.
|
||||||
|
var (
|
||||||
|
head uint16
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
for i := range outBuffers {
|
||||||
|
for {
|
||||||
|
bufs := [][]byte{prepend, outBuffers[i]}
|
||||||
|
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// I don't wanna use errors.Is, it's slow
|
||||||
|
//goland:noinspection GoDirectComparisonOfErrors
|
||||||
|
if err == ErrNotEnoughFreeDescriptors {
|
||||||
|
if waitFree {
|
||||||
|
// Wait for more free descriptors to be put back into the queue.
|
||||||
|
// If the number of free descriptors is still not sufficient, we'll
|
||||||
|
// land here again.
|
||||||
|
sq.blockForMoreDescriptors()
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("create descriptor chain: %w", err)
|
||||||
|
}
|
||||||
|
chains[i] = head
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the descriptor chain available to the device.
|
||||||
|
sq.availableRing.offer(chains)
|
||||||
|
|
||||||
|
// Notify the device to make it process the updated available ring.
|
||||||
|
if err := sq.kickEventFD.Kick(); err != nil {
|
||||||
|
return chains, fmt.Errorf("notify device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return chains, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
||||||
// device-writable buffers (in buffers) of the descriptor chain with the given
|
// device-writable buffers (in buffers) of the descriptor chain with the given
|
||||||
// head index.
|
// head index.
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
const MTU = 9001
|
const MTU = 9001
|
||||||
|
|
||||||
type EncReader func(
|
type EncReader func(
|
||||||
addr netip.AddrPort,
|
addrs []netip.AddrPort,
|
||||||
payload []byte,
|
payload [][]byte,
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
|
addrPorts := make([]netip.AddrPort, u.batch)
|
||||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
@@ -141,8 +141,11 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
} else {
|
} else {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
}
|
}
|
||||||
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
addrPorts[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
||||||
|
buffers[i] = buffers[i][:msgs[i].Len]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
r(addrPorts, buffers)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user