Add a way to set the network type on windows + tests (#1710)
Some checks failed
gofmt / Run gofmt (push) Failing after 2s
smoke-extra / freebsd-amd64 (push) Failing after 2s
smoke-extra / linux-amd64-ipv6disable (push) Failing after 3s
smoke-extra / netbsd-amd64 (push) Failing after 3s
smoke-extra / openbsd-amd64 (push) Failing after 3s
smoke-extra / linux-386 (push) Failing after 3s
smoke / Run multi node smoke test (push) Failing after 2s
Build and test / Build all and test on ubuntu-linux (push) Failing after 3s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
smoke-extra / Run windows smoke test (push) Has been cancelled
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled

This commit is contained in:
Nate Brown
2026-05-07 20:17:38 -05:00
committed by GitHub
parent c82db210ef
commit 696903d6d9
15 changed files with 1349 additions and 20 deletions

View File

@@ -0,0 +1,358 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"errors"
"fmt"
"log/slog"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h.
type networkCategory int32
const (
networkCategoryPublic networkCategory = 0
networkCategoryPrivate networkCategory = 1
networkCategoryDomainAuthenticated networkCategory = 2
)
func (c networkCategory) String() string {
switch c {
case networkCategoryPublic:
return "public"
case networkCategoryPrivate:
return "private"
case networkCategoryDomainAuthenticated:
return "domain"
}
return fmt.Sprintf("unknown(%d)", c)
}
// parseNetworkCategory accepts the user-supplied tun.network_category. A
// second return of false means "leave the category alone".
func parseNetworkCategory(s string) (networkCategory, bool, error) {
switch strings.ToLower(strings.TrimSpace(s)) {
case "", "unset":
return 0, false, nil
case "public":
return networkCategoryPublic, true, nil
case "private":
return networkCategoryPrivate, true, nil
case "domain", "domainauthenticated":
return networkCategoryDomainAuthenticated, true, nil
}
return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s)
}
// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B}
var clsidNetworkListManager = windows.GUID{
Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B,
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
}
// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B}
var iidINetworkListManager = windows.GUID{
Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B,
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
}
// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves.
var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance")
const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER |
windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER
const (
hrSFALSE = 0x00000001
hrRPCEChangedMode = 0x80010106
)
type hresult uint32
func (h hresult) failed() bool { return int32(h) < 0 }
func (h hresult) String() string {
return fmt.Sprintf("HRESULT 0x%08x", uint32(h))
}
var errAdapterNotFound = errors.New("adapter not present in network connections enumeration")
// Vtable layouts. Slot order must match the declaration order in netlistmgr.h.
// All NLM interfaces here derive from IDispatch, which derives from IUnknown.
type iUnknownVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
}
type iDispatchVtbl struct {
iUnknownVtbl
GetTypeInfoCount uintptr
GetTypeInfo uintptr
GetIDsOfNames uintptr
Invoke uintptr
}
type iNetworkListManagerVtbl struct {
iDispatchVtbl
GetNetworks uintptr
GetNetwork uintptr
GetNetworkConnections uintptr
GetNetworkConnection uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
}
type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl }
func (n *iNetworkListManager) Release() {
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
}
func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) {
var enum *iEnumNetworkConnections
r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections,
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr)
}
return enum, nil
}
type iEnumNetworkConnectionsVtbl struct {
iDispatchVtbl
NewEnum uintptr
Next uintptr
Skip uintptr
Reset uintptr
Clone uintptr
}
type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl }
func (e *iEnumNetworkConnections) Release() {
syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e)))
}
// Next returns the next connection, or (nil, nil) at the end of the enumeration.
func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) {
var conn *iNetworkConnection
var fetched uint32
r1, _, _ := syscall.SyscallN(e.Vtbl.Next,
uintptr(unsafe.Pointer(e)), 1,
uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr)
}
if fetched == 0 {
return nil, nil
}
return conn, nil
}
type iNetworkConnectionVtbl struct {
iDispatchVtbl
GetNetwork uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
GetConnectionId uintptr
GetAdapterId uintptr
GetDomainType uintptr
}
type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl }
func (c *iNetworkConnection) Release() {
syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c)))
}
func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) {
var g windows.GUID
r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId,
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)),
)
if hr := hresult(r1); hr.failed() {
return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr)
}
return g, nil
}
func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) {
var net *iNetwork
r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork,
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr)
}
return net, nil
}
type iNetworkVtbl struct {
iDispatchVtbl
GetName uintptr
SetName uintptr
GetDescription uintptr
SetDescription uintptr
GetNetworkId uintptr
GetDomainType uintptr
GetNetworkConnections uintptr
GetTimeCreatedAndConnected uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
GetCategory uintptr
SetCategory uintptr
}
type iNetwork struct{ Vtbl *iNetworkVtbl }
func (n *iNetwork) Release() {
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
}
func (n *iNetwork) GetCategory() (networkCategory, error) {
var c networkCategory
r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory,
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)),
)
if hr := hresult(r1); hr.failed() {
return 0, fmt.Errorf("INetwork.GetCategory: %s", hr)
}
return c, nil
}
func (n *iNetwork) SetCategory(c networkCategory) error {
r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory,
uintptr(unsafe.Pointer(n)), uintptr(int32(c)),
)
if hr := hresult(r1); hr.failed() {
return fmt.Errorf("INetwork.SetCategory: %s", hr)
}
return nil
}
// coInit initializes COM for the current OS thread. The returned function must
// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is
// already initialized in a different mode on this thread, which is still fine
// for our calls but we must not Uninitialize in that case.
func coInit() (func(), error) {
err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED)
if err == nil {
return windows.CoUninitialize, nil
}
if e, ok := err.(syscall.Errno); ok {
switch uint32(e) {
case hrSFALSE:
return windows.CoUninitialize, nil
case hrRPCEChangedMode:
return func() {}, nil
}
}
return nil, fmt.Errorf("CoInitializeEx: %w", err)
}
func createNetworkListManager() (*iNetworkListManager, error) {
var nlm *iNetworkListManager
r1, _, _ := procCoCreateInstance.Call(
uintptr(unsafe.Pointer(&clsidNetworkListManager)),
0,
uintptr(clsCtxAll),
uintptr(unsafe.Pointer(&iidINetworkListManager)),
uintptr(unsafe.Pointer(&nlm)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr)
}
return nlm, nil
}
// setNetworkCategory locates the network connection bound to adapterGUID and
// sets the category of its parent network. Returns errAdapterNotFound if the
// adapter is not yet visible in the NLM enumeration.
func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error {
deinit, err := coInit()
if err != nil {
return err
}
defer deinit()
nlm, err := createNetworkListManager()
if err != nil {
return err
}
defer nlm.Release()
enum, err := nlm.GetNetworkConnections()
if err != nil {
return err
}
defer enum.Release()
for {
conn, err := enum.Next()
if err != nil {
return err
}
if conn == nil {
return errAdapterNotFound
}
guid, err := conn.GetAdapterId()
if err != nil || guid != adapterGUID {
conn.Release()
continue
}
net, err := conn.GetNetwork()
conn.Release()
if err != nil {
return err
}
err = net.SetCategory(cat)
net.Release()
return err
}
}
// applyNetworkCategory polls until the wintun adapter shows up in the NLM
// enumeration, then sets the category. Intended to run in its own goroutine.
func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) {
// COM Init/Uninit must be paired on the same OS thread.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
const (
attempts = 30
interval = 500 * time.Millisecond
)
for i := 0; i < attempts; i++ {
err := setNetworkCategory(adapterGUID, cat)
if err == nil {
l.Info("Set Windows network category", "category", cat.String())
return
}
if !errors.Is(err, errAdapterNotFound) {
l.Warn("Failed to set Windows network category", "error", err, "category", cat.String())
return
}
time.Sleep(interval)
}
l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set",
"category", cat.String(),
"waited", time.Duration(attempts)*interval,
)
}

View File

@@ -0,0 +1,109 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"testing"
)
func Test_parseNetworkCategory(t *testing.T) {
cases := []struct {
in string
wantCat networkCategory
wantApply bool
wantErr bool
}{
{"", 0, false, false},
{"unset", 0, false, false},
{" UNSET ", 0, false, false},
{"private", networkCategoryPrivate, true, false},
{"Private", networkCategoryPrivate, true, false},
{" PRIVATE ", networkCategoryPrivate, true, false},
{"public", networkCategoryPublic, true, false},
{"PUBLIC", networkCategoryPublic, true, false},
{"domain", networkCategoryDomainAuthenticated, true, false},
{"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false},
{"garbage", 0, false, true},
{"privates", 0, false, true},
}
for _, tc := range cases {
cat, apply, err := parseNetworkCategory(tc.in)
if (err != nil) != tc.wantErr {
t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr)
continue
}
if cat != tc.wantCat || apply != tc.wantApply {
t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply)
}
}
}
// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory
// without mutating the host's network state. It validates the CLSID/IID
// constants and every vtable index by enumerating connections, fetching the
// adapter id and parent network, reading the current category, and writing it
// back unchanged.
//
// Requires Windows but does not require admin or the wintun driver. Skips if
// no network connections are available (unlikely outside of an isolated
// container).
func Test_NLM_round_trip(t *testing.T) {
deinit, err := coInit()
if err != nil {
t.Fatalf("coInit: %v", err)
}
defer deinit()
nlm, err := createNetworkListManager()
if err != nil {
t.Fatalf("createNetworkListManager: %v", err)
}
defer nlm.Release()
enum, err := nlm.GetNetworkConnections()
if err != nil {
t.Fatalf("GetNetworkConnections: %v", err)
}
defer enum.Release()
saw := 0
for {
conn, err := enum.Next()
if err != nil {
t.Fatalf("EnumNetworkConnections.Next: %v", err)
}
if conn == nil {
break
}
saw++
if _, err := conn.GetAdapterId(); err != nil {
conn.Release()
t.Fatalf("INetworkConnection.GetAdapterId: %v", err)
}
net, err := conn.GetNetwork()
conn.Release()
if err != nil {
t.Fatalf("INetworkConnection.GetNetwork: %v", err)
}
cat, err := net.GetCategory()
if err != nil {
net.Release()
t.Fatalf("INetwork.GetCategory: %v", err)
}
// Set to the current value so the host's NLM state is unchanged but
// SetCategory's vtable slot is still validated end-to-end.
if err := net.SetCategory(cat); err != nil {
net.Release()
t.Fatalf("INetwork.SetCategory(%v): %v", cat, err)
}
net.Release()
}
if saw == 0 {
t.Skip("no NLM network connections available; skipping round-trip")
}
}

View File

@@ -0,0 +1,23 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
package overlay
import (
"log/slog"
"github.com/slackhq/nebula/wfp"
)
// installInterfaceBypass installs a WFP PERMIT filter scoped to the wintun interface LUID so inbound traffic on the
// nebula adapter bypasses Windows Defender Firewall.
func installInterfaceBypass(l *slog.Logger, luid uint64) closer {
s, err := wfp.PermitInterface(luid)
if err != nil {
l.Warn("Failed to install WFP bypass filters on nebula interface", "error", err)
return nil
}
l.Info("Installed WFP filters bypassing Windows Defender Firewall on nebula interface")
return s
}

View File

@@ -0,0 +1,11 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import "log/slog"
// installInterfaceBypass is a no-op on windows-386 because we don't currently build for it.
func installInterfaceBypass(_ *slog.Logger, _ uint64) closer {
return nil
}

View File

@@ -25,15 +25,24 @@ import (
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
type closer interface {
Close()
}
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *slog.Logger
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
guid windows.GUID
networkCategory networkCategory
setCategory bool
bypassWDF bool
wdfBypass closer
l *slog.Logger
tun *wintun.NativeTun
}
@@ -54,11 +63,20 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private"))
if err != nil {
return nil, err
}
t := &winTun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
guid: *guid,
networkCategory: cat,
setCategory: setCat,
bypassWDF: c.GetBool("tun.windows_bypass_wdf", true),
l: l,
}
err = t.reload(c, true)
@@ -142,6 +160,17 @@ func (t *winTun) Activate() error {
return err
}
if t.setCategory {
// The wintun adapter takes a moment to register with the Network List
// Manager, so we apply the category in the background and retry until
// it shows up.
go applyNetworkCategory(t.l, t.guid, t.networkCategory)
}
if t.bypassWDF {
t.wdfBypass = installInterfaceBypass(t.l, uint64(t.tun.LUID()))
}
return nil
}
@@ -255,6 +284,11 @@ func (t *winTun) Close() error {
_ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6)
if t.wdfBypass != nil {
t.wdfBypass.Close()
t.wdfBypass = nil
}
return t.tun.Close()
}