diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7e4aa418..d6941997 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -238,6 +238,22 @@ func (t *tun) SupportsMultiqueue() bool { return true } +type MultiQueueReader struct { + fd int +} + +func (m *MultiQueueReader) Read(p []byte) (int, error) { + return unix.Read(m.fd, p) +} + +func (m *MultiQueueReader) Close() error { + return unix.Close(m.fd) +} + +func (m *MultiQueueReader) Write(p []byte) (int, error) { + return write(m.fd, p) +} + func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { @@ -248,12 +264,11 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") - - return file, nil + return &MultiQueueReader{fd: fd}, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -261,12 +276,12 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t *tun) Write(b []byte) (int, error) { +func write(fd int, b []byte) (int, error) { var nn int maximum := len(b) for { - n, err := unix.Write(t.fd, b[nn:maximum]) + n, err := unix.Write(fd, b[nn:maximum]) if n > 0 { nn += n } @@ -284,6 +299,14 @@ func (t *tun) Write(b []byte) (int, error) { } } +func (t *tun) Read(p []byte) (int, error) { + return unix.Read(t.fd, p) +} + +func (t *tun) Write(b []byte) (int, error) { + return write(t.fd, b) +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c)