pax_global_header00006660000000000000000000000064146630600340014514gustar00rootroot0000000000000052 comment=b5a83883d72dfb07411aa6a9e087a48bdbbe0fdd wireguard-go-0.0.20231211/000077500000000000000000000000001466306003400146415ustar00rootroot00000000000000wireguard-go-0.0.20231211/.gitignore000066400000000000000000000000151466306003400166250ustar00rootroot00000000000000wireguard-go wireguard-go-0.0.20231211/LICENSE000066400000000000000000000017771466306003400156620ustar00rootroot00000000000000Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. wireguard-go-0.0.20231211/Makefile000066400000000000000000000014731466306003400163060ustar00rootroot00000000000000PREFIX ?= /usr DESTDIR ?= BINDIR ?= $(PREFIX)/bin export GO111MODULE := on all: generate-version-and-build MAKEFLAGS += --no-print-directory generate-version-and-build: @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ tag="$$(git describe --dirty 2>/dev/null)" && \ ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \ [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ echo "$$ver" > version.go && \ git update-index --assume-unchanged version.go || true @$(MAKE) wireguard-go wireguard-go: $(wildcard *.go) $(wildcard */*.go) go build -v -o "$@" install: wireguard-go @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go" test: go test ./... clean: rm -f wireguard-go .PHONY: all clean test install generate-version-and-build wireguard-go-0.0.20231211/README.md000066400000000000000000000075111466306003400161240ustar00rootroot00000000000000# Go Implementation of [WireGuard](https://www.wireguard.com/) This is an implementation of WireGuard in Go. ## Usage Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run: ``` $ wireguard-go wg0 ``` This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down. To run wireguard-go without forking to the background, pass `-f` or `--foreground`: ``` $ wireguard-go -f wg0 ``` When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. To run with more logging you may set the environment variable `LOG_LEVEL=debug`. ## Platforms ### Linux This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions. ### macOS This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. ### Windows This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module. ### FreeBSD This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`. ### OpenBSD This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. ## Building This requires an installation of the latest version of [Go](https://go.dev/). ``` $ git clone https://git.zx2c4.com/wireguard-go $ cd wireguard-go $ make ``` ## License Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. wireguard-go-0.0.20231211/conn/000077500000000000000000000000001466306003400155765ustar00rootroot00000000000000wireguard-go-0.0.20231211/conn/bind_std.go000066400000000000000000000314331466306003400177170ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "context" "errors" "fmt" "net" "net/netip" "runtime" "strconv" "sync" "syscall" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) var ( _ Bind = (*StdNetBind)(nil) ) // StdNetBind implements Bind for all platforms. While Windows has its own Bind // (see bind_windows.go), it may fall back to StdNetBind. // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { mu sync.Mutex // protects all fields except as specified ipv4 *net.UDPConn ipv6 *net.UDPConn ipv4PC *ipv4.PacketConn // will be nil on non-Linux ipv6PC *ipv6.PacketConn // will be nil on non-Linux ipv4TxOffload bool ipv4RxOffload bool ipv6TxOffload bool ipv6RxOffload bool // these two fields are not guarded by mu udpAddrPool sync.Pool msgsPool sync.Pool blackhole4 bool blackhole6 bool } func NewStdNetBind() Bind { return &StdNetBind{ udpAddrPool: sync.Pool{ New: func() any { return &net.UDPAddr{ IP: make([]byte, 16), } }, }, msgsPool: sync.Pool{ New: func() any { // ipv6.Message and ipv4.Message are interchangeable as they are // both aliases for x/net/internal/socket.Message. msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) } return &msgs }, }, } } type StdNetEndpoint struct { // AddrPort is the endpoint destination. netip.AddrPort // src is the current sticky source address and interface index, if // supported. Typically this is a PKTINFO structure from/for control // messages, see unix.PKTINFO for an example. src []byte } var ( _ Bind = (*StdNetBind)(nil) _ Endpoint = &StdNetEndpoint{} ) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { e, err := netip.ParseAddrPort(s) if err != nil { return nil, err } return &StdNetEndpoint{ AddrPort: e, }, nil } func (e *StdNetEndpoint) ClearSrc() { if e.src != nil { // Truncate src, no need to reallocate. e.src = e.src[:0] } } func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } // See control_default,linux, etc for implementations of SrcIP and SrcIfidx. func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() return b } func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } func listenNet(network string, port int) (*net.UDPConn, int, error) { conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } // Retrieve port. laddr := conn.LocalAddr() uaddr, err := net.ResolveUDPAddr( laddr.Network(), laddr.String(), ) if err != nil { return nil, 0, err } return conn.(*net.UDPConn), uaddr.Port, nil } func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { s.mu.Lock() defer s.mu.Unlock() var err error var tries int if s.ipv4 != nil || s.ipv6 != nil { return nil, 0, ErrBindAlreadyOpen } // Attempt to open ipv4 and ipv6 listeners on the same port. // If uport is 0, we can retry on failure. again: port := int(uport) var v4conn, v6conn *net.UDPConn var v4pc *ipv4.PacketConn var v6pc *ipv6.PacketConn v4conn, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. v6conn, port, err = listenNet("udp6", port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { v4conn.Close() tries++ goto again } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { v4conn.Close() return nil, 0, err } var fns []ReceiveFunc if v4conn != nil { s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) if runtime.GOOS == "linux" || runtime.GOOS == "android" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) s.ipv4 = v4conn } if v6conn != nil { s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) if runtime.GOOS == "linux" || runtime.GOOS == "android" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } return fns, uint16(port), nil } func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { for i := range *msgs { (*msgs)[i].OOB = (*msgs)[i].OOB[:0] (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} } s.msgsPool.Put(msgs) } func (s *StdNetBind) getMessages() *[]ipv6.Message { return s.msgsPool.Get().(*[]ipv6.Message) } var ( // If compilation fails here these are no longer the same underlying type. _ ipv6.Message = ipv4.Message{} ) type batchReader interface { ReadBatch([]ipv6.Message, int) (int, error) } type batchWriter interface { WriteBatch([]ipv6.Message, int) (int, error) } func (s *StdNetBind) receiveIP( br batchReader, conn *net.UDPConn, rxOffload bool, bufs [][]byte, sizes []int, eps []Endpoint, ) (n int, err error) { msgs := s.getMessages() for i := range bufs { (*msgs)[i].Buffers[0] = bufs[i] (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] } defer s.putMessages(msgs) var numMsgs int if runtime.GOOS == "linux" || runtime.GOOS == "android" { if rxOffload { readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) if err != nil { return 0, err } numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) if err != nil { return 0, err } } else { numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } } } else { msg := &(*msgs)[0] msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) if err != nil { return 0, err } numMsgs = 1 } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] sizes[i] = msg.N if sizes[i] == 0 { continue } addrPort := msg.Addr.(*net.UDPAddr).AddrPort() ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation getSrcFromControl(msg.OOB[:msg.NN], ep) eps[i] = ep } return numMsgs, nil } func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // rename the IdealBatchSize constant to BatchSize. func (s *StdNetBind) BatchSize() int { if runtime.GOOS == "linux" || runtime.GOOS == "android" { return IdealBatchSize } return 1 } func (s *StdNetBind) Close() error { s.mu.Lock() defer s.mu.Unlock() var err1, err2 error if s.ipv4 != nil { err1 = s.ipv4.Close() s.ipv4 = nil s.ipv4PC = nil } if s.ipv6 != nil { err2 = s.ipv6.Close() s.ipv6 = nil s.ipv6PC = nil } s.blackhole4 = false s.blackhole6 = false s.ipv4TxOffload = false s.ipv4RxOffload = false s.ipv6TxOffload = false s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } type ErrUDPGSODisabled struct { onLaddr string RetryErr error } func (e ErrUDPGSODisabled) Error() string { return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) } func (e ErrUDPGSODisabled) Unwrap() error { return e.RetryErr } func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 offload := s.ipv4TxOffload br := batchWriter(s.ipv4PC) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 br = s.ipv6PC is6 = true offload = s.ipv6TxOffload } s.mu.Unlock() if blackhole { return nil } if conn == nil { return syscall.EAFNOSUPPORT } msgs := s.getMessages() defer s.putMessages(msgs) ua := s.udpAddrPool.Get().(*net.UDPAddr) defer s.udpAddrPool.Put(ua) if is6 { as16 := endpoint.DstIP().As16() copy(ua.IP, as16[:]) ua.IP = ua.IP[:16] } else { as4 := endpoint.DstIP().As4() copy(ua.IP, as4[:]) ua.IP = ua.IP[:4] } ua.Port = int(endpoint.(*StdNetEndpoint).Port()) var ( retried bool err error ) retry: if offload { n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) err = s.send(conn, br, (*msgs)[:n]) if err != nil && offload && errShouldDisableUDPGSO(err) { offload = false s.mu.Lock() if is6 { s.ipv6TxOffload = false } else { s.ipv4TxOffload = false } s.mu.Unlock() retried = true goto retry } } else { for i := range bufs { (*msgs)[i].Addr = ua (*msgs)[i].Buffers[0] = bufs[i] setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) } err = s.send(conn, br, (*msgs)[:len(bufs)]) } if retried { return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} } return err } func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { var ( n int err error start int ) if runtime.GOOS == "linux" || runtime.GOOS == "android" { for { n, err = pc.WriteBatch(msgs[start:], 0) if err != nil || n == len(msgs[start:]) { break } start += n } } else { for _, msg := range msgs { _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) if err != nil { break } } } return err } const ( // Exceeding these values results in EMSGSIZE. They account for layer3 and // layer4 headers. IPv6 does not need to account for itself as the payload // length field is self excluding. maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 maxIPv6PayloadLen = 1<<16 - 1 - 8 // This is a hard limit imposed by the kernel. udpSegmentMaxDatagrams = 64 ) type setGSOFunc func(control *[]byte, gsoSize uint16) func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( base = -1 // index of msg we are currently coalescing into gsoSize int // segmentation size of msgs[base] dgramCnt int // number of dgrams coalesced into msgs[base] endBatch bool // tracking flag to start a new batch on next iteration of bufs ) maxPayloadLen := maxIPv4PayloadLen if ep.DstIP().Is6() { maxPayloadLen = maxIPv6PayloadLen } for i, buf := range bufs { if i > 0 { msgLen := len(buf) baseLenBefore := len(msgs[base].Buffers[0]) freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore if msgLen+baseLenBefore <= maxPayloadLen && msgLen <= gsoSize && msgLen <= freeBaseCap && dgramCnt < udpSegmentMaxDatagrams && !endBatch { msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) if i == len(bufs)-1 { setGSO(&msgs[base].OOB, uint16(gsoSize)) } dgramCnt++ if msgLen < gsoSize { // A smaller than gsoSize packet on the tail is legal, but // it must end the batch. endBatch = true } continue } } if dgramCnt > 1 { setGSO(&msgs[base].OOB, uint16(gsoSize)) } // Reset prior to incrementing base since we are preparing to start a // new potential batch. endBatch = false base++ gsoSize = len(buf) setSrcControl(&msgs[base].OOB, ep) msgs[base].Buffers[0] = buf msgs[base].Addr = addr dgramCnt = 1 } return base + 1 } type getGSOFunc func(control []byte) (int, error) func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { for i := firstMsgAt; i < len(msgs); i++ { msg := &msgs[i] if msg.N == 0 { return n, err } var ( gsoSize int start int end = msg.N numToSplit = 1 ) gsoSize, err = getGSO(msg.OOB[:msg.NN]) if err != nil { return n, err } if gsoSize > 0 { numToSplit = (msg.N + gsoSize - 1) / gsoSize end = gsoSize } for j := 0; j < numToSplit; j++ { if n > i { return n, errors.New("splitting coalesced packet resulted in overflow") } copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) msgs[n].N = copied msgs[n].Addr = msg.Addr start = end end += gsoSize if end > msg.N { end = msg.N } n++ } if i != n-1 { // It is legal for bytes to move within msg.Buffers[0] as a result // of splitting, so we only zero the source msg len when it is not // the destination of the last split operation above. msg.N = 0 } } return n, nil } wireguard-go-0.0.20231211/conn/bind_std_test.go000066400000000000000000000124601466306003400207550ustar00rootroot00000000000000package conn import ( "encoding/binary" "net" "testing" "golang.org/x/net/ipv6" ) func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { bind := NewStdNetBind().(*StdNetBind) fns, _, err := bind.Open(0) if err != nil { t.Fatal(err) } bind.Close() bufs := make([][]byte, 1) bufs[0] = make([]byte, 1) sizes := make([]int, 1) eps := make([]Endpoint, 1) for _, fn := range fns { // The ReceiveFuncs must not access conn-related fields on StdNetBind // unguarded. Close() nils the conn-related fields resulting in a panic // if they violate the mutex. fn(bufs, sizes, eps) } } func mockSetGSOSize(control *[]byte, gsoSize uint16) { *control = (*control)[:cap(*control)] binary.LittleEndian.PutUint16(*control, gsoSize) } func Test_coalesceMessages(t *testing.T) { cases := []struct { name string buffs [][]byte wantLens []int wantGSO []int }{ { name: "one message no coalesce", buffs: [][]byte{ make([]byte, 1, 1), }, wantLens: []int{1}, wantGSO: []int{0}, }, { name: "two messages equal len coalesce", buffs: [][]byte{ make([]byte, 1, 2), make([]byte, 1, 1), }, wantLens: []int{2}, wantGSO: []int{1}, }, { name: "two messages unequal len coalesce", buffs: [][]byte{ make([]byte, 2, 3), make([]byte, 1, 1), }, wantLens: []int{3}, wantGSO: []int{2}, }, { name: "three messages second unequal len coalesce", buffs: [][]byte{ make([]byte, 2, 3), make([]byte, 1, 1), make([]byte, 2, 2), }, wantLens: []int{3, 2}, wantGSO: []int{2, 0}, }, { name: "three messages limited cap coalesce", buffs: [][]byte{ make([]byte, 2, 4), make([]byte, 2, 2), make([]byte, 2, 2), }, wantLens: []int{4, 2}, wantGSO: []int{2, 0}, }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { addr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1").To4(), Port: 1, } msgs := make([]ipv6.Message, len(tt.buffs)) for i := range msgs { msgs[i].Buffers = make([][]byte, 1) msgs[i].OOB = make([]byte, 0, 2) } got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) if got != len(tt.wantLens) { t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) } for i := 0; i < got; i++ { if msgs[i].Addr != addr { t.Errorf("msgs[%d].Addr != passed addr", i) } gotLen := len(msgs[i].Buffers[0]) if gotLen != tt.wantLens[i] { t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) } gotGSO, err := mockGetGSOSize(msgs[i].OOB) if err != nil { t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) } if gotGSO != tt.wantGSO[i] { t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) } } }) } } func mockGetGSOSize(control []byte) (int, error) { if len(control) < 2 { return 0, nil } return int(binary.LittleEndian.Uint16(control)), nil } func Test_splitCoalescedMessages(t *testing.T) { newMsg := func(n, gso int) ipv6.Message { msg := ipv6.Message{ Buffers: [][]byte{make([]byte, 1<<16-1)}, N: n, OOB: make([]byte, 2), } binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) if gso > 0 { msg.NN = 2 } return msg } cases := []struct { name string msgs []ipv6.Message firstMsgAt int wantNumEval int wantMsgLens []int wantErr bool }{ { name: "second last split last empty", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), newMsg(3, 1), newMsg(0, 0), }, firstMsgAt: 2, wantNumEval: 3, wantMsgLens: []int{1, 1, 1, 0}, wantErr: false, }, { name: "second last no split last empty", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), newMsg(1, 0), newMsg(0, 0), }, firstMsgAt: 2, wantNumEval: 1, wantMsgLens: []int{1, 0, 0, 0}, wantErr: false, }, { name: "second last no split last no split", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), newMsg(1, 0), newMsg(1, 0), }, firstMsgAt: 2, wantNumEval: 2, wantMsgLens: []int{1, 1, 0, 0}, wantErr: false, }, { name: "second last no split last split", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), newMsg(1, 0), newMsg(3, 1), }, firstMsgAt: 2, wantNumEval: 4, wantMsgLens: []int{1, 1, 1, 1}, wantErr: false, }, { name: "second last split last split", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), newMsg(2, 1), newMsg(2, 1), }, firstMsgAt: 2, wantNumEval: 4, wantMsgLens: []int{1, 1, 1, 1}, wantErr: false, }, { name: "second last no split last split overflow", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), newMsg(1, 0), newMsg(4, 1), }, firstMsgAt: 2, wantNumEval: 4, wantMsgLens: []int{1, 1, 1, 1}, wantErr: true, }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) if err != nil && !tt.wantErr { t.Fatalf("err: %v", err) } if got != tt.wantNumEval { t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) } for i, msg := range tt.msgs { if msg.N != tt.wantMsgLens[i] { t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) } } }) } } wireguard-go-0.0.20231211/conn/bind_windows.go000066400000000000000000000353031466306003400206170ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "encoding/binary" "io" "net" "net/netip" "strconv" "sync" "sync/atomic" "unsafe" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" ) const ( packetsPerRing = 1024 bytesPerPacket = 2048 - 32 receiveSpins = 15 ) type ringPacket struct { addr WinRingEndpoint data [bytesPerPacket]byte } type ringBuffer struct { packets uintptr head, tail uint32 id winrio.BufferId iocp windows.Handle isFull bool cq winrio.Cq mu sync.Mutex overlapped windows.Overlapped } func (rb *ringBuffer) Push() *ringPacket { for rb.isFull { panic("ring is full") } ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) rb.tail += 1 if rb.tail%packetsPerRing == rb.head%packetsPerRing { rb.isFull = true } return ret } func (rb *ringBuffer) Return(count uint32) { if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { return } rb.head += count rb.isFull = false } type afWinRingBind struct { sock windows.Handle rx, tx ringBuffer rq winrio.Rq mu sync.Mutex blackhole bool } // WinRingBind uses Windows registered I/O for fast ring buffered networking. type WinRingBind struct { v4, v6 afWinRingBind mu sync.RWMutex isOpen atomic.Uint32 // 0, 1, or 2 } func NewDefaultBind() Bind { return NewWinRingBind() } func NewWinRingBind() Bind { if !winrio.Initialize() { return NewStdNetBind() } return new(WinRingBind) } type WinRingEndpoint struct { family uint16 data [30]byte } var ( _ Bind = (*WinRingBind)(nil) _ Endpoint = (*WinRingEndpoint)(nil) ) func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { host, port, err := net.SplitHostPort(s) if err != nil { return nil, err } host16, err := windows.UTF16PtrFromString(host) if err != nil { return nil, err } port16, err := windows.UTF16PtrFromString(port) if err != nil { return nil, err } hints := windows.AddrinfoW{ Flags: windows.AI_NUMERICHOST, Family: windows.AF_UNSPEC, Socktype: windows.SOCK_DGRAM, Protocol: windows.IPPROTO_UDP, } var addrinfo *windows.AddrinfoW err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) if err != nil { return nil, err } defer windows.FreeAddrInfoW(addrinfo) if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { return nil, windows.ERROR_INVALID_ADDRESS } var dst [unsafe.Sizeof(WinRingEndpoint{})]byte copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil } func (*WinRingEndpoint) ClearSrc() {} func (e *WinRingEndpoint) DstIP() netip.Addr { switch e.family { case windows.AF_INET: return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) case windows.AF_INET6: return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) } return netip.Addr{} } func (e *WinRingEndpoint) SrcIP() netip.Addr { return netip.Addr{} // not supported } func (e *WinRingEndpoint) DstToBytes() []byte { switch e.family { case windows.AF_INET: b := make([]byte, 0, 6) b = append(b, e.data[2:6]...) b = append(b, e.data[1], e.data[0]) return b case windows.AF_INET6: b := make([]byte, 0, 18) b = append(b, e.data[6:22]...) b = append(b, e.data[1], e.data[0]) return b } return nil } func (e *WinRingEndpoint) DstToString() string { switch e.family { case windows.AF_INET: return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() case windows.AF_INET6: var zone string if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { zone = strconv.FormatUint(uint64(scope), 10) } return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() } return "" } func (e *WinRingEndpoint) SrcToString() string { return "" } func (ring *ringBuffer) CloseAndZero() { if ring.cq != 0 { winrio.CloseCompletionQueue(ring.cq) ring.cq = 0 } if ring.iocp != 0 { windows.CloseHandle(ring.iocp) ring.iocp = 0 } if ring.id != 0 { winrio.DeregisterBuffer(ring.id) ring.id = 0 } if ring.packets != 0 { windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) ring.packets = 0 } ring.head = 0 ring.tail = 0 ring.isFull = false } func (bind *afWinRingBind) CloseAndZero() { bind.rx.CloseAndZero() bind.tx.CloseAndZero() if bind.sock != 0 { windows.CloseHandle(bind.sock) bind.sock = 0 } bind.blackhole = false } func (bind *WinRingBind) closeAndZero() { bind.isOpen.Store(0) bind.v4.CloseAndZero() bind.v6.CloseAndZero() } func (ring *ringBuffer) Open() error { var err error packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) if err != nil { return err } ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) if err != nil { return err } ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { return err } ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) if err != nil { return err } return nil } func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { var err error bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) if err != nil { return nil, err } err = bind.rx.Open() if err != nil { return nil, err } err = bind.tx.Open() if err != nil { return nil, err } bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) if err != nil { return nil, err } err = windows.Bind(bind.sock, sa) if err != nil { return nil, err } sa, err = windows.Getsockname(bind.sock) if err != nil { return nil, err } return sa, nil } func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { bind.mu.Lock() defer bind.mu.Unlock() defer func() { if err != nil { bind.closeAndZero() } }() if bind.isOpen.Load() != 0 { return nil, 0, ErrBindAlreadyOpen } var sa windows.Sockaddr sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) if err != nil { return nil, 0, err } sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) if err != nil { return nil, 0, err } selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) for i := 0; i < packetsPerRing; i++ { err = bind.v4.InsertReceiveRequest() if err != nil { return nil, 0, err } err = bind.v6.InsertReceiveRequest() if err != nil { return nil, 0, err } } bind.isOpen.Store(1) return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err } func (bind *WinRingBind) Close() error { bind.mu.RLock() if bind.isOpen.Load() != 1 { bind.mu.RUnlock() return nil } bind.isOpen.Store(2) windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) bind.mu.RUnlock() bind.mu.Lock() defer bind.mu.Unlock() bind.closeAndZero() return nil } // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // rename the IdealBatchSize constant to BatchSize. func (bind *WinRingBind) BatchSize() int { // TODO: implement batching in and out of the ring return 1 } func (bind *WinRingBind) SetMark(mark uint32) error { return nil } func (bind *afWinRingBind) InsertReceiveRequest() error { packet := bind.rx.Push() dataBuffer := &winrio.Buffer{ Id: bind.rx.id, Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), Length: uint32(len(packet.data)), } addressBuffer := &winrio.Buffer{ Id: bind.rx.id, Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), Length: uint32(unsafe.Sizeof(packet.addr)), } bind.mu.Lock() defer bind.mu.Unlock() return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) } //go:linkname procyield runtime.procyield func procyield(cycles uint32) func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } bind.rx.mu.Lock() defer bind.rx.mu.Unlock() var err error var count uint32 var results [1]winrio.Result retry: count = 0 for tries := 0; count == 0 && tries < receiveSpins; tries++ { if tries > 0 { if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } procyield(1) } count = winrio.DequeueCompletion(bind.rx.cq, results[:]) } if count == 0 { err = winrio.Notify(bind.rx.cq) if err != nil { return 0, nil, err } var bytes uint32 var key uintptr var overlapped *windows.Overlapped err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) if err != nil { return 0, nil, err } if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } count = winrio.DequeueCompletion(bind.rx.cq, results[:]) if count == 0 { return 0, nil, io.ErrNoProgress } } bind.rx.Return(1) err = bind.InsertReceiveRequest() if err != nil { return 0, nil, err } // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to // attacker bandwidth, just like the rest of the receive path. if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } goto retry } if results[0].Status != 0 { return 0, nil, windows.Errno(results[0].Status) } packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) ep := packet.addr n := copy(buf, packet.data[:results[0].BytesTransferred]) return n, &ep, nil } func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) sizes[0] = n eps[0] = ep return 1, err } func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) sizes[0] = n eps[0] = ep return 1, err } func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { if isOpen.Load() != 1 { return net.ErrClosed } if len(buf) > bytesPerPacket { return io.ErrShortBuffer } bind.tx.mu.Lock() defer bind.tx.mu.Unlock() var results [packetsPerRing]winrio.Result count := winrio.DequeueCompletion(bind.tx.cq, results[:]) if count == 0 && bind.tx.isFull { err := winrio.Notify(bind.tx.cq) if err != nil { return err } var bytes uint32 var key uintptr var overlapped *windows.Overlapped err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) if err != nil { return err } if isOpen.Load() != 1 { return net.ErrClosed } count = winrio.DequeueCompletion(bind.tx.cq, results[:]) if count == 0 { return io.ErrNoProgress } } if count > 0 { bind.tx.Return(count) } packet := bind.tx.Push() packet.addr = *nend copy(packet.data[:], buf) dataBuffer := &winrio.Buffer{ Id: bind.tx.id, Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), Length: uint32(len(buf)), } addressBuffer := &winrio.Buffer{ Id: bind.tx.id, Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), Length: uint32(unsafe.Sizeof(packet.addr)), } bind.mu.Lock() defer bind.mu.Unlock() return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType } bind.mu.RLock() defer bind.mu.RUnlock() for _, buf := range bufs { switch nend.family { case windows.AF_INET: if bind.v4.blackhole { continue } if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { return err } case windows.AF_INET6: if bind.v6.blackhole { continue } if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { return err } } } return nil } func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { s.mu.Lock() defer s.mu.Unlock() sysconn, err := s.ipv4.SyscallConn() if err != nil { return err } err2 := sysconn.Control(func(fd uintptr) { err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) }) if err2 != nil { return err2 } if err != nil { return err } s.blackhole4 = blackhole return nil } func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { s.mu.Lock() defer s.mu.Unlock() sysconn, err := s.ipv6.SyscallConn() if err != nil { return err } err2 := sysconn.Control(func(fd uintptr) { err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) }) if err2 != nil { return err2 } if err != nil { return err } s.blackhole6 = blackhole return nil } func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { bind.mu.RLock() defer bind.mu.RUnlock() if bind.isOpen.Load() != 1 { return net.ErrClosed } err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) if err != nil { return err } bind.v4.blackhole = blackhole return nil } func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { bind.mu.RLock() defer bind.mu.RUnlock() if bind.isOpen.Load() != 1 { return net.ErrClosed } err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) if err != nil { return err } bind.v6.blackhole = blackhole return nil } func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { const IP_UNICAST_IF = 31 /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ var bytes [4]byte binary.BigEndian.PutUint32(bytes[:], interfaceIndex) interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) if err != nil { return err } return nil } func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { const IPV6_UNICAST_IF = 31 return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) } wireguard-go-0.0.20231211/conn/bindtest/000077500000000000000000000000001466306003400174125ustar00rootroot00000000000000wireguard-go-0.0.20231211/conn/bindtest/bindtest.go000066400000000000000000000062121466306003400215560ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package bindtest import ( "fmt" "math/rand" "net" "net/netip" "os" "golang.zx2c4.com/wireguard/conn" ) type ChannelBind struct { rx4, tx4 *chan []byte rx6, tx6 *chan []byte closeSignal chan bool source4, source6 ChannelEndpoint target4, target6 ChannelEndpoint } type ChannelEndpoint uint16 var ( _ conn.Bind = (*ChannelBind)(nil) _ conn.Endpoint = (*ChannelEndpoint)(nil) ) func NewChannelBinds() [2]conn.Bind { arx4 := make(chan []byte, 8192) brx4 := make(chan []byte, 8192) arx6 := make(chan []byte, 8192) brx6 := make(chan []byte, 8192) var binds [2]ChannelBind binds[0].rx4 = &arx4 binds[0].tx4 = &brx4 binds[1].rx4 = &brx4 binds[1].tx4 = &arx4 binds[0].rx6 = &arx6 binds[0].tx6 = &brx6 binds[1].rx6 = &brx6 binds[1].tx6 = &arx6 binds[0].target4 = ChannelEndpoint(1) binds[1].target4 = ChannelEndpoint(2) binds[0].target6 = ChannelEndpoint(3) binds[1].target6 = ChannelEndpoint(4) binds[0].source4 = binds[1].target4 binds[0].source6 = binds[1].target6 binds[1].source4 = binds[0].target4 binds[1].source6 = binds[0].target6 return [2]conn.Bind{&binds[0], &binds[1]} } func (c ChannelEndpoint) ClearSrc() {} func (c ChannelEndpoint) SrcToString() string { return "" } func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { c.closeSignal = make(chan bool) fns = append(fns, c.makeReceiveFunc(*c.rx4)) fns = append(fns, c.makeReceiveFunc(*c.rx6)) if rand.Uint32()&1 == 0 { return fns, uint16(c.source4), nil } else { return fns, uint16(c.source6), nil } } func (c *ChannelBind) Close() error { if c.closeSignal != nil { select { case <-c.closeSignal: default: close(c.closeSignal) } } return nil } func (c *ChannelBind) BatchSize() int { return 1 } func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: return 0, net.ErrClosed case rx := <-ch: copied := copy(bufs[0], rx) sizes[0] = copied eps[0] = c.target6 return 1, nil } } } func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { for _, b := range bufs { select { case <-c.closeSignal: return net.ErrClosed default: bc := make([]byte, len(b)) copy(bc, b) if ep.(ChannelEndpoint) == c.target4 { *c.tx4 <- bc } else if ep.(ChannelEndpoint) == c.target6 { *c.tx6 <- bc } else { return os.ErrInvalid } } } return nil } func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { addr, err := netip.ParseAddrPort(s) if err != nil { return nil, err } return ChannelEndpoint(addr.Port()), nil } wireguard-go-0.0.20231211/conn/boundif_android.go000066400000000000000000000011421466306003400212510ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { sysconn, err := s.ipv4.SyscallConn() if err != nil { return -1, err } err = sysconn.Control(func(f uintptr) { fd = int(f) }) if err != nil { return -1, err } return } func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { sysconn, err := s.ipv6.SyscallConn() if err != nil { return -1, err } err = sysconn.Control(func(f uintptr) { fd = int(f) }) if err != nil { return -1, err } return } wireguard-go-0.0.20231211/conn/conn.go000066400000000000000000000104311466306003400170610ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ // Package conn implements WireGuard's network connections. package conn import ( "errors" "fmt" "net/netip" "reflect" "runtime" "strings" ) const ( IdealBatchSize = 128 // maximum number of packets handled per read and write ) // A ReceiveFunc receives at least one packet from the network and writes them // into packets. On a successful read it returns the number of elements of // sizes, packets, and endpoints that should be evaluated. Some elements of // sizes may be zero, and callers should ignore them. Callers must pass a sizes // and eps slice with a length greater than or equal to the length of packets. // These lengths must not exceed the length of the associated Bind.BatchSize(). type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, // depending on the platform-specific implementation. type Bind interface { // Open puts the Bind into a listening state on a given port and reports the actual // port that it bound to. Passing zero results in a random selection. // fns is the set of functions that will be called to receive packets. Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) // Close closes the Bind listener. // All fns returned by Open must return net.ErrClosed after a call to Close. Close() error // SetMark sets the mark for each packet sent through this Bind. // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error // Send writes one or more packets in bufs to address ep. The length of // bufs must not exceed BatchSize(). Send(bufs [][]byte, ep Endpoint) error // ParseEndpoint creates a new endpoint from a string. ParseEndpoint(s string) (Endpoint, error) // BatchSize is the number of buffers expected to be passed to // the ReceiveFuncs, and the maximum expected to be passed to SendBatch. BatchSize() int } // BindSocketToInterface is implemented by Bind objects that support being // tied to a single network interface. Used by wireguard-windows. type BindSocketToInterface interface { BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error } // PeekLookAtSocketFd is implemented by Bind objects that support having their // file descriptor peeked at. Used by wireguard-android. type PeekLookAtSocketFd interface { PeekLookAtSocketFd4() (fd int, err error) PeekLookAtSocketFd6() (fd int, err error) } // An Endpoint maintains the source/destination caching for a peer. // // dst: the remote address of a peer ("endpoint" in uapi terminology) // src: the local address from which datagrams originate going to the peer type Endpoint interface { ClearSrc() // clears the source address SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations DstIP() netip.Addr SrcIP() netip.Addr } var ( ErrBindAlreadyOpen = errors.New("bind is already open") ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") ) func (fn ReceiveFunc) PrettyName() string { name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() // 0. cheese/taco.beansIPv6.func12.func21218-fm name = strings.TrimSuffix(name, "-fm") // 1. cheese/taco.beansIPv6.func12.func21218 if idx := strings.LastIndexByte(name, '/'); idx != -1 { name = name[idx+1:] // 2. taco.beansIPv6.func12.func21218 } for { var idx int for idx = len(name) - 1; idx >= 0; idx-- { if name[idx] < '0' || name[idx] > '9' { break } } if idx == len(name)-1 { break } const dotFunc = ".func" if !strings.HasSuffix(name[:idx+1], dotFunc) { break } name = name[:idx+1-len(dotFunc)] // 3. taco.beansIPv6.func12 // 4. taco.beansIPv6 } if idx := strings.LastIndexByte(name, '.'); idx != -1 { name = name[idx+1:] // 5. beansIPv6 } if name == "" { return fmt.Sprintf("%p", fn) } if strings.HasSuffix(name, "IPv4") { return "v4" } if strings.HasSuffix(name, "IPv6") { return "v6" } return name } wireguard-go-0.0.20231211/conn/conn_test.go000066400000000000000000000007511466306003400201240ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "testing" ) func TestPrettyName(t *testing.T) { var ( recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } ) const want = "TestPrettyName" t.Run("ReceiveFunc.PrettyName", func(t *testing.T) { if got := recvFunc.PrettyName(); got != want { t.Errorf("PrettyName() = %v, want %v", got, want) } }) } wireguard-go-0.0.20231211/conn/controlfns.go000066400000000000000000000025461466306003400203230ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "net" "syscall" ) // UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is // the max supported by a default configuration of macOS. Some platforms will // silently clamp the value to other maximums, such as linux clamping to // net.core.{r,w}mem_max (see _linux.go for additional implementation that works // around this limitation) const socketBufferSize = 7 << 20 // controlFn is the callback function signature from net.ListenConfig.Control. // It is used to apply platform specific configuration to the socket prior to // bind. type controlFn func(network, address string, c syscall.RawConn) error // controlFns is a list of functions that are called from the listen config // that can apply socket options. var controlFns = []controlFn{} // listenConfig returns a net.ListenConfig that applies the controlFns to the // socket prior to bind. This is used to apply socket buffer sizing and packet // information OOB configuration for sticky sockets. func listenConfig() *net.ListenConfig { return &net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { for _, fn := range controlFns { if err := fn(network, address, c); err != nil { return err } } return nil }, } } wireguard-go-0.0.20231211/conn/controlfns_linux.go000066400000000000000000000040531466306003400215350ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "fmt" "runtime" "syscall" "golang.org/x/sys/unix" ) func init() { controlFns = append(controlFns, // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to // fail silently - the result of failure is lower performance on very fast // links or high latency links. func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { // Set up to *mem_max _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) // Set beyond *mem_max if CAP_NET_ADMIN _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) }) }, // Enable receiving of the packet information (IP_PKTINFO for IPv4, // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. func(network, address string, c syscall.RawConn) error { var err error switch network { case "udp4": if runtime.GOOS != "android" { c.Control(func(fd uintptr) { err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) }) } case "udp6": c.Control(func(fd uintptr) { if runtime.GOOS != "android" { err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) if err != nil { return } } err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) }) default: err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) } return err }, // Attempt to enable UDP_GRO func(network, address string, c syscall.RawConn) error { c.Control(func(fd uintptr) { _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) }) return nil }, ) } wireguard-go-0.0.20231211/conn/controlfns_unix.go000066400000000000000000000014231466306003400213570ustar00rootroot00000000000000//go:build !windows && !linux && !wasm /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "syscall" "golang.org/x/sys/unix" ) func init() { controlFns = append(controlFns, func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) }) }, func(network, address string, c syscall.RawConn) error { var err error if network == "udp6" { c.Control(func(fd uintptr) { err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) }) } return err }, ) } wireguard-go-0.0.20231211/conn/controlfns_windows.go000066400000000000000000000010401466306003400220610ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "syscall" "golang.org/x/sys/windows" ) func init() { controlFns = append(controlFns, func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize) _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize) }) }, ) } wireguard-go-0.0.20231211/conn/default.go000066400000000000000000000003001466306003400175420ustar00rootroot00000000000000//go:build !windows /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn func NewDefaultBind() Bind { return NewStdNetBind() } wireguard-go-0.0.20231211/conn/errors_default.go000066400000000000000000000003061466306003400211440ustar00rootroot00000000000000//go:build !linux /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn func errShouldDisableUDPGSO(err error) bool { return false } wireguard-go-0.0.20231211/conn/errors_linux.go000066400000000000000000000013651466306003400206650ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "errors" "os" "golang.org/x/sys/unix" ) func errShouldDisableUDPGSO(err error) bool { var serr *os.SyscallError if errors.As(err, &serr) { // EIO is returned by udp_send_skb() if the device driver does not have // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. // See: // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 return serr.Err == unix.EIO } return false } wireguard-go-0.0.20231211/conn/features_default.go000066400000000000000000000003721466306003400214510ustar00rootroot00000000000000//go:build !linux // +build !linux /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import "net" func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { return } wireguard-go-0.0.20231211/conn/features_linux.go000066400000000000000000000012111466306003400211550ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "net" "golang.org/x/sys/unix" ) func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { rc, err := conn.SyscallConn() if err != nil { return } err = rc.Control(func(fd uintptr) { _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) txOffload = errSyscall == nil opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) rxOffload = errSyscall == nil && opt == 1 }) if err != nil { return false, false } return txOffload, rxOffload } wireguard-go-0.0.20231211/conn/gso_default.go000066400000000000000000000010241466306003400204160ustar00rootroot00000000000000//go:build !linux /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn // getGSOSize parses control for UDP_GRO and if found returns its GSO size data. func getGSOSize(control []byte) (int, error) { return 0, nil } // setGSOSize sets a UDP_SEGMENT in control based on gsoSize. func setGSOSize(control *[]byte, gsoSize uint16) { } // gsoControlSize returns the recommended buffer size for pooling sticky and UDP // offloading control data. const gsoControlSize = 0 wireguard-go-0.0.20231211/conn/gso_linux.go000066400000000000000000000032421466306003400201350ustar00rootroot00000000000000//go:build linux /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "fmt" "unsafe" "golang.org/x/sys/unix" ) const ( sizeOfGSOData = 2 ) // getGSOSize parses control for UDP_GRO and if found returns its GSO size data. func getGSOSize(control []byte) (int, error) { var ( hdr unix.Cmsghdr data []byte rem = control err error ) for len(rem) > unix.SizeofCmsghdr { hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) if err != nil { return 0, fmt.Errorf("error parsing socket control message: %w", err) } if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { var gso uint16 copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) return int(gso), nil } } return 0, nil } // setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing // data in control untouched. func setGSOSize(control *[]byte, gsoSize uint16) { existingLen := len(*control) avail := cap(*control) - existingLen space := unix.CmsgSpace(sizeOfGSOData) if avail < space { return } *control = (*control)[:cap(*control)] gsoControl := (*control)[existingLen:] hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) hdr.Level = unix.SOL_UDP hdr.Type = unix.UDP_SEGMENT hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) *control = (*control)[:existingLen+space] } // gsoControlSize returns the recommended buffer size for pooling UDP // offloading control data. var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) wireguard-go-0.0.20231211/conn/mark_default.go000066400000000000000000000003401466306003400205600ustar00rootroot00000000000000//go:build !linux && !openbsd && !freebsd /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn func (s *StdNetBind) SetMark(mark uint32) error { return nil } wireguard-go-0.0.20231211/conn/mark_unix.go000066400000000000000000000022071466306003400201230ustar00rootroot00000000000000//go:build linux || openbsd || freebsd /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "runtime" "golang.org/x/sys/unix" ) var fwmarkIoctl int func init() { switch runtime.GOOS { case "linux", "android": fwmarkIoctl = 36 /* unix.SO_MARK */ case "freebsd": fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ case "openbsd": fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ } } func (s *StdNetBind) SetMark(mark uint32) error { var operr error if fwmarkIoctl == 0 { return nil } if s.ipv4 != nil { fd, err := s.ipv4.SyscallConn() if err != nil { return err } err = fd.Control(func(fd uintptr) { operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) }) if err == nil { err = operr } if err != nil { return err } } if s.ipv6 != nil { fd, err := s.ipv6.SyscallConn() if err != nil { return err } err = fd.Control(func(fd uintptr) { operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) }) if err == nil { err = operr } if err != nil { return err } } return nil } wireguard-go-0.0.20231211/conn/sticky_default.go000066400000000000000000000020521466306003400211360ustar00rootroot00000000000000//go:build !linux || android /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import "net/netip" func (e *StdNetEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (e *StdNetEndpoint) SrcIfidx() int32 { return 0 } func (e *StdNetEndpoint) SrcToString() string { return "" } // TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets // {get,set}srcControl feature set, but use alternatively named flags and need // ports and require testing. // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. func getSrcFromControl(control []byte, ep *StdNetEndpoint) { } // setSrcControl parses the control for PKTINFO and if found updates ep with // the source information found. func setSrcControl(control *[]byte, ep *StdNetEndpoint) { } // stickyControlSize returns the recommended buffer size for pooling sticky // offloading control data. const stickyControlSize = 0 const StdNetSupportsStickySockets = false wireguard-go-0.0.20231211/conn/sticky_linux.go000066400000000000000000000061031466306003400206520ustar00rootroot00000000000000//go:build linux && !android /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "net/netip" "unsafe" "golang.org/x/sys/unix" ) func (e *StdNetEndpoint) SrcIP() netip.Addr { switch len(e.src) { case unix.CmsgSpace(unix.SizeofInet4Pktinfo): info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) return netip.AddrFrom4(info.Spec_dst) case unix.CmsgSpace(unix.SizeofInet6Pktinfo): info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) // TODO: set zone. in order to do so we need to check if the address is // link local, and if it is perform a syscall to turn the ifindex into a // zone string because netip uses string zones. return netip.AddrFrom16(info.Addr) } return netip.Addr{} } func (e *StdNetEndpoint) SrcIfidx() int32 { switch len(e.src) { case unix.CmsgSpace(unix.SizeofInet4Pktinfo): info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) return info.Ifindex case unix.CmsgSpace(unix.SizeofInet6Pktinfo): info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) return int32(info.Ifindex) } return 0 } func (e *StdNetEndpoint) SrcToString() string { return e.SrcIP().String() } // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. func getSrcFromControl(control []byte, ep *StdNetEndpoint) { ep.ClearSrc() var ( hdr unix.Cmsghdr data []byte rem []byte = control err error ) for len(rem) > unix.SizeofCmsghdr { hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) if err != nil { return } if hdr.Level == unix.IPPROTO_IP && hdr.Type == unix.IP_PKTINFO { if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) } ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) copy(ep.src, hdrBuf) copy(ep.src[unix.CmsgLen(0):], data) return } if hdr.Level == unix.IPPROTO_IPV6 && hdr.Type == unix.IPV6_PKTINFO { if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) } ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) copy(ep.src, hdrBuf) copy(ep.src[unix.CmsgLen(0):], data) return } } } // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address // and source ifindex found in ep. control's len will be set to 0 in the event // that ep is a default value. func setSrcControl(control *[]byte, ep *StdNetEndpoint) { if cap(*control) < len(ep.src) { return } *control = (*control)[:0] *control = append(*control, ep.src...) } // stickyControlSize returns the recommended buffer size for pooling sticky // offloading control data. var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) const StdNetSupportsStickySockets = true wireguard-go-0.0.20231211/conn/sticky_linux_test.go000066400000000000000000000162721466306003400217210ustar00rootroot00000000000000//go:build linux && !android /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "context" "net" "net/netip" "runtime" "testing" "unsafe" "golang.org/x/sys/unix" ) func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { var buf []byte if addr.Is4() { buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) hdr := unix.Cmsghdr{ Level: unix.IPPROTO_IP, Type: unix.IP_PKTINFO, } hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) info := unix.Inet4Pktinfo{ Ifindex: ifidx, Spec_dst: addr.As4(), } copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) } else { buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) hdr := unix.Cmsghdr{ Level: unix.IPPROTO_IPV6, Type: unix.IPV6_PKTINFO, } hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) info := unix.Inet6Pktinfo{ Ifindex: uint32(ifidx), Addr: addr.As16(), } copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) } ep.src = buf } func Test_setSrcControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), } setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) control := make([]byte, stickyControlSize) setSrcControl(&control, ep) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) if hdr.Level != unix.IPPROTO_IP { t.Errorf("unexpected level: %d", hdr.Level) } if hdr.Type != unix.IP_PKTINFO { t.Errorf("unexpected type: %d", hdr.Type) } if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { t.Errorf("unexpected length: %d", hdr.Len) } info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { t.Errorf("unexpected address: %v", info.Spec_dst) } if info.Ifindex != 5 { t.Errorf("unexpected ifindex: %d", info.Ifindex) } }) t.Run("IPv6", func(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("[::1]:1234"), } setSrc(ep, netip.MustParseAddr("::1"), 5) control := make([]byte, stickyControlSize) setSrcControl(&control, ep) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) if hdr.Level != unix.IPPROTO_IPV6 { t.Errorf("unexpected level: %d", hdr.Level) } if hdr.Type != unix.IPV6_PKTINFO { t.Errorf("unexpected type: %d", hdr.Type) } if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { t.Errorf("unexpected length: %d", hdr.Len) } info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) if info.Addr != ep.SrcIP().As16() { t.Errorf("unexpected address: %v", info.Addr) } if info.Ifindex != 5 { t.Errorf("unexpected ifindex: %d", info.Ifindex) } }) t.Run("ClearOnNoSrc", func(t *testing.T) { control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = 1 hdr.Type = 2 hdr.Len = 3 setSrcControl(&control, &StdNetEndpoint{}) if len(control) != 0 { t.Errorf("unexpected control: %v", control) } }) } func Test_getSrcFromControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) info.Spec_dst = [4]byte{127, 0, 0, 1} info.Ifindex = 5 ep := &StdNetEndpoint{} getSrcFromControl(control, ep) if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { t.Errorf("unexpected address: %v", ep.SrcIP()) } if ep.SrcIfidx() != 5 { t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("IPv6", func(t *testing.T) { control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IPV6 hdr.Type = unix.IPV6_PKTINFO hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} info.Ifindex = 5 ep := &StdNetEndpoint{} getSrcFromControl(control, ep) if ep.SrcIP() != netip.MustParseAddr("::1") { t.Errorf("unexpected address: %v", ep.SrcIP()) } if ep.SrcIfidx() != 5 { t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("ClearOnEmpty", func(t *testing.T) { var control []byte ep := &StdNetEndpoint{} setSrc(ep, netip.MustParseAddr("::1"), 5) getSrcFromControl(control, ep) if ep.SrcIP().IsValid() { t.Errorf("unexpected address: %v", ep.SrcIP()) } if ep.SrcIfidx() != 0 { t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("Multiple", func(t *testing.T) { zeroControl := make([]byte, unix.CmsgSpace(0)) zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) zeroHdr.SetLen(unix.CmsgLen(0)) control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) info.Spec_dst = [4]byte{127, 0, 0, 1} info.Ifindex = 5 combined := make([]byte, 0) combined = append(combined, zeroControl...) combined = append(combined, control...) ep := &StdNetEndpoint{} getSrcFromControl(combined, ep) if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { t.Errorf("unexpected address: %v", ep.SrcIP()) } if ep.SrcIfidx() != 5 { t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) } func Test_listenConfig(t *testing.T) { t.Run("IPv4", func(t *testing.T) { conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") if err != nil { t.Fatal(err) } defer conn.Close() sc, err := conn.(*net.UDPConn).SyscallConn() if err != nil { t.Fatal(err) } if runtime.GOOS == "linux" { var i int sc.Control(func(fd uintptr) { i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) }) if err != nil { t.Fatal(err) } if i != 1 { t.Error("IP_PKTINFO not set!") } } else { t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) } }) t.Run("IPv6", func(t *testing.T) { conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") if err != nil { t.Fatal(err) } sc, err := conn.(*net.UDPConn).SyscallConn() if err != nil { t.Fatal(err) } if runtime.GOOS == "linux" { var i int sc.Control(func(fd uintptr) { i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) }) if err != nil { t.Fatal(err) } if i != 1 { t.Error("IPV6_PKTINFO not set!") } } else { t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) } }) } wireguard-go-0.0.20231211/conn/winrio/000077500000000000000000000000001466306003400171055ustar00rootroot00000000000000wireguard-go-0.0.20231211/conn/winrio/rio_windows.go000066400000000000000000000166731466306003400220140ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package winrio import ( "log" "sync" "syscall" "unsafe" "golang.org/x/sys/windows" ) const ( MsgDontNotify = 1 MsgDefer = 2 MsgWaitAll = 4 MsgCommitOnly = 8 MaxCqSize = 0x8000000 invalidBufferId = 0xFFFFFFFF invalidCq = 0 invalidRq = 0 corruptCq = 0xFFFFFFFF ) var extensionFunctionTable struct { cbSize uint32 rioReceive uintptr rioReceiveEx uintptr rioSend uintptr rioSendEx uintptr rioCloseCompletionQueue uintptr rioCreateCompletionQueue uintptr rioCreateRequestQueue uintptr rioDequeueCompletion uintptr rioDeregisterBuffer uintptr rioNotify uintptr rioRegisterBuffer uintptr rioResizeCompletionQueue uintptr rioResizeRequestQueue uintptr } type Cq uintptr type Rq uintptr type BufferId uintptr type Buffer struct { Id BufferId Offset uint32 Length uint32 } type Result struct { Status int32 BytesTransferred uint32 SocketContext uint64 RequestContext uint64 } type notificationCompletionType uint32 const ( eventCompletion notificationCompletionType = 1 iocpCompletion notificationCompletionType = 2 ) type eventNotificationCompletion struct { completionType notificationCompletionType event windows.Handle notifyReset uint32 } type iocpNotificationCompletion struct { completionType notificationCompletionType iocp windows.Handle key uintptr overlapped *windows.Overlapped } var ( initialized sync.Once available bool ) func Initialize() bool { initialized.Do(func() { var ( err error socket windows.Handle cq Cq ) defer func() { if err == nil { return } if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 { return } log.Printf("Registered I/O is unavailable: %v", err) }() socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP) if err != nil { return } defer windows.CloseHandle(socket) WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 ob := uint32(0) err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)), (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)), &ob, nil, 0) if err != nil { return } // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes // failures in RIOCreateRequestQueue, so keep going to be certain this is supported. var iocp windows.Handle iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { return } defer windows.CloseHandle(iocp) var overlapped windows.Overlapped cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped) if err != nil { return } defer CloseCompletionQueue(cq) _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0) if err != nil { return } available = true }) return available } func Socket(af, typ, proto int32) (windows.Handle, error) { return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO) } func CloseCompletionQueue(cq Cq) { _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0) } func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) { notificationCompletion := &eventNotificationCompletion{ completionType: eventCompletion, event: event, } if notifyReset { notificationCompletion.notifyReset = 1 } ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) if ret == invalidCq { return 0, err } return Cq(ret), nil } func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) { notificationCompletion := &iocpNotificationCompletion{ completionType: iocpCompletion, iocp: iocp, key: key, overlapped: overlapped, } ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) if ret == invalidCq { return 0, err } return Cq(ret), nil } func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) { ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0) if ret == invalidCq { return 0, err } return Cq(ret), nil } func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) { ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0) if ret == invalidRq { return 0, err } return Rq(ret), nil } func DequeueCompletion(cq Cq, results []Result) uint32 { var array uintptr if len(results) > 0 { array = uintptr(unsafe.Pointer(&results[0])) } ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results))) if ret == corruptCq { panic("cq is corrupt") } return uint32(ret) } func DeregisterBuffer(id BufferId) { _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0) } func RegisterBuffer(buffer []byte) (BufferId, error) { var buf unsafe.Pointer if len(buffer) > 0 { buf = unsafe.Pointer(&buffer[0]) } return RegisterPointer(buf, uint32(len(buffer))) } func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) { ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0) if ret == invalidBufferId { return 0, err } return BufferId(ret), nil } func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) if ret == 0 { return err } return nil } func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) if ret == 0 { return err } return nil } func Notify(cq Cq) error { ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0) if ret != 0 { return windows.Errno(ret) } return nil } wireguard-go-0.0.20231211/device/000077500000000000000000000000001466306003400161005ustar00rootroot00000000000000wireguard-go-0.0.20231211/device/allowedips.go000066400000000000000000000151311466306003400205730ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "container/list" "encoding/binary" "errors" "math/bits" "net" "net/netip" "sync" "unsafe" ) type parentIndirection struct { parentBit **trieEntry parentBitType uint8 } type trieEntry struct { peer *Peer child [2]*trieEntry parent parentIndirection cidr uint8 bitAtByte uint8 bitAtShift uint8 bits []byte perPeerElem *list.Element } func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { a := binary.BigEndian.Uint32(ip1) b := binary.BigEndian.Uint32(ip2) x := a ^ b return uint8(bits.LeadingZeros32(x)) } else if size == net.IPv6len { a := binary.BigEndian.Uint64(ip1) b := binary.BigEndian.Uint64(ip2) x := a ^ b if x != 0 { return uint8(bits.LeadingZeros64(x)) } a = binary.BigEndian.Uint64(ip1[8:]) b = binary.BigEndian.Uint64(ip2[8:]) x = a ^ b return 64 + uint8(bits.LeadingZeros64(x)) } else { panic("Wrong size bit string") } } func (node *trieEntry) addToPeerEntries() { node.perPeerElem = node.peer.trieEntries.PushBack(node) } func (node *trieEntry) removeFromPeerEntries() { if node.perPeerElem != nil { node.peer.trieEntries.Remove(node.perPeerElem) node.perPeerElem = nil } } func (node *trieEntry) choose(ip []byte) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } func (node *trieEntry) maskSelf() { mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) for i := 0; i < len(mask); i++ { node.bits[i] &= mask[i] } } func (node *trieEntry) zeroizePointers() { // Make the garbage collector's life slightly easier node.peer = nil node.child[0] = nil node.child[1] = nil node.parent.parentBit = nil } func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { parent = node if parent.cidr == cidr { exact = true return } bit := node.choose(ip) node = node.child[bit] } return } func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { if *trie.parentBit == nil { node := &trieEntry{ peer: peer, parent: trie, bits: ip, cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), } node.maskSelf() node.addToPeerEntries() *trie.parentBit = node return } node, exact := (*trie.parentBit).nodePlacement(ip, cidr) if exact { node.removeFromPeerEntries() node.peer = peer node.addToPeerEntries() return } newNode := &trieEntry{ peer: peer, bits: ip, cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), } newNode.maskSelf() newNode.addToPeerEntries() var down *trieEntry if node == nil { down = *trie.parentBit } else { bit := node.choose(ip) down = node.child[bit] if down == nil { newNode.parent = parentIndirection{&node.child[bit], bit} node.child[bit] = newNode return } } common := commonBits(down.bits, ip) if common < cidr { cidr = common } parent := node if newNode.cidr == cidr { bit := newNode.choose(down.bits) down.parent = parentIndirection{&newNode.child[bit], bit} newNode.child[bit] = down if parent == nil { newNode.parent = trie *trie.parentBit = newNode } else { bit := parent.choose(newNode.bits) newNode.parent = parentIndirection{&parent.child[bit], bit} parent.child[bit] = newNode } return } node = &trieEntry{ bits: append([]byte{}, newNode.bits...), cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), } node.maskSelf() bit := node.choose(down.bits) down.parent = parentIndirection{&node.child[bit], bit} node.child[bit] = down bit = node.choose(newNode.bits) newNode.parent = parentIndirection{&node.child[bit], bit} node.child[bit] = newNode if parent == nil { node.parent = trie *trie.parentBit = node } else { bit := parent.choose(node.bits) node.parent = parentIndirection{&parent.child[bit], bit} parent.child[bit] = node } } func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } if node.bitAtByte == size { break } bit := node.choose(ip) node = node.child[bit] } return found } type AllowedIPs struct { IPv4 *trieEntry IPv6 *trieEntry mutex sync.RWMutex } func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { node := elem.Value.(*trieEntry) a, _ := netip.AddrFromSlice(node.bits) if !cb(netip.PrefixFrom(a, int(node.cidr))) { return } } } func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() var next *list.Element for elem := peer.trieEntries.Front(); elem != nil; elem = next { next = elem.Next() node := elem.Value.(*trieEntry) node.removeFromPeerEntries() node.peer = nil if node.child[0] != nil && node.child[1] != nil { continue } bit := 0 if node.child[0] == nil { bit = 1 } child := node.child[bit] if child != nil { child.parent = node.parent } *node.parent.parentBit = child if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { node.zeroizePointers() continue } parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) if parent.peer != nil { node.zeroizePointers() continue } child = parent.child[node.parent.parentBitType^1] if child != nil { child.parent = parent.parent } *parent.parent.parentBit = child node.zeroizePointers() parent.zeroizePointers() } } func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() if prefix.Addr().Is6() { ip := prefix.Addr().As16() parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) } else if prefix.Addr().Is4() { ip := prefix.Addr().As4() parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) } else { panic(errors.New("inserting unknown address type")) } } func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() switch len(ip) { case net.IPv6len: return table.IPv6.lookup(ip) case net.IPv4len: return table.IPv4.lookup(ip) default: panic(errors.New("looking up unknown address type")) } } wireguard-go-0.0.20231211/device/allowedips_rand_test.go000066400000000000000000000055761466306003400226520ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" "net" "net/netip" "sort" "testing" ) const ( NumberOfPeers = 100 NumberOfPeerRemovals = 4 NumberOfAddresses = 250 NumberOfTests = 10000 ) type SlowNode struct { peer *Peer cidr uint8 bits []byte } type SlowRouter []*SlowNode func (r SlowRouter) Len() int { return len(r) } func (r SlowRouter) Less(i, j int) bool { return r[i].cidr > r[j].cidr } func (r SlowRouter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { for _, t := range r { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { t.peer = peer t.bits = addr return r } } r = append(r, &SlowNode{ cidr: cidr, bits: addr, peer: peer, }) sort.Sort(r) return r } func (r SlowRouter) Lookup(addr []byte) *Peer { for _, t := range r { common := commonBits(t.bits, addr) if common >= t.cidr { return t.peer } } return nil } func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { n := 0 for _, x := range r { if x.peer != peer { r[n] = x n++ } } return r[:n] } func TestTrieRandom(t *testing.T) { var slow4, slow6 SlowRouter var peers []*Peer var allowedIPs AllowedIPs rand.Seed(1) for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) } for n := 0; n < NumberOfAddresses; n++ { var addr4 [4]byte rand.Read(addr4[:]) cidr := uint8(rand.Intn(32) + 1) index := rand.Intn(NumberOfPeers) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index]) var addr6 [16]byte rand.Read(addr6[:]) cidr = uint8(rand.Intn(128) + 1) index = rand.Intn(NumberOfPeers) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } var p int for p = 0; ; p++ { for n := 0; n < NumberOfTests; n++ { var addr4 [4]byte rand.Read(addr4[:]) peer1 := slow4.Lookup(addr4[:]) peer2 := allowedIPs.Lookup(addr4[:]) if peer1 != peer2 { t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) } var addr6 [16]byte rand.Read(addr6[:]) peer1 = slow6.Lookup(addr6[:]) peer2 = allowedIPs.Lookup(addr6[:]) if peer1 != peer2 { t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) } } if p >= len(peers) || p >= NumberOfPeerRemovals { break } allowedIPs.RemoveByPeer(peers[p]) slow4 = slow4.RemoveByPeer(peers[p]) slow6 = slow6.RemoveByPeer(peers[p]) } for ; p < len(peers); p++ { allowedIPs.RemoveByPeer(peers[p]) } if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { t.Error("Failed to remove all nodes from trie by peer") } } wireguard-go-0.0.20231211/device/allowedips_test.go000066400000000000000000000140361466306003400216350ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" "net" "net/netip" "testing" ) type testPairCommonBits struct { s1 []byte s2 []byte match uint8 } func TestCommonBits(t *testing.T) { tests := []testPairCommonBits{ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, } for _, p := range tests { v := commonBits(p.s1, p.s2) if v != p.match { t.Error( "For slice", p.s1, p.s2, "expected match", p.match, ",but got", v, ) } } } func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { var trie *trieEntry var peers []*Peer root := parentIndirection{&trie, 2} rand.Seed(1) const AddressLength = 4 for n := 0; n < peerNumber; n++ { peers = append(peers, &Peer{}) } for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte rand.Read(addr[:]) cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % peerNumber root.insert(addr[:], cidr, peers[index]) } for n := 0; n < b.N; n++ { var addr [AddressLength]byte rand.Read(addr[:]) trie.lookup(addr[:]) } } func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { benchmarkTrie(100, 1000, net.IPv4len, b) } func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { benchmarkTrie(10, 10, net.IPv4len, b) } func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { benchmarkTrie(100, 1000, net.IPv6len, b) } func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { benchmarkTrie(10, 10, net.IPv6len, b) } /* Test ported from kernel implementation: * selftest/allowedips.h */ func TestTrieIPv4(t *testing.T) { a := &Peer{} b := &Peer{} c := &Peer{} d := &Peer{} e := &Peer{} g := &Peer{} h := &Peer{} var allowedIPs AllowedIPs insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { p := allowedIPs.Lookup([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } } insert(a, 192, 168, 4, 0, 24) insert(b, 192, 168, 4, 4, 32) insert(c, 192, 168, 0, 0, 16) insert(d, 192, 95, 5, 64, 27) insert(c, 192, 95, 5, 65, 27) insert(e, 0, 0, 0, 0, 0) insert(g, 64, 15, 112, 0, 20) insert(h, 64, 15, 123, 211, 25) insert(a, 10, 0, 0, 0, 25) insert(b, 10, 0, 0, 128, 25) insert(a, 10, 1, 0, 0, 30) insert(b, 10, 1, 0, 4, 30) insert(c, 10, 1, 0, 8, 29) insert(d, 10, 1, 0, 16, 29) assertEQ(a, 192, 168, 4, 20) assertEQ(a, 192, 168, 4, 0) assertEQ(b, 192, 168, 4, 4) assertEQ(c, 192, 168, 200, 182) assertEQ(c, 192, 95, 5, 68) assertEQ(e, 192, 95, 5, 96) assertEQ(g, 64, 15, 116, 26) assertEQ(g, 64, 15, 127, 3) insert(a, 1, 0, 0, 0, 32) insert(a, 64, 0, 0, 0, 32) insert(a, 128, 0, 0, 0, 32) insert(a, 192, 0, 0, 0, 32) insert(a, 255, 0, 0, 0, 32) assertEQ(a, 1, 0, 0, 0) assertEQ(a, 64, 0, 0, 0) assertEQ(a, 128, 0, 0, 0) assertEQ(a, 192, 0, 0, 0) assertEQ(a, 255, 0, 0, 0) allowedIPs.RemoveByPeer(a) assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0) assertNEQ(a, 128, 0, 0, 0) assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0) allowedIPs.RemoveByPeer(a) allowedIPs.RemoveByPeer(b) allowedIPs.RemoveByPeer(c) allowedIPs.RemoveByPeer(d) allowedIPs.RemoveByPeer(e) allowedIPs.RemoveByPeer(g) allowedIPs.RemoveByPeer(h) if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { t.Error("Expected removing all the peers to empty trie, but it did not") } insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) } /* Test ported from kernel implementation: * selftest/allowedips.h */ func TestTrieIPv6(t *testing.T) { a := &Peer{} b := &Peer{} c := &Peer{} d := &Peer{} e := &Peer{} f := &Peer{} g := &Peer{} h := &Peer{} var allowedIPs AllowedIPs expand := func(a uint32) []byte { var out [4]byte out[0] = byte(a >> 24 & 0xff) out[1] = byte(a >> 16 & 0xff) out[2] = byte(a >> 8 & 0xff) out[3] = byte(a & 0xff) return out[:] } insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { var addr []byte addr = append(addr, expand(a)...) addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { var addr []byte addr = append(addr, expand(a)...) addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) p := allowedIPs.Lookup(addr) if p != peer { t.Error("Assert EQ failed") } } insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) insert(c, 0x26075300, 0x60006b00, 0, 0, 64) insert(e, 0, 0, 0, 0, 0) insert(f, 0, 0, 0, 0, 0) insert(g, 0x24046800, 0, 0, 0, 32) insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) assertEQ(f, 0x26075300, 0x60006b01, 0, 0) assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) assertEQ(h, 0x24046800, 0x40040800, 0, 0) assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) } wireguard-go-0.0.20231211/device/bind_test.go000066400000000000000000000020121466306003400203750ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "errors" "golang.zx2c4.com/wireguard/conn" ) type DummyDatagram struct { msg []byte endpoint conn.Endpoint } type DummyBind struct { in6 chan DummyDatagram in4 chan DummyDatagram closed bool } func (b *DummyBind) SetMark(v uint32) error { return nil } func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in6 if !ok { return 0, nil, errors.New("closed") } copy(buf, datagram.msg) return len(datagram.msg), datagram.endpoint, nil } func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in4 if !ok { return 0, nil, errors.New("closed") } copy(buf, datagram.msg) return len(datagram.msg), datagram.endpoint, nil } func (b *DummyBind) Close() error { close(b.in6) close(b.in4) b.closed = true return nil } func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error { return nil } wireguard-go-0.0.20231211/device/channels.go000066400000000000000000000071511466306003400202260ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "runtime" "sync" ) // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. // An outboundQueue is ref-counted using its wg field. // An outboundQueue created with newOutboundQueue has one reference. // Every additional writer must call wg.Add(1). // Every completed writer must call wg.Done(). // When no further writers will be added, // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. type outboundQueue struct { c chan *QueueOutboundElementsContainer wg sync.WaitGroup } func newOutboundQueue() *outboundQueue { q := &outboundQueue{ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } q.wg.Add(1) go func() { q.wg.Wait() close(q.c) }() return q } // A inboundQueue is similar to an outboundQueue; see those docs. type inboundQueue struct { c chan *QueueInboundElementsContainer wg sync.WaitGroup } func newInboundQueue() *inboundQueue { q := &inboundQueue{ c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } q.wg.Add(1) go func() { q.wg.Wait() close(q.c) }() return q } // A handshakeQueue is similar to an outboundQueue; see those docs. type handshakeQueue struct { c chan QueueHandshakeElement wg sync.WaitGroup } func newHandshakeQueue() *handshakeQueue { q := &handshakeQueue{ c: make(chan QueueHandshakeElement, QueueHandshakeSize), } q.wg.Add(1) go func() { q.wg.Wait() close(q.c) }() return q } type autodrainingInboundQueue struct { c chan *QueueInboundElementsContainer } // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. // It is useful in cases in which is it hard to manage the lifetime of the channel. // The returned channel must not be closed. Senders should signal shutdown using // some other means, such as sending a sentinel nil values. func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { q := &autodrainingInboundQueue{ c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } runtime.SetFinalizer(q, device.flushInboundQueue) return q } func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) default: return } } } type autodrainingOutboundQueue struct { c chan *QueueOutboundElementsContainer } // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. // It is useful in cases in which is it hard to manage the lifetime of the channel. // The returned channel must not be closed. Senders should signal shutdown using // some other means, such as sending a sentinel nil values. // All sends to the channel must be best-effort, because there may be no receivers. func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { q := &autodrainingOutboundQueue{ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } runtime.SetFinalizer(q, device.flushOutboundQueue) return q } func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) default: return } } } wireguard-go-0.0.20231211/device/constants.go000066400000000000000000000023211466306003400204410ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "time" ) /* Specification constants */ const ( RekeyAfterMessages = (1 << 60) RejectAfterMessages = (1 << 64) - (1 << 13) - 1 RekeyAfterTime = time.Second * 120 RekeyAttemptTime = time.Second * 90 RekeyTimeout = time.Second * 5 MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ RekeyTimeoutJitterMaxMs = 334 RejectAfterTime = time.Second * 180 KeepaliveTimeout = time.Second * 10 CookieRefreshTime = time.Second * 120 HandshakeInitationRate = time.Second / 50 PaddingMultiple = 16 ) const ( MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) MaxMessageSize = MaxSegmentSize // maximum size of transport message MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content ) /* Implementation constants */ const ( UnderLoadAfterTime = time.Second // how long does the device remain under load after detected MaxPeers = 1 << 16 // maximum number of configured peers ) wireguard-go-0.0.20231211/device/cookie.go000066400000000000000000000106671466306003400177120ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/hmac" "crypto/rand" "sync" "time" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" ) type CookieChecker struct { sync.RWMutex mac1 struct { key [blake2s.Size]byte } mac2 struct { secret [blake2s.Size]byte secretSet time.Time encryptionKey [chacha20poly1305.KeySize]byte } } type CookieGenerator struct { sync.RWMutex mac1 struct { key [blake2s.Size]byte } mac2 struct { cookie [blake2s.Size128]byte cookieSet time.Time hasLastMAC1 bool lastMAC1 [blake2s.Size128]byte encryptionKey [chacha20poly1305.KeySize]byte } } func (st *CookieChecker) Init(pk NoisePublicKey) { st.Lock() defer st.Unlock() // mac1 state func() { hash, _ := blake2s.New256(nil) hash.Write([]byte(WGLabelMAC1)) hash.Write(pk[:]) hash.Sum(st.mac1.key[:0]) }() // mac2 state func() { hash, _ := blake2s.New256(nil) hash.Write([]byte(WGLabelCookie)) hash.Write(pk[:]) hash.Sum(st.mac2.encryptionKey[:0]) }() st.mac2.secretSet = time.Time{} } func (st *CookieChecker) CheckMAC1(msg []byte) bool { st.RLock() defer st.RUnlock() size := len(msg) smac2 := size - blake2s.Size128 smac1 := smac2 - blake2s.Size128 var mac1 [blake2s.Size128]byte mac, _ := blake2s.New128(st.mac1.key[:]) mac.Write(msg[:smac1]) mac.Sum(mac1[:0]) return hmac.Equal(mac1[:], msg[smac1:smac2]) } func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { st.RLock() defer st.RUnlock() if time.Since(st.mac2.secretSet) > CookieRefreshTime { return false } // derive cookie key var cookie [blake2s.Size128]byte func() { mac, _ := blake2s.New128(st.mac2.secret[:]) mac.Write(src) mac.Sum(cookie[:0]) }() // calculate mac of packet (including mac1) smac2 := len(msg) - blake2s.Size128 var mac2 [blake2s.Size128]byte func() { mac, _ := blake2s.New128(cookie[:]) mac.Write(msg[:smac2]) mac.Sum(mac2[:0]) }() return hmac.Equal(mac2[:], msg[smac2:]) } func (st *CookieChecker) CreateReply( msg []byte, recv uint32, src []byte, ) (*MessageCookieReply, error) { st.RLock() // refresh cookie secret if time.Since(st.mac2.secretSet) > CookieRefreshTime { st.RUnlock() st.Lock() _, err := rand.Read(st.mac2.secret[:]) if err != nil { st.Unlock() return nil, err } st.mac2.secretSet = time.Now() st.Unlock() st.RLock() } // derive cookie var cookie [blake2s.Size128]byte func() { mac, _ := blake2s.New128(st.mac2.secret[:]) mac.Write(src) mac.Sum(cookie[:0]) }() // encrypt cookie size := len(msg) smac2 := size - blake2s.Size128 smac1 := smac2 - blake2s.Size128 reply := new(MessageCookieReply) reply.Type = MessageCookieReplyType reply.Receiver = recv _, err := rand.Read(reply.Nonce[:]) if err != nil { st.RUnlock() return nil, err } xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) st.RUnlock() return reply, nil } func (st *CookieGenerator) Init(pk NoisePublicKey) { st.Lock() defer st.Unlock() func() { hash, _ := blake2s.New256(nil) hash.Write([]byte(WGLabelMAC1)) hash.Write(pk[:]) hash.Sum(st.mac1.key[:0]) }() func() { hash, _ := blake2s.New256(nil) hash.Write([]byte(WGLabelCookie)) hash.Write(pk[:]) hash.Sum(st.mac2.encryptionKey[:0]) }() st.mac2.cookieSet = time.Time{} } func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { st.Lock() defer st.Unlock() if !st.mac2.hasLastMAC1 { return false } var cookie [blake2s.Size128]byte xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) if err != nil { return false } st.mac2.cookieSet = time.Now() st.mac2.cookie = cookie return true } func (st *CookieGenerator) AddMacs(msg []byte) { size := len(msg) smac2 := size - blake2s.Size128 smac1 := smac2 - blake2s.Size128 mac1 := msg[smac1:smac2] mac2 := msg[smac2:] st.Lock() defer st.Unlock() // set mac1 func() { mac, _ := blake2s.New128(st.mac1.key[:]) mac.Write(msg[:smac1]) mac.Sum(mac1[:0]) }() copy(st.mac2.lastMAC1[:], mac1) st.mac2.hasLastMAC1 = true // set mac2 if time.Since(st.mac2.cookieSet) > CookieRefreshTime { return } func() { mac, _ := blake2s.New128(st.mac2.cookie[:]) mac.Write(msg[:smac2]) mac.Sum(mac2[:0]) }() } wireguard-go-0.0.20231211/device/cookie_test.go000066400000000000000000000136741466306003400207520ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "testing" ) func TestCookieMAC1(t *testing.T) { // setup generator / checker var ( generator CookieGenerator checker CookieChecker ) sk, err := newPrivateKey() if err != nil { t.Fatal(err) } pk := sk.publicKey() generator.Init(pk) checker.Init(pk) // check mac1 src := []byte{192, 168, 13, 37, 10, 10, 10} checkMAC1 := func(msg []byte) { generator.AddMacs(msg) if !checker.CheckMAC1(msg) { t.Fatal("MAC1 generation/verification failed") } if checker.CheckMAC2(msg, src) { t.Fatal("MAC2 generation/verification failed") } } checkMAC1([]byte{ 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, }) checkMAC1([]byte{ 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, }) checkMAC1([]byte{ 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, }) // exchange cookie reply func() { msg := []byte{ 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, } generator.AddMacs(msg) reply, err := checker.CreateReply(msg, 1377, src) if err != nil { t.Fatal("Failed to create cookie reply:", err) } if !generator.ConsumeReply(reply) { t.Fatal("Failed to consume cookie reply") } }() // check mac2 checkMAC2 := func(msg []byte) { generator.AddMacs(msg) if !checker.CheckMAC1(msg) { t.Fatal("MAC1 generation/verification failed") } if !checker.CheckMAC2(msg, src) { t.Fatal("MAC2 generation/verification failed") } msg[5] ^= 0x20 if checker.CheckMAC1(msg) { t.Fatal("MAC1 generation/verification failed") } if checker.CheckMAC2(msg, src) { t.Fatal("MAC2 generation/verification failed") } msg[5] ^= 0x20 srcBad1 := []byte{192, 168, 13, 37, 40, 1} if checker.CheckMAC2(msg, srcBad1) { t.Fatal("MAC2 generation/verification failed") } srcBad2 := []byte{192, 168, 13, 38, 40, 1} if checker.CheckMAC2(msg, srcBad2) { t.Fatal("MAC2 generation/verification failed") } } checkMAC2([]byte{ 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, }) checkMAC2([]byte{ 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, }) } wireguard-go-0.0.20231211/device/device.go000066400000000000000000000326411466306003400176740ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "runtime" "sync" "sync/atomic" "time" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" ) type Device struct { state struct { // state holds the device's state. It is accessed atomically. // Use the device.deviceState method to read it. // device.deviceState does not acquire the mutex, so it captures only a snapshot. // During state transitions, the state variable is updated before the device itself. // The state is thus either the current state of the device or // the intended future state of the device. // For example, while executing a call to Up, state will be deviceStateUp. // There is no guarantee that that intended future state of the device // will become the actual state; Up can fail. // The device can also change state multiple times between time of check and time of use. // Unsynchronized uses of state must therefore be advisory/best-effort only. state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience // stopping blocks until all inputs to Device have been closed. stopping sync.WaitGroup // mu protects state changes. sync.Mutex } net struct { stopping sync.WaitGroup sync.RWMutex bind conn.Bind // bind interface netlinkCancel *rwcancel.RWCancel port uint16 // listening port fwmark uint32 // mark value (0 = disabled) brokenRoaming bool } staticIdentity struct { sync.RWMutex privateKey NoisePrivateKey publicKey NoisePublicKey } peers struct { sync.RWMutex // protects keyMap keyMap map[NoisePublicKey]*Peer } rate struct { underLoadUntil atomic.Int64 limiter ratelimiter.Ratelimiter } allowedips AllowedIPs indexTable IndexTable cookieChecker CookieChecker pool struct { inboundElementsContainer *WaitPool outboundElementsContainer *WaitPool messageBuffers *WaitPool inboundElements *WaitPool outboundElements *WaitPool } queue struct { encryption *outboundQueue decryption *inboundQueue handshake *handshakeQueue } tun struct { device tun.Device mtu atomic.Int32 } ipcMutex sync.RWMutex closed chan struct{} log *Logger } // deviceState represents the state of a Device. // There are three states: down, up, closed. // Transitions: // // down -----+ // ↑↓ ↓ // up -> closed type deviceState uint32 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState const ( deviceStateDown deviceState = iota deviceStateUp deviceStateClosed ) // deviceState returns device.state.state as a deviceState // See those docs for how to interpret this value. func (device *Device) deviceState() deviceState { return deviceState(device.state.state.Load()) } // isClosed reports whether the device is closed (or is closing). // See device.state.state comments for how to interpret this value. func (device *Device) isClosed() bool { return device.deviceState() == deviceStateClosed } // isUp reports whether the device is up (or is attempting to come up). // See device.state.state comments for how to interpret this value. func (device *Device) isUp() bool { return device.deviceState() == deviceStateUp } // Must hold device.peers.Lock() func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { // stop routing and processing of packets device.allowedips.RemoveByPeer(peer) peer.Stop() // remove from peer map delete(device.peers.keyMap, key) } // changeState attempts to change the device state to match want. func (device *Device) changeState(want deviceState) (err error) { device.state.Lock() defer device.state.Unlock() old := device.deviceState() if old == deviceStateClosed { // once closed, always closed device.log.Verbosef("Interface closed, ignored requested state %s", want) return nil } switch want { case old: return nil case deviceStateUp: device.state.state.Store(uint32(deviceStateUp)) err = device.upLocked() if err == nil { break } fallthrough // up failed; bring the device all the way back down case deviceStateDown: device.state.state.Store(uint32(deviceStateDown)) errDown := device.downLocked() if err == nil { err = errDown } } device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) return } // upLocked attempts to bring the device up and reports whether it succeeded. // The caller must hold device.state.mu and is responsible for updating device.state.state. func (device *Device) upLocked() error { if err := device.BindUpdate(); err != nil { device.log.Errorf("Unable to update bind: %v", err) return err } // The IPC set operation waits for peers to be created before calling Start() on them, // so if there's a concurrent IPC set request happening, we should wait for it to complete. device.ipcMutex.Lock() defer device.ipcMutex.Unlock() device.peers.RLock() for _, peer := range device.peers.keyMap { peer.Start() if peer.persistentKeepaliveInterval.Load() > 0 { peer.SendKeepalive() } } device.peers.RUnlock() return nil } // downLocked attempts to bring the device down. // The caller must hold device.state.mu and is responsible for updating device.state.state. func (device *Device) downLocked() error { err := device.BindClose() if err != nil { device.log.Errorf("Bind close failed: %v", err) } device.peers.RLock() for _, peer := range device.peers.keyMap { peer.Stop() } device.peers.RUnlock() return err } func (device *Device) Up() error { return device.changeState(deviceStateUp) } func (device *Device) Down() error { return device.changeState(deviceStateDown) } func (device *Device) IsUnderLoad() bool { // check if currently under load now := time.Now() underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 if underLoad { device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) return true } // check if recently under load return device.rate.underLoadUntil.Load() > now.UnixNano() } func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // lock required resources device.staticIdentity.Lock() defer device.staticIdentity.Unlock() if sk.Equals(device.staticIdentity.privateKey) { return nil } device.peers.Lock() defer device.peers.Unlock() lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { peer.handshake.mutex.RLock() lockedPeers = append(lockedPeers, peer) } // remove peers with matching public keys publicKey := sk.publicKey() for key, peer := range device.peers.keyMap { if peer.handshake.remoteStatic.Equals(publicKey) { peer.handshake.mutex.RUnlock() removePeerLocked(device, peer, key) peer.handshake.mutex.RLock() } } // update key material device.staticIdentity.privateKey = sk device.staticIdentity.publicKey = publicKey device.cookieChecker.Init(publicKey) // do static-static DH pre-computations expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) expiredPeers = append(expiredPeers, peer) } for _, peer := range lockedPeers { peer.handshake.mutex.RUnlock() } for _, peer := range expiredPeers { peer.ExpireCurrentKeypairs() } return nil } func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device := new(Device) device.state.state.Store(uint32(deviceStateDown)) device.closed = make(chan struct{}) device.log = logger device.net.bind = bind device.tun.device = tunDevice mtu, err := device.tun.device.MTU() if err != nil { device.log.Errorf("Trouble determining MTU, assuming default: %v", err) mtu = DefaultMTU } device.tun.mtu.Store(int32(mtu)) device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.rate.limiter.Init() device.indexTable.Init() device.PopulatePools() // create queues device.queue.handshake = newHandshakeQueue() device.queue.encryption = newOutboundQueue() device.queue.decryption = newInboundQueue() // start workers cpus := runtime.NumCPU() device.state.stopping.Wait() device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake for i := 0; i < cpus; i++ { go device.RoutineEncryption(i + 1) go device.RoutineDecryption(i + 1) go device.RoutineHandshake(i + 1) } device.state.stopping.Add(1) // RoutineReadFromTUN device.queue.encryption.wg.Add(1) // RoutineReadFromTUN go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() return device } // BatchSize returns the BatchSize for the device as a whole which is the max of // the bind batch size and the tun batch size. The batch size reported by device // is the size used to construct memory pools, and is the allowed batch size for // the lifetime of the device. func (device *Device) BatchSize() int { size := device.net.bind.BatchSize() dSize := device.tun.device.BatchSize() if size < dSize { size = dSize } return size } func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() return device.peers.keyMap[pk] } func (device *Device) RemovePeer(key NoisePublicKey) { device.peers.Lock() defer device.peers.Unlock() // stop peer and remove from routing peer, ok := device.peers.keyMap[key] if ok { removePeerLocked(device, peer, key) } } func (device *Device) RemoveAllPeers() { device.peers.Lock() defer device.peers.Unlock() for key, peer := range device.peers.keyMap { removePeerLocked(device, peer, key) } device.peers.keyMap = make(map[NoisePublicKey]*Peer) } func (device *Device) Close() { device.state.Lock() defer device.state.Unlock() device.ipcMutex.Lock() defer device.ipcMutex.Unlock() if device.isClosed() { return } device.state.state.Store(uint32(deviceStateClosed)) device.log.Verbosef("Device closing") device.tun.device.Close() device.downLocked() // Remove peers before closing queues, // because peers assume that queues are active. device.RemoveAllPeers() // We kept a reference to the encryption and decryption queues, // in case we started any new peers that might write to them. // No new peers are coming; we are done with these queues. device.queue.encryption.wg.Done() device.queue.decryption.wg.Done() device.queue.handshake.wg.Done() device.state.stopping.Wait() device.rate.limiter.Close() device.log.Verbosef("Device closed") close(device.closed) } func (device *Device) Wait() chan struct{} { return device.closed } func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { if !device.isUp() { return } device.peers.RLock() for _, peer := range device.peers.keyMap { peer.keypairs.RLock() sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) peer.keypairs.RUnlock() if sendKeepalive { peer.SendKeepalive() } } device.peers.RUnlock() } // closeBindLocked closes the device's net.bind. // The caller must hold the net mutex. func closeBindLocked(device *Device) error { var err error netc := &device.net if netc.netlinkCancel != nil { netc.netlinkCancel.Cancel() } if netc.bind != nil { err = netc.bind.Close() } netc.stopping.Wait() return err } func (device *Device) Bind() conn.Bind { device.net.Lock() defer device.net.Unlock() return device.net.bind } func (device *Device) BindSetMark(mark uint32) error { device.net.Lock() defer device.net.Unlock() // check if modified if device.net.fwmark == mark { return nil } // update fwmark on existing bind device.net.fwmark = mark if device.isUp() && device.net.bind != nil { if err := device.net.bind.SetMark(mark); err != nil { return err } } // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { peer.markEndpointSrcForClearing() } device.peers.RUnlock() return nil } func (device *Device) BindUpdate() error { device.net.Lock() defer device.net.Unlock() // close existing sockets if err := closeBindLocked(device); err != nil { return err } // open new sockets if !device.isUp() { return nil } // bind to new port var err error var recvFns []conn.ReceiveFunc netc := &device.net recvFns, netc.port, err = netc.bind.Open(netc.port) if err != nil { netc.port = 0 return err } netc.netlinkCancel, err = device.startRouteListener(netc.bind) if err != nil { netc.bind.Close() netc.port = 0 return err } // set fwmark if netc.fwmark != 0 { err = netc.bind.SetMark(netc.fwmark) if err != nil { return err } } // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { peer.markEndpointSrcForClearing() } device.peers.RUnlock() // start receiving routines device.net.stopping.Add(len(recvFns)) device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake batchSize := netc.bind.BatchSize() for _, fn := range recvFns { go device.RoutineReceiveIncoming(batchSize, fn) } device.log.Verbosef("UDP bind has been updated") return nil } func (device *Device) BindClose() error { device.net.Lock() err := closeBindLocked(device) device.net.Unlock() return err } wireguard-go-0.0.20231211/device/device_test.go000066400000000000000000000303221466306003400207250ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bytes" "encoding/hex" "fmt" "io" "math/rand" "net/netip" "os" "runtime" "runtime/pprof" "sync" "sync/atomic" "testing" "time" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/tuntest" ) // uapiCfg returns a string that contains cfg formatted use with IpcSet. // cfg is a series of alternating key/value strings. // uapiCfg exists because editors and humans like to insert // whitespace into configs, which can cause failures, some of which are silent. // For example, a leading blank newline causes the remainder // of the config to be silently ignored. func uapiCfg(cfg ...string) string { if len(cfg)%2 != 0 { panic("odd number of args to uapiReader") } buf := new(bytes.Buffer) for i, s := range cfg { buf.WriteString(s) sep := byte('\n') if i%2 == 0 { sep = '=' } buf.WriteByte(sep) } return buf.String() } // genConfigs generates a pair of configs that connect to each other. // The configs use distinct, probably-usable ports. func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { var key1, key2 NoisePrivateKey _, err := rand.Read(key1[:]) if err != nil { tb.Errorf("unable to generate private key random bytes: %v", err) } _, err = rand.Read(key2[:]) if err != nil { tb.Errorf("unable to generate private key random bytes: %v", err) } pub1, pub2 := key1.publicKey(), key2.publicKey() cfgs[0] = uapiCfg( "private_key", hex.EncodeToString(key1[:]), "listen_port", "0", "replace_peers", "true", "public_key", hex.EncodeToString(pub2[:]), "protocol_version", "1", "replace_allowed_ips", "true", "allowed_ip", "1.0.0.2/32", ) endpointCfgs[0] = uapiCfg( "public_key", hex.EncodeToString(pub2[:]), "endpoint", "127.0.0.1:%d", ) cfgs[1] = uapiCfg( "private_key", hex.EncodeToString(key2[:]), "listen_port", "0", "replace_peers", "true", "public_key", hex.EncodeToString(pub1[:]), "protocol_version", "1", "replace_allowed_ips", "true", "allowed_ip", "1.0.0.1/32", ) endpointCfgs[1] = uapiCfg( "public_key", hex.EncodeToString(pub1[:]), "endpoint", "127.0.0.1:%d", ) return } // A testPair is a pair of testPeers. type testPair [2]testPeer // A testPeer is a peer used for testing. type testPeer struct { tun *tuntest.ChannelTUN dev *Device ip netip.Addr } type SendDirection bool const ( Ping SendDirection = true Pong SendDirection = false ) func (d SendDirection) String() string { if d == Ping { return "ping" } return "pong" } func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { tb.Helper() p0, p1 := pair[0], pair[1] if !ping { // pong is the new ping p0, p1 = p1, p0 } msg := tuntest.Ping(p0.ip, p1.ip) p1.tun.Outbound <- msg timer := time.NewTimer(5 * time.Second) defer timer.Stop() var err error select { case msgRecv := <-p0.tun.Inbound: if !bytes.Equal(msg, msgRecv) { err = fmt.Errorf("%s did not transit correctly", ping) } case <-timer.C: err = fmt.Errorf("%s did not transit", ping) case <-done: } if err != nil { // The error may have occurred because the test is done. select { case <-done: return default: } // Real error. tb.Error(err) } } // genTestPair creates a testPair. func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { cfg, endpointCfg := genConfigs(tb) var binds [2]conn.Bind if realSocket { binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() } else { binds = bindtest.NewChannelBinds() } // Bring up a ChannelTun for each config. for i := range pair { p := &pair[i] p.tun = tuntest.NewChannelTUN() p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) level := LogLevelVerbose if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError } p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) if err := p.dev.IpcSet(cfg[i]); err != nil { tb.Errorf("failed to configure device %d: %v", i, err) p.dev.Close() continue } if err := p.dev.Up(); err != nil { tb.Errorf("failed to bring up device %d: %v", i, err) p.dev.Close() continue } endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port) } for i := range pair { p := &pair[i] if err := p.dev.IpcSet(endpointCfg[i]); err != nil { tb.Errorf("failed to configure device endpoint %d: %v", i, err) p.dev.Close() continue } // The device is ready. Close it when the test completes. tb.Cleanup(p.dev.Close) } return } func TestTwoDevicePing(t *testing.T) { goroutineLeakCheck(t) pair := genTestPair(t, true) t.Run("ping 1.0.0.1", func(t *testing.T) { pair.Send(t, Ping, nil) }) t.Run("ping 1.0.0.2", func(t *testing.T) { pair.Send(t, Pong, nil) }) } func TestUpDown(t *testing.T) { goroutineLeakCheck(t) const itrials = 50 const otrials = 10 for n := 0; n < otrials; n++ { pair := genTestPair(t, false) for i := range pair { for k := range pair[i].dev.peers.keyMap { pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) } } var wg sync.WaitGroup wg.Add(len(pair)) for i := range pair { go func(d *Device) { defer wg.Done() for i := 0; i < itrials; i++ { if err := d.Up(); err != nil { t.Errorf("failed up bring up device: %v", err) } time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) if err := d.Down(); err != nil { t.Errorf("failed to bring down device: %v", err) } time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) } }(pair[i].dev) } wg.Wait() for i := range pair { pair[i].dev.Up() pair[i].dev.Close() } } } // TestConcurrencySafety does other things concurrently with tunnel use. // It is intended to be used with the race detector to catch data races. func TestConcurrencySafety(t *testing.T) { pair := genTestPair(t, true) done := make(chan struct{}) const warmupIters = 10 var warmup sync.WaitGroup warmup.Add(warmupIters) go func() { // Send data continuously back and forth until we're done. // Note that we may continue to attempt to send data // even after done is closed. i := warmupIters for ping := Ping; ; ping = !ping { pair.Send(t, ping, done) select { case <-done: return default: } if i > 0 { warmup.Done() i-- } } }() warmup.Wait() applyCfg := func(cfg string) { err := pair[0].dev.IpcSet(cfg) if err != nil { t.Fatal(err) } } // Change persistent_keepalive_interval concurrently with tunnel use. t.Run("persistentKeepaliveInterval", func(t *testing.T) { var pub NoisePublicKey for key := range pair[0].dev.peers.keyMap { pub = key break } cfg := uapiCfg( "public_key", hex.EncodeToString(pub[:]), "persistent_keepalive_interval", "1", ) for i := 0; i < 1000; i++ { applyCfg(cfg) } }) // Change private keys concurrently with tunnel use. t.Run("privateKey", func(t *testing.T) { bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) // Set iters to a large number like 1000 to flush out data races quickly. // Don't leave it large. That can cause logical races // in which the handshake is interleaved with key changes // such that the private key appears to be unchanging but // other state gets reset, which can cause handshake failures like // "Received packet with invalid mac1". const iters = 1 for i := 0; i < iters; i++ { applyCfg(bad) applyCfg(good) } }) // Perform bind updates and keepalive sends concurrently with tunnel use. t.Run("bindUpdate and keepalive", func(t *testing.T) { const iters = 10 for i := 0; i < iters; i++ { for _, peer := range pair { peer.dev.BindUpdate() peer.dev.SendKeepalivesToPeersWithCurrentKeypair() } } }) close(done) } func BenchmarkLatency(b *testing.B) { pair := genTestPair(b, true) // Establish a connection. pair.Send(b, Ping, nil) pair.Send(b, Pong, nil) b.ResetTimer() for i := 0; i < b.N; i++ { pair.Send(b, Ping, nil) pair.Send(b, Pong, nil) } } func BenchmarkThroughput(b *testing.B) { pair := genTestPair(b, true) // Establish a connection. pair.Send(b, Ping, nil) pair.Send(b, Pong, nil) // Measure how long it takes to receive b.N packets, // starting when we receive the first packet. var recv atomic.Uint64 var elapsed time.Duration var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() var start time.Time for { <-pair[0].tun.Inbound new := recv.Add(1) if new == 1 { start = time.Now() } // Careful! Don't change this to else if; b.N can be equal to 1. if new == uint64(b.N) { elapsed = time.Since(start) return } } }() // Send packets as fast as we can until we've received enough. ping := tuntest.Ping(pair[0].ip, pair[1].ip) pingc := pair[1].tun.Outbound var sent uint64 for recv.Load() != uint64(b.N) { sent++ pingc <- ping } wg.Wait() b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op") b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss") } func BenchmarkUAPIGet(b *testing.B) { pair := genTestPair(b, true) pair.Send(b, Ping, nil) pair.Send(b, Pong, nil) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { pair[0].dev.IpcGetOperation(io.Discard) } } func goroutineLeakCheck(t *testing.T) { goroutines := func() (int, []byte) { p := pprof.Lookup("goroutine") b := new(bytes.Buffer) p.WriteTo(b, 1) return p.Count(), b.Bytes() } startGoroutines, startStacks := goroutines() t.Cleanup(func() { if t.Failed() { return } // Give goroutines time to exit, if they need it. for i := 0; i < 10000; i++ { if runtime.NumGoroutine() <= startGoroutines { return } time.Sleep(1 * time.Millisecond) } endGoroutines, endStacks := goroutines() t.Logf("starting stacks:\n%s\n", startStacks) t.Logf("ending stacks:\n%s\n", endStacks) t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) }) } type fakeBindSized struct { size int } func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 0, nil } func (b *fakeBindSized) Close() error { return nil } func (b *fakeBindSized) SetMark(mark uint32) error { return nil } func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } func (b *fakeBindSized) BatchSize() int { return b.size } type fakeTUNDeviceSized struct { size int } func (t *fakeTUNDeviceSized) File() *os.File { return nil } func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { return 0, nil } func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } func (t *fakeTUNDeviceSized) Close() error { return nil } func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } func TestBatchSize(t *testing.T) { d := Device{} d.net.bind = &fakeBindSized{1} d.tun.device = &fakeTUNDeviceSized{1} if want, got := 1, d.BatchSize(); got != want { t.Errorf("expected batch size %d, got %d", want, got) } d.net.bind = &fakeBindSized{1} d.tun.device = &fakeTUNDeviceSized{128} if want, got := 128, d.BatchSize(); got != want { t.Errorf("expected batch size %d, got %d", want, got) } d.net.bind = &fakeBindSized{128} d.tun.device = &fakeTUNDeviceSized{1} if want, got := 128, d.BatchSize(); got != want { t.Errorf("expected batch size %d, got %d", want, got) } d.net.bind = &fakeBindSized{128} d.tun.device = &fakeTUNDeviceSized{128} if want, got := 128, d.BatchSize(); got != want { t.Errorf("expected batch size %d, got %d", want, got) } } wireguard-go-0.0.20231211/device/devicestate_string.go000066400000000000000000000006761466306003400223260ustar00rootroot00000000000000// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. package device import "strconv" const _deviceState_name = "DownUpClosed" var _deviceState_index = [...]uint8{0, 4, 6, 12} func (i deviceState) String() string { if i >= deviceState(len(_deviceState_index)-1) { return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" } return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] } wireguard-go-0.0.20231211/device/endpoint_test.go000066400000000000000000000017401466306003400213100ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" "net/netip" ) type DummyEndpoint struct { src, dst netip.Addr } func CreateDummyEndpoint() (*DummyEndpoint, error) { var src, dst [16]byte if _, err := rand.Read(src[:]); err != nil { return nil, err } _, err := rand.Read(dst[:]) return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err } func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) SrcToString() string { return netip.AddrPortFrom(e.SrcIP(), 1000).String() } func (e *DummyEndpoint) DstToString() string { return netip.AddrPortFrom(e.DstIP(), 1000).String() } func (e *DummyEndpoint) DstToBytes() []byte { out := e.DstIP().AsSlice() out = append(out, byte(1000&0xff)) out = append(out, byte((1000>>8)&0xff)) return out } func (e *DummyEndpoint) DstIP() netip.Addr { return e.dst } func (e *DummyEndpoint) SrcIP() netip.Addr { return e.src } wireguard-go-0.0.20231211/device/indextable.go000066400000000000000000000034031466306003400205460ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/rand" "encoding/binary" "sync" ) type IndexTableEntry struct { peer *Peer handshake *Handshake keypair *Keypair } type IndexTable struct { sync.RWMutex table map[uint32]IndexTableEntry } func randUint32() (uint32, error) { var integer [4]byte _, err := rand.Read(integer[:]) // Arbitrary endianness; both are intrinsified by the Go compiler. return binary.LittleEndian.Uint32(integer[:]), err } func (table *IndexTable) Init() { table.Lock() defer table.Unlock() table.table = make(map[uint32]IndexTableEntry) } func (table *IndexTable) Delete(index uint32) { table.Lock() defer table.Unlock() delete(table.table, index) } func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { table.Lock() defer table.Unlock() entry, ok := table.table[index] if !ok { return } table.table[index] = IndexTableEntry{ peer: entry.peer, keypair: keypair, handshake: nil, } } func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { for { // generate random index index, err := randUint32() if err != nil { return index, err } // check if index used table.RLock() _, ok := table.table[index] table.RUnlock() if ok { continue } // check again while locked table.Lock() _, found := table.table[index] if found { table.Unlock() continue } table.table[index] = IndexTableEntry{ peer: peer, handshake: handshake, keypair: nil, } table.Unlock() return index, nil } } func (table *IndexTable) Lookup(id uint32) IndexTableEntry { table.RLock() defer table.RUnlock() return table.table[id] } wireguard-go-0.0.20231211/device/ip.go000066400000000000000000000005741466306003400170450ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "net" ) const ( IPv4offsetTotalLength = 2 IPv4offsetSrc = 12 IPv4offsetDst = IPv4offsetSrc + net.IPv4len ) const ( IPv6offsetPayloadLength = 4 IPv6offsetSrc = 8 IPv6offsetDst = IPv6offsetSrc + net.IPv6len ) wireguard-go-0.0.20231211/device/kdf_test.go000066400000000000000000000042251466306003400202350ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "encoding/hex" "testing" "golang.org/x/crypto/blake2s" ) type KDFTest struct { key string input string t0 string t1 string t2 string } func assertEquals(t *testing.T, a, b string) { if a != b { t.Fatal("expected", a, "=", b) } } func TestKDF(t *testing.T) { tests := []KDFTest{ { key: "746573742d6b6579", input: "746573742d696e707574", t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", }, { key: "776972656775617264", input: "776972656775617264", t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", }, { key: "", input: "", t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", }, } var t0, t1, t2 [blake2s.Size]byte for _, test := range tests { key, _ := hex.DecodeString(test.key) input, _ := hex.DecodeString(test.input) KDF3(&t0, &t1, &t2, key, input) t0s := hex.EncodeToString(t0[:]) t1s := hex.EncodeToString(t1[:]) t2s := hex.EncodeToString(t2[:]) assertEquals(t, t0s, test.t0) assertEquals(t, t1s, test.t1) assertEquals(t, t2s, test.t2) } for _, test := range tests { key, _ := hex.DecodeString(test.key) input, _ := hex.DecodeString(test.input) KDF2(&t0, &t1, key, input) t0s := hex.EncodeToString(t0[:]) t1s := hex.EncodeToString(t1[:]) assertEquals(t, t0s, test.t0) assertEquals(t, t1s, test.t1) } for _, test := range tests { key, _ := hex.DecodeString(test.key) input, _ := hex.DecodeString(test.input) KDF1(&t0, key, input) t0s := hex.EncodeToString(t0[:]) assertEquals(t, t0s, test.t0) } } wireguard-go-0.0.20231211/device/keypair.go000066400000000000000000000017701466306003400201000ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/cipher" "sync" "sync/atomic" "time" "golang.zx2c4.com/wireguard/replay" ) /* Due to limitations in Go and /x/crypto there is currently * no way to ensure that key material is securely ereased in memory. * * Since this may harm the forward secrecy property, * we plan to resolve this issue; whenever Go allows us to do so. */ type Keypair struct { sendNonce atomic.Uint64 send cipher.AEAD receive cipher.AEAD replayFilter replay.Filter isInitiator bool created time.Time localIndex uint32 remoteIndex uint32 } type Keypairs struct { sync.RWMutex current *Keypair previous *Keypair next atomic.Pointer[Keypair] } func (kp *Keypairs) Current() *Keypair { kp.RLock() defer kp.RUnlock() return kp.current } func (device *Device) DeleteKeypair(key *Keypair) { if key != nil { device.indexTable.Delete(key.localIndex) } } wireguard-go-0.0.20231211/device/logger.go000066400000000000000000000023471466306003400177140ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "log" "os" ) // A Logger provides logging for a Device. // The functions are Printf-style functions. // They must be safe for concurrent use. // They do not require a trailing newline in the format. // If nil, that level of logging will be silent. type Logger struct { Verbosef func(format string, args ...any) Errorf func(format string, args ...any) } // Log levels for use with NewLogger. const ( LogLevelSilent = iota LogLevelError LogLevelVerbose ) // Function for use in Logger for discarding logged lines. func DiscardLogf(format string, args ...any) {} // NewLogger constructs a Logger that writes to stdout. // It logs at the specified log level and above. // It decorates log lines with the log level, date, time, and prepend. func NewLogger(level int, prepend string) *Logger { logger := &Logger{DiscardLogf, DiscardLogf} logf := func(prefix string) func(string, ...any) { return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf } if level >= LogLevelVerbose { logger.Verbosef = logf("DEBUG") } if level >= LogLevelError { logger.Errorf = logf("ERROR") } return logger } wireguard-go-0.0.20231211/device/mobilequirks.go000066400000000000000000000011171466306003400211350ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device // DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created, // though it will try to deal with it, and race maybe, if called after. func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { device.net.brokenRoaming = true device.peers.RLock() for _, peer := range device.peers.keyMap { peer.endpoint.Lock() peer.endpoint.disableRoaming = peer.endpoint.val != nil peer.endpoint.Unlock() } device.peers.RUnlock() } wireguard-go-0.0.20231211/device/noise-helpers.go000066400000000000000000000044731466306003400212140ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/hmac" "crypto/rand" "crypto/subtle" "errors" "hash" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/curve25519" ) /* KDF related functions. * HMAC-based Key Derivation Function (HKDF) * https://tools.ietf.org/html/rfc5869 */ func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { mac := hmac.New(func() hash.Hash { h, _ := blake2s.New256(nil) return h }, key) mac.Write(in0) mac.Sum(sum[:0]) } func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { mac := hmac.New(func() hash.Hash { h, _ := blake2s.New256(nil) return h }, key) mac.Write(in0) mac.Write(in1) mac.Sum(sum[:0]) } func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { HMAC1(t0, key, input) HMAC1(t0, t0[:], []byte{0x1}) } func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { var prk [blake2s.Size]byte HMAC1(&prk, key, input) HMAC1(t0, prk[:], []byte{0x1}) HMAC2(t1, prk[:], t0[:], []byte{0x2}) setZero(prk[:]) } func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { var prk [blake2s.Size]byte HMAC1(&prk, key, input) HMAC1(t0, prk[:], []byte{0x1}) HMAC2(t1, prk[:], t0[:], []byte{0x2}) HMAC2(t2, prk[:], t1[:], []byte{0x3}) setZero(prk[:]) } func isZero(val []byte) bool { acc := 1 for _, b := range val { acc &= subtle.ConstantTimeByteEq(b, 0) } return acc == 1 } /* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ func setZero(arr []byte) { for i := range arr { arr[i] = 0 } } func (sk *NoisePrivateKey) clamp() { sk[0] &= 248 sk[31] = (sk[31] & 127) | 64 } func newPrivateKey() (sk NoisePrivateKey, err error) { _, err = rand.Read(sk[:]) sk.clamp() return } func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { apk := (*[NoisePublicKeySize]byte)(&pk) ask := (*[NoisePrivateKeySize]byte)(sk) curve25519.ScalarBaseMult(apk, ask) return } var errInvalidPublicKey = errors.New("invalid public key") func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) { apk := (*[NoisePublicKeySize]byte)(&pk) ask := (*[NoisePrivateKeySize]byte)(sk) curve25519.ScalarMult(&ss, ask, apk) if isZero(ss[:]) { return ss, errInvalidPublicKey } return ss, nil } wireguard-go-0.0.20231211/device/noise-protocol.go000066400000000000000000000353441466306003400214140ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "errors" "fmt" "sync" "time" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" "golang.zx2c4.com/wireguard/tai64n" ) type handshakeState int const ( handshakeZeroed = handshakeState(iota) handshakeInitiationCreated handshakeInitiationConsumed handshakeResponseCreated handshakeResponseConsumed ) func (hs handshakeState) String() string { switch hs { case handshakeZeroed: return "handshakeZeroed" case handshakeInitiationCreated: return "handshakeInitiationCreated" case handshakeInitiationConsumed: return "handshakeInitiationConsumed" case handshakeResponseCreated: return "handshakeResponseCreated" case handshakeResponseConsumed: return "handshakeResponseConsumed" default: return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) } } const ( NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" WGLabelMAC1 = "mac1----" WGLabelCookie = "cookie--" ) const ( MessageInitiationType = 1 MessageResponseType = 2 MessageCookieReplyType = 3 MessageTransportType = 4 ) const ( MessageInitiationSize = 148 // size of handshake initiation message MessageResponseSize = 92 // size of response message MessageCookieReplySize = 64 // size of cookie reply message MessageTransportHeaderSize = 16 // size of data preceding content in transport message MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport MessageKeepaliveSize = MessageTransportSize // size of keepalive MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message ) const ( MessageTransportOffsetReceiver = 4 MessageTransportOffsetCounter = 8 MessageTransportOffsetContent = 16 ) /* Type is an 8-bit field, followed by 3 nul bytes, * by marshalling the messages in little-endian byteorder * we can treat these as a 32-bit unsigned int (for now) * */ type MessageInitiation struct { Type uint32 Sender uint32 Ephemeral NoisePublicKey Static [NoisePublicKeySize + poly1305.TagSize]byte Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte } type MessageResponse struct { Type uint32 Sender uint32 Receiver uint32 Ephemeral NoisePublicKey Empty [poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte } type MessageTransport struct { Type uint32 Receiver uint32 Counter uint64 Content []byte } type MessageCookieReply struct { Type uint32 Receiver uint32 Nonce [chacha20poly1305.NonceSizeX]byte Cookie [blake2s.Size128 + poly1305.TagSize]byte } type Handshake struct { state handshakeState mutex sync.RWMutex hash [blake2s.Size]byte // hash value chainKey [blake2s.Size]byte // chain key presharedKey NoisePresharedKey // psk localEphemeral NoisePrivateKey // ephemeral secret key localIndex uint32 // used to clear hash-table remoteIndex uint32 // index for sending remoteStatic NoisePublicKey // long term key remoteEphemeral NoisePublicKey // ephemeral public key precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp lastInitiationConsumption time.Time lastSentHandshake time.Time } var ( InitialChainKey [blake2s.Size]byte InitialHash [blake2s.Size]byte ZeroNonce [chacha20poly1305.NonceSize]byte ) func mixKey(dst, c *[blake2s.Size]byte, data []byte) { KDF1(dst, c[:], data) } func mixHash(dst, h *[blake2s.Size]byte, data []byte) { hash, _ := blake2s.New256(nil) hash.Write(h[:]) hash.Write(data) hash.Sum(dst[:0]) hash.Reset() } func (h *Handshake) Clear() { setZero(h.localEphemeral[:]) setZero(h.remoteEphemeral[:]) setZero(h.chainKey[:]) setZero(h.hash[:]) h.localIndex = 0 h.state = handshakeZeroed } func (h *Handshake) mixHash(data []byte) { mixHash(&h.hash, &h.hash, data) } func (h *Handshake) mixKey(data []byte) { mixKey(&h.chainKey, &h.chainKey, data) } /* Do basic precomputations */ func init() { InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() // create ephemeral key var err error handshake.hash = InitialHash handshake.chainKey = InitialChainKey handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err } handshake.mixHash(handshake.remoteStatic[:]) msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) // encrypt static key ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) if err != nil { return nil, err } var key [chacha20poly1305.KeySize]byte KDF2( &handshake.chainKey, &key, handshake.chainKey[:], ss[:], ) aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) handshake.mixHash(msg.Static[:]) // encrypt timestamp if isZero(handshake.precomputedStaticStatic[:]) { return nil, errInvalidPublicKey } KDF2( &handshake.chainKey, &key, handshake.chainKey[:], handshake.precomputedStaticStatic[:], ) timestamp := tai64n.Now() aead, _ = chacha20poly1305.New(key[:]) aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) // assign index device.indexTable.Delete(handshake.localIndex) msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) if err != nil { return nil, err } handshake.localIndex = msg.Sender handshake.mixHash(msg.Timestamp[:]) handshake.state = handshakeInitiationCreated return &msg, nil } func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { var ( hash [blake2s.Size]byte chainKey [blake2s.Size]byte ) if msg.Type != MessageInitiationType { return nil } device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) mixHash(&hash, &hash, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key var peerPK NoisePublicKey var key [chacha20poly1305.KeySize]byte ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if err != nil { return nil } KDF2(&chainKey, &key, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) if err != nil { return nil } mixHash(&hash, &hash, msg.Static[:]) // lookup peer peer := device.LookupPeer(peerPK) if peer == nil || !peer.isRunning.Load() { return nil } handshake := &peer.handshake // verify identity var timestamp tai64n.Timestamp handshake.mutex.RLock() if isZero(handshake.precomputedStaticStatic[:]) { handshake.mutex.RUnlock() return nil } KDF2( &chainKey, &key, chainKey[:], handshake.precomputedStaticStatic[:], ) aead, _ = chacha20poly1305.New(key[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) if err != nil { handshake.mutex.RUnlock() return nil } mixHash(&hash, &hash, msg.Timestamp[:]) // protect against replay & flood replay := !timestamp.After(handshake.lastTimestamp) flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate handshake.mutex.RUnlock() if replay { device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) return nil } if flood { device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) return nil } // update handshake state handshake.mutex.Lock() handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender handshake.remoteEphemeral = msg.Ephemeral if timestamp.After(handshake.lastTimestamp) { handshake.lastTimestamp = timestamp } now := time.Now() if now.After(handshake.lastInitiationConsumption) { handshake.lastInitiationConsumption = now } handshake.state = handshakeInitiationConsumed handshake.mutex.Unlock() setZero(hash[:]) setZero(chainKey[:]) return peer } func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() if handshake.state != handshakeInitiationConsumed { return nil, errors.New("handshake initiation must be consumed first") } // assign index var err error device.indexTable.Delete(handshake.localIndex) handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) if err != nil { return nil, err } var msg MessageResponse msg.Type = MessageResponseType msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex // create ephemeral key handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err } msg.Ephemeral = handshake.localEphemeral.publicKey() handshake.mixHash(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:]) ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) if err != nil { return nil, err } handshake.mixKey(ss[:]) ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) if err != nil { return nil, err } handshake.mixKey(ss[:]) // add preshared key var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte KDF3( &handshake.chainKey, &tau, &key, handshake.chainKey[:], handshake.presharedKey[:], ) handshake.mixHash(tau[:]) aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) handshake.mixHash(msg.Empty[:]) handshake.state = handshakeResponseCreated return &msg, nil } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { if msg.Type != MessageResponseType { return nil } // lookup handshake by receiver lookup := device.indexTable.Lookup(msg.Receiver) handshake := lookup.handshake if handshake == nil { return nil } var ( hash [blake2s.Size]byte chainKey [blake2s.Size]byte ) ok := func() bool { // lock handshake state handshake.mutex.RLock() defer handshake.mutex.RUnlock() if handshake.state != handshakeInitiationCreated { return false } // lock private key for reading device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() // finish 3-way DH mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) if err != nil { return false } mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if err != nil { return false } mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) // add preshared key (psk) var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte KDF3( &chainKey, &tau, &key, chainKey[:], handshake.presharedKey[:], ) mixHash(&hash, &hash, tau[:]) // authenticate transcript aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) if err != nil { return false } mixHash(&hash, &hash, msg.Empty[:]) return true }() if !ok { return nil } // update handshake state handshake.mutex.Lock() handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender handshake.state = handshakeResponseConsumed handshake.mutex.Unlock() setZero(hash[:]) setZero(chainKey[:]) return lookup.peer } /* Derives a new keypair from the current handshake state * */ func (peer *Peer) BeginSymmetricSession() error { device := peer.device handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() // derive keys var isInitiator bool var sendKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte if handshake.state == handshakeResponseConsumed { KDF2( &sendKey, &recvKey, handshake.chainKey[:], nil, ) isInitiator = true } else if handshake.state == handshakeResponseCreated { KDF2( &recvKey, &sendKey, handshake.chainKey[:], nil, ) isInitiator = false } else { return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) } // zero handshake setZero(handshake.chainKey[:]) setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. setZero(handshake.localEphemeral[:]) peer.handshake.state = handshakeZeroed // create AEAD instances keypair := new(Keypair) keypair.send, _ = chacha20poly1305.New(sendKey[:]) keypair.receive, _ = chacha20poly1305.New(recvKey[:]) setZero(sendKey[:]) setZero(recvKey[:]) keypair.created = time.Now() keypair.replayFilter.Reset() keypair.isInitiator = isInitiator keypair.localIndex = peer.handshake.localIndex keypair.remoteIndex = peer.handshake.remoteIndex // remap index device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) handshake.localIndex = 0 // rotate key pairs keypairs := &peer.keypairs keypairs.Lock() defer keypairs.Unlock() previous := keypairs.previous next := keypairs.next.Load() current := keypairs.current if isInitiator { if next != nil { keypairs.next.Store(nil) keypairs.previous = next device.DeleteKeypair(current) } else { keypairs.previous = current } device.DeleteKeypair(previous) keypairs.current = keypair } else { keypairs.next.Store(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) } return nil } func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs if keypairs.next.Load() != receivedKeypair { return false } keypairs.Lock() defer keypairs.Unlock() if keypairs.next.Load() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) keypairs.current = keypairs.next.Load() keypairs.next.Store(nil) return true } wireguard-go-0.0.20231211/device/noise-types.go000066400000000000000000000031011466306003400207010ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/subtle" "encoding/hex" "errors" ) const ( NoisePublicKeySize = 32 NoisePrivateKeySize = 32 NoisePresharedKeySize = 32 ) type ( NoisePublicKey [NoisePublicKeySize]byte NoisePrivateKey [NoisePrivateKeySize]byte NoisePresharedKey [NoisePresharedKeySize]byte NoiseNonce uint64 // padded to 12-bytes ) func loadExactHex(dst []byte, src string) error { slice, err := hex.DecodeString(src) if err != nil { return err } if len(slice) != len(dst) { return errors.New("hex string does not fit the slice") } copy(dst, slice) return nil } func (key NoisePrivateKey) IsZero() bool { var zero NoisePrivateKey return key.Equals(zero) } func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 } func (key *NoisePrivateKey) FromHex(src string) (err error) { err = loadExactHex(key[:], src) key.clamp() return } func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { err = loadExactHex(key[:], src) if key.IsZero() { return } key.clamp() return } func (key *NoisePublicKey) FromHex(src string) error { return loadExactHex(key[:], src) } func (key NoisePublicKey) IsZero() bool { var zero NoisePublicKey return key.Equals(zero) } func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 } func (key *NoisePresharedKey) FromHex(src string) error { return loadExactHex(key[:], src) } wireguard-go-0.0.20231211/device/noise_test.go000066400000000000000000000066461466306003400206170ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bytes" "encoding/binary" "testing" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/tun/tuntest" ) func TestCurveWrappers(t *testing.T) { sk1, err := newPrivateKey() assertNil(t, err) sk2, err := newPrivateKey() assertNil(t, err) pk1 := sk1.publicKey() pk2 := sk2.publicKey() ss1, err1 := sk1.sharedSecret(pk2) ss2, err2 := sk2.sharedSecret(pk1) if ss1 != ss2 || err1 != nil || err2 != nil { t.Fatal("Failed to compute shared secet") } } func randDevice(t *testing.T) *Device { sk, err := newPrivateKey() if err != nil { t.Fatal(err) } tun := tuntest.NewChannelTUN() logger := NewLogger(LogLevelError, "") device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) device.SetPrivateKey(sk) return device } func assertNil(t *testing.T, err error) { if err != nil { t.Fatal(err) } } func assertEqual(t *testing.T, a, b []byte) { if !bytes.Equal(a, b) { t.Fatal(a, "!=", b) } } func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) defer dev1.Close() defer dev2.Close() peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) if err != nil { t.Fatal(err) } peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) if err != nil { t.Fatal(err) } peer1.Start() peer2.Start() assertEqual( t, peer1.handshake.precomputedStaticStatic[:], peer2.handshake.precomputedStaticStatic[:], ) /* simulate handshake */ // initiation message t.Log("exchange initiation message") msg1, err := dev1.CreateMessageInitiation(peer2) assertNil(t, err) packet := make([]byte, 0, 256) writer := bytes.NewBuffer(packet) err = binary.Write(writer, binary.LittleEndian, msg1) assertNil(t, err) peer := dev2.ConsumeMessageInitiation(msg1) if peer == nil { t.Fatal("handshake failed at initiation message") } assertEqual( t, peer1.handshake.chainKey[:], peer2.handshake.chainKey[:], ) assertEqual( t, peer1.handshake.hash[:], peer2.handshake.hash[:], ) // response message t.Log("exchange response message") msg2, err := dev2.CreateMessageResponse(peer1) assertNil(t, err) peer = dev1.ConsumeMessageResponse(msg2) if peer == nil { t.Fatal("handshake failed at response message") } assertEqual( t, peer1.handshake.chainKey[:], peer2.handshake.chainKey[:], ) assertEqual( t, peer1.handshake.hash[:], peer2.handshake.hash[:], ) // key pairs t.Log("deriving keys") err = peer1.BeginSymmetricSession() if err != nil { t.Fatal("failed to derive keypair for peer 1", err) } err = peer2.BeginSymmetricSession() if err != nil { t.Fatal("failed to derive keypair for peer 2", err) } key1 := peer1.keypairs.next.Load() key2 := peer2.keypairs.current // encrypting / decryption test t.Log("test key pairs") func() { testMsg := []byte("wireguard test message 1") var err error var out []byte var nonce [12]byte out = key1.send.Seal(out, nonce[:], testMsg, nil) out, err = key2.receive.Open(out[:0], nonce[:], out, nil) assertNil(t, err) assertEqual(t, out, testMsg) }() func() { testMsg := []byte("wireguard test message 2") var err error var out []byte var nonce [12]byte out = key2.send.Seal(out, nonce[:], testMsg, nil) out, err = key1.receive.Open(out[:0], nonce[:], out, nil) assertNil(t, err) assertEqual(t, out, testMsg) }() } wireguard-go-0.0.20231211/device/peer.go000066400000000000000000000165261466306003400173740ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "container/list" "errors" "sync" "sync/atomic" "time" "golang.zx2c4.com/wireguard/conn" ) type Peer struct { isRunning atomic.Bool keypairs Keypairs handshake Handshake device *Device stopping sync.WaitGroup // routines pending stop txBytes atomic.Uint64 // bytes send to peer (endpoint) rxBytes atomic.Uint64 // bytes received from peer lastHandshakeNano atomic.Int64 // nano seconds since epoch endpoint struct { sync.Mutex val conn.Endpoint clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission disableRoaming bool } timers struct { retransmitHandshake *Timer sendKeepalive *Timer newHandshake *Timer zeroKeyMaterial *Timer persistentKeepalive *Timer handshakeAttempts atomic.Uint32 needAnotherKeepalive atomic.Bool sentLastMinuteHandshake atomic.Bool } state struct { sync.Mutex // protects against concurrent Start/Stop } queue struct { staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available outbound *autodrainingOutboundQueue // sequential ordering of udp transmission inbound *autodrainingInboundQueue // sequential ordering of tun writing } cookieGenerator CookieGenerator trieEntries list.List persistentKeepaliveInterval atomic.Uint32 } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { if device.isClosed() { return nil, errors.New("device closed") } // lock resources device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() device.peers.Lock() defer device.peers.Unlock() // check if over limit if len(device.peers.keyMap) >= MaxPeers { return nil, errors.New("too many peers") } // create peer peer := new(Peer) peer.cookieGenerator.Init(pk) peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device) peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("adding existing peer") } // pre-compute DH handshake := &peer.handshake handshake.mutex.Lock() handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk) handshake.remoteStatic = pk handshake.mutex.Unlock() // reset endpoint peer.endpoint.Lock() peer.endpoint.val = nil peer.endpoint.disableRoaming = false peer.endpoint.clearSrcOnTx = false peer.endpoint.Unlock() // init timers peer.timersInit() // add device.peers.keyMap[pk] = peer return peer, nil } func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() if peer.device.isClosed() { return nil } peer.endpoint.Lock() endpoint := peer.endpoint.val if endpoint == nil { peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } if peer.endpoint.clearSrcOnTx { endpoint.ClearSrc() peer.endpoint.clearSrcOnTx = false } peer.endpoint.Unlock() err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { var totalLen uint64 for _, b := range buffers { totalLen += uint64(len(b)) } peer.txBytes.Add(totalLen) } return err } func (peer *Peer) String() string { // The awful goo that follows is identical to: // // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43] // return fmt.Sprintf("peer(%s)", abbreviatedKey) // // except that it is considerably more efficient. src := peer.handshake.remoteStatic b64 := func(input byte) byte { return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3) } b := []byte("peer(____…____)") const first = len("peer(") const second = len("peer(____…") b[first+0] = b64((src[0] >> 2) & 63) b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63) b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63) b[first+3] = b64(src[2] & 63) b[second+0] = b64(src[29] & 63) b[second+1] = b64((src[30] >> 2) & 63) b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63) b[second+3] = b64((src[31] << 2) & 63) return string(b) } func (peer *Peer) Start() { // should never start a peer on a closed device if peer.device.isClosed() { return } // prevent simultaneous start/stop operations peer.state.Lock() defer peer.state.Unlock() if peer.isRunning.Load() { return } device := peer.device device.log.Verbosef("%v - Starting", peer) // reset routine state peer.stopping.Wait() peer.stopping.Add(2) peer.handshake.mutex.Lock() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) peer.handshake.mutex.Unlock() peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes peer.timersStart() device.flushInboundQueue(peer.queue.inbound) device.flushOutboundQueue(peer.queue.outbound) // Use the device batch size, not the bind batch size, as the device size is // the size of the batch pools. batchSize := peer.device.BatchSize() go peer.RoutineSequentialSender(batchSize) go peer.RoutineSequentialReceiver(batchSize) peer.isRunning.Store(true) } func (peer *Peer) ZeroAndFlushAll() { device := peer.device // clear key pairs keypairs := &peer.keypairs keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) device.DeleteKeypair(keypairs.next.Load()) keypairs.previous = nil keypairs.current = nil keypairs.next.Store(nil) keypairs.Unlock() // clear handshake state handshake := &peer.handshake handshake.mutex.Lock() device.indexTable.Delete(handshake.localIndex) handshake.Clear() handshake.mutex.Unlock() peer.FlushStagedPackets() } func (peer *Peer) ExpireCurrentKeypairs() { handshake := &peer.handshake handshake.mutex.Lock() peer.device.indexTable.Delete(handshake.localIndex) handshake.Clear() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) handshake.mutex.Unlock() keypairs := &peer.keypairs keypairs.Lock() if keypairs.current != nil { keypairs.current.sendNonce.Store(RejectAfterMessages) } if next := keypairs.next.Load(); next != nil { next.sendNonce.Store(RejectAfterMessages) } keypairs.Unlock() } func (peer *Peer) Stop() { peer.state.Lock() defer peer.state.Unlock() if !peer.isRunning.Swap(false) { return } peer.device.log.Verbosef("%v - Stopping", peer) peer.timersStop() // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. peer.queue.inbound.c <- nil peer.queue.outbound.c <- nil peer.stopping.Wait() peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us peer.ZeroAndFlushAll() } func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { peer.endpoint.Lock() defer peer.endpoint.Unlock() if peer.endpoint.disableRoaming { return } peer.endpoint.clearSrcOnTx = false peer.endpoint.val = endpoint } func (peer *Peer) markEndpointSrcForClearing() { peer.endpoint.Lock() defer peer.endpoint.Unlock() if peer.endpoint.val == nil { return } peer.endpoint.clearSrcOnTx = true } wireguard-go-0.0.20231211/device/pools.go000066400000000000000000000061201466306003400175620ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "sync" "sync/atomic" ) type WaitPool struct { pool sync.Pool cond sync.Cond lock sync.Mutex count atomic.Uint32 max uint32 } func NewWaitPool(max uint32, new func() any) *WaitPool { p := &WaitPool{pool: sync.Pool{New: new}, max: max} p.cond = sync.Cond{L: &p.lock} return p } func (p *WaitPool) Get() any { if p.max != 0 { p.lock.Lock() for p.count.Load() >= p.max { p.cond.Wait() } p.count.Add(1) p.lock.Unlock() } return p.pool.Get() } func (p *WaitPool) Put(x any) { p.pool.Put(x) if p.max == 0 { return } p.count.Add(^uint32(0)) p.cond.Signal() } func (device *Device) PopulatePools() { device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) return &QueueInboundElementsContainer{elems: s} }) device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { s := make([]*QueueOutboundElement, 0, device.BatchSize()) return &QueueOutboundElementsContainer{elems: s} }) device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) }) device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueInboundElement) }) device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueOutboundElement) }) } func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) c.Mutex = sync.Mutex{} return c } func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { for i := range c.elems { c.elems[i] = nil } c.elems = c.elems[:0] device.pool.inboundElementsContainer.Put(c) } func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) c.Mutex = sync.Mutex{} return c } func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { for i := range c.elems { c.elems[i] = nil } c.elems = c.elems[:0] device.pool.outboundElementsContainer.Put(c) } func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) } func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { device.pool.messageBuffers.Put(msg) } func (device *Device) GetInboundElement() *QueueInboundElement { return device.pool.inboundElements.Get().(*QueueInboundElement) } func (device *Device) PutInboundElement(elem *QueueInboundElement) { elem.clearPointers() device.pool.inboundElements.Put(elem) } func (device *Device) GetOutboundElement() *QueueOutboundElement { return device.pool.outboundElements.Get().(*QueueOutboundElement) } func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { elem.clearPointers() device.pool.outboundElements.Put(elem) } wireguard-go-0.0.20231211/device/pools_test.go000066400000000000000000000054311466306003400206250ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" "runtime" "sync" "sync/atomic" "testing" "time" ) func TestWaitPool(t *testing.T) { t.Skip("Currently disabled") var wg sync.WaitGroup var trials atomic.Int32 startTrials := int32(100000) if raceEnabled { // This test can be very slow with -race. startTrials /= 10 } trials.Store(startTrials) workers := runtime.NumCPU() + 2 if workers-4 <= 0 { t.Skip("Not enough cores") } p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) wg.Add(workers) var max atomic.Uint32 updateMax := func() { count := p.count.Load() if count > p.max { t.Errorf("count (%d) > max (%d)", count, p.max) } for { old := max.Load() if count <= old { break } if max.CompareAndSwap(old, count) { break } } } for i := 0; i < workers; i++ { go func() { defer wg.Done() for trials.Add(-1) > 0 { updateMax() x := p.Get() updateMax() time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) updateMax() p.Put(x) updateMax() } }() } wg.Wait() if max.Load() != p.max { t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) } } func BenchmarkWaitPool(b *testing.B) { var wg sync.WaitGroup var trials atomic.Int32 trials.Store(int32(b.N)) workers := runtime.NumCPU() + 2 if workers-4 <= 0 { b.Skip("Not enough cores") } p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) wg.Add(workers) b.ResetTimer() for i := 0; i < workers; i++ { go func() { defer wg.Done() for trials.Add(-1) > 0 { x := p.Get() time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) p.Put(x) } }() } wg.Wait() } func BenchmarkWaitPoolEmpty(b *testing.B) { var wg sync.WaitGroup var trials atomic.Int32 trials.Store(int32(b.N)) workers := runtime.NumCPU() + 2 if workers-4 <= 0 { b.Skip("Not enough cores") } p := NewWaitPool(0, func() any { return make([]byte, 16) }) wg.Add(workers) b.ResetTimer() for i := 0; i < workers; i++ { go func() { defer wg.Done() for trials.Add(-1) > 0 { x := p.Get() time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) p.Put(x) } }() } wg.Wait() } func BenchmarkSyncPool(b *testing.B) { var wg sync.WaitGroup var trials atomic.Int32 trials.Store(int32(b.N)) workers := runtime.NumCPU() + 2 if workers-4 <= 0 { b.Skip("Not enough cores") } p := sync.Pool{New: func() any { return make([]byte, 16) }} wg.Add(workers) b.ResetTimer() for i := 0; i < workers; i++ { go func() { defer wg.Done() for trials.Add(-1) > 0 { x := p.Get() time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) p.Put(x) } }() } wg.Wait() } wireguard-go-0.0.20231211/device/queueconstants_android.go000066400000000000000000000007431466306003400232140ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import "golang.zx2c4.com/wireguard/conn" /* Reduce memory consumption for Android */ const ( QueueStagedSize = conn.IdealBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram PreallocatedBuffersPerPool = 4096 ) wireguard-go-0.0.20231211/device/queueconstants_default.go000066400000000000000000000010301466306003400232060ustar00rootroot00000000000000//go:build !android && !ios && !windows /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import "golang.zx2c4.com/wireguard/conn" const ( QueueStagedSize = conn.IdealBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth ) wireguard-go-0.0.20231211/device/queueconstants_ios.go000066400000000000000000000011111466306003400223540ustar00rootroot00000000000000//go:build ios /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device // Fit within memory limits for iOS's Network Extension API, which has stricter requirements. // These are vars instead of consts, because heavier network extensions might want to reduce // them further. var ( QueueStagedSize = 128 QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 PreallocatedBuffersPerPool uint32 = 1024 ) const MaxSegmentSize = 1700 wireguard-go-0.0.20231211/device/queueconstants_windows.go000066400000000000000000000006551466306003400232700ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device const ( QueueStagedSize = 128 QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 MaxSegmentSize = 2048 - 32 // largest possible UDP datagram PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth ) wireguard-go-0.0.20231211/device/race_disabled_test.go000066400000000000000000000002431466306003400222260ustar00rootroot00000000000000//go:build !race /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device const raceEnabled = false wireguard-go-0.0.20231211/device/race_enabled_test.go000066400000000000000000000002411466306003400220470ustar00rootroot00000000000000//go:build race /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device const raceEnabled = true wireguard-go-0.0.20231211/device/receive.go000066400000000000000000000323221466306003400200530ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bytes" "encoding/binary" "errors" "net" "sync" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { msgType uint32 packet []byte endpoint conn.Endpoint buffer *[MaxMessageSize]byte } type QueueInboundElement struct { buffer *[MaxMessageSize]byte packet []byte counter uint64 keypair *Keypair endpoint conn.Endpoint } type QueueInboundElementsContainer struct { sync.Mutex elems []*QueueInboundElement } // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueInboundElement) clearPointers() { elem.buffer = nil elem.packet = nil elem.keypair = nil elem.endpoint = nil } /* Called when a new authenticated message has been received * * NOTE: Not thread safe, but called by sequential receiver! */ func (peer *Peer) keepKeyFreshReceiving() { if peer.timers.sentLastMinuteHandshake.Load() { return } keypair := peer.keypairs.Current() if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { peer.timers.sentLastMinuteHandshake.Store(true) peer.SendHandshakeInitiation(false) } } /* Receives incoming datagrams for the device * * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { recvName := recv.PrettyName() defer func() { device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) device.queue.decryption.wg.Done() device.queue.handshake.wg.Done() device.net.stopping.Done() }() device.log.Verbosef("Routine: receive incoming %s - started", recvName) // receive datagrams until conn is closed var ( bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) bufs = make([][]byte, maxBatchSize) err error sizes = make([]int, maxBatchSize) count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) for i := range bufsArrs { bufsArrs[i] = device.GetMessageBuffer() bufs[i] = bufsArrs[i][:] } defer func() { for i := 0; i < maxBatchSize; i++ { if bufsArrs[i] != nil { device.PutMessageBuffer(bufsArrs[i]) } } }() for { count, err = recv(bufs, sizes, endpoints) if err != nil { if errors.Is(err, net.ErrClosed) { return } device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { return } if deathSpiral < 10 { deathSpiral++ time.Sleep(time.Second / 3) continue } return } deathSpiral = 0 // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { continue } // check size of packet packet := bufsArrs[i][:size] msgType := binary.LittleEndian.Uint32(packet[:4]) switch msgType { // check if transport case MessageTransportType: // check size if len(packet) < MessageTransportSize { continue } // lookup key pair receiver := binary.LittleEndian.Uint32( packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], ) value := device.indexTable.Lookup(receiver) keypair := value.keypair if keypair == nil { continue } // check keypair expiry if keypair.created.Add(RejectAfterTime).Before(time.Now()) { continue } // create work element peer := value.peer elem := device.GetInboundElement() elem.packet = packet elem.buffer = bufsArrs[i] elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 elemsForPeer, ok := elemsByPeer[peer] if !ok { elemsForPeer = device.GetInboundElementsContainer() elemsForPeer.Lock() elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) bufsArrs[i] = device.GetMessageBuffer() bufs[i] = bufsArrs[i][:] continue // otherwise it is a fixed size & handshake related packet case MessageInitiationType: if len(packet) != MessageInitiationSize { continue } case MessageResponseType: if len(packet) != MessageResponseSize { continue } case MessageCookieReplyType: if len(packet) != MessageCookieReplySize { continue } default: device.log.Verbosef("Received message with unknown type") continue } select { case device.queue.handshake.c <- QueueHandshakeElement{ msgType: msgType, buffer: bufsArrs[i], packet: packet, endpoint: endpoints[i], }: bufsArrs[i] = device.GetMessageBuffer() bufs[i] = bufsArrs[i][:] default: } } for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer device.queue.decryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) } delete(elemsByPeer, peer) } } } func (device *Device) RoutineDecryption(id int) { var nonce [chacha20poly1305.NonceSize]byte defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) device.log.Verbosef("Routine: decryption worker %d - started", id) for elemsContainer := range device.queue.decryption.c { for _, elem := range elemsContainer.elems { // split message into fields counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] // decrypt and release to consumer var err error elem.counter = binary.LittleEndian.Uint64(counter) // copy counter to nonce binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) elem.packet, err = elem.keypair.receive.Open( content[:0], nonce[:], content, nil, ) if err != nil { elem.packet = nil } } elemsContainer.Unlock() } } /* Handles incoming packets related to handshake */ func (device *Device) RoutineHandshake(id int) { defer func() { device.log.Verbosef("Routine: handshake worker %d - stopped", id) device.queue.encryption.wg.Done() }() device.log.Verbosef("Routine: handshake worker %d - started", id) for elem := range device.queue.handshake.c { // handle cookie fields and ratelimiting switch elem.msgType { case MessageCookieReplyType: // unmarshal packet var reply MessageCookieReply reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &reply) if err != nil { device.log.Verbosef("Failed to decode cookie reply") goto skip } // lookup peer from index entry := device.indexTable.Lookup(reply.Receiver) if entry.peer == nil { goto skip } // consume reply if peer := entry.peer; peer.isRunning.Load() { device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) if !peer.cookieGenerator.ConsumeReply(&reply) { device.log.Verbosef("Could not decrypt invalid cookie response") } } goto skip case MessageInitiationType, MessageResponseType: // check mac fields and maybe ratelimit if !device.cookieChecker.CheckMAC1(elem.packet) { device.log.Verbosef("Received packet with invalid mac1") goto skip } // endpoints destination address is the source of the datagram if device.IsUnderLoad() { // verify MAC2 field if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { device.SendHandshakeCookie(&elem) goto skip } // check ratelimiter if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { goto skip } } default: device.log.Errorf("Invalid packet ended up in the handshake queue") goto skip } // handle handshake initiation/response content switch elem.msgType { case MessageInitiationType: // unmarshal var msg MessageInitiation reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { device.log.Errorf("Failed to decode initiation message") goto skip } // consume initiation peer := device.ConsumeMessageInitiation(&msg) if peer == nil { device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) goto skip } // update timers peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketReceived() // update endpoint peer.SetEndpointFromPacket(elem.endpoint) device.log.Verbosef("%v - Received handshake initiation", peer) peer.rxBytes.Add(uint64(len(elem.packet))) peer.SendHandshakeResponse() case MessageResponseType: // unmarshal var msg MessageResponse reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { device.log.Errorf("Failed to decode response message") goto skip } // consume response peer := device.ConsumeMessageResponse(&msg) if peer == nil { device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) goto skip } // update endpoint peer.SetEndpointFromPacket(elem.endpoint) device.log.Verbosef("%v - Received handshake response", peer) peer.rxBytes.Add(uint64(len(elem.packet))) // update timers peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketReceived() // derive keypair err = peer.BeginSymmetricSession() if err != nil { device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) goto skip } peer.timersSessionDerived() peer.timersHandshakeComplete() peer.SendKeepalive() } skip: device.PutMessageBuffer(elem.buffer) } } func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device := peer.device defer func() { device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) peer.stopping.Done() }() device.log.Verbosef("%v - Routine: sequential receiver - started", peer) bufs := make([][]byte, 0, maxBatchSize) for elemsContainer := range peer.queue.inbound.c { if elemsContainer == nil { return } elemsContainer.Lock() validTailPacket := -1 dataPacketReceived := false rxBytesLen := uint64(0) for i, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed continue } if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { continue } validTailPacket = i if peer.ReceivedWithKeypair(elem.keypair) { peer.SetEndpointFromPacket(elem.endpoint) peer.timersHandshakeComplete() peer.SendStagedPackets() } rxBytesLen += uint64(len(elem.packet) + MinMessageSize) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } dataPacketReceived = true switch elem.packet[0] >> 4 { case 4: if len(elem.packet) < ipv4.HeaderLen { continue } field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] length := binary.BigEndian.Uint16(field) if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { continue } elem.packet = elem.packet[:length] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] if device.allowedips.Lookup(src) != peer { device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) continue } case 6: if len(elem.packet) < ipv6.HeaderLen { continue } field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] length := binary.BigEndian.Uint16(field) length += ipv6.HeaderLen if int(length) > len(elem.packet) { continue } elem.packet = elem.packet[:length] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.allowedips.Lookup(src) != peer { device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) continue } default: device.log.Verbosef("Packet with invalid IP version from %v", peer) continue } bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } peer.rxBytes.Add(rxBytesLen) if validTailPacket >= 0 { peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) peer.keepKeyFreshReceiving() peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketReceived() } if dataPacketReceived { peer.timersDataReceived() } if len(bufs) > 0 { _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) } } for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } bufs = bufs[:0] device.PutInboundElementsContainer(elemsContainer) } } wireguard-go-0.0.20231211/device/send.go000066400000000000000000000352701466306003400173670ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bytes" "encoding/binary" "errors" "net" "os" "sync" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow * * 1. TUN queue * 2. Routing (sequential) * 3. Nonce assignment (sequential) * 4. Encryption (parallel) * 5. Transmission (sequential) * * The functions in this file occur (roughly) in the order in * which the packets are processed. * * Locking, Producers and Consumers * * The order of packets (per peer) must be maintained, * but encryption of packets happen out-of-order: * * The sequential consumers will attempt to take the lock, * workers release lock when they have completed work (encryption) on the packet. * * If the element is inserted into the "encryption queue", * the content is preceded by enough "junk" to contain the transport header * (to allow the construction of transport messages in-place) */ type QueueOutboundElement struct { buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption keypair *Keypair // keypair for encryption peer *Peer // related peer } type QueueOutboundElementsContainer struct { sync.Mutex elems []*QueueOutboundElement } func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() elem.buffer = device.GetMessageBuffer() elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem } // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueOutboundElement) clearPointers() { elem.buffer = nil elem.packet = nil elem.keypair = nil elem.peer = nil } /* Queues a keepalive if no packets are queued for peer */ func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() elemsContainer := peer.device.GetOutboundElementsContainer() elemsContainer.elems = append(elemsContainer.elems, elem) select { case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) peer.device.PutOutboundElementsContainer(elemsContainer) } } peer.SendStagedPackets() } func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if !isRetry { peer.timers.handshakeAttempts.Store(0) } peer.handshake.mutex.RLock() if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { peer.handshake.mutex.RUnlock() return nil } peer.handshake.mutex.RUnlock() peer.handshake.mutex.Lock() if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { peer.handshake.mutex.Unlock() return nil } peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() peer.device.log.Verbosef("%v - Sending handshake initiation", peer) msg, err := peer.device.CreateMessageInitiation(peer) if err != nil { peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) return err } var buf [MessageInitiationSize]byte writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, msg) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() err = peer.SendBuffers([][]byte{packet}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } peer.timersHandshakeInitiated() return err } func (peer *Peer) SendHandshakeResponse() error { peer.handshake.mutex.Lock() peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() peer.device.log.Verbosef("%v - Sending handshake response", peer) response, err := peer.device.CreateMessageResponse(peer) if err != nil { peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) return err } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, response) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() if err != nil { peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) return err } peer.timersSessionDerived() peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided err = peer.SendBuffers([][]byte{packet}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } return err } func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) if err != nil { device.log.Errorf("Failed to create cookie reply: %v", err) return err } var buf [MessageCookieReplySize]byte writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) // TODO: allocation could be avoided device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) return nil } func (peer *Peer) keepKeyFreshSending() { keypair := peer.keypairs.Current() if keypair == nil { return } nonce := keypair.sendNonce.Load() if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { peer.SendHandshakeInitiation(false) } } func (device *Device) RoutineReadFromTUN() { defer func() { device.log.Verbosef("Routine: TUN reader - stopped") device.state.stopping.Done() device.queue.encryption.wg.Done() }() device.log.Verbosef("Routine: TUN reader - started") var ( batchSize = device.BatchSize() readErr error elems = make([]*QueueOutboundElement, batchSize) bufs = make([][]byte, batchSize) elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) offset = MessageTransportHeaderSize ) for i := range elems { elems[i] = device.NewOutboundElement() bufs[i] = elems[i].buffer[:] } defer func() { for _, elem := range elems { if elem != nil { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } } }() for { // read packets count, readErr = device.tun.device.Read(bufs, sizes, offset) for i := 0; i < count; i++ { if sizes[i] < 1 { continue } elem := elems[i] elem.packet = bufs[i][offset : offset+sizes[i]] // lookup peer var peer *Peer switch elem.packet[0] >> 4 { case 4: if len(elem.packet) < ipv4.HeaderLen { continue } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] peer = device.allowedips.Lookup(dst) case 6: if len(elem.packet) < ipv6.HeaderLen { continue } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] peer = device.allowedips.Lookup(dst) default: device.log.Verbosef("Received packet with unknown IP version") } if peer == nil { continue } elemsForPeer, ok := elemsByPeer[peer] if !ok { elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) elems[i] = device.NewOutboundElement() bufs[i] = elems[i].buffer[:] } for peer, elemsForPeer := range elemsByPeer { if peer.isRunning.Load() { peer.StagePackets(elemsForPeer) peer.SendStagedPackets() } else { for _, elem := range elemsForPeer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsForPeer) } delete(elemsByPeer, peer) } if readErr != nil { if errors.Is(readErr, tun.ErrTooManySegments) { // TODO: record stat for this // This will happen if MSS is surprisingly small (< 576) // coincident with reasonably high throughput. device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) continue } if !device.isClosed() { if !errors.Is(readErr, os.ErrClosed) { device.log.Errorf("Failed to read packet from TUN device: %v", readErr) } go device.Close() } return } } } func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { case peer.queue.staged <- elems: return default: } select { case tooOld := <-peer.queue.staged: for _, elem := range tooOld.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(tooOld) default: } } } func (peer *Peer) SendStagedPackets() { top: if len(peer.queue.staged) == 0 || !peer.device.isUp() { return } keypair := peer.keypairs.Current() if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { peer.SendHandshakeInitiation(false) return } for { var elemsContainerOOO *QueueOutboundElementsContainer select { case elemsContainer := <-peer.queue.staged: i := 0 for _, elem := range elemsContainer.elems { elem.peer = peer elem.nonce = keypair.sendNonce.Add(1) - 1 if elem.nonce >= RejectAfterMessages { keypair.sendNonce.Store(RejectAfterMessages) if elemsContainerOOO == nil { elemsContainerOOO = peer.device.GetOutboundElementsContainer() } elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) continue } else { elemsContainer.elems[i] = elem i++ } elem.keypair = keypair } elemsContainer.Lock() elemsContainer.elems = elemsContainer.elems[:i] if elemsContainerOOO != nil { peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans } if len(elemsContainer.elems) == 0 { peer.device.PutOutboundElementsContainer(elemsContainer) goto top } // add to parallel and sequential queue if peer.isRunning.Load() { peer.queue.outbound.c <- elemsContainer peer.device.queue.encryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) } if elemsContainerOOO != nil { goto top } default: return } } } func (peer *Peer) FlushStagedPackets() { for { select { case elemsContainer := <-peer.queue.staged: for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) default: return } } } func calculatePaddingSize(packetSize, mtu int) int { lastUnit := packetSize if mtu == 0 { return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit } if lastUnit > mtu { lastUnit %= mtu } paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) if paddedSize > mtu { paddedSize = mtu } return paddedSize - lastUnit } /* Encrypts the elements in the queue * and marks them for sequential consumption (by releasing the mutex) * * Obs. One instance per core */ func (device *Device) RoutineEncryption(id int) { var paddingZeros [PaddingMultiple]byte var nonce [chacha20poly1305.NonceSize]byte defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { // populate header fields header := elem.buffer[:MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] fieldNonce := header[8:16] binary.LittleEndian.PutUint32(fieldType, MessageTransportType) binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) // pad content to multiple of 16 paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) // encrypt content and release to consumer binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) elem.packet = elem.keypair.send.Seal( header, nonce[:], elem.packet, nil, ) } elemsContainer.Unlock() } } func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device := peer.device defer func() { defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) peer.stopping.Done() }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) bufs := make([][]byte, 0, maxBatchSize) for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] if elemsContainer == nil { return } if !peer.isRunning.Load() { // peer has been stopped; return re-usable elems to the shared pool. // This is an optimization only. It is possible for the peer to be stopped // immediately after this check, in which case, elem will get processed. // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. elemsContainer.Lock() for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } continue } dataSent := false elemsContainer.Lock() for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true } bufs = append(bufs, elem.packet) } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() err := peer.SendBuffers(bufs) if dataSent { peer.timersDataSent() } for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) if err != nil { var errGSO conn.ErrUDPGSODisabled if errors.As(err, &errGSO) { device.log.Verbosef(err.Error()) err = errGSO.RetryErr } } if err != nil { device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue } peer.keepKeyFreshSending() } } wireguard-go-0.0.20231211/device/sticky_default.go000066400000000000000000000003431466306003400214410ustar00rootroot00000000000000//go:build !linux package device import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { return nil, nil } wireguard-go-0.0.20231211/device/sticky_linux.go000066400000000000000000000127711466306003400211640ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * * This implements userspace semantics of "sticky sockets", modeled after * WireGuard's kernelspace implementation. This is more or less a straight port * of the sticky-sockets.c example code: * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 * So this code is remains platform dependent. */ package device import ( "sync" "unsafe" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { if !conn.StdNetSupportsStickySockets { return nil, nil } if _, ok := bind.(*conn.StdNetBind); !ok { return nil, nil } netlinkSock, err := createNetlinkRouteSocket() if err != nil { return nil, err } netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) if err != nil { unix.Close(netlinkSock) return nil, err } go device.routineRouteListener(bind, netlinkSock, netlinkCancel) return netlinkCancel, nil } func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { type peerEndpointPtr struct { peer *Peer endpoint *conn.Endpoint } var reqPeer map[uint32]peerEndpointPtr var reqPeerLock sync.Mutex defer netlinkCancel.Close() defer unix.Close(netlinkSock) for msg := make([]byte, 1<<16); ; { var err error var msgn int for { msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) if err == nil || !rwcancel.RetryAfterError(err) { break } if !netlinkCancel.ReadyRead() { return } } if err != nil { return } for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) if uint(hdr.Len) > uint(len(remain)) { break } switch hdr.Type { case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: if hdr.Seq <= MaxPeers && hdr.Seq > 0 { if uint(len(remain)) < uint(hdr.Len) { break } if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] for { if uint(len(attr)) < uint(unix.SizeofRtAttr) { break } attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { break } if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) reqPeerLock.Lock() if reqPeer == nil { reqPeerLock.Unlock() break } pePtr, ok := reqPeer[hdr.Seq] reqPeerLock.Unlock() if !ok { break } pePtr.peer.endpoint.Lock() if &pePtr.peer.endpoint.val != pePtr.endpoint { pePtr.peer.endpoint.Unlock() break } if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { pePtr.peer.endpoint.Unlock() break } pePtr.peer.endpoint.clearSrcOnTx = true pePtr.peer.endpoint.Unlock() } attr = attr[attrhdr.Len:] } } break } reqPeerLock.Lock() reqPeer = make(map[uint32]peerEndpointPtr) reqPeerLock.Unlock() go func() { device.peers.RLock() i := uint32(1) for _, peer := range device.peers.keyMap { peer.endpoint.Lock() if peer.endpoint.val == nil { peer.endpoint.Unlock() continue } nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) if nativeEP == nil { peer.endpoint.Unlock() continue } if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { peer.endpoint.Unlock() break } nlmsg := struct { hdr unix.NlMsghdr msg unix.RtMsg dsthdr unix.RtAttr dst [4]byte srchdr unix.RtAttr src [4]byte markhdr unix.RtAttr mark uint32 }{ unix.NlMsghdr{ Type: uint16(unix.RTM_GETROUTE), Flags: unix.NLM_F_REQUEST, Seq: i, }, unix.RtMsg{ Family: unix.AF_INET, Dst_len: 32, Src_len: 32, }, unix.RtAttr{ Len: 8, Type: unix.RTA_DST, }, nativeEP.DstIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_SRC, }, nativeEP.SrcIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_MARK, }, device.net.fwmark, } nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) reqPeerLock.Lock() reqPeer[i] = peerEndpointPtr{ peer: peer, endpoint: &peer.endpoint.val, } reqPeerLock.Unlock() peer.endpoint.Unlock() i++ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) if err != nil { break } } device.peers.RUnlock() }() } remain = remain[hdr.Len:] } } } func createNetlinkRouteSocket() (int, error) { sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) if err != nil { return -1, err } saddr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, Groups: unix.RTMGRP_IPV4_ROUTE, } err = unix.Bind(sock, saddr) if err != nil { unix.Close(sock) return -1, err } return sock, nil } wireguard-go-0.0.20231211/device/timers.go000066400000000000000000000152541466306003400177410ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * * This is based heavily on timers.c from the kernel implementation. */ package device import ( "sync" "time" _ "unsafe" ) //go:linkname fastrandn runtime.fastrandn func fastrandn(n uint32) uint32 // A Timer manages time-based aspects of the WireGuard protocol. // Timer roughly copies the interface of the Linux kernel's struct timer_list. type Timer struct { *time.Timer modifyingLock sync.RWMutex runningLock sync.Mutex isPending bool } func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { timer := &Timer{} timer.Timer = time.AfterFunc(time.Hour, func() { timer.runningLock.Lock() defer timer.runningLock.Unlock() timer.modifyingLock.Lock() if !timer.isPending { timer.modifyingLock.Unlock() return } timer.isPending = false timer.modifyingLock.Unlock() expirationFunction(peer) }) timer.Stop() return timer } func (timer *Timer) Mod(d time.Duration) { timer.modifyingLock.Lock() timer.isPending = true timer.Reset(d) timer.modifyingLock.Unlock() } func (timer *Timer) Del() { timer.modifyingLock.Lock() timer.isPending = false timer.Stop() timer.modifyingLock.Unlock() } func (timer *Timer) DelSync() { timer.Del() timer.runningLock.Lock() timer.Del() timer.runningLock.Unlock() } func (timer *Timer) IsPending() bool { timer.modifyingLock.RLock() defer timer.modifyingLock.RUnlock() return timer.isPending } func (peer *Peer) timersActive() bool { return peer.isRunning.Load() && peer.device != nil && peer.device.isUp() } func expiredRetransmitHandshake(peer *Peer) { if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) if peer.timersActive() { peer.timers.sendKeepalive.Del() } /* We drop all packets without a keypair and don't try again, * if we try unsuccessfully for too long to make a handshake. */ peer.FlushStagedPackets() /* We set a timer for destroying any residue that might be left * of a partial exchange. */ if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) } } else { peer.timers.handshakeAttempts.Add(1) peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(true) } } func expiredSendKeepalive(peer *Peer) { peer.SendKeepalive() if peer.timers.needAnotherKeepalive.Load() { peer.timers.needAnotherKeepalive.Store(false) if peer.timersActive() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } } } func expiredNewHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) /* We clear the endpoint address src address, in case this is the cause of trouble. */ peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(false) } func expiredZeroKeyMaterial(peer *Peer) { peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) peer.ZeroAndFlushAll() } func expiredPersistentKeepalive(peer *Peer) { if peer.persistentKeepaliveInterval.Load() > 0 { peer.SendKeepalive() } } /* Should be called after an authenticated data packet is sent. */ func (peer *Peer) timersDataSent() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() { peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) } } /* Should be called after an authenticated data packet is received. */ func (peer *Peer) timersDataReceived() { if peer.timersActive() { if !peer.timers.sendKeepalive.IsPending() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } else { peer.timers.needAnotherKeepalive.Store(true) } } } /* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ func (peer *Peer) timersAnyAuthenticatedPacketSent() { if peer.timersActive() { peer.timers.sendKeepalive.Del() } } /* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ func (peer *Peer) timersAnyAuthenticatedPacketReceived() { if peer.timersActive() { peer.timers.newHandshake.Del() } } /* Should be called after a handshake initiation message is sent. */ func (peer *Peer) timersHandshakeInitiated() { if peer.timersActive() { peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) } } /* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ func (peer *Peer) timersHandshakeComplete() { if peer.timersActive() { peer.timers.retransmitHandshake.Del() } peer.timers.handshakeAttempts.Store(0) peer.timers.sentLastMinuteHandshake.Store(false) peer.lastHandshakeNano.Store(time.Now().UnixNano()) } /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ func (peer *Peer) timersSessionDerived() { if peer.timersActive() { peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) } } /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { keepalive := peer.persistentKeepaliveInterval.Load() if keepalive > 0 && peer.timersActive() { peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) } } func (peer *Peer) timersInit() { peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) } func (peer *Peer) timersStart() { peer.timers.handshakeAttempts.Store(0) peer.timers.sentLastMinuteHandshake.Store(false) peer.timers.needAnotherKeepalive.Store(false) } func (peer *Peer) timersStop() { peer.timers.retransmitHandshake.DelSync() peer.timers.sendKeepalive.DelSync() peer.timers.newHandshake.DelSync() peer.timers.zeroKeyMaterial.DelSync() peer.timers.persistentKeepalive.DelSync() } wireguard-go-0.0.20231211/device/tun.go000066400000000000000000000022521466306003400172360ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "fmt" "golang.zx2c4.com/wireguard/tun" ) const DefaultMTU = 1420 func (device *Device) RoutineTUNEventReader() { device.log.Verbosef("Routine: event worker - started") for event := range device.tun.device.Events() { if event&tun.EventMTUUpdate != 0 { mtu, err := device.tun.device.MTU() if err != nil { device.log.Errorf("Failed to load updated MTU of device: %v", err) continue } if mtu < 0 { device.log.Errorf("MTU not updated to negative value: %v", mtu) continue } var tooLarge string if mtu > MaxContentSize { tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) mtu = MaxContentSize } old := device.tun.mtu.Swap(int32(mtu)) if int(old) != mtu { device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) } } if event&tun.EventUp != 0 { device.log.Verbosef("Interface up requested") device.Up() } if event&tun.EventDown != 0 { device.log.Verbosef("Interface down requested") device.Down() } } device.log.Verbosef("Routine: event worker - stopped") } wireguard-go-0.0.20231211/device/uapi.go000066400000000000000000000271061466306003400173730ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bufio" "bytes" "errors" "fmt" "io" "net" "net/netip" "strconv" "strings" "sync" "time" "golang.zx2c4.com/wireguard/ipc" ) type IPCError struct { code int64 // error code err error // underlying/wrapped error } func (s IPCError) Error() string { return fmt.Sprintf("IPC error %d: %v", s.code, s.err) } func (s IPCError) Unwrap() error { return s.err } func (s IPCError) ErrorCode() int64 { return s.code } func ipcErrorf(code int64, msg string, args ...any) *IPCError { return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } var byteBufferPool = &sync.Pool{ New: func() any { return new(bytes.Buffer) }, } // IpcGetOperation implements the WireGuard configuration protocol "get" operation. // See https://www.wireguard.com/xplatform/#configuration-protocol for details. func (device *Device) IpcGetOperation(w io.Writer) error { device.ipcMutex.RLock() defer device.ipcMutex.RUnlock() buf := byteBufferPool.Get().(*bytes.Buffer) buf.Reset() defer byteBufferPool.Put(buf) sendf := func(format string, args ...any) { fmt.Fprintf(buf, format, args...) buf.WriteByte('\n') } keyf := func(prefix string, key *[32]byte) { buf.Grow(len(key)*2 + 2 + len(prefix)) buf.WriteString(prefix) buf.WriteByte('=') const hex = "0123456789abcdef" for i := 0; i < len(key); i++ { buf.WriteByte(hex[key[i]>>4]) buf.WriteByte(hex[key[i]&0xf]) } buf.WriteByte('\n') } func() { // lock required resources device.net.RLock() defer device.net.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() device.peers.RLock() defer device.peers.RUnlock() // serialize device related values if !device.staticIdentity.privateKey.IsZero() { keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey)) } if device.net.port != 0 { sendf("listen_port=%d", device.net.port) } if device.net.fwmark != 0 { sendf("fwmark=%d", device.net.fwmark) } for _, peer := range device.peers.keyMap { // Serialize peer state. peer.handshake.mutex.RLock() keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) peer.handshake.mutex.RUnlock() sendf("protocol_version=1") peer.endpoint.Lock() if peer.endpoint.val != nil { sendf("endpoint=%s", peer.endpoint.val.DstToString()) } peer.endpoint.Unlock() nano := peer.lastHandshakeNano.Load() secs := nano / time.Second.Nanoseconds() nano %= time.Second.Nanoseconds() sendf("last_handshake_time_sec=%d", secs) sendf("last_handshake_time_nsec=%d", nano) sendf("tx_bytes=%d", peer.txBytes.Load()) sendf("rx_bytes=%d", peer.rxBytes.Load()) sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { sendf("allowed_ip=%s", prefix.String()) return true }) } }() // send lines (does not require resource locks) if _, err := w.Write(buf.Bytes()); err != nil { return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err) } return nil } // IpcSetOperation implements the WireGuard configuration protocol "set" operation. // See https://www.wireguard.com/xplatform/#configuration-protocol for details. func (device *Device) IpcSetOperation(r io.Reader) (err error) { device.ipcMutex.Lock() defer device.ipcMutex.Unlock() defer func() { if err != nil { device.log.Errorf("%v", err) } }() peer := new(ipcSetPeer) deviceConfig := true scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { // Blank line means terminate operation. peer.handlePostConfig() return nil } key, value, ok := strings.Cut(line, "=") if !ok { return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line) } if key == "public_key" { if deviceConfig { deviceConfig = false } peer.handlePostConfig() // Load/create the peer we are now configuring. err := device.handlePublicKeyLine(peer, value) if err != nil { return err } continue } var err error if deviceConfig { err = device.handleDeviceLine(key, value) } else { err = device.handlePeerLine(peer, key, value) } if err != nil { return err } } peer.handlePostConfig() if err := scanner.Err(); err != nil { return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err) } return nil } func (device *Device) handleDeviceLine(key, value string) error { switch key { case "private_key": var sk NoisePrivateKey err := sk.FromMaybeZeroHex(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) } device.log.Verbosef("UAPI: Updating private key") device.SetPrivateKey(sk) case "listen_port": port, err := strconv.ParseUint(value, 10, 16) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) } // update port and rebind device.log.Verbosef("UAPI: Updating listen port") device.net.Lock() device.net.port = uint16(port) device.net.Unlock() if err := device.BindUpdate(); err != nil { return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) } case "fwmark": mark, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) } device.log.Verbosef("UAPI: Updating fwmark") if err := device.BindSetMark(uint32(mark)); err != nil { return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) } case "replace_peers": if value != "true" { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) } device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } return nil } // An ipcSetPeer is the current state of an IPC set operation on a peer. type ipcSetPeer struct { *Peer // Peer is the current peer being operated on dummy bool // dummy reports whether this peer is a temporary, placeholder peer created bool // new reports whether this is a newly created peer pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on } func (peer *ipcSetPeer) handlePostConfig() { if peer.Peer == nil || peer.dummy { return } if peer.created { peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil } if peer.device.isUp() { peer.Start() if peer.pkaOn { peer.SendKeepalive() } peer.SendStagedPackets() } } func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { // Load/create the peer we are configuring. var publicKey NoisePublicKey err := publicKey.FromHex(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) } // Ignore peer with the same public key as this device. device.staticIdentity.RLock() peer.dummy = device.staticIdentity.publicKey.Equals(publicKey) device.staticIdentity.RUnlock() if peer.dummy { peer.Peer = &Peer{} } else { peer.Peer = device.LookupPeer(publicKey) } peer.created = peer.Peer == nil if peer.created { peer.Peer, err = device.NewPeer(publicKey) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) } device.log.Verbosef("%v - UAPI: Created", peer.Peer) } return nil } func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { switch key { case "update_only": // allow disabling of creation if value != "true" { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) } if peer.created && !peer.dummy { device.RemovePeer(peer.handshake.remoteStatic) peer.Peer = &Peer{} peer.dummy = true } case "remove": // remove currently selected peer from device if value != "true" { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) } if !peer.dummy { device.log.Verbosef("%v - UAPI: Removing", peer.Peer) device.RemovePeer(peer.handshake.remoteStatic) } peer.Peer = &Peer{} peer.dummy = true case "preshared_key": device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer) peer.handshake.mutex.Lock() err := peer.handshake.presharedKey.FromHex(value) peer.handshake.mutex.Unlock() if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) } case "endpoint": device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) endpoint, err := device.net.bind.ParseEndpoint(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) } peer.endpoint.Lock() defer peer.endpoint.Unlock() peer.endpoint.val = endpoint case "persistent_keepalive_interval": device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) secs, err := strconv.ParseUint(value, 10, 16) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) } old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) // Send immediate keepalive if we're turning it on and before it wasn't on. peer.pkaOn = old == 0 && secs != 0 case "replace_allowed_ips": device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) if value != "true" { return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) } if peer.dummy { return nil } device.allowedips.RemoveByPeer(peer.Peer) case "allowed_ip": device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) prefix, err := netip.ParsePrefix(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) } if peer.dummy { return nil } device.allowedips.Insert(prefix, peer.Peer) case "protocol_version": if value != "1" { return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) } default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) } return nil } func (device *Device) IpcGet() (string, error) { buf := new(strings.Builder) if err := device.IpcGetOperation(buf); err != nil { return "", err } return buf.String(), nil } func (device *Device) IpcSet(uapiConf string) error { return device.IpcSetOperation(strings.NewReader(uapiConf)) } func (device *Device) IpcHandle(socket net.Conn) { defer socket.Close() buffered := func(s io.ReadWriter) *bufio.ReadWriter { reader := bufio.NewReader(s) writer := bufio.NewWriter(s) return bufio.NewReadWriter(reader, writer) }(socket) for { op, err := buffered.ReadString('\n') if err != nil { return } // handle operation switch op { case "set=1\n": err = device.IpcSetOperation(buffered.Reader) case "get=1\n": var nextByte byte nextByte, err = buffered.ReadByte() if err != nil { return } if nextByte != '\n' { err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) break } err = device.IpcGetOperation(buffered.Writer) default: device.log.Errorf("invalid UAPI operation: %v", op) return } // write status var status *IPCError if err != nil && !errors.As(err, &status) { // shouldn't happen status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) } if status != nil { device.log.Errorf("%v", status) fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) } else { fmt.Fprintf(buffered, "errno=0\n\n") } buffered.Flush() } } wireguard-go-0.0.20231211/format_test.go000066400000000000000000000020011466306003400175100ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package main import ( "bytes" "go/format" "io/fs" "os" "path/filepath" "runtime" "sync" "testing" ) func TestFormatting(t *testing.T) { var wg sync.WaitGroup filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { if err != nil { t.Errorf("unable to walk %s: %v", path, err) return nil } if d.IsDir() || filepath.Ext(path) != ".go" { return nil } wg.Add(1) go func(path string) { defer wg.Done() src, err := os.ReadFile(path) if err != nil { t.Errorf("unable to read %s: %v", path, err) return } if runtime.GOOS == "windows" { src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'}) } formatted, err := format.Source(src) if err != nil { t.Errorf("unable to format %s: %v", path, err) return } if !bytes.Equal(src, formatted) { t.Errorf("unformatted code: %s", path) } }(path) return nil }) wg.Wait() } wireguard-go-0.0.20231211/go.mod000066400000000000000000000005661466306003400157560ustar00rootroot00000000000000module golang.zx2c4.com/wireguard go 1.20 require ( golang.org/x/crypto v0.13.0 golang.org/x/net v0.15.0 golang.org/x/sys v0.12.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 ) require ( github.com/google/btree v1.0.1 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) wireguard-go-0.0.20231211/go.sum000066400000000000000000000023651466306003400160020ustar00rootroot00000000000000github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= wireguard-go-0.0.20231211/ipc/000077500000000000000000000000001466306003400154145ustar00rootroot00000000000000wireguard-go-0.0.20231211/ipc/namedpipe/000077500000000000000000000000001466306003400173565ustar00rootroot00000000000000wireguard-go-0.0.20231211/ipc/namedpipe/file.go000066400000000000000000000145511466306003400206320ustar00rootroot00000000000000// Copyright 2021 The Go Authors. All rights reserved. // Copyright 2015 Microsoft // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build windows package namedpipe import ( "io" "os" "runtime" "sync" "sync/atomic" "time" "unsafe" "golang.org/x/sys/windows" ) type timeoutChan chan struct{} var ( ioInitOnce sync.Once ioCompletionPort windows.Handle ) // ioResult contains the result of an asynchronous IO operation type ioResult struct { bytes uint32 err error } // ioOperation represents an outstanding asynchronous Win32 IO type ioOperation struct { o windows.Overlapped ch chan ioResult } func initIo() { h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { panic(err) } ioCompletionPort = h go ioCompletionProcessor(h) } // file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // It takes ownership of this handle and will close it if it is garbage collected. type file struct { handle windows.Handle wg sync.WaitGroup wgLock sync.RWMutex closing atomic.Bool socket bool readDeadline deadlineHandler writeDeadline deadlineHandler } type deadlineHandler struct { setLock sync.Mutex channel timeoutChan channelLock sync.RWMutex timer *time.Timer timedout atomic.Bool } // makeFile makes a new file from an existing file handle func makeFile(h windows.Handle) (*file, error) { f := &file{handle: h} ioInitOnce.Do(initIo) _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) if err != nil { return nil, err } err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) if err != nil { return nil, err } f.readDeadline.channel = make(timeoutChan) f.writeDeadline.channel = make(timeoutChan) return f, nil } // closeHandle closes the resources associated with a Win32 handle func (f *file) closeHandle() { f.wgLock.Lock() // Atomically set that we are closing, releasing the resources only once. if f.closing.Swap(true) == false { f.wgLock.Unlock() // cancel all IO and wait for it to complete windows.CancelIoEx(f.handle, nil) f.wg.Wait() // at this point, no new IO can start windows.Close(f.handle) f.handle = 0 } else { f.wgLock.Unlock() } } // Close closes a file. func (f *file) Close() error { f.closeHandle() return nil } // prepareIo prepares for a new IO operation. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. func (f *file) prepareIo() (*ioOperation, error) { f.wgLock.RLock() if f.closing.Load() { f.wgLock.RUnlock() return nil, os.ErrClosed } f.wg.Add(1) f.wgLock.RUnlock() c := &ioOperation{} c.ch = make(chan ioResult) return c, nil } // ioCompletionProcessor processes completed async IOs forever func ioCompletionProcessor(h windows.Handle) { for { var bytes uint32 var key uintptr var op *ioOperation err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) if op == nil { panic(err) } op.ch <- ioResult{bytes, err} } } // asyncIo processes the return value from ReadFile or WriteFile, blocking until // the operation has actually completed. func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { if err != windows.ERROR_IO_PENDING { return int(bytes), err } if f.closing.Load() { windows.CancelIoEx(f.handle, &c.o) } var timeout timeoutChan if d != nil { d.channelLock.Lock() timeout = d.channel d.channelLock.Unlock() } var r ioResult select { case r = <-c.ch: err = r.err if err == windows.ERROR_OPERATION_ABORTED { if f.closing.Load() { err = os.ErrClosed } } else if err != nil && f.socket { // err is from Win32. Query the overlapped structure to get the winsock error. var bytes, flags uint32 err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) } case <-timeout: windows.CancelIoEx(f.handle, &c.o) r = <-c.ch err = r.err if err == windows.ERROR_OPERATION_ABORTED { err = os.ErrDeadlineExceeded } } // runtime.KeepAlive is needed, as c is passed via native // code to ioCompletionProcessor, c must remain alive // until the channel read is complete. runtime.KeepAlive(c) return int(r.bytes), err } // Read reads from a file handle. func (f *file) Read(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() if f.readDeadline.timedout.Load() { return 0, os.ErrDeadlineExceeded } var bytes uint32 err = windows.ReadFile(f.handle, b, &bytes, &c.o) n, err := f.asyncIo(c, &f.readDeadline, bytes, err) runtime.KeepAlive(b) // Handle EOF conditions. if err == nil && n == 0 && len(b) != 0 { return 0, io.EOF } else if err == windows.ERROR_BROKEN_PIPE { return 0, io.EOF } else { return n, err } } // Write writes to a file handle. func (f *file) Write(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() if f.writeDeadline.timedout.Load() { return 0, os.ErrDeadlineExceeded } var bytes uint32 err = windows.WriteFile(f.handle, b, &bytes, &c.o) n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) runtime.KeepAlive(b) return n, err } func (f *file) SetReadDeadline(deadline time.Time) error { return f.readDeadline.set(deadline) } func (f *file) SetWriteDeadline(deadline time.Time) error { return f.writeDeadline.set(deadline) } func (f *file) Flush() error { return windows.FlushFileBuffers(f.handle) } func (f *file) Fd() uintptr { return uintptr(f.handle) } func (d *deadlineHandler) set(deadline time.Time) error { d.setLock.Lock() defer d.setLock.Unlock() if d.timer != nil { if !d.timer.Stop() { <-d.channel } d.timer = nil } d.timedout.Store(false) select { case <-d.channel: d.channelLock.Lock() d.channel = make(chan struct{}) d.channelLock.Unlock() default: } if deadline.IsZero() { return nil } timeoutIO := func() { d.timedout.Store(true) close(d.channel) } now := time.Now() duration := deadline.Sub(now) if deadline.After(now) { // Deadline is in the future, set a timer to wait d.timer = time.AfterFunc(duration, timeoutIO) } else { // Deadline is in the past. Cancel all pending IO now. timeoutIO() } return nil } wireguard-go-0.0.20231211/ipc/namedpipe/namedpipe.go000066400000000000000000000307131466306003400216530ustar00rootroot00000000000000// Copyright 2021 The Go Authors. All rights reserved. // Copyright 2015 Microsoft // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build windows // Package namedpipe implements a net.Conn and net.Listener around Windows named pipes. package namedpipe import ( "context" "io" "net" "os" "runtime" "sync/atomic" "time" "unsafe" "golang.org/x/sys/windows" ) type pipe struct { *file path string } type messageBytePipe struct { pipe writeClosed atomic.Bool readEOF bool } type pipeAddress string func (f *pipe) LocalAddr() net.Addr { return pipeAddress(f.path) } func (f *pipe) RemoteAddr() net.Addr { return pipeAddress(f.path) } func (f *pipe) SetDeadline(t time.Time) error { f.SetReadDeadline(t) f.SetWriteDeadline(t) return nil } // CloseWrite closes the write side of a message pipe in byte mode. func (f *messageBytePipe) CloseWrite() error { if !f.writeClosed.CompareAndSwap(false, true) { return io.ErrClosedPipe } err := f.file.Flush() if err != nil { f.writeClosed.Store(false) return err } _, err = f.file.Write(nil) if err != nil { f.writeClosed.Store(false) return err } return nil } // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // they are used to implement CloseWrite. func (f *messageBytePipe) Write(b []byte) (int, error) { if f.writeClosed.Load() { return 0, io.ErrClosedPipe } if len(b) == 0 { return 0, nil } return f.file.Write(b) } // Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message // mode pipe will return io.EOF, as will all subsequent reads. func (f *messageBytePipe) Read(b []byte) (int, error) { if f.readEOF { return 0, io.EOF } n, err := f.file.Read(b) if err == io.EOF { // If this was the result of a zero-byte read, then // it is possible that the read was due to a zero-size // message. Since we are simulating CloseWrite with a // zero-byte message, ensure that all future Read calls // also return EOF. f.readEOF = true } else if err == windows.ERROR_MORE_DATA { // ERROR_MORE_DATA indicates that the pipe's read mode is message mode // and the message still has more bytes. Treat this as a success, since // this package presents all named pipes as byte streams. err = nil } return n, err } func (f *pipe) Handle() windows.Handle { return f.handle } func (s pipeAddress) Network() string { return "pipe" } func (s pipeAddress) String() string { return string(s) } // tryDialPipe attempts to dial the specified pipe until cancellation or timeout. func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { for { select { case <-ctx.Done(): return 0, ctx.Err() default: path16, err := windows.UTF16PtrFromString(*path) if err != nil { return 0, err } h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) if err == nil { return h, nil } if err != windows.ERROR_PIPE_BUSY { return h, &os.PathError{Err: err, Op: "open", Path: *path} } // Wait 10 msec and try again. This is a rather simplistic // view, as we always try each 10 milliseconds. time.Sleep(10 * time.Millisecond) } } } // DialConfig exposes various options for use in Dial and DialContext. type DialConfig struct { ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. } // DialTimeout connects to the specified named pipe by path, timing out if the // connection takes longer than the specified duration. If timeout is zero, then // we use a default timeout of 2 seconds. func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) { if timeout == 0 { timeout = time.Second * 2 } absTimeout := time.Now().Add(timeout) ctx, _ := context.WithDeadline(context.Background(), absTimeout) conn, err := config.DialContext(ctx, path) if err == context.DeadlineExceeded { return nil, os.ErrDeadlineExceeded } return conn, err } // DialContext attempts to connect to the specified named pipe by path. func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) { var err error var h windows.Handle h, err = tryDialPipe(ctx, &path) if err != nil { return nil, err } if config.ExpectedOwner != nil { sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) if err != nil { windows.Close(h) return nil, err } realOwner, _, err := sd.Owner() if err != nil { windows.Close(h) return nil, err } if !realOwner.Equals(config.ExpectedOwner) { windows.Close(h) return nil, windows.ERROR_ACCESS_DENIED } } var flags uint32 err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil) if err != nil { windows.Close(h) return nil, err } f, err := makeFile(h) if err != nil { windows.Close(h) return nil, err } // If the pipe is in message mode, return a message byte pipe, which // supports CloseWrite. if flags&windows.PIPE_TYPE_MESSAGE != 0 { return &messageBytePipe{ pipe: pipe{file: f, path: path}, }, nil } return &pipe{file: f, path: path}, nil } var defaultDialer DialConfig // DialTimeout calls DialConfig.DialTimeout using an empty configuration. func DialTimeout(path string, timeout time.Duration) (net.Conn, error) { return defaultDialer.DialTimeout(path, timeout) } // DialContext calls DialConfig.DialContext using an empty configuration. func DialContext(ctx context.Context, path string) (net.Conn, error) { return defaultDialer.DialContext(ctx, path) } type acceptResponse struct { f *file err error } type pipeListener struct { firstHandle windows.Handle path string config ListenConfig acceptCh chan chan acceptResponse closeCh chan int doneCh chan int } func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) { path16, err := windows.UTF16PtrFromString(path) if err != nil { return 0, &os.PathError{Op: "open", Path: path, Err: err} } var oa windows.OBJECT_ATTRIBUTES oa.Length = uint32(unsafe.Sizeof(oa)) var ntPath windows.NTUnicodeString if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil { if ntstatus, ok := err.(windows.NTStatus); ok { err = ntstatus.Errno() } return 0, &os.PathError{Op: "open", Path: path, Err: err} } defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer))) oa.ObjectName = &ntPath // The security descriptor is only needed for the first pipe. if isFirstPipe { if sd != nil { oa.SecurityDescriptor = sd } else { // Construct the default named pipe security descriptor. var acl *windows.ACL if err := windows.RtlDefaultNpAcl(&acl); err != nil { return 0, err } defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) sd, err = windows.NewSecurityDescriptor() if err != nil { return 0, err } if err = sd.SetDACL(acl, true, false); err != nil { return 0, err } oa.SecurityDescriptor = sd } } typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) if c.MessageMode { typ |= windows.FILE_PIPE_MESSAGE_TYPE } disposition := uint32(windows.FILE_OPEN) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) if isFirstPipe { disposition = windows.FILE_CREATE // By not asking for read or write access, the named pipe file system // will put this pipe into an initially disconnected state, blocking // client connections until the next call with isFirstPipe == false. access = windows.SYNCHRONIZE } timeout := int64(-50 * 10000) // 50ms var ( h windows.Handle iosb windows.IO_STATUS_BLOCK ) err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout) if err != nil { if ntstatus, ok := err.(windows.NTStatus); ok { err = ntstatus.Errno() } return 0, &os.PathError{Op: "open", Path: path, Err: err} } runtime.KeepAlive(ntPath) return h, nil } func (l *pipeListener) makeServerPipe() (*file, error) { h, err := makeServerPipeHandle(l.path, nil, &l.config, false) if err != nil { return nil, err } f, err := makeFile(h) if err != nil { windows.Close(h) return nil, err } return f, nil } func (l *pipeListener) makeConnectedServerPipe() (*file, error) { p, err := l.makeServerPipe() if err != nil { return nil, err } // Wait for the client to connect. ch := make(chan error) go func(p *file) { ch <- connectPipe(p) }(p) select { case err = <-ch: if err != nil { p.Close() p = nil } case <-l.closeCh: // Abort the connect request by closing the handle. p.Close() p = nil err = <-ch if err == nil || err == os.ErrClosed { err = net.ErrClosed } } return p, err } func (l *pipeListener) listenerRoutine() { closed := false for !closed { select { case <-l.closeCh: closed = true case responseCh := <-l.acceptCh: var ( p *file err error ) for { p, err = l.makeConnectedServerPipe() // If the connection was immediately closed by the client, try // again. if err != windows.ERROR_NO_DATA { break } } responseCh <- acceptResponse{p, err} closed = err == net.ErrClosed } } windows.Close(l.firstHandle) l.firstHandle = 0 // Notify Close and Accept callers that the handle has been closed. close(l.doneCh) } // ListenConfig contains configuration for the pipe listener. type ListenConfig struct { // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used. SecurityDescriptor *windows.SECURITY_DESCRIPTOR // MessageMode determines whether the pipe is in byte or message mode. In either // case the pipe is read in byte mode by default. The only practical difference in // this implementation is that CloseWrite is only supported for message mode pipes; // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only // transferred to the reader (and returned as io.EOF in this implementation) // when the pipe is in message mode. MessageMode bool // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed. InputBufferSize int32 // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed. OutputBufferSize int32 } // Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. // The pipe must not already exist. func (c *ListenConfig) Listen(path string) (net.Listener, error) { h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) if err != nil { return nil, err } l := &pipeListener{ firstHandle: h, path: path, config: *c, acceptCh: make(chan chan acceptResponse), closeCh: make(chan int), doneCh: make(chan int), } // The first connection is swallowed on Windows 7 & 8, so synthesize it. if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) { path16, err := windows.UTF16PtrFromString(path) if err == nil { h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) if err == nil { windows.CloseHandle(h) } } } go l.listenerRoutine() return l, nil } var defaultListener ListenConfig // Listen calls ListenConfig.Listen using an empty configuration. func Listen(path string) (net.Listener, error) { return defaultListener.Listen(path) } func connectPipe(p *file) error { c, err := p.prepareIo() if err != nil { return err } defer p.wg.Done() err = windows.ConnectNamedPipe(p.handle, &c.o) _, err = p.asyncIo(c, nil, 0, err) if err != nil && err != windows.ERROR_PIPE_CONNECTED { return err } return nil } func (l *pipeListener) Accept() (net.Conn, error) { ch := make(chan acceptResponse) select { case l.acceptCh <- ch: response := <-ch err := response.err if err != nil { return nil, err } if l.config.MessageMode { return &messageBytePipe{ pipe: pipe{file: response.f, path: l.path}, }, nil } return &pipe{file: response.f, path: l.path}, nil case <-l.doneCh: return nil, net.ErrClosed } } func (l *pipeListener) Close() error { select { case l.closeCh <- 1: <-l.doneCh case <-l.doneCh: } return nil } func (l *pipeListener) Addr() net.Addr { return pipeAddress(l.path) } wireguard-go-0.0.20231211/ipc/namedpipe/namedpipe_test.go000066400000000000000000000324101466306003400227060ustar00rootroot00000000000000// Copyright 2021 The Go Authors. All rights reserved. // Copyright 2015 Microsoft // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build windows package namedpipe_test import ( "bufio" "bytes" "context" "errors" "io" "net" "os" "sync" "syscall" "testing" "time" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/ipc/namedpipe" ) func randomPipePath() string { guid, err := windows.GenerateGUID() if err != nil { panic(err) } return `\\.\PIPE\go-namedpipe-test-` + guid.String() } func TestPingPong(t *testing.T) { const ( ping = 42 pong = 24 ) pipePath := randomPipePath() listener, err := namedpipe.Listen(pipePath) if err != nil { t.Fatalf("unable to listen on pipe: %v", err) } defer listener.Close() go func() { incoming, err := listener.Accept() if err != nil { t.Fatalf("unable to accept pipe connection: %v", err) } defer incoming.Close() var data [1]byte _, err = incoming.Read(data[:]) if err != nil { t.Fatalf("unable to read ping from pipe: %v", err) } if data[0] != ping { t.Fatalf("expected ping, got %d", data[0]) } data[0] = pong _, err = incoming.Write(data[:]) if err != nil { t.Fatalf("unable to write pong to pipe: %v", err) } }() client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatalf("unable to dial pipe: %v", err) } defer client.Close() client.SetDeadline(time.Now().Add(time.Second * 5)) var data [1]byte data[0] = ping _, err = client.Write(data[:]) if err != nil { t.Fatalf("unable to write ping to pipe: %v", err) } _, err = client.Read(data[:]) if err != nil { t.Fatalf("unable to read pong from pipe: %v", err) } if data[0] != pong { t.Fatalf("expected pong, got %d", data[0]) } } func TestDialUnknownFailsImmediately(t *testing.T) { _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0)) if !errors.Is(err, syscall.ENOENT) { t.Fatalf("expected ENOENT got %v", err) } } func TestDialListenerTimesOut(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond) if err == nil { pipe.Close() } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } } func TestDialContextListenerTimesOut(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() d := 10 * time.Millisecond ctx, _ := context.WithTimeout(context.Background(), d) pipe, err := namedpipe.DialContext(ctx, pipePath) if err == nil { pipe.Close() } if err != context.DeadlineExceeded { t.Fatalf("expected context.DeadlineExceeded, got %v", err) } } func TestDialListenerGetsCancelled(t *testing.T) { pipePath := randomPipePath() ctx, cancel := context.WithCancel(context.Background()) l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() ch := make(chan error) go func(ctx context.Context, ch chan error) { _, err := namedpipe.DialContext(ctx, pipePath) ch <- err }(ctx, ch) time.Sleep(time.Millisecond * 30) cancel() err = <-ch if err != context.Canceled { t.Fatalf("expected context.Canceled, got %v", err) } } func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil { t.Skip("dacls on named pipes are broken on wine") } pipePath := randomPipePath() sd, _ := windows.SecurityDescriptorFromString("D:") l, err := (&namedpipe.ListenConfig{ SecurityDescriptor: sd, }).Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err == nil { pipe.Close() } if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) } } func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) { pipePath := randomPipePath() if cfg == nil { cfg = &namedpipe.ListenConfig{} } l, err := cfg.Listen(pipePath) if err != nil { return } defer l.Close() type response struct { c net.Conn err error } ch := make(chan response) go func() { c, err := l.Accept() ch <- response{c, err} }() c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { return } r := <-ch if err = r.err; err != nil { c.Close() return } client = c server = r.c return } func TestReadTimeout(t *testing.T) { c, s, err := getConnection(nil) if err != nil { t.Fatal(err) } defer c.Close() defer s.Close() c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) buf := make([]byte, 10) _, err = c.Read(buf) if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } } func server(l net.Listener, ch chan int) { c, err := l.Accept() if err != nil { panic(err) } rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) s, err := rw.ReadString('\n') if err != nil { panic(err) } _, err = rw.WriteString("got " + s) if err != nil { panic(err) } err = rw.Flush() if err != nil { panic(err) } c.Close() ch <- 1 } func TestFullListenDialReadWrite(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() ch := make(chan int) go server(l, ch) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } defer c.Close() rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) _, err = rw.WriteString("hello world\n") if err != nil { t.Fatal(err) } err = rw.Flush() if err != nil { t.Fatal(err) } s, err := rw.ReadString('\n') if err != nil { t.Fatal(err) } ms := "got hello world\n" if s != ms { t.Errorf("expected '%s', got '%s'", ms, s) } <-ch } func TestCloseAbortsListen(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } ch := make(chan error) go func() { _, err := l.Accept() ch <- err }() time.Sleep(30 * time.Millisecond) l.Close() err = <-ch if err != net.ErrClosed { t.Fatalf("expected net.ErrClosed, got %v", err) } } func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { b := make([]byte, 10) w.Close() n, err := r.Read(b) if n > 0 { t.Errorf("unexpected byte count %d", n) } if err != io.EOF { t.Errorf("expected EOF: %v", err) } } func TestCloseClientEOFServer(t *testing.T) { c, s, err := getConnection(nil) if err != nil { t.Fatal(err) } defer c.Close() defer s.Close() ensureEOFOnClose(t, c, s) } func TestCloseServerEOFClient(t *testing.T) { c, s, err := getConnection(nil) if err != nil { t.Fatal(err) } defer c.Close() defer s.Close() ensureEOFOnClose(t, s, c) } func TestCloseWriteEOF(t *testing.T) { cfg := &namedpipe.ListenConfig{ MessageMode: true, } c, s, err := getConnection(cfg) if err != nil { t.Fatal(err) } defer c.Close() defer s.Close() type closeWriter interface { CloseWrite() error } err = c.(closeWriter).CloseWrite() if err != nil { t.Fatal(err) } b := make([]byte, 10) _, err = s.Read(b) if err != io.EOF { t.Fatal(err) } } func TestAcceptAfterCloseFails(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } l.Close() _, err = l.Accept() if err != net.ErrClosed { t.Fatalf("expected net.ErrClosed, got %v", err) } } func TestDialTimesOutByDefault(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. if err == nil { pipe.Close() } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } } func TestTimeoutPendingRead(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() serverDone := make(chan struct{}) go func() { s, err := l.Accept() if err != nil { t.Fatal(err) } time.Sleep(1 * time.Second) s.Close() close(serverDone) }() client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } defer client.Close() clientErr := make(chan error) go func() { buf := make([]byte, 10) _, err = client.Read(buf) clientErr <- err }() time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline client.SetReadDeadline(time.Unix(1, 0)) select { case err = <-clientErr: if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timed out while waiting for read to cancel") <-clientErr } <-serverDone } func TestTimeoutPendingWrite(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() serverDone := make(chan struct{}) go func() { s, err := l.Accept() if err != nil { t.Fatal(err) } time.Sleep(1 * time.Second) s.Close() close(serverDone) }() client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } defer client.Close() clientErr := make(chan error) go func() { _, err = client.Write([]byte("this should timeout")) clientErr <- err }() time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline client.SetWriteDeadline(time.Unix(1, 0)) select { case err = <-clientErr: if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timed out while waiting for write to cancel") <-clientErr } <-serverDone } type CloseWriter interface { CloseWrite() error } func TestEchoWithMessaging(t *testing.T) { pipePath := randomPipePath() l, err := (&namedpipe.ListenConfig{ MessageMode: true, // Use message mode so that CloseWrite() is supported InputBufferSize: 65536, // Use 64KB buffers to improve performance OutputBufferSize: 65536, }).Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() listenerDone := make(chan bool) clientDone := make(chan bool) go func() { // server echo conn, err := l.Accept() if err != nil { t.Fatal(err) } defer conn.Close() time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent _, err = io.Copy(conn, conn) if err != nil { t.Fatal(err) } conn.(CloseWriter).CloseWrite() close(listenerDone) }() client, err := namedpipe.DialTimeout(pipePath, time.Second) if err != nil { t.Fatal(err) } defer client.Close() go func() { // client read back bytes := make([]byte, 2) n, e := client.Read(bytes) if e != nil { t.Fatal(e) } if n != 2 || bytes[0] != 0 || bytes[1] != 1 { t.Fatalf("expected 2 bytes, got %v", n) } close(clientDone) }() payload := make([]byte, 2) payload[0] = 0 payload[1] = 1 n, err := client.Write(payload) if err != nil { t.Fatal(err) } if n != 2 { t.Fatalf("expected 2 bytes, got %v", n) } client.(CloseWriter).CloseWrite() <-listenerDone <-clientDone } func TestConnectRace(t *testing.T) { pipePath := randomPipePath() l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() go func() { for { s, err := l.Accept() if err == net.ErrClosed { return } if err != nil { t.Fatal(err) } s.Close() } }() for i := 0; i < 1000; i++ { c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } c.Close() } } func TestMessageReadMode(t *testing.T) { if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { t.Skipf("Skipping on Windows %d", maj) } var wg sync.WaitGroup defer wg.Wait() pipePath := randomPipePath() l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() msg := ([]byte)("hello world") wg.Add(1) go func() { defer wg.Done() s, err := l.Accept() if err != nil { t.Fatal(err) } _, err = s.Write(msg) if err != nil { t.Fatal(err) } s.Close() }() c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } defer c.Close() mode := uint32(windows.PIPE_READMODE_MESSAGE) err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) if err != nil { t.Fatal(err) } ch := make([]byte, 1) var vmsg []byte for { n, err := c.Read(ch) if err == io.EOF { break } if err != nil { t.Fatal(err) } if n != 1 { t.Fatalf("expected 1, got %d", n) } vmsg = append(vmsg, ch[0]) } if !bytes.Equal(msg, vmsg) { t.Fatalf("expected %s, got %s", msg, vmsg) } } func TestListenConnectRace(t *testing.T) { if testing.Short() { t.Skip("Skipping long race test") } pipePath := randomPipePath() for i := 0; i < 50 && !t.Failed(); i++ { var wg sync.WaitGroup wg.Add(1) go func() { c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err == nil { c.Close() } wg.Done() }() s, err := namedpipe.Listen(pipePath) if err != nil { t.Error(i, err) } else { s.Close() } wg.Wait() } } wireguard-go-0.0.20231211/ipc/uapi_bsd.go000066400000000000000000000050561466306003400175370ustar00rootroot00000000000000//go:build darwin || freebsd || openbsd /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ipc import ( "errors" "net" "os" "unsafe" "golang.org/x/sys/unix" ) type UAPIListener struct { listener net.Listener // unix socket listener connNew chan net.Conn connErr chan error kqueueFd int keventFd int } func (l *UAPIListener) Accept() (net.Conn, error) { for { select { case conn := <-l.connNew: return conn, nil case err := <-l.connErr: return nil, err } } } func (l *UAPIListener) Close() error { err1 := unix.Close(l.kqueueFd) err2 := unix.Close(l.keventFd) err3 := l.listener.Close() if err1 != nil { return err1 } if err2 != nil { return err2 } return err3 } func (l *UAPIListener) Addr() net.Addr { return l.listener.Addr() } func UAPIListen(name string, file *os.File) (net.Listener, error) { // wrap file in listener listener, err := net.FileListener(file) if err != nil { return nil, err } uapi := &UAPIListener{ listener: listener, connNew: make(chan net.Conn, 1), connErr: make(chan error, 1), } if unixListener, ok := listener.(*net.UnixListener); ok { unixListener.SetUnlinkOnClose(true) } socketPath := sockPath(name) // watch for deletion of socket uapi.kqueueFd, err = unix.Kqueue() if err != nil { return nil, err } uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0) if err != nil { unix.Close(uapi.kqueueFd) return nil, err } go func(l *UAPIListener) { event := unix.Kevent_t{ Filter: unix.EVFILT_VNODE, Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, Fflags: unix.NOTE_WRITE, } // Allow this assignment to work with both the 32-bit and 64-bit version // of the above struct. If you know another way, please submit a patch. *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd) events := make([]unix.Kevent_t, 1) n := 1 var kerr error for { // start with lstat to avoid race condition if _, err := os.Lstat(socketPath); os.IsNotExist(err) { l.connErr <- err return } if (kerr != nil || n != 1) && kerr != unix.EINTR { if kerr != nil { l.connErr <- kerr } else { l.connErr <- errors.New("kqueue returned empty") } return } n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil) } }(uapi) // watch for new connections go func(l *UAPIListener) { for { conn, err := l.listener.Accept() if err != nil { l.connErr <- err break } l.connNew <- conn } }(uapi) return uapi, nil } wireguard-go-0.0.20231211/ipc/uapi_linux.go000066400000000000000000000043511466306003400201230ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ipc import ( "net" "os" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/rwcancel" ) type UAPIListener struct { listener net.Listener // unix socket listener connNew chan net.Conn connErr chan error inotifyFd int inotifyRWCancel *rwcancel.RWCancel } func (l *UAPIListener) Accept() (net.Conn, error) { for { select { case conn := <-l.connNew: return conn, nil case err := <-l.connErr: return nil, err } } } func (l *UAPIListener) Close() error { err1 := unix.Close(l.inotifyFd) err2 := l.inotifyRWCancel.Cancel() err3 := l.listener.Close() if err1 != nil { return err1 } if err2 != nil { return err2 } return err3 } func (l *UAPIListener) Addr() net.Addr { return l.listener.Addr() } func UAPIListen(name string, file *os.File) (net.Listener, error) { // wrap file in listener listener, err := net.FileListener(file) if err != nil { return nil, err } if unixListener, ok := listener.(*net.UnixListener); ok { unixListener.SetUnlinkOnClose(true) } uapi := &UAPIListener{ listener: listener, connNew: make(chan net.Conn, 1), connErr: make(chan error, 1), } // watch for deletion of socket socketPath := sockPath(name) uapi.inotifyFd, err = unix.InotifyInit() if err != nil { return nil, err } _, err = unix.InotifyAddWatch( uapi.inotifyFd, socketPath, unix.IN_ATTRIB| unix.IN_DELETE| unix.IN_DELETE_SELF, ) if err != nil { return nil, err } uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) if err != nil { unix.Close(uapi.inotifyFd) return nil, err } go func(l *UAPIListener) { var buf [0]byte for { defer uapi.inotifyRWCancel.Close() // start with lstat to avoid race condition if _, err := os.Lstat(socketPath); os.IsNotExist(err) { l.connErr <- err return } _, err := uapi.inotifyRWCancel.Read(buf[:]) if err != nil { l.connErr <- err return } } }(uapi) // watch for new connections go func(l *UAPIListener) { for { conn, err := l.listener.Accept() if err != nil { l.connErr <- err break } l.connNew <- conn } }(uapi) return uapi, nil } wireguard-go-0.0.20231211/ipc/uapi_unix.go000066400000000000000000000026241466306003400177500ustar00rootroot00000000000000//go:build linux || darwin || freebsd || openbsd /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ipc import ( "errors" "fmt" "net" "os" "golang.org/x/sys/unix" ) const ( IpcErrorIO = -int64(unix.EIO) IpcErrorProtocol = -int64(unix.EPROTO) IpcErrorInvalid = -int64(unix.EINVAL) IpcErrorPortInUse = -int64(unix.EADDRINUSE) IpcErrorUnknown = -55 // ENOANO ) // socketDirectory is variable because it is modified by a linker // flag in wireguard-android. var socketDirectory = "/var/run/wireguard" func sockPath(iface string) string { return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) } func UAPIOpen(name string) (*os.File, error) { if err := os.MkdirAll(socketDirectory, 0o755); err != nil { return nil, err } socketPath := sockPath(name) addr, err := net.ResolveUnixAddr("unix", socketPath) if err != nil { return nil, err } oldUmask := unix.Umask(0o077) defer unix.Umask(oldUmask) listener, err := net.ListenUnix("unix", addr) if err == nil { return listener.File() } // Test socket, if not in use cleanup and try again. if _, err := net.Dial("unix", socketPath); err == nil { return nil, errors.New("unix socket in use") } if err := os.Remove(socketPath); err != nil { return nil, err } listener, err = net.ListenUnix("unix", addr) if err != nil { return nil, err } return listener.File() } wireguard-go-0.0.20231211/ipc/uapi_wasm.go000066400000000000000000000004471466306003400177350ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ipc // Made up sentinel error codes for {js,wasip1}/wasm. const ( IpcErrorIO = 1 IpcErrorInvalid = 2 IpcErrorPortInUse = 3 IpcErrorUnknown = 4 IpcErrorProtocol = 5 ) wireguard-go-0.0.20231211/ipc/uapi_windows.go000066400000000000000000000033001466306003400204470ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ipc import ( "net" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/ipc/namedpipe" ) // TODO: replace these with actual standard windows error numbers from the win package const ( IpcErrorIO = -int64(5) IpcErrorProtocol = -int64(71) IpcErrorInvalid = -int64(22) IpcErrorPortInUse = -int64(98) IpcErrorUnknown = -int64(55) ) type UAPIListener struct { listener net.Listener // unix socket listener connNew chan net.Conn connErr chan error kqueueFd int keventFd int } func (l *UAPIListener) Accept() (net.Conn, error) { for { select { case conn := <-l.connNew: return conn, nil case err := <-l.connErr: return nil, err } } } func (l *UAPIListener) Close() error { return l.listener.Close() } func (l *UAPIListener) Addr() net.Addr { return l.listener.Addr() } var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR func init() { var err error UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)") if err != nil { panic(err) } } func UAPIListen(name string) (net.Listener, error) { listener, err := (&namedpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name) if err != nil { return nil, err } uapi := &UAPIListener{ listener: listener, connNew: make(chan net.Conn, 1), connErr: make(chan error, 1), } go func(l *UAPIListener) { for { conn, err := l.listener.Accept() if err != nil { l.connErr <- err break } l.connNew <- conn } }(uapi) return uapi, nil } wireguard-go-0.0.20231211/main.go000066400000000000000000000137531466306003400161250ustar00rootroot00000000000000//go:build !windows /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package main import ( "fmt" "os" "os/signal" "runtime" "strconv" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" ) const ( ExitSetupSuccess = 0 ExitSetupFailed = 1 ) const ( ENV_WG_TUN_FD = "WG_TUN_FD" ENV_WG_UAPI_FD = "WG_UAPI_FD" ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" ) func printUsage() { fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) } func warning() { switch runtime.GOOS { case "linux", "freebsd", "openbsd": if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { return } default: return } fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐") fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │") fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │") fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") fmt.Fprintln(os.Stderr, "│ please visit: │") fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │") fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘") } func main() { if len(os.Args) == 2 && os.Args[1] == "--version" { fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", Version, runtime.GOOS, runtime.GOARCH) return } warning() var foreground bool var interfaceName string if len(os.Args) < 2 || len(os.Args) > 3 { printUsage() return } switch os.Args[1] { case "-f", "--foreground": foreground = true if len(os.Args) != 3 { printUsage() return } interfaceName = os.Args[2] default: foreground = false if len(os.Args) != 2 { printUsage() return } interfaceName = os.Args[1] } if !foreground { foreground = os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" } // get log level (default: info) logLevel := func() int { switch os.Getenv("LOG_LEVEL") { case "verbose", "debug": return device.LogLevelVerbose case "error": return device.LogLevelError case "silent": return device.LogLevelSilent } return device.LogLevelError }() // open TUN device (or use supplied fd) tdev, err := func() (tun.Device, error) { tunFdStr := os.Getenv(ENV_WG_TUN_FD) if tunFdStr == "" { return tun.CreateTUN(interfaceName, device.DefaultMTU) } // construct tun device from supplied fd fd, err := strconv.ParseUint(tunFdStr, 10, 32) if err != nil { return nil, err } err = unix.SetNonblock(int(fd), true) if err != nil { return nil, err } file := os.NewFile(uintptr(fd), "") return tun.CreateTUNFromFile(file, device.DefaultMTU) }() if err == nil { realInterfaceName, err2 := tdev.Name() if err2 == nil { interfaceName = realInterfaceName } } logger := device.NewLogger( logLevel, fmt.Sprintf("(%s) ", interfaceName), ) logger.Verbosef("Starting wireguard-go version %s", Version) if err != nil { logger.Errorf("Failed to create TUN device: %v", err) os.Exit(ExitSetupFailed) } // open UAPI file (or use supplied fd) fileUAPI, err := func() (*os.File, error) { uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) if uapiFdStr == "" { return ipc.UAPIOpen(interfaceName) } // use supplied fd fd, err := strconv.ParseUint(uapiFdStr, 10, 32) if err != nil { return nil, err } return os.NewFile(uintptr(fd), ""), nil }() if err != nil { logger.Errorf("UAPI listen error: %v", err) os.Exit(ExitSetupFailed) return } // daemonize the process if !foreground { env := os.Environ() env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) files := [3]*os.File{} if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent { files[0], _ = os.Open(os.DevNull) files[1] = os.Stdout files[2] = os.Stderr } else { files[0], _ = os.Open(os.DevNull) files[1], _ = os.Open(os.DevNull) files[2], _ = os.Open(os.DevNull) } attr := &os.ProcAttr{ Files: []*os.File{ files[0], // stdin files[1], // stdout files[2], // stderr tdev.File(), fileUAPI, }, Dir: ".", Env: env, } path, err := os.Executable() if err != nil { logger.Errorf("Failed to determine executable: %v", err) os.Exit(ExitSetupFailed) } process, err := os.StartProcess( path, os.Args, attr, ) if err != nil { logger.Errorf("Failed to daemonize: %v", err) os.Exit(ExitSetupFailed) } process.Release() return } device := device.NewDevice(tdev, conn.NewDefaultBind(), logger) logger.Verbosef("Device started") errs := make(chan error) term := make(chan os.Signal, 1) uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) if err != nil { logger.Errorf("Failed to listen on uapi socket: %v", err) os.Exit(ExitSetupFailed) } go func() { for { conn, err := uapi.Accept() if err != nil { errs <- err return } go device.IpcHandle(conn) } }() logger.Verbosef("UAPI listener started") // wait for program to terminate signal.Notify(term, unix.SIGTERM) signal.Notify(term, os.Interrupt) select { case <-term: case <-errs: case <-device.Wait(): } // clean up uapi.Close() device.Close() logger.Verbosef("Shutting down") } wireguard-go-0.0.20231211/main_windows.go000066400000000000000000000040131466306003400176640ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package main import ( "fmt" "os" "os/signal" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" ) const ( ExitSetupSuccess = 0 ExitSetupFailed = 1 ) func main() { if len(os.Args) != 2 { os.Exit(ExitSetupFailed) } interfaceName := os.Args[1] fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is , which includes this code as a module.") logger := device.NewLogger( device.LogLevelVerbose, fmt.Sprintf("(%s) ", interfaceName), ) logger.Verbosef("Starting wireguard-go version %s", Version) tun, err := tun.CreateTUN(interfaceName, 0) if err == nil { realInterfaceName, err2 := tun.Name() if err2 == nil { interfaceName = realInterfaceName } } else { logger.Errorf("Failed to create TUN device: %v", err) os.Exit(ExitSetupFailed) } device := device.NewDevice(tun, conn.NewDefaultBind(), logger) err = device.Up() if err != nil { logger.Errorf("Failed to bring up device: %v", err) os.Exit(ExitSetupFailed) } logger.Verbosef("Device started") uapi, err := ipc.UAPIListen(interfaceName) if err != nil { logger.Errorf("Failed to listen on uapi socket: %v", err) os.Exit(ExitSetupFailed) } errs := make(chan error) term := make(chan os.Signal, 1) go func() { for { conn, err := uapi.Accept() if err != nil { errs <- err return } go device.IpcHandle(conn) } }() logger.Verbosef("UAPI listener started") // wait for program to terminate signal.Notify(term, os.Interrupt) signal.Notify(term, os.Kill) signal.Notify(term, windows.SIGTERM) select { case <-term: case <-errs: case <-device.Wait(): } // clean up uapi.Close() device.Close() logger.Verbosef("Shutting down") } wireguard-go-0.0.20231211/ratelimiter/000077500000000000000000000000001466306003400171625ustar00rootroot00000000000000wireguard-go-0.0.20231211/ratelimiter/ratelimiter.go000066400000000000000000000050711466306003400220350ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ratelimiter import ( "net/netip" "sync" "time" ) const ( packetsPerSecond = 20 packetsBurstable = 5 garbageCollectTime = time.Second packetCost = 1000000000 / packetsPerSecond maxTokens = packetCost * packetsBurstable ) type RatelimiterEntry struct { mu sync.Mutex lastTime time.Time tokens int64 } type Ratelimiter struct { mu sync.RWMutex timeNow func() time.Time stopReset chan struct{} // send to reset, close to stop table map[netip.Addr]*RatelimiterEntry } func (rate *Ratelimiter) Close() { rate.mu.Lock() defer rate.mu.Unlock() if rate.stopReset != nil { close(rate.stopReset) } } func (rate *Ratelimiter) Init() { rate.mu.Lock() defer rate.mu.Unlock() if rate.timeNow == nil { rate.timeNow = time.Now } // stop any ongoing garbage collection routine if rate.stopReset != nil { close(rate.stopReset) } rate.stopReset = make(chan struct{}) rate.table = make(map[netip.Addr]*RatelimiterEntry) stopReset := rate.stopReset // store in case Init is called again. // Start garbage collection routine. go func() { ticker := time.NewTicker(time.Second) ticker.Stop() for { select { case _, ok := <-stopReset: ticker.Stop() if !ok { return } ticker = time.NewTicker(time.Second) case <-ticker.C: if rate.cleanup() { ticker.Stop() } } } }() } func (rate *Ratelimiter) cleanup() (empty bool) { rate.mu.Lock() defer rate.mu.Unlock() for key, entry := range rate.table { entry.mu.Lock() if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { delete(rate.table, key) } entry.mu.Unlock() } return len(rate.table) == 0 } func (rate *Ratelimiter) Allow(ip netip.Addr) bool { var entry *RatelimiterEntry // lookup entry rate.mu.RLock() entry = rate.table[ip] rate.mu.RUnlock() // make new entry if not found if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost entry.lastTime = rate.timeNow() rate.mu.Lock() rate.table[ip] = entry if len(rate.table) == 1 { rate.stopReset <- struct{}{} } rate.mu.Unlock() return true } // add tokens to entry entry.mu.Lock() now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.lastTime = now if entry.tokens > maxTokens { entry.tokens = maxTokens } // subtract cost of packet if entry.tokens > packetCost { entry.tokens -= packetCost entry.mu.Unlock() return true } entry.mu.Unlock() return false } wireguard-go-0.0.20231211/ratelimiter/ratelimiter_test.go000066400000000000000000000047761466306003400231070ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ratelimiter import ( "net/netip" "testing" "time" ) type result struct { allowed bool text string wait time.Duration } func TestRatelimiter(t *testing.T) { var rate Ratelimiter var expectedResults []result nano := func(nano int64) time.Duration { return time.Nanosecond * time.Duration(nano) } add := func(res result) { expectedResults = append( expectedResults, res, ) } for i := 0; i < packetsBurstable; i++ { add(result{ allowed: true, text: "initial burst", }) } add(result{ allowed: false, text: "after burst", }) add(result{ allowed: true, wait: nano(time.Second.Nanoseconds() / packetsPerSecond), text: "filling tokens for single packet", }) add(result{ allowed: false, text: "not having refilled enough", }) add(result{ allowed: true, wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), text: "filling tokens for two packet burst", }) add(result{ allowed: true, text: "second packet in 2 packet burst", }) add(result{ allowed: false, text: "packet following 2 packet burst", }) ips := []netip.Addr{ netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("172.167.2.3"), netip.MustParseAddr("97.231.252.215"), netip.MustParseAddr("248.97.91.167"), netip.MustParseAddr("188.208.233.47"), netip.MustParseAddr("104.2.183.179"), netip.MustParseAddr("72.129.46.120"), netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), } now := time.Now() rate.timeNow = func() time.Time { return now } defer func() { // Lock to avoid data race with cleanup goroutine from Init. rate.mu.Lock() defer rate.mu.Unlock() rate.timeNow = time.Now }() timeSleep := func(d time.Duration) { now = now.Add(d + 1) rate.cleanup() } rate.Init() defer rate.Close() for i, res := range expectedResults { timeSleep(res.wait) for _, ip := range ips { allowed := rate.Allow(ip) if allowed != res.allowed { t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) } } } } wireguard-go-0.0.20231211/replay/000077500000000000000000000000001466306003400161355ustar00rootroot00000000000000wireguard-go-0.0.20231211/replay/replay.go000066400000000000000000000032701466306003400177620ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. package replay type block uint64 const ( blockBitLog = 6 // 1<<6 == 64 bits blockBits = 1 << blockBitLog // must be power of 2 ringBlocks = 1 << 7 // must be power of 2 windowSize = (ringBlocks - 1) * blockBits blockMask = ringBlocks - 1 bitMask = blockBits - 1 ) // A Filter rejects replayed messages by checking if message counter value is // within a sliding window of previously received messages. // The zero value for Filter is an empty filter ready to use. // Filters are unsafe for concurrent use. type Filter struct { last uint64 ring [ringBlocks]block } // Reset resets the filter to empty state. func (f *Filter) Reset() { f.last = 0 f.ring[0] = 0 } // ValidateCounter checks if the counter should be accepted. // Overlimit counters (>= limit) are always rejected. func (f *Filter) ValidateCounter(counter, limit uint64) bool { if counter >= limit { return false } indexBlock := counter >> blockBitLog if counter > f.last { // move window forward current := f.last >> blockBitLog diff := indexBlock - current if diff > ringBlocks { diff = ringBlocks // cap diff to clear the whole ring } for i := current + 1; i <= current+diff; i++ { f.ring[i&blockMask] = 0 } f.last = counter } else if f.last-counter > windowSize { // behind current window return false } // check and set bit indexBlock &= blockMask indexBit := counter & bitMask old := f.ring[indexBlock] new := old | 1< 0; i-- { T(i, true) } t.Log("Bulk test 4") filter.Reset() testNumber = 0 for i := uint64(windowSize + 2); i > 1; i-- { T(i, true) } T(0, false) t.Log("Bulk test 5") filter.Reset() testNumber = 0 for i := uint64(windowSize); i > 0; i-- { T(i, true) } T(windowSize+1, true) T(0, false) t.Log("Bulk test 6") filter.Reset() testNumber = 0 for i := uint64(windowSize); i > 0; i-- { T(i, true) } T(0, true) T(windowSize+1, true) } wireguard-go-0.0.20231211/rwcancel/000077500000000000000000000000001466306003400164375ustar00rootroot00000000000000wireguard-go-0.0.20231211/rwcancel/rwcancel.go000066400000000000000000000043121466306003400205640ustar00rootroot00000000000000//go:build !windows && !wasm /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ // Package rwcancel implements cancelable read/write operations on // a file descriptor. package rwcancel import ( "errors" "os" "syscall" "golang.org/x/sys/unix" ) type RWCancel struct { fd int closingReader *os.File closingWriter *os.File } func NewRWCancel(fd int) (*RWCancel, error) { err := unix.SetNonblock(fd, true) if err != nil { return nil, err } rwcancel := RWCancel{fd: fd} rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe() if err != nil { return nil, err } return &rwcancel, nil } func RetryAfterError(err error) bool { return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) } func (rw *RWCancel) ReadyRead() bool { closeFd := int32(rw.closingReader.Fd()) pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}} var err error for { _, err = unix.Poll(pollFds, -1) if err == nil || !RetryAfterError(err) { break } } if err != nil { return false } if pollFds[1].Revents != 0 { return false } return pollFds[0].Revents != 0 } func (rw *RWCancel) ReadyWrite() bool { closeFd := int32(rw.closingReader.Fd()) pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} var err error for { _, err = unix.Poll(pollFds, -1) if err == nil || !RetryAfterError(err) { break } } if err != nil { return false } if pollFds[1].Revents != 0 { return false } return pollFds[0].Revents != 0 } func (rw *RWCancel) Read(p []byte) (n int, err error) { for { n, err := unix.Read(rw.fd, p) if err == nil || !RetryAfterError(err) { return n, err } if !rw.ReadyRead() { return 0, os.ErrClosed } } } func (rw *RWCancel) Write(p []byte) (n int, err error) { for { n, err := unix.Write(rw.fd, p) if err == nil || !RetryAfterError(err) { return n, err } if !rw.ReadyWrite() { return 0, os.ErrClosed } } } func (rw *RWCancel) Cancel() (err error) { _, err = rw.closingWriter.Write([]byte{0}) return } func (rw *RWCancel) Close() { rw.closingReader.Close() rw.closingWriter.Close() } wireguard-go-0.0.20231211/rwcancel/rwcancel_stub.go000066400000000000000000000002041466306003400216150ustar00rootroot00000000000000//go:build windows || wasm // SPDX-License-Identifier: MIT package rwcancel type RWCancel struct{} func (*RWCancel) Cancel() {} wireguard-go-0.0.20231211/tai64n/000077500000000000000000000000001466306003400157465ustar00rootroot00000000000000wireguard-go-0.0.20231211/tai64n/tai64n.go000066400000000000000000000015321466306003400174030ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tai64n import ( "bytes" "encoding/binary" "time" ) const ( TimestampSize = 12 base = uint64(0x400000000000000a) whitenerMask = uint32(0x1000000 - 1) ) type Timestamp [TimestampSize]byte func stamp(t time.Time) Timestamp { var tai64n Timestamp secs := base + uint64(t.Unix()) nano := uint32(t.Nanosecond()) &^ whitenerMask binary.BigEndian.PutUint64(tai64n[:], secs) binary.BigEndian.PutUint32(tai64n[8:], nano) return tai64n } func Now() Timestamp { return stamp(time.Now()) } func (t1 Timestamp) After(t2 Timestamp) bool { return bytes.Compare(t1[:], t2[:]) > 0 } func (t Timestamp) String() string { return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String() } wireguard-go-0.0.20231211/tai64n/tai64n_test.go000066400000000000000000000022431466306003400204420ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tai64n import ( "testing" "time" ) // Test that timestamps are monotonic as required by Wireguard and that // nanosecond-level information is whitened to prevent side channel attacks. func TestMonotonic(t *testing.T) { startTime := time.Unix(0, 123456789) // a nontrivial bit pattern // Whitening should reduce timestamp granularity // to more than 10 but fewer than 20 milliseconds. tests := []struct { name string t1, t2 time.Time wantAfter bool }{ {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false}, {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false}, {"after_1_ms", startTime, startTime.Add(time.Millisecond), false}, {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false}, {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ts1, ts2 := stamp(tt.t1), stamp(tt.t2) got := ts2.After(ts1) if got != tt.wantAfter { t.Errorf("after = %v; want %v", got, tt.wantAfter) } }) } } wireguard-go-0.0.20231211/tests/000077500000000000000000000000001466306003400160035ustar00rootroot00000000000000wireguard-go-0.0.20231211/tests/netns.sh000077500000000000000000000444201466306003400174750ustar00rootroot00000000000000#!/bin/bash # Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. # This script tests the below topology: # # ┌─────────────────────┐ ┌──────────────────────────────────┐ ┌─────────────────────┐ # │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ # │ │ │ │ │ │ # │┌────────┐ │ │ ┌────────┐ │ │ ┌────────┐│ # ││ wg1 │───────────┼───┼────────────│ lo │────────────┼───┼───────────│ wg2 ││ # │├────────┴──────────┐│ │ ┌───────┴────────┴────────┐ │ │┌──────────┴────────┤│ # ││192.168.241.1/24 ││ │ │(ns1) (ns2) │ │ ││192.168.241.2/24 ││ # ││fd00::1/24 ││ │ │127.0.0.1:1 127.0.0.1:2│ │ ││fd00::2/24 ││ # │└───────────────────┘│ │ │[::]:1 [::]:2 │ │ │└───────────────────┘│ # └─────────────────────┘ │ └─────────────────────────┘ │ └─────────────────────┘ # └──────────────────────────────────┘ # # After the topology is prepared we run a series of TCP/UDP iperf3 tests between the # wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1 # interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further # details on how this is accomplished. # This code is ported to the WireGuard-Go directly from the kernel project. # # Please ensure that you have installed the newest version of the WireGuard # tools from the WireGuard project and before running these tests as: # # ./netns.sh set -e exec 3>&1 export WG_HIDE_KEYS=never netns0="wg-test-$$-0" netns1="wg-test-$$-1" netns2="wg-test-$$-2" program=$1 export LOG_LEVEL="verbose" pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } pp() { pretty "" "$*"; "$@"; } maybe_exec() { if [[ $BASHPID -eq $$ ]]; then "$@"; else exec "$@"; fi; } n0() { pretty 0 "$*"; maybe_exec ip netns exec $netns0 "$@"; } n1() { pretty 1 "$*"; maybe_exec ip netns exec $netns1 "$@"; } n2() { pretty 2 "$*"; maybe_exec ip netns exec $netns2 "$@"; } ip0() { pretty 0 "ip $*"; ip -n $netns0 "$@"; } ip1() { pretty 1 "ip $*"; ip -n $netns1 "$@"; } ip2() { pretty 2 "ip $*"; ip -n $netns2 "$@"; } sleep() { read -t "$1" -N 0 || true; } waitiperf() { pretty "${1//*-}" "wait for iperf:5201"; while [[ $(ss -N "$1" -tlp 'sport = 5201') != *iperf3* ]]; do sleep 0.1; done; } waitncatudp() { pretty "${1//*-}" "wait for udp:1111"; while [[ $(ss -N "$1" -ulp 'sport = 1111') != *ncat* ]]; do sleep 0.1; done; } waitiface() { pretty "${1//*-}" "wait for $2 to come up"; ip netns exec "$1" bash -c "while [[ \$(< \"/sys/class/net/$2/operstate\") != up ]]; do read -t .1 -N 0 || true; done;"; } cleanup() { set +e exec 2>/dev/null printf "$orig_message_cost" > /proc/sys/net/core/message_cost ip0 link del dev wg1 ip1 link del dev wg1 ip2 link del dev wg1 local to_kill="$(ip netns pids $netns0) $(ip netns pids $netns1) $(ip netns pids $netns2)" [[ -n $to_kill ]] && kill $to_kill pp ip netns del $netns1 pp ip netns del $netns2 pp ip netns del $netns0 exit } orig_message_cost="$(< /proc/sys/net/core/message_cost)" trap cleanup EXIT printf 0 > /proc/sys/net/core/message_cost ip netns del $netns0 2>/dev/null || true ip netns del $netns1 2>/dev/null || true ip netns del $netns2 2>/dev/null || true pp ip netns add $netns0 pp ip netns add $netns1 pp ip netns add $netns2 ip0 link set up dev lo # ip0 link add dev wg1 type wireguard n0 $program wg1 ip0 link set wg1 netns $netns1 # ip0 link add dev wg1 type wireguard n0 $program wg2 ip0 link set wg2 netns $netns2 key1="$(pp wg genkey)" key2="$(pp wg genkey)" pub1="$(pp wg pubkey <<<"$key1")" pub2="$(pp wg pubkey <<<"$key2")" psk="$(pp wg genpsk)" [[ -n $key1 && -n $key2 && -n $psk ]] configure_peers() { ip1 addr add 192.168.241.1/24 dev wg1 ip1 addr add fd00::1/24 dev wg1 ip2 addr add 192.168.241.2/24 dev wg2 ip2 addr add fd00::2/24 dev wg2 n0 wg set wg1 \ private-key <(echo "$key1") \ listen-port 10000 \ peer "$pub2" \ preshared-key <(echo "$psk") \ allowed-ips 192.168.241.2/32,fd00::2/128 n0 wg set wg2 \ private-key <(echo "$key2") \ listen-port 20000 \ peer "$pub1" \ preshared-key <(echo "$psk") \ allowed-ips 192.168.241.1/32,fd00::1/128 n0 wg showconf wg1 n0 wg showconf wg2 ip1 link set up dev wg1 ip2 link set up dev wg2 sleep 1 } configure_peers tests() { # Ping over IPv4 n2 ping -c 10 -f -W 1 192.168.241.1 n1 ping -c 10 -f -W 1 192.168.241.2 # Ping over IPv6 n2 ping6 -c 10 -f -W 1 fd00::1 n1 ping6 -c 10 -f -W 1 fd00::2 # TCP over IPv4 n2 iperf3 -s -1 -B 192.168.241.2 & waitiperf $netns2 n1 iperf3 -Z -n 1G -c 192.168.241.2 # TCP over IPv6 n1 iperf3 -s -1 -B fd00::1 & waitiperf $netns1 n2 iperf3 -Z -n 1G -c fd00::1 # UDP over IPv4 n1 iperf3 -s -1 -B 192.168.241.1 & waitiperf $netns1 n2 iperf3 -Z -n 1G -b 0 -u -c 192.168.241.1 # UDP over IPv6 n2 iperf3 -s -1 -B fd00::2 & waitiperf $netns2 n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2 } [[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}" big_mtu=$(( 34816 - 1500 + $orig_mtu )) # Test using IPv4 as outer transport n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 n0 wg set wg2 peer "$pub1" endpoint 127.0.0.1:10000 # Before calling tests, we first make sure that the stats counters are working n2 ping -c 10 -f -W 1 192.168.241.1 { read _; read _; read _; read rx_bytes _; read _; read tx_bytes _; } < <(ip2 -stats link show dev wg2) ip2 -stats link show dev wg2 n0 wg show [[ $rx_bytes -ge 840 && $tx_bytes -ge 880 && $rx_bytes -lt 2500 && $rx_bytes -lt 2500 ]] echo "counters working" tests ip1 link set wg1 mtu $big_mtu ip2 link set wg2 mtu $big_mtu tests ip1 link set wg1 mtu $orig_mtu ip2 link set wg2 mtu $orig_mtu # Test using IPv6 as outer transport n0 wg set wg1 peer "$pub2" endpoint [::1]:20000 n0 wg set wg2 peer "$pub1" endpoint [::1]:10000 tests ip1 link set wg1 mtu $big_mtu ip2 link set wg2 mtu $big_mtu tests ip1 link set wg1 mtu $orig_mtu ip2 link set wg2 mtu $orig_mtu # Test using IPv4 that roaming works ip0 -4 addr del 127.0.0.1/8 dev lo ip0 -4 addr add 127.212.121.99/8 dev lo n0 wg set wg1 listen-port 9999 n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 n1 ping6 -W 1 -c 1 fd00::2 [[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]] # Test using IPv6 that roaming works n1 wg set wg1 listen-port 9998 n1 wg set wg1 peer "$pub2" endpoint [::1]:20000 n1 ping -W 1 -c 1 192.168.241.2 [[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]] # Test that crypto-RP filter works n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24 exec 4< <(n1 ncat -l -u -p 1111) nmap_pid=$! waitncatudp $netns1 n2 ncat -u 192.168.241.1 1111 <<<"X" read -r -N 1 -t 1 out <&4 && [[ $out == "X" ]] kill $nmap_pid more_specific_key="$(pp wg genkey | pp wg pubkey)" n0 wg set wg1 peer "$more_specific_key" allowed-ips 192.168.241.2/32 n0 wg set wg2 listen-port 9997 exec 4< <(n1 ncat -l -u -p 1111) nmap_pid=$! waitncatudp $netns1 n2 ncat -u 192.168.241.1 1111 <<<"X" ! read -r -N 1 -t 1 out <&4 kill $nmap_pid n0 wg set wg1 peer "$more_specific_key" remove [[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]] ip1 link del wg1 ip2 link del wg2 # Test using NAT. We now change the topology to this: # ┌────────────────────────────────────────┐ ┌────────────────────────────────────────────────┐ ┌────────────────────────────────────────┐ # │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ # │ │ │ │ │ │ # │ ┌─────┐ ┌─────┐ │ │ ┌──────┐ ┌──────┐ │ │ ┌─────┐ ┌─────┐ │ # │ │ wg1 │─────────────│vethc│───────────┼────┼────│vethrc│ │vethrs│──────────────┼─────┼──│veths│────────────│ wg2 │ │ # │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├──────┴─────────┐ ├──────┴────────────┐ │ │ ├─────┴──────────┐ ├─────┴──────────┐ │ # │ │192.168.241.1/24│ │192.168.1.100/24││ │ │192.168.1.100/24│ │10.0.0.1/24 │ │ │ │10.0.0.100/24 │ │192.168.241.2/24│ │ # │ │fd00::1/24 │ │ ││ │ │ │ │SNAT:192.168.1.0/24│ │ │ │ │ │fd00::2/24 │ │ # │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └───────────────────┘ │ │ └────────────────┘ └────────────────┘ │ # └────────────────────────────────────────┘ └────────────────────────────────────────────────┘ └────────────────────────────────────────┘ # ip1 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard n1 $program wg1 n2 $program wg2 configure_peers ip0 link add vethrc type veth peer name vethc ip0 link add vethrs type veth peer name veths ip0 link set vethc netns $netns1 ip0 link set veths netns $netns2 ip0 link set vethrc up ip0 link set vethrs up ip0 addr add 192.168.1.1/24 dev vethrc ip0 addr add 10.0.0.1/24 dev vethrs ip1 addr add 192.168.1.100/24 dev vethc ip1 link set vethc up ip1 route add default via 192.168.1.1 ip2 addr add 10.0.0.100/24 dev veths ip2 link set veths up waitiface $netns0 vethrc waitiface $netns0 vethrs waitiface $netns1 vethc waitiface $netns2 veths n0 bash -c 'printf 1 > /proc/sys/net/ipv4/ip_forward' n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout' n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout_stream' n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to 10.0.0.1 n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1 n1 ping -W 1 -c 1 192.168.241.2 n2 ping -W 1 -c 1 192.168.241.1 [[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] # Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`). pp sleep 3 n2 ping -W 1 -c 1 192.168.241.1 n0 iptables -t nat -F ip0 link del vethrc ip0 link del vethrs ip1 link del wg1 ip2 link del wg2 # Test that saddr routing is sticky but not too sticky, changing to this topology: # ┌────────────────────────────────────────┐ ┌────────────────────────────────────────┐ # │ $ns1 namespace │ │ $ns2 namespace │ # │ │ │ │ # │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ # │ │ wg1 │─────────────│veth1│───────────┼────┼──│veth2│────────────│ wg2 │ │ # │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├─────┴──────────┐ ├─────┴──────────┐ │ # │ │192.168.241.1/24│ │10.0.0.1/24 ││ │ │10.0.0.2/24 │ │192.168.241.2/24│ │ # │ │fd00::1/24 │ │fd00:aa::1/96 ││ │ │fd00:aa::2/96 │ │fd00::2/24 │ │ # │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └────────────────┘ │ # └────────────────────────────────────────┘ └────────────────────────────────────────┘ # ip1 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard n1 $program wg1 n2 $program wg2 configure_peers ip1 link add veth1 type veth peer name veth2 ip1 link set veth2 netns $netns2 n1 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth1/accept_dad' n2 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth2/accept_dad' n1 bash -c 'printf 1 > /proc/sys/net/ipv4/conf/veth1/promote_secondaries' # First we check that we aren't overly sticky and can fall over to new IPs when old ones are removed ip1 addr add 10.0.0.1/24 dev veth1 ip1 addr add fd00:aa::1/96 dev veth1 ip2 addr add 10.0.0.2/24 dev veth2 ip2 addr add fd00:aa::2/96 dev veth2 ip1 link set veth1 up ip2 link set veth2 up waitiface $netns1 veth1 waitiface $netns2 veth2 n0 wg set wg1 peer "$pub2" endpoint 10.0.0.2:20000 n1 ping -W 1 -c 1 192.168.241.2 ip1 addr add 10.0.0.10/24 dev veth1 ip1 addr del 10.0.0.1/24 dev veth1 n1 ping -W 1 -c 1 192.168.241.2 n0 wg set wg1 peer "$pub2" endpoint [fd00:aa::2]:20000 n1 ping -W 1 -c 1 192.168.241.2 ip1 addr add fd00:aa::10/96 dev veth1 ip1 addr del fd00:aa::1/96 dev veth1 n1 ping -W 1 -c 1 192.168.241.2 # Now we show that we can successfully do reply to sender routing ip1 link set veth1 down ip2 link set veth2 down ip1 addr flush dev veth1 ip2 addr flush dev veth2 ip1 addr add 10.0.0.1/24 dev veth1 ip1 addr add 10.0.0.2/24 dev veth1 ip1 addr add fd00:aa::1/96 dev veth1 ip1 addr add fd00:aa::2/96 dev veth1 ip2 addr add 10.0.0.3/24 dev veth2 ip2 addr add fd00:aa::3/96 dev veth2 ip1 link set veth1 up ip2 link set veth2 up waitiface $netns1 veth1 waitiface $netns2 veth2 n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000 n2 ping -W 1 -c 1 192.168.241.1 [[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000 n2 ping -W 1 -c 1 192.168.241.1 [[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]] n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000 n2 ping -W 1 -c 1 192.168.241.1 [[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]] n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000 n2 ping -W 1 -c 1 192.168.241.1 [[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]] ip1 link del veth1 ip1 link del wg1 ip2 link del wg2 # Test that Netlink/IPC is working properly by doing things that usually cause split responses n0 $program wg0 sleep 5 config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" ) for a in {1..255}; do for b in {0..255}; do config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" ) done done n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") i=0 for ip in $(n0 wg show wg0 allowed-ips); do ((++i)) done ((i == 255*256*2+1)) ip0 link del wg0 n0 $program wg0 config=( "[Interface]" "PrivateKey=$(wg genkey)" ) for a in {1..40}; do config+=( "[Peer]" "PublicKey=$(wg genkey)" ) for b in {1..52}; do config+=( "AllowedIPs=$a.$b.0.0/16" ) done done n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") i=0 while read -r line; do j=0 for ip in $line; do ((++j)) done ((j == 53)) ((++i)) done < <(n0 wg show wg0 allowed-ips) ((i == 40)) ip0 link del wg0 n0 $program wg0 config=( ) for i in {1..29}; do config+=( "[Peer]" "PublicKey=$(wg genkey)" ) done config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" ) n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") n0 wg showconf wg0 > /dev/null ip0 link del wg0 ! n0 wg show doesnotexist || false declare -A objects while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}" done < /dev/kmsg alldeleted=1 for object in "${!objects[@]}"; do if [[ ${objects["$object"]} != *createddestroyed ]]; then echo "Error: $object: merely ${objects["$object"]}" >&3 alldeleted=0 fi done [[ $alldeleted -eq 1 ]] pretty "" "Objects that were created were also destroyed." wireguard-go-0.0.20231211/tun/000077500000000000000000000000001466306003400154475ustar00rootroot00000000000000wireguard-go-0.0.20231211/tun/alignment_windows_test.go000066400000000000000000000036611466306003400225730ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "reflect" "testing" "unsafe" ) func checkAlignment(t *testing.T, name string, offset uintptr) { t.Helper() if offset%8 != 0 { t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) } } // TestRateJugglerAlignment checks that atomically-accessed fields are // aligned to 64-bit boundaries, as required by the atomic package. // // Unfortunately, violating this rule on 32-bit platforms results in a // hard segfault at runtime. func TestRateJugglerAlignment(t *testing.T) { var r rateJuggler typ := reflect.TypeOf(&r).Elem() t.Logf("Peer type size: %d, with fields:", typ.Size()) for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", field.Name, field.Offset, field.Type.Size(), field.Type.Align(), ) } checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current)) checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount)) checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime)) } // TestNativeTunAlignment checks that atomically-accessed fields are // aligned to 64-bit boundaries, as required by the atomic package. // // Unfortunately, violating this rule on 32-bit platforms results in a // hard segfault at runtime. func TestNativeTunAlignment(t *testing.T) { var tun NativeTun typ := reflect.TypeOf(&tun).Elem() t.Logf("Peer type size: %d, with fields:", typ.Size()) for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", field.Name, field.Offset, field.Type.Size(), field.Type.Align(), ) } checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate)) } wireguard-go-0.0.20231211/tun/checksum.go000066400000000000000000000102541466306003400176020ustar00rootroot00000000000000package tun import "encoding/binary" // TODO: Explore SIMD and/or other assembly optimizations. // TODO: Test native endian loads. See RFC 1071 section 2 part B. func checksumNoFold(b []byte, initial uint64) uint64 { ac := initial for len(b) >= 128 { ac += uint64(binary.BigEndian.Uint32(b[:4])) ac += uint64(binary.BigEndian.Uint32(b[4:8])) ac += uint64(binary.BigEndian.Uint32(b[8:12])) ac += uint64(binary.BigEndian.Uint32(b[12:16])) ac += uint64(binary.BigEndian.Uint32(b[16:20])) ac += uint64(binary.BigEndian.Uint32(b[20:24])) ac += uint64(binary.BigEndian.Uint32(b[24:28])) ac += uint64(binary.BigEndian.Uint32(b[28:32])) ac += uint64(binary.BigEndian.Uint32(b[32:36])) ac += uint64(binary.BigEndian.Uint32(b[36:40])) ac += uint64(binary.BigEndian.Uint32(b[40:44])) ac += uint64(binary.BigEndian.Uint32(b[44:48])) ac += uint64(binary.BigEndian.Uint32(b[48:52])) ac += uint64(binary.BigEndian.Uint32(b[52:56])) ac += uint64(binary.BigEndian.Uint32(b[56:60])) ac += uint64(binary.BigEndian.Uint32(b[60:64])) ac += uint64(binary.BigEndian.Uint32(b[64:68])) ac += uint64(binary.BigEndian.Uint32(b[68:72])) ac += uint64(binary.BigEndian.Uint32(b[72:76])) ac += uint64(binary.BigEndian.Uint32(b[76:80])) ac += uint64(binary.BigEndian.Uint32(b[80:84])) ac += uint64(binary.BigEndian.Uint32(b[84:88])) ac += uint64(binary.BigEndian.Uint32(b[88:92])) ac += uint64(binary.BigEndian.Uint32(b[92:96])) ac += uint64(binary.BigEndian.Uint32(b[96:100])) ac += uint64(binary.BigEndian.Uint32(b[100:104])) ac += uint64(binary.BigEndian.Uint32(b[104:108])) ac += uint64(binary.BigEndian.Uint32(b[108:112])) ac += uint64(binary.BigEndian.Uint32(b[112:116])) ac += uint64(binary.BigEndian.Uint32(b[116:120])) ac += uint64(binary.BigEndian.Uint32(b[120:124])) ac += uint64(binary.BigEndian.Uint32(b[124:128])) b = b[128:] } if len(b) >= 64 { ac += uint64(binary.BigEndian.Uint32(b[:4])) ac += uint64(binary.BigEndian.Uint32(b[4:8])) ac += uint64(binary.BigEndian.Uint32(b[8:12])) ac += uint64(binary.BigEndian.Uint32(b[12:16])) ac += uint64(binary.BigEndian.Uint32(b[16:20])) ac += uint64(binary.BigEndian.Uint32(b[20:24])) ac += uint64(binary.BigEndian.Uint32(b[24:28])) ac += uint64(binary.BigEndian.Uint32(b[28:32])) ac += uint64(binary.BigEndian.Uint32(b[32:36])) ac += uint64(binary.BigEndian.Uint32(b[36:40])) ac += uint64(binary.BigEndian.Uint32(b[40:44])) ac += uint64(binary.BigEndian.Uint32(b[44:48])) ac += uint64(binary.BigEndian.Uint32(b[48:52])) ac += uint64(binary.BigEndian.Uint32(b[52:56])) ac += uint64(binary.BigEndian.Uint32(b[56:60])) ac += uint64(binary.BigEndian.Uint32(b[60:64])) b = b[64:] } if len(b) >= 32 { ac += uint64(binary.BigEndian.Uint32(b[:4])) ac += uint64(binary.BigEndian.Uint32(b[4:8])) ac += uint64(binary.BigEndian.Uint32(b[8:12])) ac += uint64(binary.BigEndian.Uint32(b[12:16])) ac += uint64(binary.BigEndian.Uint32(b[16:20])) ac += uint64(binary.BigEndian.Uint32(b[20:24])) ac += uint64(binary.BigEndian.Uint32(b[24:28])) ac += uint64(binary.BigEndian.Uint32(b[28:32])) b = b[32:] } if len(b) >= 16 { ac += uint64(binary.BigEndian.Uint32(b[:4])) ac += uint64(binary.BigEndian.Uint32(b[4:8])) ac += uint64(binary.BigEndian.Uint32(b[8:12])) ac += uint64(binary.BigEndian.Uint32(b[12:16])) b = b[16:] } if len(b) >= 8 { ac += uint64(binary.BigEndian.Uint32(b[:4])) ac += uint64(binary.BigEndian.Uint32(b[4:8])) b = b[8:] } if len(b) >= 4 { ac += uint64(binary.BigEndian.Uint32(b)) b = b[4:] } if len(b) >= 2 { ac += uint64(binary.BigEndian.Uint16(b)) b = b[2:] } if len(b) == 1 { ac += uint64(b[0]) << 8 } return ac } func checksum(b []byte, initial uint64) uint16 { ac := checksumNoFold(b, initial) ac = (ac >> 16) + (ac & 0xffff) ac = (ac >> 16) + (ac & 0xffff) ac = (ac >> 16) + (ac & 0xffff) ac = (ac >> 16) + (ac & 0xffff) return uint16(ac) } func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { sum := checksumNoFold(srcAddr, 0) sum = checksumNoFold(dstAddr, sum) sum = checksumNoFold([]byte{0, protocol}, sum) tmp := make([]byte, 2) binary.BigEndian.PutUint16(tmp, totalLen) return checksumNoFold(tmp, sum) } wireguard-go-0.0.20231211/tun/checksum_test.go000066400000000000000000000007151466306003400206420ustar00rootroot00000000000000package tun import ( "fmt" "math/rand" "testing" ) func BenchmarkChecksum(b *testing.B) { lengths := []int{ 64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 9000, 9001, } for _, length := range lengths { b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { buf := make([]byte, length) rng := rand.New(rand.NewSource(1)) rng.Read(buf) b.ResetTimer() for i := 0; i < b.N; i++ { checksum(buf, 0) } }) } } wireguard-go-0.0.20231211/tun/errors.go000066400000000000000000000004051466306003400173110ustar00rootroot00000000000000package tun import ( "errors" ) var ( // ErrTooManySegments is returned by Device.Read() when segmentation // overflows the length of supplied buffers. This error should not cause // reads to cease. ErrTooManySegments = errors.New("too many segments") ) wireguard-go-0.0.20231211/tun/netstack/000077500000000000000000000000001466306003400172635ustar00rootroot00000000000000wireguard-go-0.0.20231211/tun/netstack/examples/000077500000000000000000000000001466306003400211015ustar00rootroot00000000000000wireguard-go-0.0.20231211/tun/netstack/examples/http_client.go000066400000000000000000000022121466306003400237420ustar00rootroot00000000000000//go:build ignore /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package main import ( "io" "log" "net/http" "net/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" ) func main() { tun, tnet, err := netstack.CreateNetTUN( []netip.Addr{netip.MustParseAddr("192.168.4.28")}, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 1420) if err != nil { log.Panic(err) } dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379 public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28 allowed_ip=0.0.0.0/0 endpoint=127.0.0.1:58120 `) err = dev.Up() if err != nil { log.Panic(err) } client := http.Client{ Transport: &http.Transport{ DialContext: tnet.DialContext, }, } resp, err := client.Get("http://192.168.4.29/") if err != nil { log.Panic(err) } body, err := io.ReadAll(resp.Body) if err != nil { log.Panic(err) } log.Println(string(body)) } wireguard-go-0.0.20231211/tun/netstack/examples/http_server.go000066400000000000000000000024301466306003400237740ustar00rootroot00000000000000//go:build ignore /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package main import ( "io" "log" "net" "net/http" "net/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" ) func main() { tun, tnet, err := netstack.CreateNetTUN( []netip.Addr{netip.MustParseAddr("192.168.4.29")}, []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 1420, ) if err != nil { log.Panic(err) } dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641 listen_port=58120 public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c allowed_ip=192.168.4.28/32 persistent_keepalive_interval=25 `) dev.Up() listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) if err != nil { log.Panicln(err) } http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent()) io.WriteString(writer, "Hello from userspace TCP!") }) err = http.Serve(listener, nil) if err != nil { log.Panicln(err) } } wireguard-go-0.0.20231211/tun/netstack/examples/ping_client.go000066400000000000000000000034611466306003400237270ustar00rootroot00000000000000//go:build ignore /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package main import ( "bytes" "log" "math/rand" "net/netip" "time" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" ) func main() { tun, tnet, err := netstack.CreateNetTUN( []netip.Addr{netip.MustParseAddr("192.168.4.29")}, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 1420) if err != nil { log.Panic(err) } dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b endpoint=163.172.161.0:12912 allowed_ip=0.0.0.0/0 `) err = dev.Up() if err != nil { log.Panic(err) } socket, err := tnet.Dial("ping4", "zx2c4.com") if err != nil { log.Panic(err) } requestPing := icmp.Echo{ Seq: rand.Intn(1 << 16), Data: []byte("gopher burrow"), } icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) socket.SetReadDeadline(time.Now().Add(time.Second * 10)) start := time.Now() _, err = socket.Write(icmpBytes) if err != nil { log.Panic(err) } n, err := socket.Read(icmpBytes[:]) if err != nil { log.Panic(err) } replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) if err != nil { log.Panic(err) } replyPing, ok := replyPacket.Body.(*icmp.Echo) if !ok { log.Panicf("invalid reply type: %v", replyPacket) } if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { log.Panicf("invalid ping reply: %v", replyPing) } log.Printf("Ping latency: %v", time.Since(start)) } wireguard-go-0.0.20231211/tun/netstack/tun.go000066400000000000000000000631711466306003400204300ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package netstack import ( "bytes" "context" "crypto/rand" "encoding/binary" "errors" "fmt" "io" "net" "net/netip" "os" "regexp" "strconv" "strings" "syscall" "time" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) type netTun struct { ep *channel.Endpoint stack *stack.Stack events chan tun.Event incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool } type Net netTun func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, HandleLocal: true, } dev := &netTun{ ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) if tcpipErr != nil { return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } dev.ep.AddNotify(dev) tcpipErr = dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } for _, ip := range localAddresses { var protoNumber tcpip.NetworkProtocolNumber if ip.Is4() { protoNumber = ipv4.ProtocolNumber } else if ip.Is6() { protoNumber = ipv6.ProtocolNumber } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) } if ip.Is4() { dev.hasV4 = true } else if ip.Is6() { dev.hasV6 = true } } if dev.hasV4 { dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) } if dev.hasV6 { dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) } dev.events <- tun.EventUp return dev, (*Net)(dev), nil } func (tun *netTun) Name() (string, error) { return "go", nil } func (tun *netTun) File() *os.File { return nil } func (tun *netTun) Events() <-chan tun.Event { return tun.events } func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } n, err := view.Read(buf[0][offset:]) if err != nil { return 0, err } sizes[0] = n return 1, nil } func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { for _, buf := range buf { packet := buf[offset:] if len(packet) == 0 { continue } pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) case 6: tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) default: return 0, syscall.EAFNOSUPPORT } } return len(buf), nil } func (tun *netTun) WriteNotify() { pkt := tun.ep.Read() if pkt.IsNil() { return } view := pkt.ToView() pkt.DecRef() tun.incomingPacket <- view } func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) if tun.events != nil { close(tun.events) } tun.ep.Close() if tun.incomingPacket != nil { close(tun.incomingPacket) } return nil } func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } func (tun *netTun) BatchSize() int { return 1 } func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { var protoNumber tcpip.NetworkProtocolNumber if endpoint.Addr().Is4() { protoNumber = ipv4.ProtocolNumber } else { protoNumber = ipv6.ProtocolNumber } return tcpip.FullAddress{ NIC: 1, Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { fa, pn := convertToFullAddr(addr) return gonet.DialContextTCP(ctx, net.stack, fa, pn) } func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { fa, pn := convertToFullAddr(addr) return gonet.DialTCP(net.stack, fa, pn) } func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { return net.DialTCPAddrPort(netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { fa, pn := convertToFullAddr(addr) return gonet.ListenTCP(net.stack, fa, pn) } func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { return net.ListenTCPAddrPort(netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { var lfa, rfa *tcpip.FullAddress var pn tcpip.NetworkProtocolNumber if laddr.IsValid() || laddr.Port() > 0 { var addr tcpip.FullAddress addr, pn = convertToFullAddr(laddr) lfa = &addr } if raddr.IsValid() || raddr.Port() > 0 { var addr tcpip.FullAddress addr, pn = convertToFullAddr(raddr) rfa = &addr } return gonet.DialUDP(net.stack, lfa, rfa, pn) } func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { return net.DialUDPAddrPort(laddr, netip.AddrPort{}) } func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { var la, ra netip.AddrPort if laddr != nil { ip, _ := netip.AddrFromSlice(laddr.IP) la = netip.AddrPortFrom(ip, uint16(laddr.Port)) } if raddr != nil { ip, _ := netip.AddrFromSlice(raddr.IP) ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) } return net.DialUDPAddrPort(la, ra) } func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { return net.DialUDP(laddr, nil) } type PingConn struct { laddr PingAddr raddr PingAddr wq waiter.Queue ep tcpip.Endpoint deadline *time.Timer } type PingAddr struct{ addr netip.Addr } func (ia PingAddr) String() string { return ia.addr.String() } func (ia PingAddr) Network() string { if ia.addr.Is4() { return "ping4" } else if ia.addr.Is6() { return "ping6" } return "ping" } func (ia PingAddr) Addr() netip.Addr { return ia.addr } func PingAddrFromAddr(addr netip.Addr) *PingAddr { return &PingAddr{addr} } func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { if !laddr.IsValid() && !raddr.IsValid() { return nil, errors.New("ping dial: invalid address") } v6 := laddr.Is6() || raddr.Is6() bind := laddr.IsValid() if !bind { if v6 { laddr = netip.IPv6Unspecified() } else { laddr = netip.IPv4Unspecified() } } tn := icmp.ProtocolNumber4 pn := ipv4.ProtocolNumber if v6 { tn = icmp.ProtocolNumber6 pn = ipv6.ProtocolNumber } pc := &PingConn{ laddr: PingAddr{laddr}, deadline: time.NewTimer(time.Hour << 10), } pc.deadline.Stop() ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) if tcpipErr != nil { return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) } pc.ep = ep if bind { fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { return nil, fmt.Errorf("ping bind: %s", tcpipErr) } } if raddr.IsValid() { pc.raddr = PingAddr{raddr} fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { return nil, fmt.Errorf("ping connect: %s", tcpipErr) } } return pc, nil } func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { return net.DialPingAddr(laddr, netip.Addr{}) } func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { var la, ra netip.Addr if laddr != nil { la = laddr.addr } if raddr != nil { ra = raddr.addr } return net.DialPingAddr(la, ra) } func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { var la netip.Addr if laddr != nil { la = laddr.addr } return net.ListenPingAddr(la) } func (pc *PingConn) LocalAddr() net.Addr { return pc.laddr } func (pc *PingConn) RemoteAddr() net.Addr { return pc.raddr } func (pc *PingConn) Close() error { pc.deadline.Reset(0) pc.ep.Close() return nil } func (pc *PingConn) SetWriteDeadline(t time.Time) error { return errors.New("not implemented") } func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { var na netip.Addr switch v := addr.(type) { case *PingAddr: na = v.addr case *net.IPAddr: na, _ = netip.AddrFromSlice(v.IP) default: return 0, fmt.Errorf("ping write: wrong net.Addr type") } if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { return 0, fmt.Errorf("ping write: mismatched protocols") } buf := bytes.NewReader(p) rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) // won't block, no deadlines n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ To: &rfa, }) if tcpipErr != nil { return int(n64), fmt.Errorf("ping write: %s", tcpipErr) } return int(n64), nil } func (pc *PingConn) Write(p []byte) (n int, err error) { return pc.WriteTo(p, &pc.raddr) } func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) pc.wq.EventRegister(&e) defer pc.wq.EventUnregister(&e) select { case <-pc.deadline.C: return 0, nil, os.ErrDeadlineExceeded case <-notifyCh: } w := tcpip.SliceWriter(p) res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ NeedRemoteAddr: true, }) if tcpipErr != nil { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) return res.Count, &PingAddr{remoteAddr}, nil } func (pc *PingConn) Read(p []byte) (n int, err error) { n, _, err = pc.ReadFrom(p) return } func (pc *PingConn) SetDeadline(t time.Time) error { // pc.SetWriteDeadline is unimplemented return pc.SetReadDeadline(t) } func (pc *PingConn) SetReadDeadline(t time.Time) error { pc.deadline.Reset(time.Until(t)) return nil } var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") errServerMisbehaving = errors.New("server misbehaving") errInvalidDNSResponse = errors.New("invalid DNS response") errNoAnswerFromDNSServer = errors.New("no answer from DNS server") errServerTemporarilyMisbehaving = errors.New("server misbehaving") errCanceled = errors.New("operation was canceled") errTimeout = errors.New("i/o timeout") errNumericPort = errors.New("port must be numeric") errNoSuitableAddress = errors.New("no suitable address found") errMissingAddress = errors.New("missing address") ) func (net *Net) LookupHost(host string) (addrs []string, err error) { return net.LookupContextHost(context.Background(), host) } func isDomainName(s string) bool { l := len(s) if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { return false } last := byte('.') nonNumeric := false partlen := 0 for i := 0; i < len(s); i++ { c := s[i] switch { default: return false case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': nonNumeric = true partlen++ case '0' <= c && c <= '9': partlen++ case c == '-': if last == '.' { return false } partlen++ nonNumeric = true case c == '.': if last == '.' || last == '-' { return false } if partlen > 63 || partlen == 0 { return false } partlen = 0 } last = c } if last == '-' || partlen > 63 { return false } return nonNumeric } func randU16() uint16 { var b [2]byte _, err := rand.Read(b[:]) if err != nil { panic(err) } return binary.LittleEndian.Uint16(b[:]) } func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { id = randU16() b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) b.EnableCompression() if err := b.StartQuestions(); err != nil { return 0, nil, nil, err } if err := b.Question(q); err != nil { return 0, nil, nil, err } tcpReq, err = b.Finish() udpReq = tcpReq[2:] l := len(tcpReq) - 2 tcpReq[0] = byte(l >> 8) tcpReq[1] = byte(l) return id, udpReq, tcpReq, err } func equalASCIIName(x, y dnsmessage.Name) bool { if x.Length != y.Length { return false } for i := 0; i < int(x.Length); i++ { a := x.Data[i] b := y.Data[i] if 'A' <= a && a <= 'Z' { a += 0x20 } if 'A' <= b && b <= 'Z' { b += 0x20 } if a != b { return false } } return true } func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { if !respHdr.Response { return false } if reqID != respHdr.ID { return false } if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { return false } return true } func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { if _, err := c.Write(b); err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } b = make([]byte, 512) for { n, err := c.Read(b) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } var p dnsmessage.Parser h, err := p.Start(b[:n]) if err != nil { continue } q, err := p.Question() if err != nil || !checkResponse(id, query, h, q) { continue } return p, h, nil } } func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { if _, err := c.Write(b); err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } b = make([]byte, 1280) if _, err := io.ReadFull(c, b[:2]); err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } l := int(b[0])<<8 | int(b[1]) if l > len(b) { b = make([]byte, l) } n, err := io.ReadFull(c, b[:l]) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } var p dnsmessage.Parser h, err := p.Start(b[:n]) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage } q, err := p.Question() if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage } if !checkResponse(id, query, h, q) { return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse } return p, h, nil } func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { q.Class = dnsmessage.ClassINET id, udpReq, tcpReq, err := newRequest(q) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage } for _, useUDP := range []bool{true, false} { ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) defer cancel() var c net.Conn var err error if useUDP { c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) } else { c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) } if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } if d, ok := ctx.Deadline(); ok && !d.IsZero() { err := c.SetDeadline(d) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } } var p dnsmessage.Parser var h dnsmessage.Header if useUDP { p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) } else { p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) } c.Close() if err != nil { if err == context.Canceled { err = errCanceled } else if err == context.DeadlineExceeded { err = errTimeout } return dnsmessage.Parser{}, dnsmessage.Header{}, err } if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse } if h.Truncated { continue } return p, h, nil } return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer } func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { if h.RCode == dnsmessage.RCodeNameError { return errNoSuchHost } _, err := p.AnswerHeader() if err != nil && err != dnsmessage.ErrSectionDone { return errCannotUnmarshalDNSMessage } if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { return errLameReferral } if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { if h.RCode == dnsmessage.RCodeServerFailure { return errServerTemporarilyMisbehaving } return errServerMisbehaving } return nil } func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { for { h, err := p.AnswerHeader() if err == dnsmessage.ErrSectionDone { return errNoSuchHost } if err != nil { return errCannotUnmarshalDNSMessage } if h.Type == qtype { return nil } if err := p.SkipAnswer(); err != nil { return errCannotUnmarshalDNSMessage } } } func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { var lastErr error n, err := dnsmessage.NewName(name) if err != nil { return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage } q := dnsmessage.Question{ Name: n, Type: qtype, Class: dnsmessage.ClassINET, } for i := 0; i < 2; i++ { for _, server := range tnet.dnsServers { p, h, err := tnet.exchange(ctx, server, q, time.Second*5) if err != nil { dnsErr := &net.DNSError{ Err: err.Error(), Name: name, Server: server.String(), } if nerr, ok := err.(net.Error); ok && nerr.Timeout() { dnsErr.IsTimeout = true } if _, ok := err.(*net.OpError); ok { dnsErr.IsTemporary = true } lastErr = dnsErr continue } if err := checkHeader(&p, h); err != nil { dnsErr := &net.DNSError{ Err: err.Error(), Name: name, Server: server.String(), } if err == errServerTemporarilyMisbehaving { dnsErr.IsTemporary = true } if err == errNoSuchHost { dnsErr.IsNotFound = true return p, server.String(), dnsErr } lastErr = dnsErr continue } err = skipToAnswer(&p, qtype) if err == nil { return p, server.String(), nil } lastErr = &net.DNSError{ Err: err.Error(), Name: name, Server: server.String(), } if err == errNoSuchHost { lastErr.(*net.DNSError).IsNotFound = true return p, server.String(), lastErr } } } return dnsmessage.Parser{}, "", lastErr } func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { if host == "" || (!tnet.hasV6 && !tnet.hasV4) { return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} } zlen := len(host) if strings.IndexByte(host, ':') != -1 { if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { zlen = zidx } } if ip, err := netip.ParseAddr(host[:zlen]); err == nil { return []string{ip.String()}, nil } if !isDomainName(host) { return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} } type result struct { p dnsmessage.Parser server string error } var addrsV4, addrsV6 []netip.Addr lanes := 0 if tnet.hasV4 { lanes++ } if tnet.hasV6 { lanes++ } lane := make(chan result, lanes) var lastErr error if tnet.hasV4 { go func() { p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) lane <- result{p, server, err} }() } if tnet.hasV6 { go func() { p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) lane <- result{p, server, err} }() } for l := 0; l < lanes; l++ { result := <-lane if result.error != nil { if lastErr == nil { lastErr = result.error } continue } loop: for { h, err := result.p.AnswerHeader() if err != nil && err != dnsmessage.ErrSectionDone { lastErr = &net.DNSError{ Err: errCannotMarshalDNSMessage.Error(), Name: host, Server: result.server, } } if err != nil { break } switch h.Type { case dnsmessage.TypeA: a, err := result.p.AResource() if err != nil { lastErr = &net.DNSError{ Err: errCannotMarshalDNSMessage.Error(), Name: host, Server: result.server, } break loop } addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) case dnsmessage.TypeAAAA: aaaa, err := result.p.AAAAResource() if err != nil { lastErr = &net.DNSError{ Err: errCannotMarshalDNSMessage.Error(), Name: host, Server: result.server, } break loop } addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) default: if err := result.p.SkipAnswer(); err != nil { lastErr = &net.DNSError{ Err: errCannotMarshalDNSMessage.Error(), Name: host, Server: result.server, } break loop } continue } } } // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) } else { addrs = append(addrsV4, addrsV6...) } if len(addrs) == 0 && lastErr != nil { return nil, lastErr } saddrs := make([]string, 0, len(addrs)) for _, ip := range addrs { saddrs = append(saddrs, ip.String()) } return saddrs, nil } func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { if deadline.IsZero() { return deadline, nil } timeRemaining := deadline.Sub(now) if timeRemaining <= 0 { return time.Time{}, errTimeout } timeout := timeRemaining / time.Duration(addrsRemaining) const saneMinimum = 2 * time.Second if timeout < saneMinimum { if timeRemaining < saneMinimum { timeout = timeRemaining } else { timeout = saneMinimum } } return now.Add(timeout), nil } var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { if ctx == nil { panic("nil context") } var acceptV4, acceptV6 bool matches := protoSplitter.FindStringSubmatch(network) if matches == nil { return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} } else if len(matches[2]) == 0 { acceptV4 = true acceptV6 = true } else { acceptV4 = matches[2][0] == '4' acceptV6 = !acceptV4 } var host string var port int if matches[1] == "ping" { host = address } else { var sport string var err error host, sport, err = net.SplitHostPort(address) if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } port, err = strconv.Atoi(sport) if err != nil || port < 0 || port > 65535 { return nil, &net.OpError{Op: "dial", Err: errNumericPort} } } allAddr, err := tnet.LookupContextHost(ctx, host) if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } var addrs []netip.AddrPort for _, addr := range allAddr { ip, err := netip.ParseAddr(addr) if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) } } if len(addrs) == 0 && len(allAddr) != 0 { return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} } var firstErr error for i, addr := range addrs { select { case <-ctx.Done(): err := ctx.Err() if err == context.Canceled { err = errCanceled } else if err == context.DeadlineExceeded { err = errTimeout } return nil, &net.OpError{Op: "dial", Err: err} default: } dialCtx := ctx if deadline, hasDeadline := ctx.Deadline(); hasDeadline { partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i) if err != nil { if firstErr == nil { firstErr = &net.OpError{Op: "dial", Err: err} } break } if partialDeadline.Before(deadline) { var cancel context.CancelFunc dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) defer cancel() } } var c net.Conn switch matches[1] { case "tcp": c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) case "udp": c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) case "ping": c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) } if err == nil { return c, nil } if firstErr == nil { firstErr = err } } if firstErr == nil { firstErr = &net.OpError{Op: "dial", Err: errMissingAddress} } return nil, firstErr } func (tnet *Net) Dial(network, address string) (net.Conn, error) { return tnet.DialContext(context.Background(), network, address) } wireguard-go-0.0.20231211/tun/offload_linux.go000066400000000000000000001013771466306003400206400ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "bytes" "encoding/binary" "errors" "io" "unsafe" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" ) const tcpFlagsOffset = 13 const ( tcpFlagFIN uint8 = 0x01 tcpFlagPSH uint8 = 0x08 tcpFlagACK uint8 = 0x10 ) // virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The // kernel symbol is virtio_net_hdr. type virtioNetHdr struct { flags uint8 gsoType uint8 hdrLen uint16 gsoSize uint16 csumStart uint16 csumOffset uint16 } func (v *virtioNetHdr) decode(b []byte) error { if len(b) < virtioNetHdrLen { return io.ErrShortBuffer } copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) return nil } func (v *virtioNetHdr) encode(b []byte) error { if len(b) < virtioNetHdrLen { return io.ErrShortBuffer } copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) return nil } const ( // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) ) // tcpFlowKey represents the key for a TCP flow. type tcpFlowKey struct { srcAddr, dstAddr [16]byte srcPort, dstPort uint16 rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. isV6 bool } // tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. type tcpGROTable struct { itemsByFlow map[tcpFlowKey][]tcpGROItem itemsPool [][]tcpGROItem } func newTCPGROTable() *tcpGROTable { t := &tcpGROTable{ itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize), itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), } for i := range t.itemsPool { t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize) } return t } func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { key := tcpFlowKey{} addrSize := dstAddrOffset - srcAddrOffset copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) key.isV6 = addrSize == 16 return key } // lookupOrInsert looks up a flow for the provided packet and metadata, // returning the packets found for the flow, or inserting a new one if none // is found. func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) items, ok := t.itemsByFlow[key] if ok { return items, ok } // TODO: insert() performs another map lookup. This could be rearranged to avoid. t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) return nil, false } // insert an item in the table for the provided packet and packet metadata. func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) item := tcpGROItem{ key: key, bufsIndex: uint16(bufsIndex), gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), iphLen: uint8(tcphOffset), tcphLen: uint8(tcphLen), sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, } items, ok := t.itemsByFlow[key] if !ok { items = t.newItems() } items = append(items, item) t.itemsByFlow[key] = items } func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { items, _ := t.itemsByFlow[item.key] items[i] = item } func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { items, _ := t.itemsByFlow[key] items = append(items[:i], items[i+1:]...) t.itemsByFlow[key] = items } // tcpGROItem represents bookkeeping data for a TCP packet during the lifetime // of a GRO evaluation across a vector of packets. type tcpGROItem struct { key tcpFlowKey sentSeq uint32 // the sequence number bufsIndex uint16 // the index into the original bufs slice numMerged uint16 // the number of packets merged into this item gsoSize uint16 // payload size iphLen uint8 // ip header len tcphLen uint8 // tcp header len pshSet bool // psh flag is set } func (t *tcpGROTable) newItems() []tcpGROItem { var items []tcpGROItem items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] return items } func (t *tcpGROTable) reset() { for k, items := range t.itemsByFlow { items = items[:0] t.itemsPool = append(t.itemsPool, items) delete(t.itemsByFlow, k) } } // udpFlowKey represents the key for a UDP flow. type udpFlowKey struct { srcAddr, dstAddr [16]byte srcPort, dstPort uint16 isV6 bool } // udpGROTable holds flow and coalescing information for the purposes of UDP GRO. type udpGROTable struct { itemsByFlow map[udpFlowKey][]udpGROItem itemsPool [][]udpGROItem } func newUDPGROTable() *udpGROTable { u := &udpGROTable{ itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize), itemsPool: make([][]udpGROItem, conn.IdealBatchSize), } for i := range u.itemsPool { u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize) } return u } func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { key := udpFlowKey{} addrSize := dstAddrOffset - srcAddrOffset copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) key.isV6 = addrSize == 16 return key } // lookupOrInsert looks up a flow for the provided packet and metadata, // returning the packets found for the flow, or inserting a new one if none // is found. func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) items, ok := u.itemsByFlow[key] if ok { return items, ok } // TODO: insert() performs another map lookup. This could be rearranged to avoid. u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) return nil, false } // insert an item in the table for the provided packet and packet metadata. func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) item := udpGROItem{ key: key, bufsIndex: uint16(bufsIndex), gsoSize: uint16(len(pkt[udphOffset+udphLen:])), iphLen: uint8(udphOffset), cSumKnownInvalid: cSumKnownInvalid, } items, ok := u.itemsByFlow[key] if !ok { items = u.newItems() } items = append(items, item) u.itemsByFlow[key] = items } func (u *udpGROTable) updateAt(item udpGROItem, i int) { items, _ := u.itemsByFlow[item.key] items[i] = item } // udpGROItem represents bookkeeping data for a UDP packet during the lifetime // of a GRO evaluation across a vector of packets. type udpGROItem struct { key udpFlowKey bufsIndex uint16 // the index into the original bufs slice numMerged uint16 // the number of packets merged into this item gsoSize uint16 // payload size iphLen uint8 // ip header len cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. } func (u *udpGROTable) newItems() []udpGROItem { var items []udpGROItem items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] return items } func (u *udpGROTable) reset() { for k, items := range u.itemsByFlow { items = items[:0] u.itemsPool = append(u.itemsPool, items) delete(u.itemsByFlow, k) } } // canCoalesce represents the outcome of checking if two TCP packets are // candidates for coalescing. type canCoalesce int const ( coalescePrepend canCoalesce = -1 coalesceUnavailable canCoalesce = 0 coalesceAppend canCoalesce = 1 ) // ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB // meet all requirements to be merged as part of a GRO operation, otherwise it // returns false. func ipHeadersCanCoalesce(pktA, pktB []byte) bool { if len(pktA) < 9 || len(pktB) < 9 { return false } if pktA[0]>>4 == 6 { if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { // cannot coalesce with unequal Traffic class values return false } if pktA[7] != pktB[7] { // cannot coalesce with unequal Hop limit values return false } } else { if pktA[1] != pktB[1] { // cannot coalesce with unequal ToS values return false } if pktA[6]>>5 != pktB[6]>>5 { // cannot coalesce with unequal DF or reserved bits. MF is checked // further up the stack. return false } if pktA[8] != pktB[8] { // cannot coalesce with unequal TTL values return false } } return true } // udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet // described by item. iphLen and gsoSize describe pkt. bufs is the vector of // packets involved in the current GRO evaluation. bufsOffset is the offset at // which packet data begins within bufs. func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { pktTarget := bufs[item.bufsIndex][bufsOffset:] if !ipHeadersCanCoalesce(pkt, pktTarget) { return coalesceUnavailable } if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { // A smaller than gsoSize packet has been appended previously. // Nothing can come after a smaller packet on the end. return coalesceUnavailable } if gsoSize > item.gsoSize { // We cannot have a larger packet following a smaller one. return coalesceUnavailable } return coalesceAppend } // tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet // described by item. This function makes considerations that match the kernel's // GRO self tests, which can be found in tools/testing/selftests/net/gro.c. func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { pktTarget := bufs[item.bufsIndex][bufsOffset:] if tcphLen != item.tcphLen { // cannot coalesce with unequal tcp options len return coalesceUnavailable } if tcphLen > 20 { if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { // cannot coalesce with unequal tcp options return coalesceUnavailable } } if !ipHeadersCanCoalesce(pkt, pktTarget) { return coalesceUnavailable } // seq adjacency lhsLen := item.gsoSize lhsLen += item.numMerged * item.gsoSize if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective if item.pshSet { // We cannot append to a segment that has the PSH flag set, PSH // can only be set on the final segment in a reassembled group. return coalesceUnavailable } if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { // A smaller than gsoSize packet has been appended previously. // Nothing can come after a smaller packet on the end. return coalesceUnavailable } if gsoSize > item.gsoSize { // We cannot have a larger packet following a smaller one. return coalesceUnavailable } return coalesceAppend } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective if pshSet { // We cannot prepend with a segment that has the PSH flag set, PSH // can only be set on the final segment in a reassembled group. return coalesceUnavailable } if gsoSize < item.gsoSize { // We cannot have a larger packet following a smaller one. return coalesceUnavailable } if gsoSize > item.gsoSize && item.numMerged > 0 { // There's at least one previous merge, and we're larger than all // previous. This would put multiple smaller packets on the end. return coalesceUnavailable } return coalescePrepend } return coalesceUnavailable } func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { srcAddrAt := ipv4SrcAddrOffset addrSize := 4 if isV6 { srcAddrAt = ipv6SrcAddrOffset addrSize = 16 } lenForPseudo := uint16(len(pkt) - int(iphLen)) cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) return ^checksum(pkt[iphLen:], cSum) == 0 } // coalesceResult represents the result of attempting to coalesce two TCP // packets. type coalesceResult int const ( coalesceInsufficientCap coalesceResult = iota coalescePSHEnding coalesceItemInvalidCSum coalescePktInvalidCSum coalesceSuccess ) // coalesceUDPPackets attempts to coalesce pkt with the packet described by // item, and returns the outcome. func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front headersLen := item.iphLen + udphLen coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) if cap(pktHead)-bufsOffset < coalescedLen { // We don't want to allocate a new underlying array if capacity is // too small. return coalesceInsufficientCap } if item.numMerged == 0 { if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { return coalesceItemInvalidCSum } } if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { return coalescePktInvalidCSum } extendBy := len(pkt) - int(headersLen) bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) item.numMerged++ return coalesceSuccess } // coalesceTCPPackets attempts to coalesce pkt with the packet described by // item, and returns the outcome. This function may swap bufs elements in the // event of a prepend as item's bufs index is already being tracked for writing // to a Device. func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { var pktHead []byte // the packet that will end up at the front headersLen := item.iphLen + item.tcphLen coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) // Copy data if mode == coalescePrepend { pktHead = pkt if cap(pkt)-bufsOffset < coalescedLen { // We don't want to allocate a new underlying array if capacity is // too small. return coalesceInsufficientCap } if pshSet { return coalescePSHEnding } if item.numMerged == 0 { if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { return coalescePktInvalidCSum } item.sentSeq = seq extendBy := coalescedLen - len(pktHead) bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) // Flip the slice headers in bufs as part of prepend. The index of item // is already being tracked for writing. bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] } else { pktHead = bufs[item.bufsIndex][bufsOffset:] if cap(pktHead)-bufsOffset < coalescedLen { // We don't want to allocate a new underlying array if capacity is // too small. return coalesceInsufficientCap } if item.numMerged == 0 { if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { return coalescePktInvalidCSum } if pshSet { // We are appending a segment with PSH set. item.pshSet = pshSet pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH } extendBy := len(pkt) - int(headersLen) bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) } if gsoSize > item.gsoSize { item.gsoSize = gsoSize } item.numMerged++ return coalesceSuccess } const ( ipv4FlagMoreFragments uint8 = 0x20 ) const ( ipv4SrcAddrOffset = 12 ipv6SrcAddrOffset = 8 maxUint16 = 1<<16 - 1 ) type groResult int const ( groResultNoop groResult = iota groResultTableInsert groResultCoalesced ) // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with // existing packets tracked in table. It returns a groResultNoop when no // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. return groResultNoop } iphLen := int((pkt[0] & 0x0F) * 4) if isV6 { iphLen = 40 ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) if ipv6HPayloadLen != len(pkt)-iphLen { return groResultNoop } } else { totalLen := int(binary.BigEndian.Uint16(pkt[2:])) if totalLen != len(pkt) { return groResultNoop } } if len(pkt) < iphLen { return groResultNoop } tcphLen := int((pkt[iphLen+12] >> 4) * 4) if tcphLen < 20 || tcphLen > 60 { return groResultNoop } if len(pkt) < iphLen+tcphLen { return groResultNoop } if !isV6 { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now return groResultNoop } } tcpFlags := pkt[iphLen+tcpFlagsOffset] var pshSet bool // not a candidate if any non-ACK flags (except PSH+ACK) are set if tcpFlags != tcpFlagACK { if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { return groResultNoop } pshSet = true } gsoSize := uint16(len(pkt) - tcphLen - iphLen) // not a candidate if payload len is 0 if gsoSize < 1 { return groResultNoop } seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) srcAddrOffset := ipv4SrcAddrOffset addrLen := 4 if isV6 { srcAddrOffset = ipv6SrcAddrOffset addrLen = 16 } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) if !existing { return groResultTableInsert } for i := len(items) - 1; i >= 0; i-- { // In the best case of packets arriving in order iterating in reverse is // more efficient if there are multiple items for a given flow. This // also enables a natural table.deleteAt() in the // coalesceItemInvalidCSum case without the need for index tracking. // This algorithm makes a best effort to coalesce in the event of // unordered packets, where pkt may land anywhere in items from a // sequence number perspective, however once an item is inserted into // the table it is never compared across other items later. item := items[i] can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) if can != coalesceUnavailable { result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) switch result { case coalesceSuccess: table.updateAt(item, i) return groResultCoalesced case coalesceItemInvalidCSum: // delete the item with an invalid csum table.deleteAt(item.key, i) case coalescePktInvalidCSum: // no point in inserting an item that we can't coalesce return groResultNoop default: } } } // failed to coalesce with any other packets; store the item in the flow table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) return groResultTableInsert } // applyTCPCoalesceAccounting updates bufs to account for coalescing based on the // metadata found in table. func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { for _, items := range table.itemsByFlow { for _, item := range items { if item.numMerged > 0 { hdr := virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb hdrLen: uint16(item.iphLen + item.tcphLen), gsoSize: item.gsoSize, csumStart: uint16(item.iphLen), csumOffset: 16, } pkt := bufs[item.bufsIndex][offset:] // Recalculate the total len (IPv4) or payload len (IPv6). // Recalculate the (IPv4) header checksum. if item.key.isV6 { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len } else { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 pkt[10], pkt[11] = 0, 0 binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) if err != nil { return err } // Calculate the pseudo header checksum and place it at the TCP // checksum offset. Downstream checksum offloading will combine // this with computation of the tcp header and payload checksum. addrLen := 4 addrOffset := ipv4SrcAddrOffset if item.key.isV6 { addrLen = 16 addrOffset = ipv6SrcAddrOffset } srcAddrAt := offset + addrOffset srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) } else { hdr := virtioNetHdr{} err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) if err != nil { return err } } } } return nil } // applyUDPCoalesceAccounting updates bufs to account for coalescing based on the // metadata found in table. func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { for _, items := range table.itemsByFlow { for _, item := range items { if item.numMerged > 0 { hdr := virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb hdrLen: uint16(item.iphLen + udphLen), gsoSize: item.gsoSize, csumStart: uint16(item.iphLen), csumOffset: 6, } pkt := bufs[item.bufsIndex][offset:] // Recalculate the total len (IPv4) or payload len (IPv6). // Recalculate the (IPv4) header checksum. hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 if item.key.isV6 { binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len } else { pkt[10], pkt[11] = 0, 0 binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) if err != nil { return err } // Recalculate the UDP len field value binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) // Calculate the pseudo header checksum and place it at the UDP // checksum offset. Downstream checksum offloading will combine // this with computation of the udp header and payload checksum. addrLen := 4 addrOffset := ipv4SrcAddrOffset if item.key.isV6 { addrLen = 16 addrOffset = ipv6SrcAddrOffset } srcAddrAt := offset + addrOffset srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) } else { hdr := virtioNetHdr{} err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) if err != nil { return err } } } } return nil } type groCandidateType uint8 const ( notGROCandidate groCandidateType = iota tcp4GROCandidate tcp6GROCandidate udp4GROCandidate udp6GROCandidate ) func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { if len(b) < 28 { return notGROCandidate } if b[0]>>4 == 4 { if b[0]&0x0F != 5 { // IPv4 packets w/IP options do not coalesce return notGROCandidate } if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { return tcp4GROCandidate } if b[9] == unix.IPPROTO_UDP && canUDPGRO { return udp4GROCandidate } } else if b[0]>>4 == 6 { if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { return tcp6GROCandidate } if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO { return udp6GROCandidate } } return notGROCandidate } const ( udphLen = 8 ) // udpGRO evaluates the UDP packet at pktI in bufs for coalescing with // existing packets tracked in table. It returns a groResultNoop when no // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. return groResultNoop } iphLen := int((pkt[0] & 0x0F) * 4) if isV6 { iphLen = 40 ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) if ipv6HPayloadLen != len(pkt)-iphLen { return groResultNoop } } else { totalLen := int(binary.BigEndian.Uint16(pkt[2:])) if totalLen != len(pkt) { return groResultNoop } } if len(pkt) < iphLen { return groResultNoop } if len(pkt) < iphLen+udphLen { return groResultNoop } if !isV6 { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now return groResultNoop } } gsoSize := uint16(len(pkt) - udphLen - iphLen) // not a candidate if payload len is 0 if gsoSize < 1 { return groResultNoop } srcAddrOffset := ipv4SrcAddrOffset addrLen := 4 if isV6 { srcAddrOffset = ipv6SrcAddrOffset addrLen = 16 } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) if !existing { return groResultTableInsert } // With UDP we only check the last item, otherwise we could reorder packets // for a given flow. We must also always insert a new item, or successfully // coalesce with an existing item, for the same reason. item := items[len(items)-1] can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) var pktCSumKnownInvalid bool if can == coalesceAppend { result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) switch result { case coalesceSuccess: table.updateAt(item, len(items)-1) return groResultCoalesced case coalesceItemInvalidCSum: // If the existing item has an invalid csum we take no action. A new // item will be stored after it, and the existing item will never be // revisited as part of future coalescing candidacy checks. case coalescePktInvalidCSum: // We must insert a new item, but we also mark it as invalid csum // to prevent a repeat checksum validation. pktCSumKnownInvalid = true default: } } // failed to coalesce with any other packets; store the item in the flow table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) return groResultTableInsert } // handleGRO evaluates bufs for GRO, and writes the indices of the resulting // packets into toWrite. toWrite, tcpTable, and udpTable should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset // and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is // supported. func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } var result groResult switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) { case tcp4GROCandidate: result = tcpGRO(bufs, offset, i, tcpTable, false) case tcp6GROCandidate: result = tcpGRO(bufs, offset, i, tcpTable, true) case udp4GROCandidate: result = udpGRO(bufs, offset, i, udpTable, false) case udp6GROCandidate: result = udpGRO(bufs, offset, i, udpTable, true) } switch result { case groResultNoop: hdr := virtioNetHdr{} err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) if err != nil { return err } fallthrough case groResultTableInsert: *toWrite = append(*toWrite, i) } } errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) return errors.Join(errTCP, errUDP) } // gsoSplit splits packets from in into outBuffs, writing the size of each // element into sizes. It returns the number of buffers populated, and/or an // error. func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { iphLen := int(hdr.csumStart) srcAddrOffset := ipv6SrcAddrOffset addrLen := 16 if !isV6 { in[10], in[11] = 0, 0 // clear ipv4 header checksum srcAddrOffset = ipv4SrcAddrOffset addrLen = 4 } transportCsumAt := int(hdr.csumStart + hdr.csumOffset) in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum var firstTCPSeqNum uint32 var protocol uint8 if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { protocol = unix.IPPROTO_TCP firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) } else { protocol = unix.IPPROTO_UDP } nextSegmentDataAt := int(hdr.hdrLen) i := 0 for ; nextSegmentDataAt < len(in); i++ { if i == len(outBuffs) { return i - 1, ErrTooManySegments } nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) if nextSegmentEnd > len(in) { nextSegmentEnd = len(in) } segmentDataLen := nextSegmentEnd - nextSegmentDataAt totalLen := int(hdr.hdrLen) + segmentDataLen sizes[i] = totalLen out := outBuffs[i][outOffset:] copy(out, in[:iphLen]) if !isV6 { // For IPv4 we are responsible for incrementing the ID field, // updating the total len field, and recalculating the header // checksum. if i > 0 { id := binary.BigEndian.Uint16(out[4:]) id += uint16(i) binary.BigEndian.PutUint16(out[4:], id) } binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) ipv4CSum := ^checksum(out[:iphLen], 0) binary.BigEndian.PutUint16(out[10:], ipv4CSum) } else { // For IPv6 we are responsible for updating the payload length field. binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) } // copy transport header copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) if protocol == unix.IPPROTO_TCP { // set TCP seq and adjust TCP flags tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) if nextSegmentEnd != len(in) { // FIN and PSH should only be set on last segment clearFlags := tcpFlagFIN | tcpFlagPSH out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags } } else { // set UDP header len binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart)) } // payload copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) // transport checksum transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) lenForPseudo := uint16(transportHeaderLen + segmentDataLen) transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold) binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) nextSegmentDataAt += int(hdr.gsoSize) } return i, nil } func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { cSumAt := cSumStart + cSumOffset // The initial value at the checksum offset should be summed with the // checksum we compute. This is typically the pseudo-header checksum. initial := binary.BigEndian.Uint16(in[cSumAt:]) in[cSumAt], in[cSumAt+1] = 0, 0 binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) return nil } wireguard-go-0.0.20231211/tun/offload_linux_test.go000066400000000000000000000517541466306003400217020ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "net/netip" "testing" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( offset = virtioNetHdrLen ) var ( ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") ) func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte { totalLen := 28 + payloadLen b := make([]byte, offset+int(totalLen), 65535) ipv4H := header.IPv4(b[offset:]) srcAs4 := srcIPPort.Addr().As4() dstAs4 := dstIPPort.Addr().As4() ipFields := &header.IPv4Fields{ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), DstAddr: tcpip.AddrFromSlice(dstAs4[:]), Protocol: unix.IPPROTO_UDP, TTL: 64, TotalLength: uint16(totalLen), } if ipFn != nil { ipFn(ipFields) } ipv4H.Encode(ipFields) udpH := header.UDP(b[offset+20:]) udpH.Encode(&header.UDPFields{ SrcPort: srcIPPort.Port(), DstPort: dstIPPort.Port(), Length: uint16(payloadLen + udphLen), }) ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen)) udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) return b } func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) } func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte { totalLen := 48 + payloadLen b := make([]byte, offset+int(totalLen), 65535) ipv6H := header.IPv6(b[offset:]) srcAs16 := srcIPPort.Addr().As16() dstAs16 := dstIPPort.Addr().As16() ipFields := &header.IPv6Fields{ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), DstAddr: tcpip.AddrFromSlice(dstAs16[:]), TransportProtocol: unix.IPPROTO_UDP, HopLimit: 64, PayloadLength: uint16(payloadLen + udphLen), } if ipFn != nil { ipFn(ipFields) } ipv6H.Encode(ipFields) udpH := header.UDP(b[offset+40:]) udpH.Encode(&header.UDPFields{ SrcPort: srcIPPort.Port(), DstPort: dstIPPort.Port(), Length: uint16(payloadLen + udphLen), }) pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen)) udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) return b } func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) } func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { totalLen := 40 + segmentSize b := make([]byte, offset+int(totalLen), 65535) ipv4H := header.IPv4(b[offset:]) srcAs4 := srcIPPort.Addr().As4() dstAs4 := dstIPPort.Addr().As4() ipFields := &header.IPv4Fields{ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), DstAddr: tcpip.AddrFromSlice(dstAs4[:]), Protocol: unix.IPPROTO_TCP, TTL: 64, TotalLength: uint16(totalLen), } if ipFn != nil { ipFn(ipFields) } ipv4H.Encode(ipFields) tcpH := header.TCP(b[offset+20:]) tcpH.Encode(&header.TCPFields{ SrcPort: srcIPPort.Port(), DstPort: dstIPPort.Port(), SeqNum: seq, AckNum: 1, DataOffset: 20, Flags: flags, WindowSize: 3000, }) ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) return b } func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) } func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { totalLen := 60 + segmentSize b := make([]byte, offset+int(totalLen), 65535) ipv6H := header.IPv6(b[offset:]) srcAs16 := srcIPPort.Addr().As16() dstAs16 := dstIPPort.Addr().As16() ipFields := &header.IPv6Fields{ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), DstAddr: tcpip.AddrFromSlice(dstAs16[:]), TransportProtocol: unix.IPPROTO_TCP, HopLimit: 64, PayloadLength: uint16(segmentSize + 20), } if ipFn != nil { ipFn(ipFields) } ipv6H.Encode(ipFields) tcpH := header.TCP(b[offset+40:]) tcpH.Encode(&header.TCPFields{ SrcPort: srcIPPort.Port(), DstPort: dstIPPort.Port(), SeqNum: seq, AckNum: 1, DataOffset: 20, Flags: flags, WindowSize: 3000, }) pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) return b } func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) } func Test_handleVirtioRead(t *testing.T) { tests := []struct { name string hdr virtioNetHdr pktIn []byte wantLens []int wantErr bool }{ { "tcp4", virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, gsoSize: 100, hdrLen: 40, csumStart: 20, csumOffset: 16, }, tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), []int{140, 140}, false, }, { "tcp6", virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, gsoSize: 100, hdrLen: 60, csumStart: 40, csumOffset: 16, }, tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), []int{160, 160}, false, }, { "udp4", virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, gsoSize: 100, hdrLen: 28, csumStart: 20, csumOffset: 6, }, udp4Packet(ip4PortA, ip4PortB, 200), []int{128, 128}, false, }, { "udp6", virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, gsoSize: 100, hdrLen: 48, csumStart: 40, csumOffset: 6, }, udp6Packet(ip6PortA, ip6PortB, 200), []int{148, 148}, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { out := make([][]byte, conn.IdealBatchSize) sizes := make([]int, conn.IdealBatchSize) for i := range out { out[i] = make([]byte, 65535) } tt.hdr.encode(tt.pktIn) n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) if err != nil { if tt.wantErr { return } t.Fatalf("got err: %v", err) } if n != len(tt.wantLens) { t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) } for i := range tt.wantLens { if tt.wantLens[i] != sizes[i] { t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) } } }) } } func flipTCP4Checksum(b []byte) []byte { at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 b[at] ^= 0xFF b[at+1] ^= 0xFF return b } func flipUDP4Checksum(b []byte) []byte { at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6 b[at] ^= 0xFF b[at+1] ^= 0xFF return b } func Fuzz_handleGRO(f *testing.F) { pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) pkt6 := udp4Packet(ip4PortA, ip4PortB, 100) pkt7 := udp4Packet(ip4PortA, ip4PortB, 100) pkt8 := udp4Packet(ip4PortA, ip4PortC, 100) pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} toWrite := make([]int, 0, len(pkts)) handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) if len(toWrite) > len(pkts) { t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) } seenWriteI := make(map[int]bool) for _, writeI := range toWrite { if writeI < 0 || writeI > len(pkts)-1 { t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) } if seenWriteI[writeI] { t.Errorf("duplicate toWrite value: %d", writeI) } seenWriteI[writeI] = true } }) } func Test_handleGRO(t *testing.T) { tests := []struct { name string pktsIn [][]byte canUDPGRO bool wantToWrite []int wantLens []int wantErr bool }{ { "multiple protocols and flows", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, true, []int{0, 1, 2, 4, 5, 7, 9}, []int{240, 228, 128, 140, 260, 160, 248}, false, }, { "multiple protocols and flows no UDP GRO", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, false, []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, false, }, { "PSH interleaved", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 }, true, []int{0, 2, 4, 6}, []int{240, 240, 260, 260}, false, }, { "coalesceItemInvalidCSum", [][]byte{ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), udp4Packet(ip4PortA, ip4PortB, 100), udp4Packet(ip4PortA, ip4PortB, 100), }, true, []int{0, 1, 3, 4}, []int{140, 240, 128, 228}, false, }, { "out of order", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 }, true, []int{0}, []int{340}, false, }, { "unequal TTL", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { fields.TTL++ }), udp4Packet(ip4PortA, ip4PortB, 100), udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { fields.TTL++ }), }, true, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, }, { "unequal ToS", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { fields.TOS++ }), udp4Packet(ip4PortA, ip4PortB, 100), udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { fields.TOS++ }), }, true, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, }, { "unequal flags more fragments set", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { fields.Flags = 1 }), udp4Packet(ip4PortA, ip4PortB, 100), udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { fields.Flags = 1 }), }, true, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, }, { "unequal flags DF set", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { fields.Flags = 2 }), udp4Packet(ip4PortA, ip4PortB, 100), udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { fields.Flags = 2 }), }, true, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, }, { "ipv6 unequal hop limit", [][]byte{ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { fields.HopLimit++ }), udp6Packet(ip6PortA, ip6PortB, 100), udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { fields.HopLimit++ }), }, true, []int{0, 1, 2, 3}, []int{160, 160, 148, 148}, false, }, { "ipv6 unequal traffic class", [][]byte{ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { fields.TrafficClass++ }), udp6Packet(ip6PortA, ip6PortB, 100), udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { fields.TrafficClass++ }), }, true, []int{0, 1, 2, 3}, []int{160, 160, 148, 148}, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { toWrite := make([]int, 0, len(tt.pktsIn)) err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) if err != nil { if tt.wantErr { return } t.Fatalf("got err: %v", err) } if len(toWrite) != len(tt.wantToWrite) { t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) } for i, pktI := range tt.wantToWrite { if tt.wantToWrite[i] != toWrite[i] { t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) } if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) } } }) } } func Test_packetIsGROCandidate(t *testing.T) { tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] tcp4TooShort := tcp4[:39] ip4InvalidHeaderLen := make([]byte, len(tcp4)) copy(ip4InvalidHeaderLen, tcp4) ip4InvalidHeaderLen[0] = 0x46 ip4InvalidProtocol := make([]byte, len(tcp4)) copy(ip4InvalidProtocol, tcp4) ip4InvalidProtocol[9] = unix.IPPROTO_GRE tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] tcp6TooShort := tcp6[:59] ip6InvalidProtocol := make([]byte, len(tcp6)) copy(ip6InvalidProtocol, tcp6) ip6InvalidProtocol[6] = unix.IPPROTO_GRE udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:] udp4TooShort := udp4[:27] udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:] udp6TooShort := udp6[:47] tests := []struct { name string b []byte canUDPGRO bool want groCandidateType }{ { "tcp4", tcp4, true, tcp4GROCandidate, }, { "tcp6", tcp6, true, tcp6GROCandidate, }, { "udp4", udp4, true, udp4GROCandidate, }, { "udp4 no support", udp4, false, notGROCandidate, }, { "udp6", udp6, true, udp6GROCandidate, }, { "udp6 no support", udp6, false, notGROCandidate, }, { "udp4 too short", udp4TooShort, true, notGROCandidate, }, { "udp6 too short", udp6TooShort, true, notGROCandidate, }, { "tcp4 too short", tcp4TooShort, true, notGROCandidate, }, { "tcp6 too short", tcp6TooShort, true, notGROCandidate, }, { "invalid IP version", []byte{0x00}, true, notGROCandidate, }, { "invalid IP header len", ip4InvalidHeaderLen, true, notGROCandidate, }, { "ip4 invalid protocol", ip4InvalidProtocol, true, notGROCandidate, }, { "ip6 invalid protocol", ip6InvalidProtocol, true, notGROCandidate, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want { t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) } }) } } func Test_udpPacketsCanCoalesce(t *testing.T) { udp4a := udp4Packet(ip4PortA, ip4PortB, 100) udp4b := udp4Packet(ip4PortA, ip4PortB, 100) udp4c := udp4Packet(ip4PortA, ip4PortB, 110) type args struct { pkt []byte iphLen uint8 gsoSize uint16 item udpGROItem bufs [][]byte bufsOffset int } tests := []struct { name string args args want canCoalesce }{ { "coalesceAppend equal gso", args{ pkt: udp4a[offset:], iphLen: 20, gsoSize: 100, item: udpGROItem{ gsoSize: 100, iphLen: 20, }, bufs: [][]byte{ udp4a, udp4b, }, bufsOffset: offset, }, coalesceAppend, }, { "coalesceAppend smaller gso", args{ pkt: udp4a[offset : len(udp4a)-90], iphLen: 20, gsoSize: 10, item: udpGROItem{ gsoSize: 100, iphLen: 20, }, bufs: [][]byte{ udp4a, udp4b, }, bufsOffset: offset, }, coalesceAppend, }, { "coalesceUnavailable smaller gso previously appended", args{ pkt: udp4a[offset:], iphLen: 20, gsoSize: 100, item: udpGROItem{ gsoSize: 100, iphLen: 20, }, bufs: [][]byte{ udp4c, udp4b, }, bufsOffset: offset, }, coalesceUnavailable, }, { "coalesceUnavailable larger following smaller", args{ pkt: udp4c[offset:], iphLen: 20, gsoSize: 110, item: udpGROItem{ gsoSize: 100, iphLen: 20, }, bufs: [][]byte{ udp4a, udp4c, }, bufsOffset: offset, }, coalesceUnavailable, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) } }) } } wireguard-go-0.0.20231211/tun/operateonfd.go000066400000000000000000000007761466306003400203160ustar00rootroot00000000000000//go:build darwin || freebsd /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "fmt" ) func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { sysconn, err := tun.tunFile.SyscallConn() if err != nil { tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) return } err = sysconn.Control(fn) if err != nil { tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) } } wireguard-go-0.0.20231211/tun/tun.go000066400000000000000000000030201466306003400165770ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "os" ) type Event int const ( EventUp = 1 << iota EventDown EventMTUUpdate ) type Device interface { // File returns the file descriptor of the device. File() *os.File // Read one or more packets from the Device (without any additional headers). // On a successful read it returns the number of packets read, and sets // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). // A nonzero offset can be used to instruct the Device on where to begin // reading into each element of the bufs slice. Read(bufs [][]byte, sizes []int, offset int) (n int, err error) // Write one or more packets to the device (without any additional headers). // On a successful write it returns the number of packets written. A nonzero // offset can be used to instruct the Device on where to begin writing from // each packet contained within the bufs slice. Write(bufs [][]byte, offset int) (int, error) // MTU returns the MTU of the Device. MTU() (int, error) // Name returns the current name of the Device. Name() (string, error) // Events returns a channel of type Event, which is fed Device events. Events() <-chan Event // Close stops the Device and closes the Event channel. Close() error // BatchSize returns the preferred/max number of packets that can be read or // written in a single read/write call. BatchSize must not change over the // lifetime of a Device. BatchSize() int } wireguard-go-0.0.20231211/tun/tun_darwin.go000066400000000000000000000143221466306003400201520ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "errors" "fmt" "io" "net" "os" "sync" "syscall" "time" "unsafe" "golang.org/x/sys/unix" ) const utunControlName = "com.apple.net.utun_control" type NativeTun struct { name string tunFile *os.File events chan Event errors chan error routeSocket int closeOnce sync.Once } func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { for i := 0; i < 20; i++ { iface, err = net.InterfaceByIndex(index) if err != nil && errors.Is(err, unix.ENOMEM) { time.Sleep(time.Duration(i) * time.Second / 3) continue } return iface, err } return nil, err } func (tun *NativeTun) routineRouteListener(tunIfindex int) { var ( statusUp bool statusMTU int ) defer close(tun.events) data := make([]byte, os.Getpagesize()) for { retry: n, err := unix.Read(tun.routeSocket, data) if err != nil { if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR { goto retry } tun.errors <- err return } if n < 14 { continue } if data[3 /* type */] != unix.RTM_IFINFO { continue } ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) if ifindex != tunIfindex { continue } iface, err := retryInterfaceByIndex(ifindex) if err != nil { tun.errors <- err return } // Up / Down event up := (iface.Flags & net.FlagUp) != 0 if up != statusUp && up { tun.events <- EventUp } if up != statusUp && !up { tun.events <- EventDown } statusUp = up // MTU changes if iface.MTU != statusMTU { tun.events <- EventMTUUpdate } statusMTU = iface.MTU } } func CreateTUN(name string, mtu int) (Device, error) { ifIndex := -1 if name != "utun" { _, err := fmt.Sscanf(name, "utun%d", &ifIndex) if err != nil || ifIndex < 0 { return nil, fmt.Errorf("Interface name must be utun[0-9]*") } } fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) if err != nil { return nil, err } ctlInfo := &unix.CtlInfo{} copy(ctlInfo.Name[:], []byte(utunControlName)) err = unix.IoctlCtlInfo(fd, ctlInfo) if err != nil { unix.Close(fd) return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) } sc := &unix.SockaddrCtl{ ID: ctlInfo.Id, Unit: uint32(ifIndex) + 1, } err = unix.Connect(fd, sc) if err != nil { unix.Close(fd) return nil, err } err = unix.SetNonblock(fd, true) if err != nil { unix.Close(fd) return nil, err } tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) if err == nil && name == "utun" { fname := os.Getenv("WG_TUN_NAME_FILE") if fname != "" { os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) } } return tun, err } func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 10), errors: make(chan error, 5), } name, err := tun.Name() if err != nil { tun.tunFile.Close() return nil, err } tunIfindex, err := func() (int, error) { iface, err := net.InterfaceByName(name) if err != nil { return -1, err } return iface.Index, nil }() if err != nil { tun.tunFile.Close() return nil, err } tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err } go tun.routineRouteListener(tunIfindex) if mtu > 0 { err = tun.setMTU(mtu) if err != nil { tun.Close() return nil, err } } return tun, nil } func (tun *NativeTun) Name() (string, error) { var err error tun.operateOnFd(func(fd uintptr) { tun.name, err = unix.GetsockoptString( int(fd), 2, /* #define SYSPROTO_CONTROL 2 */ 2, /* #define UTUN_OPT_IFNAME 2 */ ) }) if err != nil { return "", fmt.Errorf("GetSockoptString: %w", err) } return tun.name, nil } func (tun *NativeTun) File() *os.File { return tun.tunFile } func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { // TODO: the BSDs look very similar in Read() and Write(). They should be // collapsed, with platform-specific files containing the varying parts of // their implementations. select { case err := <-tun.errors: return 0, err default: buf := bufs[0][offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } sizes[0] = n - 4 return 1, err } } func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { if offset < 4 { return 0, io.ErrShortBuffer } for i, buf := range bufs { buf = buf[offset-4:] buf[0] = 0x00 buf[1] = 0x00 buf[2] = 0x00 switch buf[4] >> 4 { case 4: buf[3] = unix.AF_INET case 6: buf[3] = unix.AF_INET6 default: return i, unix.EAFNOSUPPORT } if _, err := tun.tunFile.Write(buf); err != nil { return i, err } } return len(bufs), nil } func (tun *NativeTun) Close() error { var err1, err2 error tun.closeOnce.Do(func() { err1 = tun.tunFile.Close() if tun.routeSocket != -1 { unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) err2 = unix.Close(tun.routeSocket) } else if tun.events != nil { close(tun.events) } }) if err1 != nil { return err1 } return err2 } func (tun *NativeTun) setMTU(n int) error { fd, err := socketCloexec( unix.AF_INET, unix.SOCK_DGRAM, 0, ) if err != nil { return err } defer unix.Close(fd) var ifr unix.IfreqMTU copy(ifr.Name[:], tun.name) ifr.MTU = int32(n) err = unix.IoctlSetIfreqMTU(fd, &ifr) if err != nil { return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err) } return nil } func (tun *NativeTun) MTU() (int, error) { fd, err := socketCloexec( unix.AF_INET, unix.SOCK_DGRAM, 0, ) if err != nil { return 0, err } defer unix.Close(fd) ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name) if err != nil { return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err) } return int(ifr.MTU), nil } func (tun *NativeTun) BatchSize() int { return 1 } func socketCloexec(family, sotype, proto int) (fd int, err error) { // See go/src/net/sys_cloexec.go for background. syscall.ForkLock.RLock() defer syscall.ForkLock.RUnlock() fd, err = unix.Socket(family, sotype, proto) if err == nil { unix.CloseOnExec(fd) } return } wireguard-go-0.0.20231211/tun/tun_freebsd.go000066400000000000000000000231471466306003400203050ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "errors" "fmt" "io" "net" "os" "sync" "syscall" "unsafe" "golang.org/x/sys/unix" ) const ( _TUNSIFHEAD = 0x80047460 _TUNSIFMODE = 0x8004745e _TUNGIFNAME = 0x4020745d _TUNSIFPID = 0x2000745f _SIOCGIFINFO_IN6 = 0xc048696c _SIOCSIFINFO_IN6 = 0xc048696d _ND6_IFF_AUTO_LINKLOCAL = 0x20 _ND6_IFF_NO_DAD = 0x100 ) // Iface requests with just the name type ifreqName struct { Name [unix.IFNAMSIZ]byte _ [16]byte } // Iface requests with a pointer type ifreqPtr struct { Name [unix.IFNAMSIZ]byte Data uintptr _ [16 - unsafe.Sizeof(uintptr(0))]byte } // Iface requests with MTU type ifreqMtu struct { Name [unix.IFNAMSIZ]byte MTU uint32 _ [12]byte } // ND6 flag manipulation type nd6Req struct { Name [unix.IFNAMSIZ]byte Linkmtu uint32 Maxmtu uint32 Basereachable uint32 Reachable uint32 Retrans uint32 Flags uint32 Recalctm int Chlim uint8 Initialized uint8 Randomseed0 [8]byte Randomseed1 [8]byte Randomid [8]byte } type NativeTun struct { name string tunFile *os.File events chan Event errors chan error routeSocket int closeOnce sync.Once } func (tun *NativeTun) routineRouteListener(tunIfindex int) { var ( statusUp bool statusMTU int ) defer close(tun.events) data := make([]byte, os.Getpagesize()) for { retry: n, err := unix.Read(tun.routeSocket, data) if err != nil { if errors.Is(err, syscall.EINTR) { goto retry } tun.errors <- err return } if n < 14 { continue } if data[3 /* type */] != unix.RTM_IFINFO { continue } ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) if ifindex != tunIfindex { continue } iface, err := net.InterfaceByIndex(ifindex) if err != nil { tun.errors <- err return } // Up / Down event up := (iface.Flags & net.FlagUp) != 0 if up != statusUp && up { tun.events <- EventUp } if up != statusUp && !up { tun.events <- EventDown } statusUp = up // MTU changes if iface.MTU != statusMTU { tun.events <- EventMTUUpdate } statusMTU = iface.MTU } } func tunName(fd uintptr) (string, error) { var ifreq ifreqName _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq))) if err != 0 { return "", err } return unix.ByteSliceToString(ifreq.Name[:]), nil } // Destroy a named system interface func tunDestroy(name string) error { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return err } defer unix.Close(fd) var ifr [32]byte copy(ifr[:], name) _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0]))) if errno != 0 { return fmt.Errorf("failed to destroy interface %s: %w", name, errno) } return nil } func CreateTUN(name string, mtu int) (Device, error) { if len(name) > unix.IFNAMSIZ-1 { return nil, errors.New("interface name too long") } // See if interface already exists iface, _ := net.InterfaceByName(name) if iface != nil { return nil, fmt.Errorf("interface %s already exists", name) } tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { return nil, err } tun := NativeTun{tunFile: tunFile} var assignedName string tun.operateOnFd(func(fd uintptr) { assignedName, err = tunName(fd) }) if err != nil { tunFile.Close() return nil, err } // Enable ifhead mode, otherwise tun will complain if it gets a non-AF_INET packet ifheadmode := 1 var errno syscall.Errno tun.operateOnFd(func(fd uintptr) { _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode))) }) if errno != 0 { tunFile.Close() tunDestroy(assignedName) return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno) } // Get out of PTP mode. ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST tun.operateOnFd(func(fd uintptr) { _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags))) }) if errno != 0 { tunFile.Close() tunDestroy(assignedName) return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno) } // Disable link-local v6, not just because WireGuard doesn't do that anyway, but // also because there are serious races with attaching and detaching LLv6 addresses // in relation to interface lifetime within the FreeBSD kernel. confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { tunFile.Close() tunDestroy(assignedName) return nil, err } defer unix.Close(confd6) var ndireq nd6Req copy(ndireq.Name[:], assignedName) _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq))) if errno != 0 { tunFile.Close() tunDestroy(assignedName) return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno) } ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq))) if errno != 0 { tunFile.Close() tunDestroy(assignedName) return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno) } if name != "" { confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { tunFile.Close() tunDestroy(assignedName) return nil, err } defer unix.Close(confd) var newnp [unix.IFNAMSIZ]byte copy(newnp[:], name) var ifr ifreqPtr copy(ifr.Name[:], assignedName) ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr))) if errno != 0 { tunFile.Close() tunDestroy(assignedName) return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno) } } return CreateTUNFromFile(tunFile, mtu) } func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 10), errors: make(chan error, 1), } var errno syscall.Errno tun.operateOnFd(func(fd uintptr) { _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0)) }) if errno != 0 { tun.tunFile.Close() return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno) } name, err := tun.Name() if err != nil { tun.tunFile.Close() return nil, err } tunIfindex, err := func() (int, error) { iface, err := net.InterfaceByName(name) if err != nil { return -1, err } return iface.Index, nil }() if err != nil { tun.tunFile.Close() return nil, err } tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err } go tun.routineRouteListener(tunIfindex) err = tun.setMTU(mtu) if err != nil { tun.Close() return nil, err } return tun, nil } func (tun *NativeTun) Name() (string, error) { var name string var err error tun.operateOnFd(func(fd uintptr) { name, err = tunName(fd) }) if err != nil { return "", err } tun.name = name return name, nil } func (tun *NativeTun) File() *os.File { return tun.tunFile } func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: buf := bufs[0][offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } sizes[0] = n - 4 return 1, err } } func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { if offset < 4 { return 0, io.ErrShortBuffer } for i, buf := range bufs { buf = buf[offset-4:] if len(buf) < 5 { return i, io.ErrShortBuffer } buf[0] = 0x00 buf[1] = 0x00 buf[2] = 0x00 switch buf[4] >> 4 { case 4: buf[3] = unix.AF_INET case 6: buf[3] = unix.AF_INET6 default: return i, unix.EAFNOSUPPORT } if _, err := tun.tunFile.Write(buf); err != nil { return i, err } } return len(bufs), nil } func (tun *NativeTun) Close() error { var err1, err2, err3 error tun.closeOnce.Do(func() { err1 = tun.tunFile.Close() err2 = tunDestroy(tun.name) if tun.routeSocket != -1 { unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) err3 = unix.Close(tun.routeSocket) tun.routeSocket = -1 } else if tun.events != nil { close(tun.events) } }) if err1 != nil { return err1 } if err2 != nil { return err2 } return err3 } func (tun *NativeTun) setMTU(n int) error { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return err } defer unix.Close(fd) var ifr ifreqMtu copy(ifr.Name[:], tun.name) ifr.MTU = uint32(n) _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr))) if errno != 0 { return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno) } return nil } func (tun *NativeTun) MTU() (int, error) { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return 0, err } defer unix.Close(fd) var ifr ifreqMtu copy(ifr.Name[:], tun.name) _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr))) if errno != 0 { return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno) } return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } func (tun *NativeTun) BatchSize() int { return 1 } wireguard-go-0.0.20231211/tun/tun_linux.go000066400000000000000000000367721466306003400200420ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun /* Implementation of the TUN device interface for linux */ import ( "errors" "fmt" "os" "sync" "syscall" "time" "unsafe" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) const ( cloneDevicePath = "/dev/net/tun" ifReqSize = unix.IFNAMSIZ + 64 ) type NativeTun struct { tunFile *os.File index int32 // if index errors chan error // async error handling events chan Event // device related events netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} batchSize int vnetHdr bool udpGSO bool closeOnce sync.Once nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error readOpMu sync.Mutex // readOpMu guards readBuff readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable toWrite []int tcpGROTable *tcpGROTable udpGROTable *udpGROTable } func (tun *NativeTun) File() *os.File { return tun.tunFile } func (tun *NativeTun) routineHackListener() { defer tun.hackListenerClosed.Unlock() /* This is needed for the detection to work across network namespaces * If you are reading this and know a better method, please get in touch. */ last := 0 const ( up = 1 down = 2 ) for { sysconn, err := tun.tunFile.SyscallConn() if err != nil { return } err2 := sysconn.Control(func(fd uintptr) { _, err = unix.Write(int(fd), nil) }) if err2 != nil { return } switch err { case unix.EINVAL: if last != up { // If the tunnel is up, it reports that write() is // allowed but we provided invalid data. tun.events <- EventUp last = up } case unix.EIO: if last != down { // If the tunnel is down, it reports that no I/O // is possible, without checking our provided data. tun.events <- EventDown last = down } default: return } select { case <-time.After(time.Second): // nothing case <-tun.statusListenersShutdown: return } } } func createNetlinkSocket() (int, error) { sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) if err != nil { return -1, err } saddr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR, } err = unix.Bind(sock, saddr) if err != nil { return -1, err } return sock, nil } func (tun *NativeTun) routineNetlinkListener() { defer func() { unix.Close(tun.netlinkSock) tun.hackListenerClosed.Lock() close(tun.events) tun.netlinkCancel.Close() }() for msg := make([]byte, 1<<16); ; { var err error var msgn int for { msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0) if err == nil || !rwcancel.RetryAfterError(err) { break } if !tun.netlinkCancel.ReadyRead() { tun.errors <- fmt.Errorf("netlink socket closed: %w", err) return } } if err != nil { tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err) return } select { case <-tun.statusListenersShutdown: return default: } wasEverUp := false for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) if int(hdr.Len) > len(remain) { break } switch hdr.Type { case unix.NLMSG_DONE: remain = []byte{} case unix.RTM_NEWLINK: info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr])) remain = remain[hdr.Len:] if info.Index != tun.index { // not our interface continue } if info.Flags&unix.IFF_RUNNING != 0 { tun.events <- EventUp wasEverUp = true } if info.Flags&unix.IFF_RUNNING == 0 { // Don't emit EventDown before we've ever emitted EventUp. // This avoids a startup race with HackListener, which // might detect Up before we have finished reporting Down. if wasEverUp { tun.events <- EventDown } } tun.events <- EventMTUUpdate default: remain = remain[hdr.Len:] } } } } func getIFIndex(name string) (int32, error) { fd, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) if err != nil { return 0, err } defer unix.Close(fd) var ifr [ifReqSize]byte copy(ifr[:], name) _, _, errno := unix.Syscall( unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFINDEX), uintptr(unsafe.Pointer(&ifr[0])), ) if errno != 0 { return 0, errno } return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil } func (tun *NativeTun) setMTU(n int) error { name, err := tun.Name() if err != nil { return err } // open datagram socket fd, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) if err != nil { return err } defer unix.Close(fd) // do ioctl call var ifr [ifReqSize]byte copy(ifr[:], name) *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) _, _, errno := unix.Syscall( unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr[0])), ) if errno != 0 { return fmt.Errorf("failed to set MTU of TUN device: %w", errno) } return nil } func (tun *NativeTun) MTU() (int, error) { name, err := tun.Name() if err != nil { return 0, err } // open datagram socket fd, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) if err != nil { return 0, err } defer unix.Close(fd) // do ioctl call var ifr [ifReqSize]byte copy(ifr[:], name) _, _, errno := unix.Syscall( unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr[0])), ) if errno != 0 { return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno) } return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil } func (tun *NativeTun) Name() (string, error) { tun.nameOnce.Do(tun.initNameCache) return tun.nameCache, tun.nameErr } func (tun *NativeTun) initNameCache() { tun.nameCache, tun.nameErr = tun.nameSlow() } func (tun *NativeTun) nameSlow() (string, error) { sysconn, err := tun.tunFile.SyscallConn() if err != nil { return "", err } var ifr [ifReqSize]byte var errno syscall.Errno err = sysconn.Control(func(fd uintptr) { _, _, errno = unix.Syscall( unix.SYS_IOCTL, fd, uintptr(unix.TUNGETIFF), uintptr(unsafe.Pointer(&ifr[0])), ) }) if err != nil { return "", fmt.Errorf("failed to get name of TUN device: %w", err) } if errno != 0 { return "", fmt.Errorf("failed to get name of TUN device: %w", errno) } return unix.ByteSliceToString(ifr[:]), nil } func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.writeOpMu.Lock() defer func() { tun.tcpGROTable.reset() tun.udpGROTable.reset() tun.writeOpMu.Unlock() }() var ( errs error total int ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) if err != nil { return 0, err } offset -= virtioNetHdrLen } else { for i := range bufs { tun.toWrite = append(tun.toWrite, i) } } for _, bufsI := range tun.toWrite { n, err := tun.tunFile.Write(bufs[bufsI][offset:]) if errors.Is(err, syscall.EBADFD) { return total, os.ErrClosed } if err != nil { errs = errors.Join(errs, err) } else { total += n } } return total, errs } // handleVirtioRead splits in into bufs, leaving offset bytes at the front of // each buffer. It mutates sizes to reflect the size of each element of bufs, // and returns the number of packets read. func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { var hdr virtioNetHdr err := hdr.decode(in) if err != nil { return 0, err } in = in[virtioNetHdrLen:] if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { // This means CHECKSUM_PARTIAL in skb context. We are responsible // for computing the checksum starting at hdr.csumStart and placing // at hdr.csumOffset. err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) if err != nil { return 0, err } } if len(in) > len(bufs[0][offset:]) { return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) } n := copy(bufs[0][offset:], in) sizes[0] = n return 1, nil } if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) } ipVersion := in[0] >> 4 switch ipVersion { case 4: if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) } case 6: if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) } default: return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) } // Don't trust hdr.hdrLen from the kernel as it can be equal to the length // of the entire first packet when the kernel is handling it as part of a // FORWARD path. Instead, parse the transport header length and add it onto // csumStart, which is synonymous for IP header length. if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { hdr.hdrLen = hdr.csumStart + 8 } else { if len(in) <= int(hdr.csumStart+12) { return 0, errors.New("packet is too short") } tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) if tcpHLen < 20 || tcpHLen > 60 { // A TCP header must be between 20 and 60 bytes in length. return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) } hdr.hdrLen = hdr.csumStart + tcpHLen } if len(in) < int(hdr.hdrLen) { return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) } if hdr.hdrLen < hdr.csumStart { return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) } cSumAt := int(hdr.csumStart + hdr.csumOffset) if cSumAt+1 >= len(in) { return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) } return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { tun.readOpMu.Lock() defer tun.readOpMu.Unlock() select { case err := <-tun.errors: return 0, err default: readInto := bufs[0][offset:] if tun.vnetHdr { readInto = tun.readBuff[:] } n, err := tun.tunFile.Read(readInto) if errors.Is(err, syscall.EBADFD) { err = os.ErrClosed } if err != nil { return 0, err } if tun.vnetHdr { return handleVirtioRead(readInto[:n], bufs, sizes, offset) } else { sizes[0] = n return 1, nil } } } func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { var err1, err2 error tun.closeOnce.Do(func() { if tun.statusListenersShutdown != nil { close(tun.statusListenersShutdown) if tun.netlinkCancel != nil { err1 = tun.netlinkCancel.Cancel() } } else if tun.events != nil { close(tun.events) } err2 = tun.tunFile.Close() }) if err1 != nil { return err1 } return err2 } func (tun *NativeTun) BatchSize() int { return tun.batchSize } const ( // TODO: support TSO with ECN bits tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 ) func (tun *NativeTun) initFromFlags(name string) error { sc, err := tun.tunFile.SyscallConn() if err != nil { return err } if e := sc.Control(func(fd uintptr) { var ( ifr *unix.Ifreq ) ifr, err = unix.NewIfreq(name) if err != nil { return } err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) if err != nil { return } got := ifr.Uint16() if got&unix.IFF_VNET_HDR != 0 { // tunTCPOffloads were added in Linux v2.6. We require their support // if IFF_VNET_HDR is set. err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) if err != nil { return } tun.vnetHdr = true tun.batchSize = conn.IdealBatchSize // tunUDPOffloads were added in Linux v6.2. We do not return an // error if they are unsupported at runtime. tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil } else { tun.batchSize = 1 } }); e != nil { return e } return err } // CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { if os.IsNotExist(err) { return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath) } return nil, err } ifr, err := unix.NewIfreq(name) if err != nil { return nil, err } // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() // where a null write will return EINVAL indicating the TUN is up. ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) if err != nil { return nil, err } err = unix.SetNonblock(nfd, true) if err != nil { unix.Close(nfd) return nil, err } // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. fd := os.NewFile(uintptr(nfd), cloneDevicePath) return CreateTUNFromFile(fd, mtu) } // CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { return nil, err } err = tun.initFromFlags(name) if err != nil { return nil, err } // start event listener tun.index, err = getIFIndex(name) if err != nil { return nil, err } tun.netlinkSock, err = createNetlinkSocket() if err != nil { return nil, err } tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock) if err != nil { unix.Close(tun.netlinkSock) return nil, err } tun.hackListenerClosed.Lock() go tun.routineNetlinkListener() go tun.routineHackListener() // cross namespace err = tun.setMTU(mtu) if err != nil { unix.Close(tun.netlinkSock) return nil, err } return tun, nil } // CreateUnmonitoredTUNFromFD creates a Device from the provided file // descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { return nil, "", err } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } err = tun.initFromFlags(name) if err != nil { return nil, "", err } return tun, name, err } wireguard-go-0.0.20231211/tun/tun_openbsd.go000066400000000000000000000135471466306003400203300ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "errors" "fmt" "io" "net" "os" "sync" "syscall" "unsafe" "golang.org/x/sys/unix" ) // Structure for iface mtu get/set ioctls type ifreq_mtu struct { Name [unix.IFNAMSIZ]byte MTU uint32 Pad0 [12]byte } const _TUNSIFMODE = 0x8004745d type NativeTun struct { name string tunFile *os.File events chan Event errors chan error routeSocket int closeOnce sync.Once } func (tun *NativeTun) routineRouteListener(tunIfindex int) { var ( statusUp bool statusMTU int ) defer close(tun.events) check := func() bool { iface, err := net.InterfaceByIndex(tunIfindex) if err != nil { tun.errors <- err return true } // Up / Down event up := (iface.Flags & net.FlagUp) != 0 if up != statusUp && up { tun.events <- EventUp } if up != statusUp && !up { tun.events <- EventDown } statusUp = up // MTU changes if iface.MTU != statusMTU { tun.events <- EventMTUUpdate } statusMTU = iface.MTU return false } if check() { return } data := make([]byte, os.Getpagesize()) for { n, err := unix.Read(tun.routeSocket, data) if err != nil { if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { continue } tun.errors <- err return } if n < 8 { continue } if data[3 /* type */] != unix.RTM_IFINFO { continue } ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */]))) if ifindex != tunIfindex { continue } if check() { return } } } func CreateTUN(name string, mtu int) (Device, error) { ifIndex := -1 if name != "tun" { _, err := fmt.Sscanf(name, "tun%d", &ifIndex) if err != nil || ifIndex < 0 { return nil, fmt.Errorf("Interface name must be tun[0-9]*") } } var tunfile *os.File var err error if ifIndex != -1 { tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) } else { for ifIndex = 0; ifIndex < 256; ifIndex++ { tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) if err == nil || !errors.Is(err, syscall.EBUSY) { break } } } if err != nil { return nil, err } tun, err := CreateTUNFromFile(tunfile, mtu) if err == nil && name == "tun" { fname := os.Getenv("WG_TUN_NAME_FILE") if fname != "" { os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) } } return tun, err } func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 10), errors: make(chan error, 1), } name, err := tun.Name() if err != nil { tun.tunFile.Close() return nil, err } tunIfindex, err := func() (int, error) { iface, err := net.InterfaceByName(name) if err != nil { return -1, err } return iface.Index, nil }() if err != nil { tun.tunFile.Close() return nil, err } tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err } go tun.routineRouteListener(tunIfindex) currentMTU, err := tun.MTU() if err != nil || currentMTU != mtu { err = tun.setMTU(mtu) if err != nil { tun.Close() return nil, err } } return tun, nil } func (tun *NativeTun) Name() (string, error) { gostat, err := tun.tunFile.Stat() if err != nil { tun.name = "" return "", err } stat := gostat.Sys().(*syscall.Stat_t) tun.name = fmt.Sprintf("tun%d", stat.Rdev%256) return tun.name, nil } func (tun *NativeTun) File() *os.File { return tun.tunFile } func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: buf := bufs[0][offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } sizes[0] = n - 4 return 1, err } } func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { if offset < 4 { return 0, io.ErrShortBuffer } for i, buf := range bufs { buf = buf[offset-4:] buf[0] = 0x00 buf[1] = 0x00 buf[2] = 0x00 switch buf[4] >> 4 { case 4: buf[3] = unix.AF_INET case 6: buf[3] = unix.AF_INET6 default: return i, unix.EAFNOSUPPORT } if _, err := tun.tunFile.Write(buf); err != nil { return i, err } } return len(bufs), nil } func (tun *NativeTun) Close() error { var err1, err2 error tun.closeOnce.Do(func() { err1 = tun.tunFile.Close() if tun.routeSocket != -1 { unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) err2 = unix.Close(tun.routeSocket) tun.routeSocket = -1 } else if tun.events != nil { close(tun.events) } }) if err1 != nil { return err1 } return err2 } func (tun *NativeTun) setMTU(n int) error { // open datagram socket var fd int fd, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) if err != nil { return err } defer unix.Close(fd) // do ioctl call var ifr ifreq_mtu copy(ifr.Name[:], tun.name) ifr.MTU = uint32(n) _, _, errno := unix.Syscall( unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr)), ) if errno != 0 { return fmt.Errorf("failed to set MTU on %s", tun.name) } return nil } func (tun *NativeTun) MTU() (int, error) { // open datagram socket fd, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) if err != nil { return 0, err } defer unix.Close(fd) // do ioctl call var ifr ifreq_mtu copy(ifr.Name[:], tun.name) _, _, errno := unix.Syscall( unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr)), ) if errno != 0 { return 0, fmt.Errorf("failed to get MTU on %s", tun.name) } return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } func (tun *NativeTun) BatchSize() int { return 1 } wireguard-go-0.0.20231211/tun/tun_windows.go000066400000000000000000000134501466306003400203610ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( "errors" "fmt" "os" "sync" "sync/atomic" "time" _ "unsafe" "golang.org/x/sys/windows" "golang.zx2c4.com/wintun" ) const ( rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) spinloopRateThreshold = 800000000 / 8 // 800mbps spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s ) type rateJuggler struct { current atomic.Uint64 nextByteCount atomic.Uint64 nextStartTime atomic.Int64 changing atomic.Bool } type NativeTun struct { wt *wintun.Adapter name string handle windows.Handle rate rateJuggler session wintun.Session readWait windows.Handle events chan Event running sync.WaitGroup closeOnce sync.Once close atomic.Bool forcedMTU int outSizes []int } var ( WintunTunnelType = "WireGuard" WintunStaticRequestedGUID *windows.GUID ) //go:linkname procyield runtime.procyield func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. func CreateTUN(ifname string, mtu int) (Device, error) { return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { return nil, fmt.Errorf("Error creating interface: %w", err) } forcedMTU := 1420 if mtu > 0 { forcedMTU = mtu } tun := &NativeTun{ wt: wt, name: ifname, handle: windows.InvalidHandle, events: make(chan Event, 10), forcedMTU: forcedMTU, } tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { tun.wt.Close() close(tun.events) return nil, fmt.Errorf("Error starting session: %w", err) } tun.readWait = tun.session.ReadWaitEvent() return tun, nil } func (tun *NativeTun) Name() (string, error) { return tun.name, nil } func (tun *NativeTun) File() *os.File { return nil } func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { var err error tun.closeOnce.Do(func() { tun.close.Store(true) windows.SetEvent(tun.readWait) tun.running.Wait() tun.session.End() if tun.wt != nil { tun.wt.Close() } close(tun.events) }) return err } func (tun *NativeTun) MTU() (int, error) { return tun.forcedMTU, nil } // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { if tun.close.Load() { return } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update { tun.events <- EventMTUUpdate } } func (tun *NativeTun) BatchSize() int { // TODO: implement batching with wintun return 1 } // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() retry: if tun.close.Load() { return 0, os.ErrClosed } start := nanotime() shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { if tun.close.Load() { return 0, os.ErrClosed } packet, err := tun.session.ReceivePacket() switch err { case nil: n := copy(bufs[0][offset:], packet) sizes[0] = n tun.session.ReleaseReceivePacket(packet) tun.rate.update(uint64(n)) return 1, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(tun.readWait, windows.INFINITE) goto retry } procyield(1) continue case windows.ERROR_HANDLE_EOF: return 0, os.ErrClosed case windows.ERROR_INVALID_DATA: return 0, errors.New("Send ring corrupt") } return 0, fmt.Errorf("Read failed: %w", err) } } func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() if tun.close.Load() { return 0, os.ErrClosed } for i, buf := range bufs { packetSize := len(buf) - offset tun.rate.update(uint64(packetSize)) packet, err := tun.session.AllocateSendPacket(packetSize) switch err { case nil: // TODO: Explore options to eliminate this copy. copy(packet, buf[offset:]) tun.session.SendPacket(packet) continue case windows.ERROR_HANDLE_EOF: return i, os.ErrClosed case windows.ERROR_BUFFER_OVERFLOW: continue // Dropping when ring is full. default: return i, fmt.Errorf("Write failed: %w", err) } } return len(bufs), nil } // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { tun.running.Add(1) defer tun.running.Done() if tun.close.Load() { return 0 } return tun.wt.LUID() } // RunningVersion returns the running version of the Wintun driver. func (tun *NativeTun) RunningVersion() (version uint32, err error) { return wintun.RunningVersion() } func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() total := rate.nextByteCount.Add(packetLen) period := uint64(now - rate.nextStartTime.Load()) if period >= rateMeasurementGranularity { if !rate.changing.CompareAndSwap(false, true) { return } rate.nextStartTime.Store(now) rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) rate.nextByteCount.Store(0) rate.changing.Store(false) } } wireguard-go-0.0.20231211/tun/tuntest/000077500000000000000000000000001466306003400171555ustar00rootroot00000000000000wireguard-go-0.0.20231211/tun/tuntest/tuntest.go000066400000000000000000000067621466306003400212250ustar00rootroot00000000000000/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tuntest import ( "encoding/binary" "io" "net/netip" "os" "golang.zx2c4.com/wireguard/tun" ) func Ping(dst, src netip.Addr) []byte { localPort := uint16(1337) seq := uint16(0) payload := make([]byte, 4) binary.BigEndian.PutUint16(payload[0:], localPort) binary.BigEndian.PutUint16(payload[2:], seq) return genICMPv4(payload, dst, src) } // Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. func checksum(buf []byte, initial uint16) uint16 { v := uint32(initial) for i := 0; i < len(buf)-1; i += 2 { v += uint32(binary.BigEndian.Uint16(buf[i:])) } if len(buf)%2 == 1 { v += uint32(buf[len(buf)-1]) << 8 } for v > 0xffff { v = (v >> 16) + (v & 0xffff) } return ^uint16(v) } func genICMPv4(payload []byte, dst, src netip.Addr) []byte { const ( icmpv4ProtocolNumber = 1 icmpv4Echo = 8 icmpv4ChecksumOffset = 2 icmpv4Size = 8 ipv4Size = 20 ipv4TotalLenOffset = 2 ipv4ChecksumOffset = 10 ttl = 65 headerSize = ipv4Size + icmpv4Size ) pkt := make([]byte, headerSize+len(payload)) ip := pkt[0:ipv4Size] icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] // https://tools.ietf.org/html/rfc792 icmpv4[0] = icmpv4Echo // type icmpv4[1] = 0 // code chksum := ^checksum(icmpv4, checksum(payload, 0)) binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) // https://tools.ietf.org/html/rfc760 section 3.1 length := uint16(len(pkt)) ip[0] = (4 << 4) | (ipv4Size / 4) binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) ip[8] = ttl ip[9] = icmpv4ProtocolNumber copy(ip[12:], src.AsSlice()) copy(ip[16:], dst.AsSlice()) chksum = ^checksum(ip[:], 0) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) copy(pkt[headerSize:], payload) return pkt } type ChannelTUN struct { Inbound chan []byte // incoming packets, closed on TUN close Outbound chan []byte // outbound packets, blocks forever on TUN close closed chan struct{} events chan tun.Event tun chTun } func NewChannelTUN() *ChannelTUN { c := &ChannelTUN{ Inbound: make(chan []byte), Outbound: make(chan []byte), closed: make(chan struct{}), events: make(chan tun.Event, 1), } c.tun.c = c c.events <- tun.EventUp return c } func (c *ChannelTUN) TUN() tun.Device { return &c.tun } type chTun struct { c *ChannelTUN } func (t *chTun) File() *os.File { return nil } func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { select { case <-t.c.closed: return 0, os.ErrClosed case msg := <-t.c.Outbound: n := copy(packets[0][offset:], msg) sizes[0] = n return 1, nil } } // Write is called by the wireguard device to deliver a packet for routing. func (t *chTun) Write(packets [][]byte, offset int) (int, error) { if offset == -1 { close(t.c.closed) close(t.c.events) return 0, io.EOF } for i, data := range packets { msg := make([]byte, len(data)-offset) copy(msg, data[offset:]) select { case <-t.c.closed: return i, os.ErrClosed case t.c.Inbound <- msg: } } return len(packets), nil } func (t *chTun) BatchSize() int { return 1 } const DefaultMTU = 1420 func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } func (t *chTun) Events() <-chan tun.Event { return t.c.events } func (t *chTun) Close() error { t.Write(nil, -1) return nil } wireguard-go-0.0.20231211/version.go000066400000000000000000000000551466306003400166550ustar00rootroot00000000000000package main const Version = "0.0.20230223"