mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
36
service/listener.go
Normal file
36
service/listener.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type tcpListener struct {
|
||||
port uint16
|
||||
s *Service
|
||||
addr *net.TCPAddr
|
||||
accept chan net.Conn
|
||||
}
|
||||
|
||||
func (l *tcpListener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-l.accept
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *tcpListener) Close() error {
|
||||
l.s.mu.Lock()
|
||||
defer l.s.mu.Unlock()
|
||||
delete(l.s.mu.listeners, uint16(l.addr.Port))
|
||||
|
||||
close(l.accept)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *tcpListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
248
service/service.go
Normal file
248
service/service.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gvisor.dev/gvisor/pkg/bufferv2"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const nicID = 1
|
||||
|
||||
type Service struct {
|
||||
eg *errgroup.Group
|
||||
control *nebula.Control
|
||||
ipstack *stack.Stack
|
||||
|
||||
mu struct {
|
||||
sync.Mutex
|
||||
|
||||
listeners map[uint16]*tcpListener
|
||||
}
|
||||
}
|
||||
|
||||
func New(config *config.C) (*Service, error) {
|
||||
logger := logrus.New()
|
||||
logger.Out = os.Stdout
|
||||
|
||||
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
control.Start()
|
||||
|
||||
ctx := control.Context()
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
s := Service{
|
||||
eg: eg,
|
||||
control: control,
|
||||
}
|
||||
s.mu.listeners = map[uint16]*tcpListener{}
|
||||
|
||||
device, ok := control.Device().(*overlay.UserDevice)
|
||||
if !ok {
|
||||
return nil, errors.New("must be using user device")
|
||||
}
|
||||
|
||||
s.ipstack = stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
|
||||
})
|
||||
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
|
||||
tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
|
||||
if tcpipErr != nil {
|
||||
return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
|
||||
}
|
||||
linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
|
||||
if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
|
||||
return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
|
||||
}
|
||||
ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
|
||||
s.ipstack.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: ipv4Subnet,
|
||||
NIC: nicID,
|
||||
},
|
||||
})
|
||||
|
||||
ipNet := device.Cidr()
|
||||
pa := tcpip.ProtocolAddress{
|
||||
AddressWithPrefix: tcpip.Address(ipNet.IP).WithPrefix(),
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
}
|
||||
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
|
||||
PEB: stack.CanBePrimaryEndpoint, // zero value default
|
||||
ConfigType: stack.AddressConfigStatic, // zero value default
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("error creating IP: %s", err)
|
||||
}
|
||||
|
||||
const tcpReceiveBufferSize = 0
|
||||
const maxInFlightConnectionAttempts = 1024
|
||||
tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
|
||||
s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
||||
|
||||
reader, writer := device.Pipe()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
reader.Close()
|
||||
writer.Close()
|
||||
}()
|
||||
|
||||
// create Goroutines to forward packets between Nebula and Gvisor
|
||||
eg.Go(func() error {
|
||||
buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
|
||||
for {
|
||||
// this will read exactly one packet
|
||||
n, err := reader.Read(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: bufferv2.MakeWithData(bytes.Clone(buf[:n])),
|
||||
})
|
||||
linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
})
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
packet := linkEP.ReadContext(ctx)
|
||||
if packet.IsNil() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
bufView := packet.ToView()
|
||||
if _, err := bufView.WriteTo(writer); err != nil {
|
||||
return err
|
||||
}
|
||||
bufView.Release()
|
||||
}
|
||||
})
|
||||
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
// DialContext dials the provided address. Currently only TCP is supported.
|
||||
func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if network != "tcp" && network != "tcp4" {
|
||||
return nil, errors.New("only tcp is supported")
|
||||
}
|
||||
|
||||
addr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fullAddr := tcpip.FullAddress{
|
||||
NIC: nicID,
|
||||
Addr: tcpip.Address(addr.IP),
|
||||
Port: uint16(addr.Port),
|
||||
}
|
||||
|
||||
return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
|
||||
}
|
||||
|
||||
// Listen listens on the provided address. Currently only TCP with wildcard
|
||||
// addresses are supported.
|
||||
func (s *Service) Listen(network, address string) (net.Listener, error) {
|
||||
if network != "tcp" && network != "tcp4" {
|
||||
return nil, errors.New("only tcp is supported")
|
||||
}
|
||||
addr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
|
||||
return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
|
||||
}
|
||||
if addr.Port == 0 {
|
||||
return nil, errors.New("specific port required, got 0")
|
||||
}
|
||||
if addr.Port < 0 || addr.Port >= math.MaxUint16 {
|
||||
return nil, fmt.Errorf("invalid port %d", addr.Port)
|
||||
}
|
||||
port := uint16(addr.Port)
|
||||
|
||||
l := &tcpListener{
|
||||
port: port,
|
||||
s: s,
|
||||
addr: addr,
|
||||
accept: make(chan net.Conn),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.mu.listeners[port]; ok {
|
||||
return nil, fmt.Errorf("already listening on port %d", port)
|
||||
}
|
||||
s.mu.listeners[port] = l
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (s *Service) Wait() error {
|
||||
return s.eg.Wait()
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
s.control.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
|
||||
endpointID := r.ID()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
l, ok := s.mu.listeners[endpointID.LocalPort]
|
||||
if !ok {
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
|
||||
var wq waiter.Queue
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
log.Printf("got error creating endpoint %q", err)
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
ep.SocketOptions().SetKeepAlive(true)
|
||||
|
||||
conn := gonet.NewTCPConn(&wq, ep)
|
||||
l.accept <- conn
|
||||
}
|
||||
165
service/service_test.go
Normal file
165
service/service_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/e2e"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
|
||||
func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
|
||||
|
||||
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
|
||||
copy(vpnIpNet.IP, udpIp)
|
||||
|
||||
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
|
||||
caB, err := caCrt.MarshalToPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": string(caB),
|
||||
"cert": string(myPEM),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": m{
|
||||
"outbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
},
|
||||
"timers": m{
|
||||
"pending_deletion_interval": 2,
|
||||
"connection_alive_interval": 2,
|
||||
},
|
||||
"handshakes": m{
|
||||
"try_interval": "200ms",
|
||||
},
|
||||
}
|
||||
|
||||
if overrides != nil {
|
||||
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mc = overrides
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var c config.C
|
||||
if err := c.LoadString(string(cb)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s, err := New(&c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
|
||||
"static_host_map": m{},
|
||||
"lighthouse": m{
|
||||
"am_lighthouse": true,
|
||||
},
|
||||
"listen": m{
|
||||
"host": "0.0.0.0",
|
||||
"port": 4243,
|
||||
},
|
||||
})
|
||||
b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
|
||||
"static_host_map": m{
|
||||
"10.0.0.1": []string{"localhost:4243"},
|
||||
},
|
||||
"lighthouse": m{
|
||||
"hosts": []string{"10.0.0.1"},
|
||||
"interval": 1,
|
||||
},
|
||||
})
|
||||
|
||||
ln, err := a.Listen("tcp", ":1234")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() error {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
t.Log("accepted connection")
|
||||
|
||||
if _, err := conn.Write([]byte("server msg")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Log("server: wrote message")
|
||||
|
||||
data := make([]byte, 100)
|
||||
n, err := conn.Read(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = data[:n]
|
||||
if !bytes.Equal(data, []byte("client msg")) {
|
||||
return errors.New("got invalid message from client")
|
||||
}
|
||||
t.Log("server: read message")
|
||||
return conn.Close()
|
||||
})
|
||||
|
||||
c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := c.Write([]byte("client msg")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data := make([]byte, 100)
|
||||
n, err := c.Read(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data = data[:n]
|
||||
if !bytes.Equal(data, []byte("server msg")) {
|
||||
t.Fatal("got invalid message from client")
|
||||
}
|
||||
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user