pax_global_header00006660000000000000000000000064151170447020014513gustar00rootroot0000000000000052 comment=a42e726bb32f0e7e6ebef8200913ddab5eebd2b7 ice-4.1.0/000077500000000000000000000000001511704470200122555ustar00rootroot00000000000000ice-4.1.0/.github/000077500000000000000000000000001511704470200136155ustar00rootroot00000000000000ice-4.1.0/.github/.gitignore000066400000000000000000000001561511704470200156070ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT .goassets ice-4.1.0/.github/fetch-scripts.sh000077500000000000000000000016001511704470200167270ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT set -eu SCRIPT_PATH="$(realpath "$(dirname "$0")")" GOASSETS_PATH="${SCRIPT_PATH}/.goassets" GOASSETS_REF=${GOASSETS_REF:-master} if [ -d "${GOASSETS_PATH}" ]; then if ! git -C "${GOASSETS_PATH}" diff --exit-code; then echo "${GOASSETS_PATH} has uncommitted changes" >&2 exit 1 fi git -C "${GOASSETS_PATH}" fetch origin git -C "${GOASSETS_PATH}" checkout ${GOASSETS_REF} git -C "${GOASSETS_PATH}" reset --hard origin/${GOASSETS_REF} else git clone -b ${GOASSETS_REF} https://github.com/pion/.goassets.git "${GOASSETS_PATH}" fi ice-4.1.0/.github/install-hooks.sh000077500000000000000000000012421511704470200167420ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT SCRIPT_PATH="$(realpath "$(dirname "$0")")" . ${SCRIPT_PATH}/fetch-scripts.sh cp "${GOASSETS_PATH}/hooks/commit-msg.sh" "${SCRIPT_PATH}/../.git/hooks/commit-msg" cp "${GOASSETS_PATH}/hooks/pre-commit.sh" "${SCRIPT_PATH}/../.git/hooks/pre-commit" cp "${GOASSETS_PATH}/hooks/pre-push.sh" "${SCRIPT_PATH}/../.git/hooks/pre-push" ice-4.1.0/.github/workflows/000077500000000000000000000000001511704470200156525ustar00rootroot00000000000000ice-4.1.0/.github/workflows/api.yaml000066400000000000000000000011141511704470200173040ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: API on: pull_request: jobs: check: uses: pion/.goassets/.github/workflows/api.reusable.yml@master ice-4.1.0/.github/workflows/codeql-analysis.yml000066400000000000000000000013201511704470200214610ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: CodeQL on: workflow_dispatch: schedule: - cron: '23 5 * * 0' pull_request: branches: - master paths: - '**.go' jobs: analyze: uses: pion/.goassets/.github/workflows/codeql-analysis.reusable.yml@master ice-4.1.0/.github/workflows/fuzz.yaml000066400000000000000000000013421511704470200175340ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Fuzz on: push: branches: - master schedule: - cron: "0 */8 * * *" jobs: fuzz: uses: pion/.goassets/.github/workflows/fuzz.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version fuzz-time: "60s" ice-4.1.0/.github/workflows/lint.yaml000066400000000000000000000011151511704470200175020ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Lint on: pull_request: jobs: lint: uses: pion/.goassets/.github/workflows/lint.reusable.yml@master ice-4.1.0/.github/workflows/release.yml000066400000000000000000000012501511704470200200130ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Release on: push: tags: - 'v*' jobs: release: uses: pion/.goassets/.github/workflows/release.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version ice-4.1.0/.github/workflows/renovate-go-sum-fix.yaml000066400000000000000000000012671511704470200223600ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Fix go.sum on: push: branches: - renovate/* jobs: fix: uses: pion/.goassets/.github/workflows/renovate-go-sum-fix.reusable.yml@master secrets: token: ${{ secrets.PIONBOT_PRIVATE_KEY }} ice-4.1.0/.github/workflows/reuse.yml000066400000000000000000000011511511704470200175160ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: REUSE Compliance Check on: push: pull_request: jobs: lint: uses: pion/.goassets/.github/workflows/reuse.reusable.yml@master ice-4.1.0/.github/workflows/test.yaml000066400000000000000000000033271511704470200175220ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Test on: push: branches: - master pull_request: jobs: test: uses: pion/.goassets/.github/workflows/test.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} secrets: inherit test-i386: uses: pion/.goassets/.github/workflows/test-i386.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-windows: uses: pion/.goassets/.github/workflows/test-windows.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-macos: uses: pion/.goassets/.github/workflows/test-macos.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-wasm: uses: pion/.goassets/.github/workflows/test-wasm.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version secrets: inherit ice-4.1.0/.github/workflows/tidy-check.yaml000066400000000000000000000013021511704470200205560ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Go mod tidy on: pull_request: push: branches: - master jobs: tidy: uses: pion/.goassets/.github/workflows/tidy-check.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version ice-4.1.0/.gitignore000066400000000000000000000006321511704470200142460ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT ### JetBrains IDE ### ##################### .idea/ ### Emacs Temporary Files ### ############################# *~ ### Folders ### ############### bin/ vendor/ node_modules/ ### Files ### ############# *.ivf *.ogg tags cover.out *.sw[poe] *.wasm examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js ice-4.1.0/.golangci.yml000066400000000000000000000202661511704470200146470ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT version: "2" linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - forbidigo # Forbids identifiers - forcetypeassert # finds forced type assertions - gochecknoglobals # Checks that no globals are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - goheader # Checks is file header matches to pattern - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - lll # Reports long lines - maintidx # maintidx measures the maintainability index of each function. - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - tagliatelle # Checks the struct tags. - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - funlen # Tool for detection of long functions - gochecknoinits # Checks that no init functions are present in Go code - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! settings: staticcheck: checks: - all - -QF1008 # "could remove embedded field", to keep it explicit! - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! exhaustive: default-signifies-exhaustive: true forbidigo: forbid: - pattern: ^fmt.Print(f|ln)?$ - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ - pattern: ^os.Exit$ - pattern: ^panic$ - pattern: ^print(ln)?$ - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ pkg: ^testing$ msg: use testify/assert instead analyze-types: true gomodguard: blocked: modules: - github.com/pkg/errors: recommendations: - errors govet: enable: - shadow revive: rules: # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility - name: use-any severity: warning disabled: false misspell: locale: US varnamelen: max-distance: 12 min-name-length: 2 ignore-type-assert-ok: true ignore-map-index-ok: true ignore-chan-recv-ok: true ignore-decls: - i int - n int - w io.Writer - r io.Reader - b []byte exclusions: generated: lax rules: - linters: - forbidigo - gocognit path: (examples|main\.go) - linters: - gocognit path: _test\.go - linters: - forbidigo path: cmd formatters: enable: - gci # Gci control golang package import order and make it always deterministic. - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports exclusions: generated: lax ice-4.1.0/.goreleaser.yml000066400000000000000000000001711511704470200152050ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT builds: - skip: true ice-4.1.0/.reuse/000077500000000000000000000000001511704470200134565ustar00rootroot00000000000000ice-4.1.0/.reuse/dep5000066400000000000000000000011141511704470200142330ustar00rootroot00000000000000Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Upstream-Name: Pion Source: https://github.com/pion/ Files: README.md DESIGN.md **/README.md AUTHORS.txt renovate.json go.mod go.sum **/go.mod **/go.sum .eslintrc.json package.json examples.json sfu-ws/flutter/.gitignore sfu-ws/flutter/pubspec.yaml c-data-channels/webrtc.h examples/examples.json yarn.lock Copyright: 2023 The Pion community License: MIT Files: testdata/seed/* testdata/fuzz/* **/testdata/fuzz/* api/*.txt Copyright: 2023 The Pion community License: CC0-1.0 ice-4.1.0/LICENSE000066400000000000000000000021051511704470200132600ustar00rootroot00000000000000MIT License Copyright (c) 2023 The Pion community Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ice-4.1.0/LICENSES/000077500000000000000000000000001511704470200134625ustar00rootroot00000000000000ice-4.1.0/LICENSES/MIT.txt000066400000000000000000000020661511704470200146600ustar00rootroot00000000000000MIT License Copyright (c) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ice-4.1.0/README.md000066400000000000000000000042311511704470200135340ustar00rootroot00000000000000


Pion ICE

A Go implementation of ICE

Pion ICE join us on Discord Follow us on Bluesky
GitHub Workflow Status Go Reference Coverage Status Go Report Card License: MIT


### Roadmap The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. ### Community Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. We are always looking to support **your projects**. Please reach out if you have something to build! If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) ### Contributing Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible ### License MIT License - see [LICENSE](LICENSE) for full text ice-4.1.0/active_tcp.go000066400000000000000000000073061511704470200147330ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "context" "io" "net" "net/netip" "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/transport/v3/packetio" ) type activeTCPConn struct { readBuffer, writeBuffer *packetio.Buffer localAddr, remoteAddr atomic.Value closed atomic.Bool } func newActiveTCPConn( ctx context.Context, localAddress string, remoteAddress netip.AddrPort, log logging.LeveledLogger, ) (a *activeTCPConn) { a = &activeTCPConn{ readBuffer: packetio.NewBuffer(), writeBuffer: packetio.NewBuffer(), } laddr, err := getTCPAddrOnInterface(localAddress) if err != nil { a.closed.Store(true) log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) return a } a.localAddr.Store(laddr) go func() { defer func() { a.closed.Store(true) }() dialer := &net.Dialer{ LocalAddr: laddr, } conn, err := dialer.DialContext(ctx, "tcp", remoteAddress.String()) if err != nil { log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) return } a.remoteAddr.Store(conn.RemoteAddr()) go func() { buff := make([]byte, receiveMTU) for !a.closed.Load() { n, err := readStreamingPacket(conn, buff) if err != nil { log.Infof("Failed to read streaming packet: %s", err) break } if _, err := a.readBuffer.Write(buff[:n]); err != nil { log.Infof("Failed to write to buffer: %s", err) break } } }() buff := make([]byte, receiveMTU) for !a.closed.Load() { n, err := a.writeBuffer.Read(buff) if err != nil { log.Infof("Failed to read from buffer: %s", err) break } if _, err = writeStreamingPacket(conn, buff[:n]); err != nil { log.Infof("Failed to write streaming packet: %s", err) break } } if err := conn.Close(); err != nil { log.Infof("Failed to close connection: %s", err) } }() return a } func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err error) { if a.closed.Load() { return 0, nil, io.ErrClosedPipe } n, err = a.readBuffer.Read(buff) // RemoteAddr is assuredly set *after* we can read from the buffer srcAddr = a.RemoteAddr() return } func (a *activeTCPConn) WriteTo(buff []byte, _ net.Addr) (n int, err error) { if a.closed.Load() { return 0, io.ErrClosedPipe } return a.writeBuffer.Write(buff) } func (a *activeTCPConn) Close() error { a.closed.Store(true) _ = a.readBuffer.Close() _ = a.writeBuffer.Close() return nil } func (a *activeTCPConn) LocalAddr() net.Addr { if v, ok := a.localAddr.Load().(*net.TCPAddr); ok { return v } return &net.TCPAddr{} } // RemoteAddr returns the remote address of the connection which is only // set once a background goroutine has successfully dialed. That means // this may return ":0" for the address prior to that happening. If this // becomes an issue, we can introduce a synchronization point between Dial // and these methods. func (a *activeTCPConn) RemoteAddr() net.Addr { if v, ok := a.remoteAddr.Load().(*net.TCPAddr); ok { return v } return &net.TCPAddr{} } func (a *activeTCPConn) SetDeadline(time.Time) error { return io.EOF } func (a *activeTCPConn) SetReadDeadline(time.Time) error { return io.EOF } func (a *activeTCPConn) SetWriteDeadline(time.Time) error { return io.EOF } func getTCPAddrOnInterface(address string) (*net.TCPAddr, error) { addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { return nil, err } l, err := net.ListenTCP("tcp", addr) if err != nil { return nil, err } defer func() { _ = l.Close() }() tcpAddr, ok := l.Addr().(*net.TCPAddr) if !ok { return nil, errInvalidAddress } return tcpAddr, nil } ice-4.1.0/active_tcp_test.go000066400000000000000000000272021511704470200157670ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "context" "fmt" "io" "net" "net/netip" "runtime" "sync/atomic" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) func getLocalIPAddress(t *testing.T, networkType NetworkType) netip.Addr { t.Helper() net, err := stdnet.NewNet() require.NoError(t, err) _, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{networkType}, false) require.NoError(t, err) require.NotEmpty(t, localAddrs) if networkType.IsIPv6() && runtime.GOOS == "darwin" { for _, addr := range localAddrs { if !addr.addr.IsLinkLocalUnicast() { return addr.addr } } t.Skip("no non-link-local IPv6 address available") } return localAddrs[0].addr } func ipv6Available(t *testing.T) bool { t.Helper() net, err := stdnet.NewNet() require.NoError(t, err) _, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{NetworkTypeTCP6}, false) require.NoError(t, err) if runtime.GOOS == "darwin" { for _, addr := range localAddrs { if !addr.addr.IsLinkLocalUnicast() { return true } } return false } return len(localAddrs) > 0 } func TestActiveTCP(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() const listenPort = 7686 type testCase struct { name string networkTypes []NetworkType listenIPAddress netip.Addr selectedPairNetworkType string useMDNS bool } testCases := []testCase{ { name: "TCP4 connection", networkTypes: []NetworkType{NetworkTypeTCP4}, listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP4), selectedPairNetworkType: tcp, }, { name: "UDP is preferred over TCP4", // This fails some time networkTypes: supportedNetworkTypes(), listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP4), selectedPairNetworkType: udp, }, } if ipv6Available(t) { testCases = append(testCases, testCase{ name: "TCP6 connection", networkTypes: []NetworkType{NetworkTypeTCP6}, listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), selectedPairNetworkType: tcp, // if we don't use mDNS, we will very likely be filtering out location tracked ips. useMDNS: true, }, testCase{ name: "UDP is preferred over TCP6", networkTypes: supportedNetworkTypes(), listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), selectedPairNetworkType: udp, // if we don't use mDNS, we will very likely be filtering out location tracked ips. useMDNS: true, }, ) } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { req := require.New(t) listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: testCase.listenIPAddress.AsSlice(), Port: listenPort, Zone: testCase.listenIPAddress.Zone(), }) req.NoError(err) defer func() { _ = listener.Close() }() loggerFactory := logging.NewDefaultLoggerFactory() tcpMux := NewTCPMuxDefault(TCPMuxParams{ Listener: listener, Logger: loggerFactory.NewLogger("passive-ice-tcp-mux"), ReadBufferSize: 20, }) defer func() { _ = tcpMux.Close() }() req.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") hostAcceptanceMinWait := 100 * time.Millisecond cfg := &AgentConfig{ TCPMux: tcpMux, CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: testCase.networkTypes, LoggerFactory: loggerFactory, HostAcceptanceMinWait: &hostAcceptanceMinWait, InterfaceFilter: problematicNetworkInterfaces, IncludeLoopback: true, } if testCase.useMDNS { cfg.MulticastDNSMode = MulticastDNSModeQueryAndGather } passiveAgent, err := NewAgent(cfg) req.NoError(err) req.NotNil(passiveAgent) defer func() { req.NoError(passiveAgent.Close()) }() activeAgent, err := NewAgent(&AgentConfig{ CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: testCase.networkTypes, LoggerFactory: loggerFactory, HostAcceptanceMinWait: &hostAcceptanceMinWait, InterfaceFilter: problematicNetworkInterfaces, IncludeLoopback: true, }) req.NoError(err) req.NotNil(activeAgent) defer func() { req.NoError(activeAgent.Close()) }() passiveAgentConn, activeAgenConn := connect(t, passiveAgent, activeAgent) req.NotNil(passiveAgentConn) req.NotNil(activeAgenConn) defer func() { req.NoError(activeAgenConn.Close()) req.NoError(passiveAgentConn.Close()) }() pair := passiveAgent.getSelectedPair() req.NotNil(pair) req.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort()) foo := []byte("foo") _, err = passiveAgentConn.Write(foo) req.NoError(err) buffer := make([]byte, 1024) n, err := activeAgenConn.Read(buffer) req.NoError(err) req.Equal(foo, buffer[:n]) bar := []byte("bar") _, err = activeAgenConn.Write(bar) req.NoError(err) n, err = passiveAgentConn.Read(buffer) req.NoError(err) req.Equal(bar, buffer[:n]) }) } } // Assert that Active TCP connectivity isn't established inside. // the main thread of the Agent. func TestActiveTCP_NonBlocking(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), InterfaceFilter: problematicNetworkInterfaces, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() isConnected := make(chan any) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateConnected { close(isConnected) } }) require.NoError(t, err) // Add a invalid ice-tcp candidate to each invalidCandidate, err := UnmarshalCandidate("1052353102 1 tcp 1675624447 192.0.2.1 8080 typ host tcptype passive") require.NoError(t, err) require.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate)) require.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate)) connect(t, aAgent, bAgent) <-isConnected } // Assert that we ignore remote TCP candidates when running a UDP Only Agent. func TestActiveTCP_Respect_NetworkTypes(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() tcpListener, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx require.NoError(t, err) _, port, err := net.SplitHostPort(tcpListener.Addr().String()) require.NoError(t, err) var incomingTCPCount uint64 go func() { for { conn, listenErr := tcpListener.Accept() if listenErr != nil { return } require.NoError(t, conn.Close()) atomic.AddUint64(&incomingTCPCount, ^uint64(0)) } }() cfg := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6, NetworkTypeTCP6}, InterfaceFilter: problematicNetworkInterfaces, IncludeLoopback: true, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() isConnected := make(chan any) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateConnected { close(isConnected) } }) require.NoError(t, err) invalidCandidate, err := UnmarshalCandidate( fmt.Sprintf("1052353102 1 tcp 1675624447 127.0.0.1 %s typ host tcptype passive", port), ) require.NoError(t, err) require.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate)) require.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate)) connect(t, aAgent, bAgent) <-isConnected require.NoError(t, tcpListener.Close()) require.Equal(t, uint64(0), atomic.LoadUint64(&incomingTCPCount)) } func TestNewActiveTCPConn_LocalAddrError_EarlyReturn(t *testing.T) { defer test.CheckRoutines(t)() logger := logging.NewDefaultLoggerFactory().NewLogger("ice") // an invalid local address so getTCPAddrOnInterface fails at ResolveTCPAddr. ctx, cancel := context.WithCancel(context.Background()) defer cancel() ra := netip.MustParseAddrPort("127.0.0.1:1") a := newActiveTCPConn(ctx, "this_is_not_a_valid_addr", ra, logger) require.NotNil(t, a) require.True(t, a.closed.Load(), "should be closed on early return error") la := a.LocalAddr() require.NotNil(t, la) } func TestActiveTCPConn_ReadLoop_BufferWriteError(t *testing.T) { defer test.CheckRoutines(t)() tcpListener, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx require.NoError(t, err) defer func() { _ = tcpListener.Close() }() ra := netip.MustParseAddrPort(tcpListener.Addr().String()) logger := logging.NewDefaultLoggerFactory().NewLogger("ice") ctx, cancel := context.WithCancel(context.Background()) defer cancel() a := newActiveTCPConn(ctx, "127.0.0.1:0", ra, logger) require.NotNil(t, a) srvConn, err := tcpListener.Accept() require.NoError(t, err) require.NoError(t, a.readBuffer.Close()) _, err = writeStreamingPacket(srvConn, []byte("ping")) require.NoError(t, err) require.NoError(t, a.Close()) require.NoError(t, srvConn.Close()) } func TestActiveTCPConn_WriteLoop_WriteStreamingError(t *testing.T) { defer test.CheckRoutines(t)() tcpListener, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx require.NoError(t, err) defer func() { _ = tcpListener.Close() }() ra := netip.MustParseAddrPort(tcpListener.Addr().String()) logger := logging.NewDefaultLoggerFactory().NewLogger("ice") ctx, cancel := context.WithCancel(context.Background()) defer cancel() a := newActiveTCPConn(ctx, "127.0.0.1:0", ra, logger) require.NotNil(t, a) srvConn, err := tcpListener.Accept() require.NoError(t, err) require.NoError(t, srvConn.Close()) n, err := a.WriteTo([]byte("data"), nil) require.NoError(t, err) require.Equal(t, len("data"), n) require.NoError(t, a.Close()) } func TestActiveTCPConn_LocalAddr_DefaultWhenUnset(t *testing.T) { defer test.CheckRoutines(t)() ctx, cancel := context.WithCancel(context.Background()) defer cancel() invalidLocal := "127.0.0.1:65536" remote := netip.MustParseAddrPort("127.0.0.1:1") log := logging.NewDefaultLoggerFactory().NewLogger("ice") a := newActiveTCPConn(ctx, invalidLocal, remote, log) require.NotNil(t, a) require.True(t, a.closed.Load(), "expected early-return closed state") la := a.LocalAddr() ta, ok := la.(*net.TCPAddr) require.True(t, ok, "LocalAddr() should return *net.TCPAddr") require.Nil(t, ta.IP, "fallback *net.TCPAddr should be zero value (nil IP)") require.Equal(t, 0, ta.Port, "fallback *net.TCPAddr should be zero value (port 0)") require.Equal(t, "", ta.Zone, "fallback *net.TCPAddr should be zero value (empty zone)") } func TestActiveTCPConn_SetDeadlines_ReturnEOF(t *testing.T) { defer test.CheckRoutines(t)() ctx, cancel := context.WithCancel(context.Background()) defer cancel() invalidLocal := "127.0.0.1:65536" remote := netip.MustParseAddrPort("127.0.0.1:1") log := logging.NewDefaultLoggerFactory().NewLogger("ice") a := newActiveTCPConn(ctx, invalidLocal, remote, log) require.NotNil(t, a) require.True(t, a.closed.Load(), "expected early-return closed state") err := a.SetReadDeadline(time.Now()) require.ErrorIs(t, err, io.EOF) err = a.SetWriteDeadline(time.Now()) require.ErrorIs(t, err, io.EOF) } ice-4.1.0/addr.go000066400000000000000000000071071511704470200135230ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" "net" "net/netip" ) func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr { if zone == "" { return addr } if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) { return addr.WithZone(zone) } return addr } // parseAddrFromIface should only be used when it's known the address belongs to that interface. // e.g. it's LocalAddress on a listener. func parseAddrFromIface(in net.Addr, ifcName string) (netip.Addr, int, NetworkType, error) { addr, port, nt, err := parseAddr(in) if err != nil { return netip.Addr{}, 0, 0, err } if _, ok := in.(*net.IPNet); ok { // net.IPNet does not have a Zone but we provide it from the interface addr = addrWithOptionalZone(addr, ifcName) } return addr, port, nt, nil } func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) { //nolint:cyclop switch addr := in.(type) { case *net.IPNet: ipAddr, err := ipAddrToNetIP(addr.IP, "") if err != nil { return netip.Addr{}, 0, 0, err } return ipAddr, 0, 0, nil case *net.IPAddr: ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) if err != nil { return netip.Addr{}, 0, 0, err } return ipAddr, 0, 0, nil case *net.UDPAddr: ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) if err != nil { return netip.Addr{}, 0, 0, err } var nt NetworkType if ipAddr.Is4() { nt = NetworkTypeUDP4 } else { nt = NetworkTypeUDP6 } return ipAddr, addr.Port, nt, nil case *net.TCPAddr: ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) if err != nil { return netip.Addr{}, 0, 0, err } var nt NetworkType if ipAddr.Is4() { nt = NetworkTypeTCP4 } else { nt = NetworkTypeTCP6 } return ipAddr, addr.Port, nt, nil default: return netip.Addr{}, 0, 0, addrParseError{in} } } type addrParseError struct { addr net.Addr } func (e addrParseError) Error() string { return fmt.Sprintf("do not know how to parse address type %T", e.addr) } type ipConvertError struct { ip []byte } func (e ipConvertError) Error() string { return fmt.Sprintf("failed to convert IP '%s' to netip.Addr", e.ip) } func ipAddrToNetIP(ip []byte, zone string) (netip.Addr, error) { netIPAddr, ok := netip.AddrFromSlice(ip) if !ok { return netip.Addr{}, ipConvertError{ip} } // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. netIPAddr = netIPAddr.Unmap() netIPAddr = addrWithOptionalZone(netIPAddr, zone) return netIPAddr, nil } func createAddr(network NetworkType, ip netip.Addr, port int) net.Addr { switch { case network.IsTCP(): return &net.TCPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()} default: return &net.UDPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()} } } func addrEqual(a, b net.Addr) bool { aIP, aPort, aType, aErr := parseAddr(a) if aErr != nil { return false } bIP, bPort, bType, bErr := parseAddr(b) if bErr != nil { return false } return aType == bType && aIP.Compare(bIP) == 0 && aPort == bPort } // AddrPort is an IP and a port number. type AddrPort [18]byte func toAddrPort(addr net.Addr) AddrPort { var ap AddrPort switch addr := addr.(type) { case *net.UDPAddr: copy(ap[:16], addr.IP.To16()) ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive case *net.TCPAddr: copy(ap[:16], addr.IP.To16()) ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive } return ap } ice-4.1.0/addr_test.go000066400000000000000000000072251511704470200145630ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "net" "net/netip" "testing" "github.com/stretchr/testify/require" ) // A net.Addr type that parseAddr doesn't handle. type unknownAddr struct{} func (unknownAddr) Network() string { return "unknown" } func (unknownAddr) String() string { return "unknown-addr" } func TestParseAddrFromIface_ErrFromParseAddr(t *testing.T) { in := unknownAddr{} ip, port, nt, err := parseAddrFromIface(in, "eth0") require.Error(t, err, "expected error from parseAddr for unknown net.Addr type") require.Zero(t, port) require.Zero(t, nt) require.True(t, !ip.IsValid(), "ip should be zero value when error is returned") } func TestParseAddr_ErrorBranches(t *testing.T) { t.Run("IPNet invalid IP -> error", func(t *testing.T) { // length 1 slice -> ipAddrToNetIP fails _, _, _, err := parseAddr(&net.IPNet{IP: net.IP{1}}) var convErr ipConvertError require.ErrorAs(t, err, &convErr) }) t.Run("IPAddr invalid IP -> error", func(t *testing.T) { _, _, _, err := parseAddr(&net.IPAddr{IP: net.IP{1}, Zone: "eth0"}) var convErr ipConvertError require.ErrorAs(t, err, &convErr) }) t.Run("UDPAddr invalid IP -> error", func(t *testing.T) { _, _, _, err := parseAddr(&net.UDPAddr{IP: net.IP{1}, Port: 3478}) var convErr ipConvertError require.ErrorAs(t, err, &convErr) }) t.Run("TCPAddr invalid IP -> error", func(t *testing.T) { _, _, _, err := parseAddr(&net.TCPAddr{IP: net.IP{1}, Port: 3478}) var convErr ipConvertError require.ErrorAs(t, err, &convErr) }) t.Run("Unknown net.Addr type -> addrParseError", func(t *testing.T) { _, _, _, err := parseAddr(unknownAddr{}) var ap addrParseError require.ErrorAs(t, err, &ap) }) } func TestParseAddr_IPAddr_Success(t *testing.T) { ip := net.ParseIP("fe80::1") require.NotNil(t, ip) gotIP, port, nt, err := parseAddr(&net.IPAddr{IP: ip, Zone: "lo0"}) require.NoError(t, err) require.Equal(t, 0, port) require.Equal(t, NetworkType(0), nt) require.True(t, gotIP.Is6()) require.Equal(t, "lo0", gotIP.Zone()) require.Equal(t, 0, gotIP.Compare(netip.MustParseAddr("fe80::1%lo0").Unmap())) } func TestAddrParseError_Error(t *testing.T) { e := addrParseError{addr: &net.TCPAddr{}} require.Equal(t, "do not know how to parse address type *net.TCPAddr", e.Error(), ) } func TestIPConvertError_Error(t *testing.T) { e := ipConvertError{ip: []byte("bad-ip")} require.Equal(t, "failed to convert IP 'bad-ip' to netip.Addr", e.Error(), ) } func TestIPAddrToNetIP_Error_InvalidBytes(t *testing.T) { bad := []byte{1} // invalid length -> AddrFromSlice returns ok=false got, err := ipAddrToNetIP(bad, "") require.Equal(t, netip.Addr{}, got, "should return zero addr on error") require.Error(t, err) require.IsType(t, ipConvertError{}, err) require.Contains(t, err.Error(), "failed to convert IP") } func TestIPAddrToNetIP_OK_IPv4(t *testing.T) { ipv4 := []byte{1, 2, 3, 4} got, err := ipAddrToNetIP(ipv4, "") require.NoError(t, err) require.True(t, got.Is4()) want := netip.AddrFrom4([4]byte{1, 2, 3, 4}) require.Equal(t, want, got) } func TestAddrEqual_FirstParseError(t *testing.T) { a := unknownAddr{} b := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 9999} require.False(t, addrEqual(a, b)) } func TestAddrEqual_SecondParseError(t *testing.T) { a := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 9999} b := unknownAddr{} require.False(t, addrEqual(a, b)) } func TestAddrEqual_SameTypeIPPort(t *testing.T) { a := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 4242} b := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 4242} require.True(t, addrEqual(a, b)) } ice-4.1.0/agent.go000066400000000000000000001452111511704470200137060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package ice implements the Interactive Connectivity Establishment (ICE) // protocol defined in rfc5245. package ice import ( "context" "fmt" "math" "net" "net/netip" "strings" "sync" "sync/atomic" "time" stunx "github.com/pion/ice/v4/internal/stun" "github.com/pion/ice/v4/internal/taskloop" "github.com/pion/logging" "github.com/pion/mdns/v2" "github.com/pion/stun/v3" "github.com/pion/transport/v3" "github.com/pion/transport/v3/packetio" "github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/vnet" "github.com/pion/turn/v4" "golang.org/x/net/proxy" ) type bindingRequest struct { timestamp time.Time transactionID [stun.TransactionIDSize]byte destination net.Addr isUseCandidate bool nominationValue *uint32 // Tracks nomination value for renomination requests } // Agent represents the ICE agent. type Agent struct { loop *taskloop.Loop onConnectionStateChangeHdlr atomic.Value // func(ConnectionState) onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate) onCandidateHdlr atomic.Value // func(Candidate) onConnected chan struct{} onConnectedOnce sync.Once // Force candidate to be contacted immediately (instead of waiting for task ticker) forceCandidateContact chan bool tieBreaker uint64 lite bool connectionState ConnectionState gatheringState GatheringState mDNSMode MulticastDNSMode mDNSName string mDNSConn *mdns.Conn muHaveStarted sync.Mutex startedCh <-chan struct{} startedFn func() isControlling atomic.Bool maxBindingRequests uint16 hostAcceptanceMinWait time.Duration srflxAcceptanceMinWait time.Duration prflxAcceptanceMinWait time.Duration relayAcceptanceMinWait time.Duration stunGatherTimeout time.Duration tcpPriorityOffset uint16 disableActiveTCP bool portMin uint16 portMax uint16 candidateTypes []CandidateType // How long connectivity checks can fail before the ICE Agent // goes to disconnected disconnectedTimeout time.Duration // How long connectivity checks can fail before the ICE Agent // goes to failed failedTimeout time.Duration // How often should we send keepalive packets? // 0 means never keepaliveInterval time.Duration // How often should we run our internal taskLoop to check for state changes when connecting checkInterval time.Duration localUfrag string localPwd string localCandidates map[NetworkType][]Candidate remoteUfrag string remotePwd string remoteCandidates map[NetworkType][]Candidate checklist []*CandidatePair selectorLock sync.RWMutex selector pairCandidateSelector selectedPair atomic.Value // *CandidatePair urls []*stun.URI networkTypes []NetworkType addressRewriteRules []AddressRewriteRule buf *packetio.Buffer // LRU of outbound Binding request Transaction IDs pendingBindingRequests []bindingRequest // Address rewrite (1:1) IP mapping addressRewriteMapper *addressRewriteMapper // Callback that allows user to implement custom behavior // for STUN Binding Requests userBindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool gatherCandidateCancel func() gatherCandidateDone chan struct{} connectionStateNotifier *handlerNotifier candidateNotifier *handlerNotifier selectedCandidatePairNotifier *handlerNotifier loggerFactory logging.LoggerFactory log logging.LeveledLogger net transport.Net tcpMux TCPMux udpMux UDPMux udpMuxSrflx UniversalUDPMux interfaceFilter func(string) (keep bool) ipFilter func(net.IP) (keep bool) includeLoopback bool insecureSkipVerify bool proxyDialer proxy.Dialer enableUseCandidateCheckPriority bool // Renomination support enableRenomination bool nominationValueGenerator func() uint32 nominationAttribute stun.AttrType // Continual gathering support continualGatheringPolicy ContinualGatheringPolicy networkMonitorInterval time.Duration lastKnownInterfaces map[string]netip.Addr // map[iface+ip] for deduplication // Automatic renomination automaticRenomination bool renominationInterval time.Duration lastRenominationTime time.Time turnClientFactory func(*turn.ClientConfig) (turnClient, error) } // NewAgent creates a new Agent. // Deprecated: use NewAgentWithOptions instead. func NewAgent(config *AgentConfig) (*Agent, error) { return newAgentFromConfig(config) } // NewAgentWithOptions creates a new Agent with options only. func NewAgentWithOptions(opts ...AgentOption) (*Agent, error) { return newAgentFromConfig(&AgentConfig{}, opts...) } func newAgentFromConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error) { if config == nil { config = &AgentConfig{} } agent, err := createAgentBase(config) if err != nil { return nil, err } agent.localUfrag = config.LocalUfrag agent.localPwd = config.LocalPwd if config.NAT1To1IPs != nil { if err := validateLegacyNAT1To1IPs(config.NAT1To1IPs); err != nil { return nil, err } typ := CandidateTypeHost if config.NAT1To1IPCandidateType != CandidateTypeUnspecified { typ = config.NAT1To1IPCandidateType } rules, err := legacyNAT1To1Rules(config.NAT1To1IPs, typ) if err != nil { return nil, err } agent.addressRewriteRules = rules } return newAgentWithConfig(agent, opts...) } func validateLegacyNAT1To1IPs(ips []string) error { var hasIPv4CatchAll, hasIPv6CatchAll bool for _, mapping := range ips { trimmed := strings.TrimSpace(mapping) var err error hasIPv4CatchAll, hasIPv6CatchAll, err = validateLegacyNAT1To1Entry(trimmed, hasIPv4CatchAll, hasIPv6CatchAll) if err != nil { return err } } return nil } func validateLegacyNAT1To1Entry(mapping string, hasIPv4CatchAll, hasIPv6CatchAll bool) (bool, bool, error) { if mapping == "" { return hasIPv4CatchAll, hasIPv6CatchAll, nil } parts := strings.Split(mapping, "/") if len(parts) == 0 || len(parts) > 2 { return hasIPv4CatchAll, hasIPv6CatchAll, ErrInvalidNAT1To1IPMapping } _, isIPv4, err := validateIPString(parts[0]) if err != nil { return hasIPv4CatchAll, hasIPv6CatchAll, err } if len(parts) == 2 { if _, _, err := validateIPString(strings.TrimSpace(parts[1])); err != nil { return hasIPv4CatchAll, hasIPv6CatchAll, err } return hasIPv4CatchAll, hasIPv6CatchAll, nil } if isIPv4 { if hasIPv4CatchAll { return hasIPv4CatchAll, hasIPv6CatchAll, ErrInvalidNAT1To1IPMapping } return true, hasIPv6CatchAll, nil } if hasIPv6CatchAll { return hasIPv4CatchAll, hasIPv6CatchAll, ErrInvalidNAT1To1IPMapping } return hasIPv4CatchAll, true, nil } func legacyNAT1To1Rules(ips []string, candidateType CandidateType) ([]AddressRewriteRule, error) { var rules []AddressRewriteRule for _, mapping := range ips { trimmed := strings.TrimSpace(mapping) if trimmed == "" { continue } parts := strings.Split(trimmed, "/") switch len(parts) { case 1: rules = append(rules, AddressRewriteRule{ External: []string{parts[0]}, AsCandidateType: candidateType, }) case 2: ext := strings.TrimSpace(parts[0]) local := strings.TrimSpace(parts[1]) if ext == "" || local == "" { return nil, ErrInvalidNAT1To1IPMapping } if _, _, err := validateIPString(ext); err != nil { return nil, err } if _, _, err := validateIPString(local); err != nil { return nil, err } rules = append(rules, AddressRewriteRule{ External: []string{ext}, Local: local, AsCandidateType: candidateType, }) default: return nil, ErrInvalidNAT1To1IPMapping } } return rules, nil } func createAgentBase(config *AgentConfig) (*Agent, error) { if config.PortMax < config.PortMin { return nil, ErrPort } mDNSName, mDNSMode, err := setupMDNSConfig(config) if err != nil { return nil, err } loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() } log := loggerFactory.NewLogger("ice") startedCtx, startedFn := context.WithCancel(context.Background()) agent := &Agent{ tieBreaker: globalMathRandomGenerator.Uint64(), lite: config.Lite, gatheringState: GatheringStateNew, connectionState: ConnectionStateNew, localCandidates: make(map[NetworkType][]Candidate), remoteCandidates: make(map[NetworkType][]Candidate), urls: config.Urls, networkTypes: config.NetworkTypes, onConnected: make(chan struct{}), buf: packetio.NewBuffer(), startedCh: startedCtx.Done(), startedFn: startedFn, portMin: config.PortMin, portMax: config.PortMax, loggerFactory: loggerFactory, log: log, net: config.Net, proxyDialer: config.ProxyDialer, tcpMux: config.TCPMux, udpMux: config.UDPMux, udpMuxSrflx: config.UDPMuxSrflx, mDNSMode: mDNSMode, mDNSName: mDNSName, gatherCandidateCancel: func() {}, forceCandidateContact: make(chan bool, 1), interfaceFilter: config.InterfaceFilter, ipFilter: config.IPFilter, insecureSkipVerify: config.InsecureSkipVerify, includeLoopback: config.IncludeLoopback, disableActiveTCP: config.DisableActiveTCP, userBindingRequestHandler: config.BindingRequestHandler, enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority, enableRenomination: false, nominationValueGenerator: nil, nominationAttribute: stun.AttrType(0x0030), // Default value continualGatheringPolicy: GatherOnce, // Default to GatherOnce networkMonitorInterval: 2 * time.Second, lastKnownInterfaces: make(map[string]netip.Addr), automaticRenomination: false, renominationInterval: 3 * time.Second, // Default matching libwebrtc turnClientFactory: defaultTurnClient, } config.initWithDefaults(agent) return agent, nil } func applyAddressRewriteMapping(agent *Agent) error { mapper, err := newAddressRewriteMapper(agent.addressRewriteRules) if err != nil { return err } agent.addressRewriteMapper = mapper if agent.addressRewriteMapper == nil { return nil } if agent.addressRewriteMapper.hasCandidateType(CandidateTypeHost) { // for mDNS QueryAndGather we never advertise rewritten host IPs to avoid // leaking local addresses, this matches the legacy NAT1:1 behavior. if agent.mDNSMode == MulticastDNSModeQueryAndGather { return ErrMulticastDNSWithNAT1To1IPMapping } // surface misconfiguration when host candidates are disabled but a host // rewrite rule was provided. if !containsCandidateType(CandidateTypeHost, agent.candidateTypes) { return ErrIneffectiveNAT1To1IPMappingHost } } if agent.addressRewriteMapper.hasCandidateType(CandidateTypeServerReflexive) { // surface misconfiguration when srflx candidates are disabled but a srflx // rewrite rule was provided. if !containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) { return ErrIneffectiveNAT1To1IPMappingSrflx } } return nil } // setupMDNSConfig validates and returns mDNS configuration. func setupMDNSConfig(config *AgentConfig) (string, MulticastDNSMode, error) { mDNSName := config.MulticastDNSHostName if mDNSName == "" { var err error if mDNSName, err = generateMulticastDNSName(); err != nil { return "", 0, err } } if !strings.HasSuffix(mDNSName, ".local") || len(strings.Split(mDNSName, ".")) != 2 { return "", 0, ErrInvalidMulticastDNSHostName } mDNSMode := config.MulticastDNSMode if mDNSMode == 0 { mDNSMode = MulticastDNSModeQueryOnly } return mDNSName, mDNSMode, nil } // newAgentWithConfig finalizes a pre-configured agent with optional overrides. // //nolint:gocognit,cyclop func newAgentWithConfig(agent *Agent, opts ...AgentOption) (*Agent, error) { var err error for _, opt := range opts { if err = opt(agent); err != nil { return nil, err } } agent.connectionStateNotifier = &handlerNotifier{ connectionStateFunc: agent.onConnectionStateChange, done: make(chan struct{}), } agent.candidateNotifier = &handlerNotifier{candidateFunc: agent.onCandidate, done: make(chan struct{})} agent.selectedCandidatePairNotifier = &handlerNotifier{ candidatePairFunc: agent.onSelectedCandidatePairChange, done: make(chan struct{}), } if agent.net == nil { agent.net, err = stdnet.NewNet() if err != nil { return nil, fmt.Errorf("failed to create network: %w", err) } } else if _, isVirtual := agent.net.(*vnet.Net); isVirtual { agent.log.Warn("Virtual network is enabled") if agent.mDNSMode != MulticastDNSModeDisabled { agent.log.Warn("Virtual network does not support mDNS yet") } } localIfcs, _, err := localInterfaces( agent.net, agent.interfaceFilter, agent.ipFilter, agent.networkTypes, agent.includeLoopback, ) if err != nil { return nil, fmt.Errorf("error getting local interfaces: %w", err) } // Opportunistic mDNS: If we can't open the connection, that's ok: we // can continue without it. if agent.mDNSConn, agent.mDNSMode, err = createMulticastDNS( agent.net, agent.networkTypes, localIfcs, agent.includeLoopback, agent.mDNSMode, agent.mDNSName, agent.log, agent.loggerFactory, ); err != nil { agent.log.Warnf("Failed to initialize mDNS %s: %v", agent.mDNSName, err) } // Make sure the buffer doesn't grow indefinitely. // NOTE: We actually won't get anywhere close to this limit. // SRTP will constantly read from the endpoint and drop packets if it's full. agent.buf.SetLimitSize(maxBufferSize) if agent.lite && (len(agent.candidateTypes) != 1 || agent.candidateTypes[0] != CandidateTypeHost) { agent.closeMulticastConn() return nil, ErrLiteUsingNonHostCandidates } if len(agent.urls) > 0 && !containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) && !containsCandidateType(CandidateTypeRelay, agent.candidateTypes) { agent.closeMulticastConn() return nil, ErrUselessUrlsProvided } if err = applyAddressRewriteMapping(agent); err != nil { agent.closeMulticastConn() return nil, err } agent.loop = taskloop.New(func() { agent.gatherCandidateCancel() if agent.gatherCandidateDone != nil { <-agent.gatherCandidateDone } agent.removeUfragFromMux() agent.deleteAllCandidates() agent.startedFn() if err := agent.buf.Close(); err != nil { agent.log.Warnf("Failed to close buffer: %v", err) } agent.closeMulticastConn() agent.updateConnectionState(ConnectionStateClosed) }) // Restart is also used to initialize the agent for the first time if err := agent.Restart(agent.localUfrag, agent.localPwd); err != nil { agent.closeMulticastConn() _ = agent.Close() return nil, err } return agent, nil } func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remotePwd string) error { a.muHaveStarted.Lock() defer a.muHaveStarted.Unlock() select { case <-a.startedCh: return ErrMultipleStart default: } if err := a.SetRemoteCredentials(remoteUfrag, remotePwd); err != nil { //nolint:contextcheck return err } a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd) return a.loop.Run(a.loop, func(_ context.Context) { a.isControlling.Store(isControlling) a.remoteUfrag = remoteUfrag a.remotePwd = remotePwd a.setSelector() a.startedFn() a.updateConnectionState(ConnectionStateChecking) a.requestConnectivityCheck() go a.connectivityChecks() //nolint:contextcheck }) } func (a *Agent) connectivityChecks() { //nolint:cyclop lastConnectionState := ConnectionState(0) checkingDuration := time.Time{} contact := func() { if err := a.loop.Run(a.loop, func(_ context.Context) { defer func() { lastConnectionState = a.connectionState }() switch a.connectionState { case ConnectionStateFailed: // The connection is currently failed so don't send any checks // In the future it may be restarted though return case ConnectionStateChecking: // We have just entered checking for the first time so update our checking timer if lastConnectionState != a.connectionState { checkingDuration = time.Now() } // We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed if time.Since(checkingDuration) > a.disconnectedTimeout+a.failedTimeout { a.updateConnectionState(ConnectionStateFailed) return } default: } a.getSelector().ContactCandidates() }); err != nil { a.log.Warnf("Failed to start connectivity checks: %v", err) } } timer := time.NewTimer(math.MaxInt64) timer.Stop() for { interval := defaultKeepaliveInterval updateInterval := func(x time.Duration) { if x != 0 && (interval == 0 || interval > x) { interval = x } } switch lastConnectionState { case ConnectionStateNew, ConnectionStateChecking: // While connecting, check candidates more frequently updateInterval(a.checkInterval) case ConnectionStateConnected, ConnectionStateDisconnected: updateInterval(a.keepaliveInterval) default: } // Ensure we run our task loop as quickly as the minimum of our various configured timeouts updateInterval(a.disconnectedTimeout) updateInterval(a.failedTimeout) timer.Reset(interval) select { case <-a.forceCandidateContact: if !timer.Stop() { <-timer.C } contact() case <-timer.C: contact() case <-a.loop.Done(): timer.Stop() return } } } func (a *Agent) updateConnectionState(newState ConnectionState) { if a.connectionState != newState { // Connection has gone to failed, release all gathered candidates if newState == ConnectionStateFailed { a.removeUfragFromMux() a.checklist = make([]*CandidatePair, 0) a.pendingBindingRequests = make([]bindingRequest, 0) a.setSelectedPair(nil) a.deleteAllCandidates() } a.log.Infof("Setting new connection state: %s", newState) a.connectionState = newState a.connectionStateNotifier.EnqueueConnectionState(newState) } } func (a *Agent) setSelectedPair(pair *CandidatePair) { if pair == nil { var nilPair *CandidatePair a.selectedPair.Store(nilPair) a.log.Tracef("Unset selected candidate pair") return } pair.nominated = true a.selectedPair.Store(pair) a.log.Tracef("Set selected candidate pair: %s", pair) // Signal connected: notify any Connect() calls waiting on onConnected a.onConnectedOnce.Do(func() { close(a.onConnected) }) // Update connection state to Connected and notify state change handlers a.updateConnectionState(ConnectionStateConnected) // Notify when the selected candidate pair changes a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(pair) } func (a *Agent) pingAllCandidates() { a.log.Trace("Pinging all candidates") if len(a.checklist) == 0 { a.log.Warn("Failed to ping without candidate pairs. Connection is not possible yet.") } for _, p := range a.checklist { if p.state == CandidatePairStateWaiting { p.state = CandidatePairStateInProgress } else if p.state != CandidatePairStateInProgress { continue } if p.bindingRequestCount > a.maxBindingRequests { a.log.Tracef("Maximum requests reached for pair %s, marking it as failed", p) p.state = CandidatePairStateFailed } else { a.getSelector().PingCandidate(p.Local, p.Remote) p.bindingRequestCount++ } } } // keepAliveCandidatesForRenomination pings all candidate pairs to keep them tested // and ready for automatic renomination. Unlike pingAllCandidates, this: // - Pings pairs in succeeded state to keep RTT measurements fresh // - Ignores maxBindingRequests limit (we want to keep testing alternate paths) // - Only pings pairs that are not failed. func (a *Agent) keepAliveCandidatesForRenomination() { a.log.Trace("Keep alive candidates for automatic renomination") if len(a.checklist) == 0 { return } for _, pair := range a.checklist { switch pair.state { case CandidatePairStateFailed: // Skip failed pairs continue case CandidatePairStateWaiting: // Transition waiting pairs to in-progress pair.state = CandidatePairStateInProgress case CandidatePairStateInProgress, CandidatePairStateSucceeded: // Continue pinging in-progress and succeeded pairs } // Ping all non-failed pairs (including succeeded ones) // to keep RTT measurements fresh for renomination decisions a.getSelector().PingCandidate(pair.Local, pair.Remote) } } func (a *Agent) getBestAvailableCandidatePair() *CandidatePair { var best *CandidatePair for _, p := range a.checklist { if p.state == CandidatePairStateFailed { continue } if best == nil { best = p } else if best.priority() < p.priority() { best = p } } return best } func (a *Agent) getBestValidCandidatePair() *CandidatePair { var best *CandidatePair for _, p := range a.checklist { if p.state != CandidatePairStateSucceeded { continue } if best == nil { best = p } else if best.priority() < p.priority() { best = p } } return best } func (a *Agent) addPair(local, remote Candidate) *CandidatePair { p := newCandidatePair(local, remote, a.isControlling.Load()) a.checklist = append(a.checklist, p) return p } func (a *Agent) findPair(local, remote Candidate) *CandidatePair { for _, p := range a.checklist { if p.Local.Equal(local) && p.Remote.Equal(remote) { return p } } return nil } // validateSelectedPair checks if the selected pair is (still) valid // Note: the caller should hold the agent lock. func (a *Agent) validateSelectedPair() bool { selectedPair := a.getSelectedPair() if selectedPair == nil { return false } disconnectedTime := time.Since(selectedPair.Remote.LastReceived()) // Only allow transitions to failed if a.failedTimeout is non-zero totalTimeToFailure := a.failedTimeout if totalTimeToFailure != 0 { totalTimeToFailure += a.disconnectedTimeout } a.updateConnectionState(a.connectionStateForDisconnection(disconnectedTime, totalTimeToFailure)) return true } func (a *Agent) connectionStateForDisconnection( disconnectedTime time.Duration, totalTimeToFailure time.Duration, ) ConnectionState { disconnected := a.disconnectedTimeout != 0 && disconnectedTime > a.disconnectedTimeout failed := totalTimeToFailure != 0 && disconnectedTime > totalTimeToFailure switch { case failed: if disconnected && a.connectionState != ConnectionStateDisconnected && a.connectionState != ConnectionStateFailed { // If we never reported disconnected but both thresholds are already exceeded, // emit disconnected first so callers can observe both transitions. return ConnectionStateDisconnected } return ConnectionStateFailed case disconnected: return ConnectionStateDisconnected default: return ConnectionStateConnected } } // checkKeepalive sends STUN Binding Indications to the selected pair // if no packet has been sent on that pair in the last keepaliveInterval // Note: the caller should hold the agent lock. func (a *Agent) checkKeepalive() { selectedPair := a.getSelectedPair() if selectedPair == nil { return } if a.keepaliveInterval != 0 { // We use binding request instead of indication to support refresh consent schemas // see https://tools.ietf.org/html/rfc7675 a.getSelector().PingCandidate(selectedPair.Local, selectedPair.Remote) } } // AddRemoteCandidate adds a new remote candidate. func (a *Agent) AddRemoteCandidate(cand Candidate) error { if cand == nil { return nil } // TCP Candidates with TCP type active will probe server passive ones, so // no need to do anything with them. if cand.TCPType() == TCPTypeActive { a.log.Infof("Ignoring remote candidate with tcpType active: %s", cand) return nil } // If we have a mDNS Candidate lets fully resolve it before adding it locally if cand.Type() == CandidateTypeHost && strings.HasSuffix(cand.Address(), ".local") { if a.mDNSMode == MulticastDNSModeDisabled { a.log.Warnf("Remote mDNS candidate added, but mDNS is disabled: (%s)", cand.Address()) return nil } hostCandidate, ok := cand.(*CandidateHost) if !ok { return ErrAddressParseFailed } go a.resolveAndAddMulticastCandidate(hostCandidate) return nil } go func() { if err := a.loop.Run(a.loop, func(_ context.Context) { // nolint: contextcheck a.addRemoteCandidate(cand) }); err != nil { a.log.Warnf("Failed to add remote candidate %s: %v", cand.Address(), err) return } }() return nil } func (a *Agent) resolveAndAddMulticastCandidate(cand *CandidateHost) { if a.mDNSConn == nil { return } _, src, err := a.mDNSConn.QueryAddr(cand.context(), cand.Address()) if err != nil { a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err) return } if err = cand.setIPAddr(src); err != nil { a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err) return } if err = a.loop.Run(a.loop, func(_ context.Context) { // nolint: contextcheck a.addRemoteCandidate(cand) }); err != nil { a.log.Warnf("Failed to add mDNS candidate %s: %v", cand.Address(), err) return } } func (a *Agent) requestConnectivityCheck() { select { case a.forceCandidateContact <- true: default: } } func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) { _, localIPs, err := localInterfaces( a.net, a.interfaceFilter, a.ipFilter, []NetworkType{remoteCandidate.NetworkType()}, a.includeLoopback, ) if err != nil { a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) return } for i := range localIPs { ip, _, _, err := parseAddr(remoteCandidate.addr()) if err != nil { a.log.Warnf("Failed to parse address: %s; error: %s", remoteCandidate.addr(), err) continue } conn := newActiveTCPConn( a.loop, net.JoinHostPort(localIPs[i].addr.String(), "0"), netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())), //nolint:gosec // G115, no overflow, a port a.log, ) tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr) if !ok { closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", errInvalidAddress) continue } localCandidate, err := NewCandidateHost(&CandidateHostConfig{ Network: remoteCandidate.NetworkType().String(), Address: localIPs[i].addr.String(), Port: tcpAddr.Port, Component: ComponentRTP, TCPType: TCPTypeActive, }) if err != nil { closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", err) continue } localCandidate.start(a, conn, a.startedCh) a.localCandidates[localCandidate.NetworkType()] = append( a.localCandidates[localCandidate.NetworkType()], localCandidate, ) a.candidateNotifier.EnqueueCandidate(localCandidate) a.addPair(localCandidate, remoteCandidate) } } // addRemoteCandidate assumes you are holding the lock (must be execute using a.run). func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop set := a.remoteCandidates[cand.NetworkType()] for _, candidate := range set { if candidate.Equal(cand) { return } } acceptRemotePassiveTCPCandidate := false // Assert that TCP4 or TCP6 is a enabled NetworkType locally if !a.disableActiveTCP && cand.TCPType() == TCPTypePassive { for _, networkType := range a.networkTypes { if cand.NetworkType() == networkType { acceptRemotePassiveTCPCandidate = true } } } if acceptRemotePassiveTCPCandidate { a.addRemotePassiveTCPCandidate(cand) } set = append(set, cand) a.remoteCandidates[cand.NetworkType()] = set if cand.TCPType() != TCPTypePassive { if localCandidates, ok := a.localCandidates[cand.NetworkType()]; ok { for _, localCandidate := range localCandidates { a.addPair(localCandidate, cand) } } } a.requestConnectivityCheck() } func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn) error { if err := ctx.Err(); err != nil { return err } return a.loop.Run(ctx, func(context.Context) { set := a.localCandidates[cand.NetworkType()] for _, candidate := range set { if candidate.Equal(cand) { a.log.Debugf("Ignore duplicate candidate: %s", cand) if err := cand.close(); err != nil { a.log.Warnf("Failed to close duplicate candidate: %v", err) } if err := candidateConn.Close(); err != nil { a.log.Warnf("Failed to close duplicate candidate connection: %v", err) } return } } a.setCandidateExtensions(cand) cand.start(a, candidateConn, a.startedCh) set = append(set, cand) a.localCandidates[cand.NetworkType()] = set if remoteCandidates, ok := a.remoteCandidates[cand.NetworkType()]; ok { for _, remoteCandidate := range remoteCandidates { a.addPair(cand, remoteCandidate) } } a.requestConnectivityCheck() if !cand.filterForLocationTracking() { a.candidateNotifier.EnqueueCandidate(cand) } }) } func (a *Agent) setCandidateExtensions(cand Candidate) { err := cand.AddExtension(CandidateExtension{ Key: "ufrag", Value: a.localUfrag, }) if err != nil { a.log.Errorf("Failed to add ufrag extension to candidate: %v", err) } } // GetRemoteCandidates returns the remote candidates. func (a *Agent) GetRemoteCandidates() ([]Candidate, error) { var res []Candidate err := a.loop.Run(a.loop, func(_ context.Context) { var candidates []Candidate for _, set := range a.remoteCandidates { candidates = append(candidates, set...) } res = candidates }) if err != nil { return nil, err } return res, nil } // GetLocalCandidates returns the local candidates. func (a *Agent) GetLocalCandidates() ([]Candidate, error) { var res []Candidate err := a.loop.Run(a.loop, func(_ context.Context) { var candidates []Candidate for _, set := range a.localCandidates { for _, c := range set { if c.filterForLocationTracking() { continue } candidates = append(candidates, c) } } res = candidates }) if err != nil { return nil, err } return res, nil } // GetGatheringState returns the current gathering state of the Agent. func (a *Agent) GetGatheringState() (GatheringState, error) { var state GatheringState err := a.loop.Run(a.loop, func(_ context.Context) { state = a.gatheringState }) if err != nil { return GatheringStateUnknown, err } return state, nil } // GetLocalUserCredentials returns the local user credentials. func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) { valSet := make(chan struct{}) err = a.loop.Run(a.loop, func(_ context.Context) { frag = a.localUfrag pwd = a.localPwd close(valSet) }) if err == nil { <-valSet } return } // GetRemoteUserCredentials returns the remote user credentials. func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) { valSet := make(chan struct{}) err = a.loop.Run(a.loop, func(_ context.Context) { frag = a.remoteUfrag pwd = a.remotePwd close(valSet) }) if err == nil { <-valSet } return } func (a *Agent) removeUfragFromMux() { if a.tcpMux != nil { a.tcpMux.RemoveConnByUfrag(a.localUfrag) } if a.udpMux != nil { a.udpMux.RemoveConnByUfrag(a.localUfrag) } if a.udpMuxSrflx != nil { a.udpMuxSrflx.RemoveConnByUfrag(a.localUfrag) } } // Close cleans up the Agent. func (a *Agent) Close() error { return a.close(false) } // GracefulClose cleans up the Agent and waits for any goroutines it started // to complete. This is only safe to call outside of Agent callbacks or if in a callback, // in its own goroutine. func (a *Agent) GracefulClose() error { return a.close(true) } func (a *Agent) close(graceful bool) error { // the loop is safe to wait on no matter what a.loop.Close() // but we are in less control of the notifiers, so we will // pass through `graceful`. a.connectionStateNotifier.Close(graceful) a.candidateNotifier.Close(graceful) a.selectedCandidatePairNotifier.Close(graceful) return nil } // Remove all candidates. This closes any listening sockets // and removes both the local and remote candidate lists. // // This is used for restarts, failures and on close. func (a *Agent) deleteAllCandidates() { for net, cs := range a.localCandidates { for _, c := range cs { if err := c.close(); err != nil { a.log.Warnf("Failed to close candidate %s: %v", c, err) } } delete(a.localCandidates, net) } for net, cs := range a.remoteCandidates { for _, c := range cs { if err := c.close(); err != nil { a.log.Warnf("Failed to close candidate %s: %v", c, err) } } delete(a.remoteCandidates, net) } } func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Candidate { ip, port, _, err := parseAddr(addr) if err != nil { a.log.Warnf("Failed to parse address: %s; error: %s", addr, err) return nil } set := a.remoteCandidates[networkType] for _, c := range set { if c.Address() == ip.String() && c.Port() == port { return c } } return nil } func (a *Agent) sendBindingRequest(msg *stun.Message, local, remote Candidate) { a.log.Tracef("Ping STUN from %s to %s", local, remote) // Extract nomination value if present var nominationValue *uint32 var nomination NominationAttribute if err := nomination.GetFromWithType(msg, a.nominationAttribute); err == nil { nominationValue = &nomination.Value } a.invalidatePendingBindingRequests(time.Now()) a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{ timestamp: time.Now(), transactionID: msg.TransactionID, destination: remote.addr(), isUseCandidate: msg.Contains(stun.AttrUseCandidate), nominationValue: nominationValue, }) if pair := a.findPair(local, remote); pair != nil { pair.UpdateRequestSent() } else { a.log.Warnf("Failed to find pair for add binding request from %s to %s", local, remote) } a.sendSTUN(msg, local, remote) } func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) { base := remote ip, port, _, err := parseAddr(base.addr()) if err != nil { a.log.Warnf("Failed to parse address: %s; error: %s", base.addr(), err) return } if out, err := stun.Build(m, stun.BindingSuccess, &stun.XORMappedAddress{ IP: ip.AsSlice(), Port: port, }, stun.NewShortTermIntegrity(a.localPwd), stun.Fingerprint, ); err != nil { a.log.Warnf("Failed to handle inbound ICE from: %s to: %s error: %s", local, remote, err) } else { if pair := a.findPair(local, remote); pair != nil { pair.UpdateResponseSent() } else { a.log.Warnf("Failed to find pair for add binding response from %s to %s", local, remote) } a.sendSTUN(out, local, remote) } } // Removes pending binding requests that are over maxBindingRequestTimeout old // // Let HTO be the transaction timeout, which SHOULD be 2*RTT if // RTT is known or 500 ms otherwise. // https://tools.ietf.org/html/rfc8445#appendix-B.1 func (a *Agent) invalidatePendingBindingRequests(filterTime time.Time) { initialSize := len(a.pendingBindingRequests) temp := a.pendingBindingRequests[:0] for _, bindingRequest := range a.pendingBindingRequests { if filterTime.Sub(bindingRequest.timestamp) < maxBindingRequestTimeout { temp = append(temp, bindingRequest) } } a.pendingBindingRequests = temp if bindRequestsRemoved := initialSize - len(a.pendingBindingRequests); bindRequestsRemoved > 0 { a.log.Tracef("Discarded %d binding requests because they expired", bindRequestsRemoved) } } // Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination // If the bindingRequest was valid remove it from our pending cache. func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest, time.Duration) { a.invalidatePendingBindingRequests(time.Now()) for i := range a.pendingBindingRequests { if a.pendingBindingRequests[i].transactionID == id { validBindingRequest := a.pendingBindingRequests[i] a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...) return true, &validBindingRequest, time.Since(validBindingRequest.timestamp) } } return false, nil, 0 } func (a *Agent) handleRoleConflict(msg *stun.Message, local, remote Candidate, remoteTieBreaker *AttrControl) { localIsGreaterOrEqual := a.tieBreaker >= remoteTieBreaker.Tiebreaker a.log.Warnf("Role conflict local and remote same role(%s), localIsGreaterOrEqual(%t)", a.role(), localIsGreaterOrEqual) // https://datatracker.ietf.org/doc/html/rfc8445#section-7.3.1.1 // An agent MUST examine the Binding request for either the ICE- // CONTROLLING or ICE-CONTROLLED attribute. It MUST follow these // procedures: // If the agent's tiebreaker value is larger than or equal to the contents of the ICE-CONTROLLING attribute // If the agent's tiebreaker value is less than the contents of the ICE-CONTROLLED attribute // the agent generates a Binding error response if (a.isControlling.Load() && localIsGreaterOrEqual) || (!a.isControlling.Load() && !localIsGreaterOrEqual) { if roleConflictMsg, err := stun.Build(msg, stun.BindingError, stun.ErrorCodeAttribute{ Code: stun.CodeRoleConflict, Reason: []byte("Role Conflict"), }, stun.NewShortTermIntegrity(a.localPwd), stun.Fingerprint, ); err != nil { a.log.Warnf("Failed to generate Role Conflict message from: %s to: %s error: %s", local, remote, err) } else { a.sendSTUN(roleConflictMsg, local, remote) } } else { a.isControlling.Store(!a.isControlling.Load()) a.setSelector() } } // handleInbound processes STUN traffic from a remote candidate. func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit,cyclop if msg == nil || local == nil { return } if msg.Type.Method != stun.MethodBinding || (msg.Type.Class != stun.ClassSuccessResponse && msg.Type.Class != stun.ClassRequest && msg.Type.Class != stun.ClassIndication) { a.log.Tracef("Unhandled STUN from %s to %s class(%s) method(%s)", remote, local, msg.Type.Class, msg.Type.Method) return } remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote) if msg.Type.Class == stun.ClassSuccessResponse { //nolint:nestif if err := stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil { a.log.Warnf("Discard success response with broken integrity from (%s), %v", remote, err) return } if remoteCandidate == nil { a.log.Warnf("Discard success message from (%s), no such remote", remote) return } a.getSelector().HandleSuccessResponse(msg, local, remoteCandidate, remote) } else if msg.Type.Class == stun.ClassRequest { a.log.Tracef( "Inbound STUN (Request) from %s to %s, useCandidate: %v", remote, local, msg.Contains(stun.AttrUseCandidate), ) if err := stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil { a.log.Warnf("Discard request with wrong username from (%s), %v", remote, err) return } else if err := stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil { a.log.Warnf("Discard request with broken integrity from (%s), %v", remote, err) return } if remoteCandidate == nil { ip, port, networkType, err := parseAddr(remote) if err != nil { a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate: %s", err) return } prflxCandidateConfig := CandidatePeerReflexiveConfig{ Network: networkType.String(), Address: ip.String(), Port: port, Component: local.Component(), RelAddr: "", RelPort: 0, } prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig) if err != nil { a.log.Errorf("Failed to create new remote prflx candidate (%s)", err) return } remoteCandidate = prflxCandidate a.log.Debugf("Adding a new peer-reflexive candidate: %s ", remote) a.addRemoteCandidate(remoteCandidate) } // Support Remotes that don't set a TIE-BREAKER. Not standards compliant, but // keeping to maintain backwards compat remoteTieBreaker := &AttrControl{} if err := remoteTieBreaker.GetFrom(msg); err == nil && remoteTieBreaker.Role == a.role() { a.handleRoleConflict(msg, local, remoteCandidate, remoteTieBreaker) return } a.getSelector().HandleBindingRequest(msg, local, remoteCandidate) } if remoteCandidate != nil { remoteCandidate.seen(false) } } // validateNonSTUNTraffic processes non STUN traffic from a remote candidate, // and returns true if it is an actual remote candidate. func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) { var remoteCandidate Candidate if err := a.loop.Run(local.context(), func(context.Context) { remoteCandidate = a.findRemoteCandidate(local.NetworkType(), remote) if remoteCandidate != nil { remoteCandidate.seen(false) } }); err != nil { a.log.Warnf("Failed to validate remote candidate: %v", err) } return remoteCandidate, remoteCandidate != nil } // GetSelectedCandidatePair returns the selected pair or nil if there is none. func (a *Agent) GetSelectedCandidatePair() (*CandidatePair, error) { selectedPair := a.getSelectedPair() if selectedPair == nil { return nil, nil //nolint:nilnil } local, err := selectedPair.Local.copy() if err != nil { return nil, err } remote, err := selectedPair.Remote.copy() if err != nil { return nil, err } return &CandidatePair{Local: local, Remote: remote}, nil } func (a *Agent) getSelectedPair() *CandidatePair { if selectedPair, ok := a.selectedPair.Load().(*CandidatePair); ok { return selectedPair } return nil } func (a *Agent) closeMulticastConn() { if a.mDNSConn != nil { if err := a.mDNSConn.Close(); err != nil { a.log.Warnf("Failed to close mDNS Conn: %v", err) } } } // SetRemoteCredentials sets the credentials of the remote agent. func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error { switch { case remoteUfrag == "": return ErrRemoteUfragEmpty case remotePwd == "": return ErrRemotePwdEmpty } return a.loop.Run(a.loop, func(_ context.Context) { a.remoteUfrag = remoteUfrag a.remotePwd = remotePwd }) } // Restart restarts the ICE Agent with the provided ufrag/pwd // If no ufrag/pwd is provided the Agent will generate one itself // // If there is a gatherer routine currently running, Restart will // cancel it. // After a Restart, the user must then call GatherCandidates explicitly // to start generating new ones. func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop if ufrag == "" { var err error ufrag, err = generateUFrag() if err != nil { return err } } if pwd == "" { var err error pwd, err = generatePwd() if err != nil { return err } } if len([]rune(ufrag))*8 < 24 { return ErrLocalUfragInsufficientBits } if len([]rune(pwd))*8 < 128 { return ErrLocalPwdInsufficientBits } var err error if runErr := a.loop.Run(a.loop, func(_ context.Context) { if a.gatheringState == GatheringStateGathering { a.gatherCandidateCancel() } // Clear all agent needed to take back to fresh state a.removeUfragFromMux() a.localUfrag = ufrag a.localPwd = pwd a.remoteUfrag = "" a.remotePwd = "" a.gatheringState = GatheringStateNew a.checklist = make([]*CandidatePair, 0) a.pendingBindingRequests = make([]bindingRequest, 0) a.setSelectedPair(nil) a.deleteAllCandidates() a.setSelector() // Restart is used by NewAgent. Accept/Connect should be used to move to checking // for new Agents if a.connectionState != ConnectionStateNew { a.updateConnectionState(ConnectionStateChecking) } }); runErr != nil { return runErr } return err } func (a *Agent) setGatheringState(newState GatheringState) error { done := make(chan struct{}) if err := a.loop.Run(a.loop, func(context.Context) { if a.gatheringState != newState && newState == GatheringStateComplete { a.candidateNotifier.EnqueueCandidate(nil) } a.gatheringState = newState close(done) }); err != nil { return err } <-done return nil } func (a *Agent) needsToCheckPriorityOnNominated() bool { return !a.lite || a.enableUseCandidateCheckPriority } func (a *Agent) role() Role { if a.isControlling.Load() { return Controlling } return Controlled } func (a *Agent) setSelector() { a.selectorLock.Lock() defer a.selectorLock.Unlock() var s pairCandidateSelector if a.isControlling.Load() { s = &controllingSelector{agent: a, log: a.log} } else { s = &controlledSelector{agent: a, log: a.log} } if a.lite { s = &liteSelector{pairCandidateSelector: s} } s.Start() a.selector = s } func (a *Agent) getSelector() pairCandidateSelector { a.selectorLock.Lock() defer a.selectorLock.Unlock() return a.selector } // getNominationValue returns a nomination value if generator is available, otherwise 0. func (a *Agent) getNominationValue() uint32 { if a.nominationValueGenerator != nil { return a.nominationValueGenerator() } return 0 } // RenominateCandidate allows the controlling ICE agent to nominate a new candidate pair. // This implements the continuous renomination feature from draft-thatcher-ice-renomination-01. func (a *Agent) RenominateCandidate(local, remote Candidate) error { if !a.isControlling.Load() { return ErrOnlyControllingAgentCanRenominate } if !a.enableRenomination { return ErrRenominationNotEnabled } // Find the candidate pair pair := a.findPair(local, remote) if pair == nil { return ErrCandidatePairNotFound } // Send nomination with custom attribute return a.sendNominationRequest(pair, a.getNominationValue()) } // sendNominationRequest sends a nomination request with custom nomination value. func (a *Agent) sendNominationRequest(pair *CandidatePair, nominationValue uint32) error { attributes := []stun.Setter{ stun.TransactionID, stun.NewUsername(a.remoteUfrag + ":" + a.localUfrag), UseCandidate(), AttrControlling(a.tieBreaker), PriorityAttr(pair.Local.Priority()), stun.NewShortTermIntegrity(a.remotePwd), stun.Fingerprint, } // Add nomination attribute if renomination is enabled and value > 0 if a.enableRenomination && nominationValue > 0 { attributes = append(attributes, NominationSetter{ Value: nominationValue, AttrType: a.nominationAttribute, }) a.log.Tracef("Sending renomination request from %s to %s with nomination value %d", pair.Local, pair.Remote, nominationValue) } msg, err := stun.Build(append([]stun.Setter{stun.BindingRequest}, attributes...)...) if err != nil { return fmt.Errorf("failed to build nomination request: %w", err) } a.sendBindingRequest(msg, pair.Local, pair.Remote) return nil } // evaluateCandidatePairQuality calculates a quality score for a candidate pair. // Higher scores indicate better quality. The score considers: // - Candidate types (host > srflx > relay) // - RTT (lower is better) // - Connection stability. func (a *Agent) evaluateCandidatePairQuality(pair *CandidatePair) float64 { //nolint:cyclop if pair == nil || pair.state != CandidatePairStateSucceeded { return 0 } score := float64(0) // Type preference scoring (host=100, srflx=50, prflx=30, relay=10) localTypeScore := float64(0) switch pair.Local.Type() { case CandidateTypeHost: localTypeScore = 100 case CandidateTypeServerReflexive: localTypeScore = 50 case CandidateTypePeerReflexive: localTypeScore = 30 case CandidateTypeRelay: localTypeScore = 10 case CandidateTypeUnspecified: localTypeScore = 0 } remoteTypeScore := float64(0) switch pair.Remote.Type() { case CandidateTypeHost: remoteTypeScore = 100 case CandidateTypeServerReflexive: remoteTypeScore = 50 case CandidateTypePeerReflexive: remoteTypeScore = 30 case CandidateTypeRelay: remoteTypeScore = 10 case CandidateTypeUnspecified: remoteTypeScore = 0 } // Combined type score (average of local and remote) score += (localTypeScore + remoteTypeScore) / 2 // RTT scoring (convert to penalty, lower RTT = higher score) // Use current RTT if available, otherwise assume high latency rtt := pair.CurrentRoundTripTime() if rtt > 0 { // Convert RTT to Duration for cleaner calculation rttDuration := time.Duration(rtt * float64(time.Second)) rttMs := float64(rttDuration / time.Millisecond) if rttMs < 1 { rttMs = 1 // Minimum 1ms to avoid log(0) } // Subtract RTT penalty (logarithmic to reduce impact of very high RTTs) score -= math.Log10(rttMs) * 10 } else { // No RTT data available, apply moderate penalty score -= 30 } // Boost score if pair has been stable (received responses recently) if pair.ResponsesReceived() > 0 { lastResponse := pair.LastResponseReceivedAt() if !lastResponse.IsZero() && time.Since(lastResponse) < 5*time.Second { score += 20 // Stability bonus } } return score } // shouldRenominate determines if automatic renomination should occur. // It compares the current selected pair with a candidate pair and decides // if switching would provide significant benefit. func (a *Agent) shouldRenominate(current, candidate *CandidatePair) bool { //nolint:cyclop if current == nil || candidate == nil || current.equal(candidate) || candidate.state != CandidatePairStateSucceeded { return false } // Type-based switching (always prefer direct over relay) currentIsRelay := current.Local.Type() == CandidateTypeRelay || current.Remote.Type() == CandidateTypeRelay candidateIsDirect := candidate.Local.Type() == CandidateTypeHost && candidate.Remote.Type() == CandidateTypeHost if currentIsRelay && candidateIsDirect { a.log.Debugf("Should renominate: relay -> direct connection available") return true } // RTT-based switching (must improve by at least 10ms) currentRTT := current.CurrentRoundTripTime() candidateRTT := candidate.CurrentRoundTripTime() // Only compare RTT if both values are valid if currentRTT > 0 && candidateRTT > 0 { currentRTTDuration := time.Duration(currentRTT * float64(time.Second)) candidateRTTDuration := time.Duration(candidateRTT * float64(time.Second)) rttImprovement := currentRTTDuration - candidateRTTDuration if rttImprovement > 10*time.Millisecond { a.log.Debugf("Should renominate: RTT improvement of %v", rttImprovement) return true } } // Quality score comparison (must improve by at least 15%) currentScore := a.evaluateCandidatePairQuality(current) candidateScore := a.evaluateCandidatePairQuality(candidate) if candidateScore > currentScore*1.15 { a.log.Debugf("Should renominate: quality score improved from %.2f to %.2f", currentScore, candidateScore) return true } return false } // findBestCandidatePair finds the best available candidate pair based on quality assessment. func (a *Agent) findBestCandidatePair() *CandidatePair { var best *CandidatePair bestScore := float64(-math.MaxFloat64) for _, pair := range a.checklist { if pair.state != CandidatePairStateSucceeded { continue } score := a.evaluateCandidatePairQuality(pair) if score > bestScore { bestScore = score best = pair } } return best } ice-4.1.0/agent_config.go000066400000000000000000000275071511704470200152420ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "time" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" "golang.org/x/net/proxy" ) const ( // defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase. defaultCheckInterval = 200 * time.Millisecond // keepaliveInterval used to keep candidates alive. defaultKeepaliveInterval = 2 * time.Second // defaultDisconnectedTimeout is the default time till an Agent transitions disconnected. defaultDisconnectedTimeout = 5 * time.Second // defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected. defaultFailedTimeout = 25 * time.Second // defaultHostAcceptanceMinWait is the wait time before nominating a host candidate. defaultHostAcceptanceMinWait = 0 // defaultSrflxAcceptanceMinWait is the wait time before nominating a srflx candidate. defaultSrflxAcceptanceMinWait = 500 * time.Millisecond // defaultPrflxAcceptanceMinWait is the wait time before nominating a prflx candidate. defaultPrflxAcceptanceMinWait = 1000 * time.Millisecond // defaultRelayAcceptanceMinWait is the wait time before nominating a relay candidate. defaultRelayAcceptanceMinWait = 2000 * time.Millisecond // defaultRelayOnlyAcceptanceMinWait is the wait time before nominating with a relay only candidate. defaultRelayOnlyAcceptanceMinWait = time.Duration(0) // defaultSTUNGatherTimeout is the wait time for STUN responses. defaultSTUNGatherTimeout = 5 * time.Second // defaultMaxBindingRequests is the maximum number of binding requests before considering a pair failed. defaultMaxBindingRequests = 7 // TCPPriorityOffset is a number which is subtracted from the default (UDP) candidate type preference // for host, srflx and prfx candidate types. defaultTCPPriorityOffset = 27 // maxBufferSize is the number of bytes that can be buffered before we start to error. maxBufferSize = 1000 * 1000 // 1MB // maxBindingRequestTimeout is the wait time before binding requests can be deleted. maxBindingRequestTimeout = 4000 * time.Millisecond ) func defaultCandidateTypes() []CandidateType { return []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay} } func defaultRelayAcceptanceMinWaitFor(candidateTypes []CandidateType) time.Duration { if len(candidateTypes) == 1 && candidateTypes[0] == CandidateTypeRelay { return defaultRelayOnlyAcceptanceMinWait } return defaultRelayAcceptanceMinWait } // AgentConfig collects the arguments to ice.Agent construction into // a single structure, for future-proofness of the interface. // Deprecated: use NewAgentWithOptions instead. type AgentConfig struct { Urls []*stun.URI // PortMin and PortMax are optional. Leave them 0 for the default UDP port allocation strategy. PortMin uint16 PortMax uint16 // LocalUfrag and LocalPwd values used to perform connectivity // checks. The values MUST be unguessable, with at least 128 bits of // random number generator output used to generate the password, and // at least 24 bits of output to generate the username fragment. LocalUfrag string LocalPwd string // MulticastDNSMode controls mDNS behavior for the ICE agent MulticastDNSMode MulticastDNSMode // MulticastDNSHostName controls the hostname for this agent. If none is specified a random one will be generated MulticastDNSHostName string // DisconnectedTimeout defaults to 5 seconds when this property is nil. // If the duration is 0, the ICE Agent will never go to disconnected DisconnectedTimeout *time.Duration // FailedTimeout defaults to 25 seconds when this property is nil. // If the duration is 0, we will never go to failed. FailedTimeout *time.Duration // KeepaliveInterval determines how often should we send ICE // keepalives (should be less then connectiontimeout above) // when this is nil, it defaults to 2 seconds. // A keepalive interval of 0 means we never send keepalive packets KeepaliveInterval *time.Duration // CheckInterval controls how often our task loop runs when in the // connecting state. CheckInterval *time.Duration // NetworkTypes is an optional configuration for disabling or enabling // support for specific network types. NetworkTypes []NetworkType // CandidateTypes is an optional configuration for disabling or enabling // support for specific candidate types. CandidateTypes []CandidateType LoggerFactory logging.LoggerFactory // MaxBindingRequests is the max amount of binding requests the agent will send // over a candidate pair for validation or nomination, if after MaxBindingRequests // the candidate is yet to answer a binding request or a nomination we set the pair as failed MaxBindingRequests *uint16 // Lite agents do not perform connectivity check and only provide host candidates. Lite bool // NAT1To1IPCandidateType is used along with NAT1To1IPs to specify which candidate type // the 1:1 NAT IP addresses should be mapped to. // If unspecified or CandidateTypeHost, NAT1To1IPs are used to replace host candidate IPs. // If CandidateTypeServerReflexive, it will insert a srflx candidate (as if it was derived // from a STUN server) with its port number being the one for the actual host candidate. // Other values will result in an error. // // Deprecated: use WithAddressRewriteRules with an explicit host or srflx rule instead. // This field will be removed in a future major release. NAT1To1IPCandidateType CandidateType // NAT1To1IPs contains a list of public IP addresses that are to be used as a host // candidate or srflx candidate. This is used typically for servers that are behind // 1:1 D-NAT (e.g. AWS EC2 instances) and to eliminate the need of server reflexive // candidate gathering. // // Deprecated: use WithAddressRewriteRules with an explicit host or srflx rule instead. // This field will be removed in a future major release. NAT1To1IPs []string // HostAcceptanceMinWait specify a minimum wait time before selecting host candidates HostAcceptanceMinWait *time.Duration // SrflxAcceptanceMinWait specify a minimum wait time before selecting srflx candidates SrflxAcceptanceMinWait *time.Duration // PrflxAcceptanceMinWait specify a minimum wait time before selecting prflx candidates PrflxAcceptanceMinWait *time.Duration // RelayAcceptanceMinWait specify a minimum wait time before selecting relay candidates RelayAcceptanceMinWait *time.Duration // STUNGatherTimeout specify a minimum wait time for STUN responses STUNGatherTimeout *time.Duration // Net is the our abstracted network interface for internal development purpose only // (see https://github.com/pion/transport) Net transport.Net // InterfaceFilter is a function that you can use in order to whitelist or blacklist // the interfaces which are used to gather ICE candidates. InterfaceFilter func(string) (keep bool) // IPFilter is a function that you can use in order to whitelist or blacklist // the ips which are used to gather ICE candidates. IPFilter func(net.IP) (keep bool) // InsecureSkipVerify controls if self-signed certificates are accepted when connecting // to TURN servers via TLS or DTLS InsecureSkipVerify bool // TCPMux will be used for multiplexing incoming TCP connections for ICE TCP. // Currently only passive candidates are supported. This functionality is // experimental and the API might change in the future. TCPMux TCPMux // UDPMux is used for multiplexing multiple incoming UDP connections on a single port // when this is set, the agent ignores PortMin and PortMax configurations and will // defer to UDPMux for incoming connections UDPMux UDPMux // UDPMuxSrflx is used for multiplexing multiple incoming UDP connections of server reflexive candidates // on a single port when this is set, the agent ignores PortMin and PortMax configurations and will // defer to UDPMuxSrflx for incoming connections // It embeds UDPMux to do the actual connection multiplexing UDPMuxSrflx UniversalUDPMux // Proxy Dialer is a dialer that should be implemented by the user based on golang.org/x/net/proxy // dial interface in order to support corporate proxies ProxyDialer proxy.Dialer // Deprecated: AcceptAggressiveNomination always enabled. AcceptAggressiveNomination bool // Include loopback addresses in the candidate list. IncludeLoopback bool // TCPPriorityOffset is a number which is subtracted from the default (UDP) candidate type preference // for host, srflx and prfx candidate types. It helps to configure relative preference of UDP candidates // against TCP ones. Relay candidates for TCP and UDP are always 0 and not affected by this setting. // When this is nil, defaultTCPPriorityOffset is used. TCPPriorityOffset *uint16 // DisableActiveTCP can be used to disable Active TCP candidates. Otherwise when TCP is enabled // Active TCP candidates will be created when a new passive TCP remote candidate is added. DisableActiveTCP bool // BindingRequestHandler allows applications to perform logic on incoming STUN Binding Requests // This was implemented to allow users to // * Log incoming Binding Requests for debugging // * Implement draft-thatcher-ice-renomination // * Implement custom CandidatePair switching logic BindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool // EnableUseCandidateCheckPriority can be used to enable checking for equal or higher priority to // switch selected candidate pair if the peer requests USE-CANDIDATE and agent is a lite agent. // This is disabled by default, i. e. when peer requests USE-CANDIDATE, the selected pair will be // switched to that irrespective of relative priority between current selected pair // and priority of the pair being switched to. EnableUseCandidateCheckPriority bool } // initWithDefaults populates an agent and falls back to defaults if fields are unset. func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop if config.MaxBindingRequests == nil { agent.maxBindingRequests = defaultMaxBindingRequests } else { agent.maxBindingRequests = *config.MaxBindingRequests } if config.HostAcceptanceMinWait == nil { agent.hostAcceptanceMinWait = defaultHostAcceptanceMinWait } else { agent.hostAcceptanceMinWait = *config.HostAcceptanceMinWait } if config.SrflxAcceptanceMinWait == nil { agent.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait } else { agent.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait } if config.PrflxAcceptanceMinWait == nil { agent.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait } else { agent.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait } if config.RelayAcceptanceMinWait == nil { agent.relayAcceptanceMinWait = defaultRelayAcceptanceMinWaitFor(config.CandidateTypes) } else { agent.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait } if config.STUNGatherTimeout == nil { agent.stunGatherTimeout = defaultSTUNGatherTimeout } else { agent.stunGatherTimeout = *config.STUNGatherTimeout } if config.TCPPriorityOffset == nil { agent.tcpPriorityOffset = defaultTCPPriorityOffset } else { agent.tcpPriorityOffset = *config.TCPPriorityOffset } if config.DisconnectedTimeout == nil { agent.disconnectedTimeout = defaultDisconnectedTimeout } else { agent.disconnectedTimeout = *config.DisconnectedTimeout } if config.FailedTimeout == nil { agent.failedTimeout = defaultFailedTimeout } else { agent.failedTimeout = *config.FailedTimeout } if config.KeepaliveInterval == nil { agent.keepaliveInterval = defaultKeepaliveInterval } else { agent.keepaliveInterval = *config.KeepaliveInterval } if config.CheckInterval == nil { agent.checkInterval = defaultCheckInterval } else { agent.checkInterval = *config.CheckInterval } if len(config.CandidateTypes) == 0 { agent.candidateTypes = defaultCandidateTypes() } else { agent.candidateTypes = config.CandidateTypes } } ice-4.1.0/agent_config_test.go000066400000000000000000000050121511704470200162640ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "time" "github.com/stretchr/testify/assert" ) func TestAgentConfig_initWithDefaults(t *testing.T) { relayAcceptanceMinWait := 5 * time.Second tests := []struct { name string config *AgentConfig fn func(*testing.T, *Agent) }{ { "default config", &AgentConfig{}, func(t *testing.T, result *Agent) { t.Helper() assert.Equal(t, result.relayAcceptanceMinWait, defaultRelayAcceptanceMinWait) }, }, { "multiple relay candidate types", &AgentConfig{CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive}}, func(t *testing.T, result *Agent) { t.Helper() assert.Equal(t, result.relayAcceptanceMinWait, defaultRelayAcceptanceMinWait) }, }, { "host only candidate type", &AgentConfig{CandidateTypes: []CandidateType{CandidateTypeHost}}, func(t *testing.T, result *Agent) { t.Helper() assert.Equal(t, result.relayAcceptanceMinWait, defaultRelayAcceptanceMinWait) }, }, { "relay only candidate type", &AgentConfig{CandidateTypes: []CandidateType{CandidateTypeRelay}}, func(t *testing.T, result *Agent) { t.Helper() assert.Equal(t, result.relayAcceptanceMinWait, defaultRelayOnlyAcceptanceMinWait) }, }, { "relay only with relayAcceptanceMinWait set", &AgentConfig{CandidateTypes: []CandidateType{CandidateTypeRelay}, RelayAcceptanceMinWait: &relayAcceptanceMinWait}, func(t *testing.T, result *Agent) { t.Helper() assert.Equal(t, result.relayAcceptanceMinWait, relayAcceptanceMinWait) }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { agent, err := NewAgent(test.config) if !assert.NoError(t, err) { return } defer func() { _ = agent.Close() }() test.fn(t, agent) }) } } func TestDefaultRelayAcceptanceMinWaitForCandidates(t *testing.T) { tests := []struct { name string candidateType []CandidateType expectedWait time.Duration }{ { name: "relay only", candidateType: []CandidateType{CandidateTypeRelay}, expectedWait: defaultRelayOnlyAcceptanceMinWait, }, { name: "mixed types", candidateType: []CandidateType{CandidateTypeHost, CandidateTypeRelay}, expectedWait: defaultRelayAcceptanceMinWait, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { assert.Equal(t, tc.expectedWait, defaultRelayAcceptanceMinWaitFor(tc.candidateType)) }) } } ice-4.1.0/agent_get_best_available_candidate_pair_test.go000066400000000000000000000007211511704470200236240ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "testing" "github.com/stretchr/testify/require" ) func TestNoBestAvailableCandidatePairAfterAgentConstruction(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.Nil(t, agent.getBestAvailableCandidatePair()) } ice-4.1.0/agent_get_best_valid_candidate_pair_test.go000066400000000000000000000031041511704470200230010ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "testing" "github.com/stretchr/testify/require" ) func TestAgentGetBestValidCandidatePair(t *testing.T) { f := setupTestAgentGetBestValidCandidatePair(t) defer func() { require.NoError(t, f.sut.Close()) }() remoteCandidatesFromLowestPriorityToHighest := []Candidate{f.relayRemote, f.srflxRemote, f.prflxRemote, f.hostRemote} for _, remoteCandidate := range remoteCandidatesFromLowestPriorityToHighest { candidatePair := f.sut.addPair(f.hostLocal, remoteCandidate) candidatePair.state = CandidatePairStateSucceeded actualBestPair := f.sut.getBestValidCandidatePair() expectedBestPair := &CandidatePair{Remote: remoteCandidate, Local: f.hostLocal, state: CandidatePairStateSucceeded} require.Equal(t, actualBestPair.String(), expectedBestPair.String()) } } func setupTestAgentGetBestValidCandidatePair(t *testing.T) *TestAgentGetBestValidCandidatePairFixture { t.Helper() fixture := new(TestAgentGetBestValidCandidatePairFixture) fixture.hostLocal = newHostLocal(t) fixture.relayRemote = newRelayRemote(t) fixture.srflxRemote = newSrflxRemote(t) fixture.prflxRemote = newPrflxRemote(t) fixture.hostRemote = newHostRemote(t) agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) fixture.sut = agent return fixture } type TestAgentGetBestValidCandidatePairFixture struct { sut *Agent hostLocal Candidate relayRemote Candidate srflxRemote Candidate prflxRemote Candidate hostRemote Candidate } ice-4.1.0/agent_handlers.go000066400000000000000000000072741511704470200155740ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import "sync" // OnConnectionStateChange sets a handler that is fired when the connection state changes. func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error { a.onConnectionStateChangeHdlr.Store(f) return nil } // OnSelectedCandidatePairChange sets a handler that is fired when the final candidate. // pair is selected. func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) error { a.onSelectedCandidatePairChangeHdlr.Store(f) return nil } // OnCandidate sets a handler that is fired when new candidates gathered. When // the gathering process complete the last candidate is nil. func (a *Agent) OnCandidate(f func(Candidate)) error { a.onCandidateHdlr.Store(f) return nil } func (a *Agent) onSelectedCandidatePairChange(p *CandidatePair) { if h, ok := a.onSelectedCandidatePairChangeHdlr.Load().(func(Candidate, Candidate)); ok && h != nil { h(p.Local, p.Remote) } } func (a *Agent) onCandidate(c Candidate) { if onCandidateHdlr, ok := a.onCandidateHdlr.Load().(func(Candidate)); ok && onCandidateHdlr != nil { onCandidateHdlr(c) } } func (a *Agent) onConnectionStateChange(s ConnectionState) { if hdlr, ok := a.onConnectionStateChangeHdlr.Load().(func(ConnectionState)); ok && hdlr != nil { hdlr(s) } } type handlerNotifier struct { sync.Mutex running bool notifiers sync.WaitGroup connectionStates []ConnectionState connectionStateFunc func(ConnectionState) candidates []Candidate candidateFunc func(Candidate) selectedCandidatePairs []*CandidatePair candidatePairFunc func(*CandidatePair) // State for closing done chan struct{} } func (h *handlerNotifier) Close(graceful bool) { if graceful { // if we were closed ungracefully before, we now // want ot wait. defer h.notifiers.Wait() } h.Lock() select { case <-h.done: h.Unlock() return default: } close(h.done) h.Unlock() } func (h *handlerNotifier) EnqueueConnectionState(state ConnectionState) { h.Lock() defer h.Unlock() select { case <-h.done: return default: } notify := func() { defer h.notifiers.Done() for { h.Lock() if len(h.connectionStates) == 0 { h.running = false h.Unlock() return } notification := h.connectionStates[0] h.connectionStates = h.connectionStates[1:] h.Unlock() h.connectionStateFunc(notification) } } h.connectionStates = append(h.connectionStates, state) if !h.running { h.running = true h.notifiers.Add(1) go notify() } } func (h *handlerNotifier) EnqueueCandidate(cand Candidate) { h.Lock() defer h.Unlock() select { case <-h.done: return default: } notify := func() { defer h.notifiers.Done() for { h.Lock() if len(h.candidates) == 0 { h.running = false h.Unlock() return } notification := h.candidates[0] h.candidates = h.candidates[1:] h.Unlock() h.candidateFunc(notification) } } h.candidates = append(h.candidates, cand) if !h.running { h.running = true h.notifiers.Add(1) go notify() } } func (h *handlerNotifier) EnqueueSelectedCandidatePair(pair *CandidatePair) { h.Lock() defer h.Unlock() select { case <-h.done: return default: } notify := func() { defer h.notifiers.Done() for { h.Lock() if len(h.selectedCandidatePairs) == 0 { h.running = false h.Unlock() return } notification := h.selectedCandidatePairs[0] h.selectedCandidatePairs = h.selectedCandidatePairs[1:] h.Unlock() h.candidatePairFunc(notification) } } h.selectedCandidatePairs = append(h.selectedCandidatePairs, pair) if !h.running { h.running = true h.notifiers.Add(1) go notify() } } ice-4.1.0/agent_handlers_test.go000066400000000000000000000111261511704470200166220ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "time" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConnectionStateNotifier(t *testing.T) { t.Run("TestManyUpdates", func(t *testing.T) { defer test.CheckRoutines(t)() updates := make(chan struct{}, 1) notifier := &handlerNotifier{ connectionStateFunc: func(_ ConnectionState) { updates <- struct{}{} }, done: make(chan struct{}), } // Enqueue all updates upfront to ensure that it // doesn't block for i := 0; i < 10000; i++ { notifier.EnqueueConnectionState(ConnectionStateNew) } done := make(chan struct{}) go func() { for i := 0; i < 10000; i++ { <-updates } select { case <-updates: t.Errorf("received more updates than expected") // nolint case <-time.After(1 * time.Second): } close(done) }() <-done notifier.Close(true) }) t.Run("TestUpdateOrdering", func(t *testing.T) { defer test.CheckRoutines(t)() updates := make(chan ConnectionState) notifer := &handlerNotifier{ connectionStateFunc: func(cs ConnectionState) { updates <- cs }, done: make(chan struct{}), } done := make(chan struct{}) go func() { for i := 0; i < 10000; i++ { assert.Equal(t, ConnectionState(i), <-updates) } select { case <-updates: t.Errorf("received more updates than expected") // nolint case <-time.After(1 * time.Second): } close(done) }() for i := 0; i < 10000; i++ { notifer.EnqueueConnectionState(ConnectionState(i)) } <-done notifer.Close(true) }) } func TestHandlerNotifier_Close_AlreadyClosed(t *testing.T) { defer test.CheckRoutines(t)() notifier := &handlerNotifier{ connectionStateFunc: func(ConnectionState) {}, candidateFunc: func(Candidate) {}, candidatePairFunc: func(*CandidatePair) {}, done: make(chan struct{}), } // first close notifier.Close(false) isClosed := func(ch <-chan struct{}) bool { select { case <-ch: return true default: return false } } assert.True(t, isClosed(notifier.done), "expected h.done to be closed after first Close") // second close should hit `case <-h.done` and return immediately // without blocking on the WaitGroup. finished := make(chan struct{}, 1) go func() { notifier.Close(true) close(finished) }() assert.Eventually(t, func() bool { select { case <-finished: return true default: return false } }, 250*time.Millisecond, 10*time.Millisecond, "second Close(true) did not return promptly") // ensure still closed afterwards assert.True(t, isClosed(notifier.done), "expected h.done to remain closed after second Close") // sanity: no enqueues should start after close. require.False(t, notifier.running) require.Zero(t, len(notifier.connectionStates)) require.Zero(t, len(notifier.candidates)) require.Zero(t, len(notifier.selectedCandidatePairs)) } func TestHandlerNotifier_EnqueueConnectionState_AfterClose(t *testing.T) { defer test.CheckRoutines(t)() connCh := make(chan struct{}, 1) notifier := &handlerNotifier{ connectionStateFunc: func(ConnectionState) { connCh <- struct{}{} }, done: make(chan struct{}), } notifier.Close(false) notifier.EnqueueConnectionState(ConnectionStateConnected) assert.Never(t, func() bool { select { case <-connCh: return true default: return false } }, 250*time.Millisecond, 10*time.Millisecond, "connectionStateFunc should not be called after close") } func TestHandlerNotifier_EnqueueCandidate_AfterClose(t *testing.T) { defer test.CheckRoutines(t)() candidateCh := make(chan struct{}, 1) h := &handlerNotifier{ candidateFunc: func(Candidate) { candidateCh <- struct{}{} }, done: make(chan struct{}), } h.Close(false) h.EnqueueCandidate(nil) assert.Never(t, func() bool { select { case <-candidateCh: return true default: return false } }, 250*time.Millisecond, 10*time.Millisecond, "candidateFunc should not be called after close") } func TestHandlerNotifier_EnqueueSelectedCandidatePair_AfterClose(t *testing.T) { defer test.CheckRoutines(t)() pairCh := make(chan struct{}, 1) h := &handlerNotifier{ candidatePairFunc: func(*CandidatePair) { pairCh <- struct{}{} }, done: make(chan struct{}), } h.Close(false) h.EnqueueSelectedCandidatePair(nil) assert.Never(t, func() bool { select { case <-pairCh: return true default: return false } }, 250*time.Millisecond, 10*time.Millisecond, "candidatePairFunc should not be called after close") } ice-4.1.0/agent_on_selected_candidate_pair_change_test.go000066400000000000000000000022411511704470200236200ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "context" "testing" "github.com/stretchr/testify/require" ) func TestOnSelectedCandidatePairChange(t *testing.T) { agent, candidatePair := fixtureTestOnSelectedCandidatePairChange(t) defer func() { require.NoError(t, agent.Close()) }() callbackCalled := make(chan struct{}, 1) err := agent.OnSelectedCandidatePairChange(func(_, _ Candidate) { close(callbackCalled) }) require.NoError(t, err) err = agent.loop.Run(context.Background(), func(_ context.Context) { agent.setSelectedPair(candidatePair) }) require.NoError(t, err) <-callbackCalled } func fixtureTestOnSelectedCandidatePairChange(t *testing.T) (*Agent, *CandidatePair) { t.Helper() agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) candidatePair := makeCandidatePair(t) return agent, candidatePair } func makeCandidatePair(t *testing.T) *CandidatePair { t.Helper() hostLocal := newHostLocal(t) relayRemote := newRelayRemote(t) candidatePair := newCandidatePair(hostLocal, relayRemote, false) return candidatePair } ice-4.1.0/agent_options.go000066400000000000000000000537201511704470200154640ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" "net" "sort" "strings" "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" "golang.org/x/net/proxy" ) // AgentOption represents a function that can be used to configure an Agent. type AgentOption func(*Agent) error // NominationValueGenerator is a function that generates nomination values for renomination. type NominationValueGenerator func() uint32 // DefaultNominationValueGenerator returns a generator that starts at 1 and increments for each call. // This provides a simple, monotonically increasing sequence suitable for renomination. func DefaultNominationValueGenerator() NominationValueGenerator { var counter atomic.Uint32 return func() uint32 { return counter.Add(1) } } // WithAddressRewriteRules appends the provided address rewrite (1:1) rules to the agent's // existing configuration. Each `AddressRewriteRule` can limit the mapping to a specific // interface (`Iface`), local address (`Local`), CIDR block (`CIDR`), or subset // of network types (`Networks`), allowing fine-grained control over which local // addresses are replaced with the supplied external IPs. // Use `Mode` to control whether a rule replaces the original candidate (default for // host) or appends additional candidates (default for other types). // // Rules are evaluated in the order they are added; for each candidate type + // local address, explicit `Local` matches win immediately. Otherwise, the most // specific catch-all is chosen (iface+CIDR > iface-only > CIDR-only > global), // with declaration order breaking ties at the same specificity. `Iface` (when // set) must also match. This lets you layer specificity (e.g., iface+CIDR, then // iface-only, then global) while still keeping rule order meaningful. // Overlapping rules in the same scope are logged as warnings. func WithAddressRewriteRules(rules ...AddressRewriteRule) AgentOption { return func(agent *Agent) error { return appendAddressRewriteRules(agent, rules...) } } func warnOnAddressRewriteConflicts(agent *Agent) { if agent == nil || agent.log == nil { return } for _, conflict := range findAddressRewriteRuleConflicts(agent.addressRewriteRules) { scope := conflict.scope scopeSummary := fmt.Sprintf( "candidate=%s iface=%s cidr=%s networks=%s local=%s", scope.candidateType.String(), emptyScopeValue(scope.iface), emptyScopeValue(scope.cidr), emptyScopeValue(scope.networksKey), scope.localKey, ) message := fmt.Sprintf( "detected overlapping address rewrite rule (%s): existing external IPs [%s], additional external IP %s", scopeSummary, strings.Join(conflict.existingExternalIPs, ", "), conflict.conflictingExternal, ) agent.log.Warn(message) } } func emptyScopeValue(v string) string { if v == "" { return "*" } return v } func appendAddressRewriteRules(agent *Agent, rules ...AddressRewriteRule) error { if len(rules) == 0 { return nil } sanitized := make([]AddressRewriteRule, 0, len(rules)) for _, rule := range rules { normalized, err := sanitizeAddressRewriteRule(rule) if err != nil { return err } sanitized = append(sanitized, normalized) } agent.addressRewriteRules = append(agent.addressRewriteRules, sanitized...) warnOnAddressRewriteConflicts(agent) return nil } func sanitizeAddressRewriteRule(rule AddressRewriteRule) (AddressRewriteRule, error) { cleaned, err := sanitizeExternalIPs(rule.External) if err != nil { return AddressRewriteRule{}, err } normalized := rule normalized.External = cleaned normalized.Local = strings.TrimSpace(rule.Local) if normalized.Local != "" { if _, _, err := validateIPString(normalized.Local); err != nil { return AddressRewriteRule{}, err } } switch normalized.Mode { case addressRewriteModeUnspecified: normalized.Mode = defaultAddressRewriteMode(normalized.AsCandidateType) case AddressRewriteReplace, AddressRewriteAppend: default: return AddressRewriteRule{}, ErrInvalidNAT1To1IPMapping } if len(rule.Networks) > 0 { normalized.Networks = append([]NetworkType(nil), rule.Networks...) } return normalized, nil } func defaultAddressRewriteMode(candidateType CandidateType) AddressRewriteMode { if candidateType == CandidateTypeUnspecified || candidateType == CandidateTypeHost { return AddressRewriteReplace } return AddressRewriteAppend } func sanitizeExternalIPs(ips []string) ([]string, error) { seen := make(map[string]struct{}, len(ips)) sanitized := make([]string, 0, len(ips)) for _, raw := range ips { trimmed := strings.TrimSpace(raw) if trimmed == "" { continue } if _, ok := seen[trimmed]; ok { continue } if strings.Contains(trimmed, "/") { return nil, ErrInvalidNAT1To1IPMapping } if _, _, err := validateIPString(trimmed); err != nil { return nil, err } seen[trimmed] = struct{}{} sanitized = append(sanitized, trimmed) } if len(sanitized) == 0 { return nil, ErrInvalidNAT1To1IPMapping } return sanitized, nil } type addressRewriteScopeKey struct { candidateType CandidateType iface string cidr string networksKey string localKey string } type addressRewriteConflict struct { scope addressRewriteScopeKey existingExternalIPs []string conflictingExternal string } func findAddressRewriteRuleConflicts(rules []AddressRewriteRule) []addressRewriteConflict { conflicts := make([]addressRewriteConflict, 0) scopeState := make(map[addressRewriteScopeKey]map[string]struct{}) for _, rule := range rules { candidateType := rule.AsCandidateType if candidateType == CandidateTypeUnspecified { candidateType = CandidateTypeHost } networksKey := "*" if len(rule.Networks) > 0 { names := make([]string, len(rule.Networks)) for i, network := range rule.Networks { names[i] = network.String() } sort.Strings(names) networksKey = strings.Join(names, ",") } externalEntries := enumerateAddressRewriteExternalEntries(rule) for _, entry := range externalEntries { key := addressRewriteScopeKey{ candidateType: candidateType, iface: rule.Iface, cidr: rule.CIDR, networksKey: networksKey, localKey: entry.localScopeKey, } existing := scopeState[key] if existing == nil { existing = make(map[string]struct{}) scopeState[key] = existing } if len(existing) > 0 { if _, ok := existing[entry.externalIP]; !ok { conflicts = append(conflicts, addressRewriteConflict{ scope: key, existingExternalIPs: mapKeys(existing), conflictingExternal: entry.externalIP, }) } } existing[entry.externalIP] = struct{}{} } } return conflicts } type addressRewriteExternalEntry struct { externalIP string localScopeKey string } func enumerateAddressRewriteExternalEntries(rule AddressRewriteRule) []addressRewriteExternalEntry { if len(rule.External) == 0 { return nil } entries := make([]addressRewriteExternalEntry, 0, len(rule.External)) localScope := deriveAddressRewriteLocalScopeKey(rule.Local) for _, mapping := range rule.External { if mapping == "" { continue } external := strings.TrimSpace(mapping) if external == "" { continue } scopeKey := localScope if scopeKey == "" { scopeKey = deriveAddressRewriteFamilyScopeKey(external) } entries = append(entries, addressRewriteExternalEntry{ externalIP: external, localScopeKey: scopeKey, }) } return entries } func deriveAddressRewriteLocalScopeKey(local string) string { local = strings.TrimSpace(local) if local == "" { return "" } ip, _, err := validateIPString(local) if err != nil { return "family:unknown" } if ip.To4() != nil { return "family:ipv4" } return "family:ipv6" } func deriveAddressRewriteFamilyScopeKey(ipStr string) string { ip, _, err := validateIPString(ipStr) if err != nil { return "family:unknown" } if ip.To4() != nil { return "family:ipv4" } return "family:ipv6" } func mapKeys(m map[string]struct{}) []string { if len(m) == 0 { return nil } keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return keys } // WithICELite configures whether the agent operates in lite mode. // Lite agents do not perform connectivity checks and only provide host candidates. func WithICELite(lite bool) AgentOption { return func(a *Agent) error { a.lite = lite return nil } } // WithUrls sets the STUN/TURN server URLs used by the agent. func WithUrls(urls []*stun.URI) AgentOption { return func(a *Agent) error { if len(urls) == 0 { a.urls = nil return nil } cloned := make([]*stun.URI, len(urls)) copy(cloned, urls) a.urls = cloned return nil } } // WithPortRange sets the UDP port range for host candidates. func WithPortRange(portMin, portMax uint16) AgentOption { return func(a *Agent) error { a.portMin = portMin a.portMax = portMax return nil } } // WithDisconnectedTimeout sets the duration before the agent transitions to disconnected state. // A timeout of 0 disables the transition. func WithDisconnectedTimeout(timeout time.Duration) AgentOption { return func(a *Agent) error { a.disconnectedTimeout = timeout return nil } } // WithFailedTimeout sets the duration before the agent transitions to failed state after disconnected. // A timeout of 0 disables the transition. func WithFailedTimeout(timeout time.Duration) AgentOption { return func(a *Agent) error { a.failedTimeout = timeout return nil } } // WithKeepaliveInterval sets how often ICE keepalive packets are sent. // An interval of 0 disables keepalives. func WithKeepaliveInterval(interval time.Duration) AgentOption { return func(a *Agent) error { a.keepaliveInterval = interval return nil } } // WithHostAcceptanceMinWait sets the minimum wait before selecting host candidates. func WithHostAcceptanceMinWait(wait time.Duration) AgentOption { return func(a *Agent) error { a.hostAcceptanceMinWait = wait return nil } } // WithSrflxAcceptanceMinWait sets the minimum wait before selecting srflx candidates. func WithSrflxAcceptanceMinWait(wait time.Duration) AgentOption { return func(a *Agent) error { a.srflxAcceptanceMinWait = wait return nil } } // WithPrflxAcceptanceMinWait sets the minimum wait before selecting prflx candidates. func WithPrflxAcceptanceMinWait(wait time.Duration) AgentOption { return func(a *Agent) error { a.prflxAcceptanceMinWait = wait return nil } } // WithRelayAcceptanceMinWait sets the minimum wait before selecting relay candidates. func WithRelayAcceptanceMinWait(wait time.Duration) AgentOption { return func(a *Agent) error { a.relayAcceptanceMinWait = wait return nil } } // WithSTUNGatherTimeout sets the STUN gather timeout. func WithSTUNGatherTimeout(timeout time.Duration) AgentOption { return func(a *Agent) error { a.stunGatherTimeout = timeout return nil } } // WithIPFilter sets a filter for IP addresses used during candidate gathering. func WithIPFilter(filter func(net.IP) bool) AgentOption { return func(a *Agent) error { a.ipFilter = filter return nil } } // WithNet sets the underlying network implementation for the agent. func WithNet(net transport.Net) AgentOption { return func(a *Agent) error { a.net = net return nil } } // WithMulticastDNSMode configures mDNS behavior for the agent. func WithMulticastDNSMode(mode MulticastDNSMode) AgentOption { return func(a *Agent) error { a.mDNSMode = mode return nil } } // WithMulticastDNSHostName sets the mDNS host name used by the agent. func WithMulticastDNSHostName(hostName string) AgentOption { return func(a *Agent) error { if !strings.HasSuffix(hostName, ".local") || len(strings.Split(hostName, ".")) != 2 { return ErrInvalidMulticastDNSHostName } a.mDNSName = hostName return nil } } // WithLocalCredentials sets the local ICE username fragment and password used during Restart. // If empty strings are provided, the agent will generate values during Restart. func WithLocalCredentials(ufrag, pwd string) AgentOption { return func(a *Agent) error { if ufrag != "" && len([]rune(ufrag))*8 < 24 { return ErrLocalUfragInsufficientBits } if pwd != "" && len([]rune(pwd))*8 < 128 { return ErrLocalPwdInsufficientBits } a.localUfrag = ufrag a.localPwd = pwd return nil } } // WithTCPMux sets the TCP mux for ICE TCP multiplexing. func WithTCPMux(tcpMux TCPMux) AgentOption { return func(a *Agent) error { a.tcpMux = tcpMux return nil } } // WithUDPMux sets the UDP mux used for multiplexing host candidates. func WithUDPMux(udpMux UDPMux) AgentOption { return func(a *Agent) error { a.udpMux = udpMux return nil } } // WithUDPMuxSrflx sets the UDP mux for server reflexive candidates. func WithUDPMuxSrflx(udpMuxSrflx UniversalUDPMux) AgentOption { return func(a *Agent) error { a.udpMuxSrflx = udpMuxSrflx return nil } } // WithProxyDialer sets the proxy dialer used for TURN over TCP/TLS/DTLS connections. func WithProxyDialer(dialer proxy.Dialer) AgentOption { return func(a *Agent) error { a.proxyDialer = dialer return nil } } // WithMaxBindingRequests sets the maximum number of binding requests before considering a pair failed. func WithMaxBindingRequests(limit uint16) AgentOption { return func(a *Agent) error { a.maxBindingRequests = limit return nil } } // WithCheckInterval sets how often the agent runs connectivity checks while connecting. func WithCheckInterval(interval time.Duration) AgentOption { return func(a *Agent) error { a.checkInterval = interval return nil } } // WithRenomination enables ICE renomination as described in draft-thatcher-ice-renomination-01. // When enabled, the controlling agent can renominate candidate pairs multiple times // and the controlled agent follows "last nomination wins" rule. // // The generator parameter specifies how nomination values are generated. // Use DefaultNominationValueGenerator() for a simple incrementing counter, // or provide a custom generator for more complex scenarios. // // Example: // // agent, err := NewAgentWithOptions(config, WithRenomination(DefaultNominationValueGenerator())) func WithRenomination(generator NominationValueGenerator) AgentOption { return func(a *Agent) error { if generator == nil { return ErrInvalidNominationValueGenerator } a.enableRenomination = true a.nominationValueGenerator = generator return nil } } // WithNominationAttribute sets the STUN attribute type to use for ICE renomination. // The default value is 0x0030. This can be configured until the attribute is officially // assigned by IANA for draft-thatcher-ice-renomination. // // This option returns an error if the provided attribute type is invalid. // Currently, validation ensures the attribute is not 0x0000 (reserved). // Additional validation may be added in the future. func WithNominationAttribute(attrType uint16) AgentOption { return func(a *Agent) error { // Basic validation: ensure it's not the reserved 0x0000 if attrType == 0x0000 { return ErrInvalidNominationAttribute } a.nominationAttribute = stun.AttrType(attrType) return nil } } // WithIncludeLoopback includes loopback addresses in the candidate list. // By default, loopback addresses are excluded. // // Example: // // agent, err := NewAgentWithOptions(WithIncludeLoopback()) func WithIncludeLoopback() AgentOption { return func(a *Agent) error { a.includeLoopback = true return nil } } // WithTCPPriorityOffset sets a number which is subtracted from the default (UDP) candidate type preference // for host, srflx and prfx candidate types. It helps to configure relative preference of UDP candidates // against TCP ones. Relay candidates for TCP and UDP are always 0 and not affected by this setting. // When not set, defaultTCPPriorityOffset (27) is used. // // Example: // // agent, err := NewAgentWithOptions(WithTCPPriorityOffset(50)) func WithTCPPriorityOffset(offset uint16) AgentOption { return func(a *Agent) error { a.tcpPriorityOffset = offset return nil } } // WithDisableActiveTCP disables Active TCP candidates. // When TCP is enabled, Active TCP candidates will be created when a new passive TCP remote candidate is added // unless this option is used. // // Example: // // agent, err := NewAgentWithOptions(WithDisableActiveTCP()) func WithDisableActiveTCP() AgentOption { return func(a *Agent) error { a.disableActiveTCP = true return nil } } // WithBindingRequestHandler sets a handler to allow applications to perform logic on incoming STUN Binding Requests. // This was implemented to allow users to: // - Log incoming Binding Requests for debugging // - Implement draft-thatcher-ice-renomination // - Implement custom CandidatePair switching logic // // Example: // // handler := func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool { // log.Printf("Binding request from %s to %s", remote.Address(), local.Address()) // return true // Accept the request // } // agent, err := NewAgentWithOptions(WithBindingRequestHandler(handler)) func WithBindingRequestHandler( handler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool, ) AgentOption { return func(a *Agent) error { a.userBindingRequestHandler = handler return nil } } // WithEnableUseCandidateCheckPriority enables checking for equal or higher priority when // switching selected candidate pair if the peer requests USE-CANDIDATE and agent is a lite agent. // This is disabled by default, i.e. when peer requests USE-CANDIDATE, the selected pair will be // switched to that irrespective of relative priority between current selected pair // and priority of the pair being switched to. // // Example: // // agent, err := NewAgentWithOptions(WithEnableUseCandidateCheckPriority()) func WithEnableUseCandidateCheckPriority() AgentOption { return func(a *Agent) error { a.enableUseCandidateCheckPriority = true return nil } } // WithContinualGatheringPolicy sets the continual gathering policy for the agent. // When set to GatherContinually, the agent will continuously monitor network interfaces // and gather new candidates as they become available. // When set to GatherOnce (default), gathering completes after the initial phase. // // Example: // // agent, err := NewAgentWithOptions(WithContinualGatheringPolicy(GatherContinually)) func WithContinualGatheringPolicy(policy ContinualGatheringPolicy) AgentOption { return func(a *Agent) error { a.continualGatheringPolicy = policy return nil } } // WithNetworkMonitorInterval sets the interval at which the agent checks for network interface changes // when using GatherContinually policy. This option only has effect when used with // WithContinualGatheringPolicy(GatherContinually). // Default is 2 seconds if not specified. // // Example: // // agent, err := NewAgentWithOptions( // WithContinualGatheringPolicy(GatherContinually), // WithNetworkMonitorInterval(5 * time.Second), // ) func WithNetworkMonitorInterval(interval time.Duration) AgentOption { return func(a *Agent) error { if interval <= 0 { return ErrInvalidNetworkMonitorInterval } a.networkMonitorInterval = interval return nil } } // WithNetworkTypes sets the enabled network types for candidate gathering. // By default, all network types are enabled. // // Example: // // agent, err := NewAgentWithOptions( // WithNetworkTypes([]NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}), // ) func WithNetworkTypes(networkTypes []NetworkType) AgentOption { return func(a *Agent) error { a.networkTypes = networkTypes return nil } } // WithCandidateTypes sets the enabled candidate types for gathering. // By default, host, server reflexive, and relay candidates are enabled. // // Example: // // agent, err := NewAgentWithOptions( // WithCandidateTypes([]CandidateType{CandidateTypeHost, CandidateTypeServerReflexive}), // ) func WithCandidateTypes(candidateTypes []CandidateType) AgentOption { return func(a *Agent) error { a.candidateTypes = candidateTypes return nil } } // WithAutomaticRenomination enables automatic renomination of candidate pairs // when better pairs become available after initial connection establishment. // This feature requires renomination to be enabled and both agents to support it. // // When enabled, the controlling agent will periodically evaluate candidate pairs // and renominate if a significantly better pair is found (e.g., switching from // relay to direct connection, or when RTT improves significantly). // // The interval parameter specifies the minimum time to wait after connection // before considering automatic renomination. If set to 0, it defaults to 3 seconds. // // Example: // // agent, err := NewAgentWithOptions( // WithRenomination(DefaultNominationValueGenerator()), // WithAutomaticRenomination(3*time.Second), // ) func WithAutomaticRenomination(interval time.Duration) AgentOption { return func(a *Agent) error { a.automaticRenomination = true if interval > 0 { a.renominationInterval = interval } // Note: renomination must be enabled separately via WithRenomination return nil } } // WithInterfaceFilter sets a filter function to whitelist or blacklist network interfaces // for ICE candidate gathering. // // The filter function receives the interface name and should return true to keep the interface, // or false to exclude it. // // Example: // // // Only use interfaces starting with "eth" // agent, err := NewAgentWithOptions( // WithInterfaceFilter(func(interfaceName string) bool { // return len(interfaceName) >= 3 && interfaceName[:3] == "eth" // }), // ) func WithInterfaceFilter(filter func(string) bool) AgentOption { return func(a *Agent) error { a.interfaceFilter = filter return nil } } // WithLoggerFactory sets the logger factory for the agent. // // Example: // // import "github.com/pion/logging" // // loggerFactory := logging.NewDefaultLoggerFactory() // loggerFactory.DefaultLogLevel = logging.LogLevelDebug // agent, err := NewAgentWithOptions(WithLoggerFactory(loggerFactory)) func WithLoggerFactory(loggerFactory logging.LoggerFactory) AgentOption { return func(a *Agent) error { a.log = loggerFactory.NewLogger("ice") return nil } } ice-4.1.0/agent_options_test.go000066400000000000000000001326761511704470200165330ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT package ice import ( "errors" "fmt" "net" "net/netip" "strings" "testing" "time" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/proxy" ) // testBooleanOption is a helper function to test boolean agent options. type booleanOptionTest struct { optionFunc func() AgentOption getValue func(*Agent) bool configSetter func(*AgentConfig, bool) } func testBooleanOption(t *testing.T, test booleanOptionTest, optionName string) { t.Helper() t.Run("enables "+optionName, func(t *testing.T) { agent, err := NewAgentWithOptions(test.optionFunc()) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.True(t, test.getValue(agent)) }) t.Run("default is false", func(t *testing.T) { agent, err := NewAgentWithOptions() assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.False(t, test.getValue(agent)) }) t.Run("works with config", func(t *testing.T) { config := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, } test.configSetter(config, true) agent, err := NewAgent(config) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.True(t, test.getValue(agent)) }) } func TestDefaultNominationValueGenerator(t *testing.T) { t.Run("generates incrementing values", func(t *testing.T) { generator := DefaultNominationValueGenerator() // Should generate incrementing values starting from 1 assert.Equal(t, uint32(1), generator()) assert.Equal(t, uint32(2), generator()) assert.Equal(t, uint32(3), generator()) }) t.Run("each generator has independent counter", func(t *testing.T) { gen1 := DefaultNominationValueGenerator() gen2 := DefaultNominationValueGenerator() assert.Equal(t, uint32(1), gen1()) assert.Equal(t, uint32(1), gen2()) // Should also start at 1 assert.Equal(t, uint32(2), gen1()) assert.Equal(t, uint32(2), gen2()) }) } func TestWithLite(t *testing.T) { t.Run("enables lite with host candidates", func(t *testing.T) { agent, err := NewAgentWithOptions( WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithICELite(true), ) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.True(t, agent.lite) }) t.Run("default is not lite", func(t *testing.T) { agent, err := NewAgentWithOptions() require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.False(t, agent.lite) }) t.Run("config sets lite", func(t *testing.T) { config := &AgentConfig{ Lite: true, CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: []NetworkType{NetworkTypeUDP4}, } agent, err := NewAgent(config) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.True(t, agent.lite) }) t.Run("errors when candidate types include non-host", func(t *testing.T) { _, err := NewAgentWithOptions(WithICELite(true)) assert.ErrorIs(t, err, ErrLiteUsingNonHostCandidates) }) } func TestWithUrls(t *testing.T) { stunURL, err := stun.ParseURI("stun:example.com:3478") require.NoError(t, err) input := []*stun.URI{stunURL} agent, err := NewAgentWithOptions(WithUrls(input)) require.NoError(t, err) defer agent.Close() //nolint:errcheck require.Len(t, agent.urls, 1) assert.Equal(t, stunURL.String(), agent.urls[0].String()) input[0] = nil require.Len(t, agent.urls, 1) assert.NotNil(t, agent.urls[0]) } func TestWithPortRange(t *testing.T) { agent, err := NewAgentWithOptions(WithPortRange(1000, 2000)) require.NoError(t, err) assert.Equal(t, uint16(1000), agent.portMin) assert.Equal(t, uint16(2000), agent.portMax) agent.Close() //nolint:gosec,errcheck agent, err = NewAgentWithOptions(WithPortRange(2000, 0)) assert.NoError(t, err) defer agent.Close() //nolint:gosec,errcheck assert.Equal(t, uint16(2000), agent.portMin) assert.Equal(t, uint16(0), agent.portMax) } func TestWithTimeoutOptions(t *testing.T) { agent, err := NewAgentWithOptions( WithDisconnectedTimeout(10*time.Second), WithFailedTimeout(20*time.Second), WithKeepaliveInterval(3*time.Second), WithCheckInterval(150*time.Millisecond), ) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, 10*time.Second, agent.disconnectedTimeout) assert.Equal(t, 20*time.Second, agent.failedTimeout) assert.Equal(t, 3*time.Second, agent.keepaliveInterval) assert.Equal(t, 150*time.Millisecond, agent.checkInterval) } func TestWithAcceptanceWaitOptions(t *testing.T) { agent, err := NewAgentWithOptions( WithHostAcceptanceMinWait(1*time.Second), WithSrflxAcceptanceMinWait(2*time.Second), WithPrflxAcceptanceMinWait(3*time.Second), WithRelayAcceptanceMinWait(4*time.Second), ) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, 1*time.Second, agent.hostAcceptanceMinWait) assert.Equal(t, 2*time.Second, agent.srflxAcceptanceMinWait) assert.Equal(t, 3*time.Second, agent.prflxAcceptanceMinWait) assert.Equal(t, 4*time.Second, agent.relayAcceptanceMinWait) } func TestWithSTUNGatherTimeout(t *testing.T) { agent, err := NewAgentWithOptions(WithSTUNGatherTimeout(7 * time.Second)) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, 7*time.Second, agent.stunGatherTimeout) } func TestWithIPFilterOption(t *testing.T) { filter := func(ip net.IP) bool { return ip.IsLoopback() } agent, err := NewAgentWithOptions(WithIPFilter(filter)) require.NoError(t, err) defer agent.Close() //nolint:errcheck require.NotNil(t, agent.ipFilter) assert.True(t, agent.ipFilter(net.IPv4(127, 0, 0, 1))) assert.False(t, agent.ipFilter(net.IPv4(192, 0, 2, 1))) } func TestWithNetOption(t *testing.T) { stub := newStubNet(t) agent, err := NewAgentWithOptions(WithNet(stub)) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() assert.Equal(t, stub, agent.net) } func TestWithMulticastDNSOptions(t *testing.T) { agent, err := NewAgentWithOptions( WithMulticastDNSMode(MulticastDNSModeDisabled), WithMulticastDNSHostName("pion-test.local"), ) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, MulticastDNSModeDisabled, agent.mDNSMode) assert.Equal(t, "pion-test.local", agent.mDNSName) _, err = NewAgentWithOptions(WithMulticastDNSHostName("invalid-host")) assert.ErrorIs(t, err, ErrInvalidMulticastDNSHostName) } func TestWithLocalCredentials(t *testing.T) { password := strings.Repeat("p", 16) agent, err := NewAgentWithOptions(WithLocalCredentials("abcd", password)) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, "abcd", agent.localUfrag) assert.Equal(t, password, agent.localPwd) _, err = NewAgentWithOptions(WithLocalCredentials("ab", password)) assert.ErrorIs(t, err, ErrLocalUfragInsufficientBits) shortPassword := strings.Repeat("p", 10) _, err = NewAgentWithOptions(WithLocalCredentials("abcd", shortPassword)) assert.ErrorIs(t, err, ErrLocalPwdInsufficientBits) } func TestWithMuxOptions(t *testing.T) { tcpMux := &stubTCPMux{} udpMux := &stubUDPMux{} udpMuxSrflx := &stubUniversalUDPMux{} agent, err := NewAgentWithOptions( WithTCPMux(tcpMux), WithUDPMux(udpMux), WithUDPMuxSrflx(udpMuxSrflx), ) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, tcpMux, agent.tcpMux) assert.Equal(t, udpMux, agent.udpMux) assert.Equal(t, udpMuxSrflx, agent.udpMuxSrflx) } func TestWithProxyDialer(t *testing.T) { agent, err := NewAgentWithOptions(WithProxyDialer(proxy.Direct)) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, proxy.Direct, agent.proxyDialer) } func TestWithMaxBindingRequests(t *testing.T) { agent, err := NewAgentWithOptions(WithMaxBindingRequests(3)) require.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, uint16(3), agent.maxBindingRequests) } func TestWithRenomination(t *testing.T) { t.Run("enables renomination with custom generator", func(t *testing.T) { counter := uint32(0) customGen := func() uint32 { counter++ return counter * 10 } agent, err := NewAgentWithOptions(WithRenomination(customGen)) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.True(t, agent.enableRenomination) assert.NotNil(t, agent.nominationValueGenerator) assert.Equal(t, uint32(10), agent.getNominationValue()) assert.Equal(t, uint32(20), agent.getNominationValue()) }) t.Run("enables renomination with default generator", func(t *testing.T) { agent, err := NewAgentWithOptions(WithRenomination(DefaultNominationValueGenerator())) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.True(t, agent.enableRenomination) assert.NotNil(t, agent.nominationValueGenerator) assert.Equal(t, uint32(1), agent.getNominationValue()) assert.Equal(t, uint32(2), agent.getNominationValue()) }) t.Run("rejects nil generator", func(t *testing.T) { _, err := NewAgentWithOptions(WithRenomination(nil)) assert.ErrorIs(t, err, ErrInvalidNominationValueGenerator) }) t.Run("default agent has renomination disabled", func(t *testing.T) { config := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, } agent, err := NewAgent(config) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.False(t, agent.enableRenomination) assert.Nil(t, agent.nominationValueGenerator) assert.Equal(t, uint32(0), agent.getNominationValue()) }) } func TestWithNominationAttribute(t *testing.T) { t.Run("sets custom nomination attribute", func(t *testing.T) { agent, err := NewAgentWithOptions(WithNominationAttribute(0x0045)) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, stun.AttrType(0x0045), agent.nominationAttribute) }) t.Run("rejects invalid attribute 0x0000", func(t *testing.T) { _, err := NewAgentWithOptions(WithNominationAttribute(0x0000)) assert.ErrorIs(t, err, ErrInvalidNominationAttribute) }) t.Run("default value when no option", func(t *testing.T) { config := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, } agent, err := NewAgent(config) assert.NoError(t, err) defer agent.Close() //nolint:errcheck // Should use default value 0x0030 assert.Equal(t, stun.AttrType(0x0030), agent.nominationAttribute) }) } func TestWithIncludeLoopback(t *testing.T) { testBooleanOption(t, booleanOptionTest{ optionFunc: WithIncludeLoopback, getValue: func(a *Agent) bool { return a.includeLoopback }, configSetter: func(c *AgentConfig, v bool) { c.IncludeLoopback = v }, }, "loopback addresses") } func TestWithTCPPriorityOffset(t *testing.T) { t.Run("sets custom TCP priority offset", func(t *testing.T) { customOffset := uint16(50) agent, err := NewAgentWithOptions(WithTCPPriorityOffset(customOffset)) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, customOffset, agent.tcpPriorityOffset) }) t.Run("default is 27", func(t *testing.T) { agent, err := NewAgentWithOptions() assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, uint16(27), agent.tcpPriorityOffset) }) t.Run("works with config", func(t *testing.T) { customOffset := uint16(100) config := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, TCPPriorityOffset: &customOffset, } agent, err := NewAgent(config) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, customOffset, agent.tcpPriorityOffset) }) } func TestWithDisableActiveTCP(t *testing.T) { testBooleanOption(t, booleanOptionTest{ optionFunc: WithDisableActiveTCP, getValue: func(a *Agent) bool { return a.disableActiveTCP }, configSetter: func(c *AgentConfig, v bool) { c.DisableActiveTCP = v }, }, "active TCP disabling") } func TestWithBindingRequestHandler(t *testing.T) { t.Run("sets binding request handler", func(t *testing.T) { handlerCalled := false handler := func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { handlerCalled = true return true } agent, err := NewAgentWithOptions(WithBindingRequestHandler(handler)) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.NotNil(t, agent.userBindingRequestHandler) // Test that the handler is actually the one we set // We can't directly compare functions, but we can call it if agent.userBindingRequestHandler != nil { agent.userBindingRequestHandler(nil, nil, nil, nil) assert.True(t, handlerCalled) } }) t.Run("default is nil", func(t *testing.T) { agent, err := NewAgentWithOptions() assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Nil(t, agent.userBindingRequestHandler) }) t.Run("works with config", func(t *testing.T) { handlerCalled := false handler := func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { handlerCalled = true return true } config := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, BindingRequestHandler: handler, } agent, err := NewAgent(config) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.NotNil(t, agent.userBindingRequestHandler) if agent.userBindingRequestHandler != nil { agent.userBindingRequestHandler(nil, nil, nil, nil) assert.True(t, handlerCalled) } }) } func TestWithEnableUseCandidateCheckPriority(t *testing.T) { testBooleanOption(t, booleanOptionTest{ optionFunc: WithEnableUseCandidateCheckPriority, getValue: func(a *Agent) bool { return a.enableUseCandidateCheckPriority }, configSetter: func(c *AgentConfig, v bool) { c.EnableUseCandidateCheckPriority = v }, }, "use candidate check priority") } func TestMultipleConfigOptions(t *testing.T) { t.Run("can apply multiple options", func(t *testing.T) { customOffset := uint16(100) handlerCalled := false handler := func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { handlerCalled = true return true } agent, err := NewAgentWithOptions( WithIncludeLoopback(), WithTCPPriorityOffset(customOffset), WithDisableActiveTCP(), WithBindingRequestHandler(handler), WithEnableUseCandidateCheckPriority(), ) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.True(t, agent.includeLoopback) assert.Equal(t, customOffset, agent.tcpPriorityOffset) assert.True(t, agent.disableActiveTCP) assert.NotNil(t, agent.userBindingRequestHandler) assert.True(t, agent.enableUseCandidateCheckPriority) if agent.userBindingRequestHandler != nil { agent.userBindingRequestHandler(nil, nil, nil, nil) assert.True(t, handlerCalled) } }) } func TestWithInterfaceFilter(t *testing.T) { t.Run("sets interface filter", func(t *testing.T) { filter := func(interfaceName string) bool { return interfaceName == "eth0" } agent, err := NewAgentWithOptions(WithInterfaceFilter(filter)) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.NotNil(t, agent.interfaceFilter) assert.True(t, agent.interfaceFilter("eth0")) assert.False(t, agent.interfaceFilter("wlan0")) }) t.Run("default is nil", func(t *testing.T) { agent, err := NewAgentWithOptions() assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Nil(t, agent.interfaceFilter) }) t.Run("works with config", func(t *testing.T) { filter := func(interfaceName string) bool { return interfaceName == "lo" } config := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, InterfaceFilter: filter, } agent, err := NewAgent(config) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.NotNil(t, agent.interfaceFilter) assert.True(t, agent.interfaceFilter("lo")) assert.False(t, agent.interfaceFilter("eth0")) }) } func TestWithLoggerFactory(t *testing.T) { t.Run("sets logger factory", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory.DefaultLogLevel = logging.LogLevelDebug agent, err := NewAgentWithOptions(WithLoggerFactory(loggerFactory)) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.NotNil(t, agent.log) }) t.Run("default uses default logger", func(t *testing.T) { agent, err := NewAgentWithOptions() assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.NotNil(t, agent.log) }) t.Run("works with config", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() config := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, LoggerFactory: loggerFactory, } agent, err := NewAgent(config) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.NotNil(t, agent.log) }) } func TestWithNetworkTypesAppliedBeforeRestart(t *testing.T) { t.Run("ipv6 listen skipped when network types option restricts to ipv4", func(t *testing.T) { stub := newStubNet(t) agent, err := newAgentFromConfig(&AgentConfig{ Net: stub, }, WithNetworkTypes([]NetworkType{NetworkTypeUDP4})) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() assert.Zero(t, stub.udp6ListenCount, "unexpected ipv6 listen before restart") }) } func TestWithCandidateTypesAffectsURLValidation(t *testing.T) { stunURL, err := stun.ParseURI("stun:example.com:3478") require.NoError(t, err) t.Run("default candidate types accept urls", func(t *testing.T) { stub := newStubNet(t) agent, err := newAgentFromConfig(&AgentConfig{ Urls: []*stun.URI{stunURL}, Net: stub, }) require.NoError(t, err) require.NoError(t, agent.Close()) }) t.Run("host only candidate types reject urls", func(t *testing.T) { stub := newStubNet(t) _, err := newAgentFromConfig(&AgentConfig{ Urls: []*stun.URI{stunURL}, Net: stub, }, WithCandidateTypes([]CandidateType{CandidateTypeHost})) require.ErrorIs(t, err, ErrUselessUrlsProvided) }) } func TestWithCandidateTypesNAT1To1Validation(t *testing.T) { t.Run("host mapping requires host candidates", func(t *testing.T) { stub := newStubNet(t) _, err := newAgentFromConfig(&AgentConfig{ NAT1To1IPs: []string{"1.2.3.4"}, NAT1To1IPCandidateType: CandidateTypeHost, Net: stub, }, WithCandidateTypes([]CandidateType{CandidateTypeRelay})) require.ErrorIs(t, err, ErrIneffectiveNAT1To1IPMappingHost) }) t.Run("srflx mapping requires srflx candidates", func(t *testing.T) { stub := newStubNet(t) _, err := newAgentFromConfig(&AgentConfig{ NAT1To1IPs: []string{"1.2.3.4"}, NAT1To1IPCandidateType: CandidateTypeServerReflexive, Net: stub, }, WithCandidateTypes([]CandidateType{CandidateTypeHost})) require.ErrorIs(t, err, ErrIneffectiveNAT1To1IPMappingSrflx) }) } func TestWith1To1CandidateIPOptions(t *testing.T) { testCases := []struct { name string rules []AddressRewriteRule candidateType CandidateType expectedFirstIP string expectedSecondIP string lookupLocalIP string }{ { name: "host candidates", rules: []AddressRewriteRule{ { External: []string{"1.2.3.4"}, AsCandidateType: CandidateTypeHost, }, { External: []string{"5.6.7.8"}, AsCandidateType: CandidateTypeHost, }, }, candidateType: CandidateTypeHost, expectedFirstIP: "1.2.3.4", expectedSecondIP: "5.6.7.8", lookupLocalIP: "10.0.0.1", }, { name: "srflx candidates", rules: []AddressRewriteRule{ { External: []string{"5.6.7.8"}, AsCandidateType: CandidateTypeServerReflexive, }, { External: []string{"9.9.9.9"}, AsCandidateType: CandidateTypeServerReflexive, }, }, candidateType: CandidateTypeServerReflexive, expectedFirstIP: "5.6.7.8", expectedSecondIP: "9.9.9.9", lookupLocalIP: "0.0.0.0", }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { assertAddressRewriteOption( t, tc.rules, tc.candidateType, tc.expectedFirstIP, tc.expectedSecondIP, tc.lookupLocalIP, ) }) } } func assertAddressRewriteOption( t *testing.T, rules []AddressRewriteRule, candidateType CandidateType, expectedFirstIP string, expectedSecondIP string, lookupLocalIP string, ) { t.Helper() stub := newStubNet(t) agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules(rules...), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, len(rules)) firstRule := agent.addressRewriteRules[0] require.Equal(t, candidateType, firstRule.AsCandidateType) require.Equal(t, []string{expectedFirstIP}, firstRule.External) secondRule := agent.addressRewriteRules[1] require.Equal(t, candidateType, secondRule.AsCandidateType) require.Equal(t, []string{expectedSecondIP}, secondRule.External) require.NotNil(t, agent.addressRewriteMapper) extIP := requireFirstExternalIP(t, agent.addressRewriteMapper, candidateType, lookupLocalIP) require.Equal(t, expectedFirstIP, extIP.String()) } func requireFirstExternalIP( t *testing.T, mapper *addressRewriteMapper, candidateType CandidateType, localIP string, ) net.IP { t.Helper() ips, matched, _, err := mapper.findExternalIPs(candidateType, localIP, "") require.NoError(t, err) require.True(t, matched) require.NotEmpty(t, ips) return ips[0] } func requireFirstMappingIP(t *testing.T, mapping *ipMapping, localIP net.IP) net.IP { t.Helper() ips := mapping.findExternalIPs(localIP) require.NotEmpty(t, ips) return ips[0] } func TestWith1To1RulesOption(t *testing.T) { stub := newStubNet(t) originalRules := []AddressRewriteRule{ { External: []string{"9.9.9.9"}, AsCandidateType: CandidateTypeHost, }, } // With append semantics the option stacks, so call twice and ensure accumulation. agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules(originalRules...), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"4.4.4.4"}, AsCandidateType: CandidateTypeServerReflexive, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 2) require.Equal(t, []string{"9.9.9.9"}, agent.addressRewriteRules[0].External) require.Equal(t, []string{"4.4.4.4"}, agent.addressRewriteRules[1].External) // mutate the original rules after option applied, agent copy should remain unchanged. originalRules[0].External[0] = "0.0.0.0" require.Equal(t, "9.9.9.9", agent.addressRewriteRules[0].External[0]) } func TestWith1To1RulesEmptyNoop(t *testing.T) { stub := newStubNet(t) agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"1.2.3.4"}, AsCandidateType: CandidateTypeHost, }), WithAddressRewriteRules(), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 1) require.Equal(t, []string{"1.2.3.4"}, agent.addressRewriteRules[0].External) require.Equal(t, CandidateTypeHost, agent.addressRewriteRules[0].AsCandidateType) require.NotNil(t, agent.addressRewriteMapper) } func TestWithAddressRewriteRulesWarnOnConflicts(t *testing.T) { stub := newStubNet(t) logger := &recordingLogger{} factory := &recordingLoggerFactory{logger: logger} agent, err := NewAgentWithOptions( WithNet(stub), WithLoggerFactory(factory), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"203.0.113.10"}, AsCandidateType: CandidateTypeHost, }), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"198.51.100.50"}, AsCandidateType: CandidateTypeHost, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, logger.warnings, 1) require.Contains(t, logger.warnings[0], "overlapping address rewrite rule") require.Contains(t, logger.warnings[0], "candidate=host") require.Contains(t, logger.warnings[0], "iface=*") require.Contains(t, logger.warnings[0], "networks=*") require.Contains(t, logger.warnings[0], "local=family:ipv4") require.Contains(t, logger.warnings[0], "203.0.113.10") require.Contains(t, logger.warnings[0], "198.51.100.50") ips, matched, mode, err := agent.addressRewriteMapper.findExternalIPs(CandidateTypeHost, "10.0.0.1", "") require.NoError(t, err) require.True(t, matched) require.Equal(t, AddressRewriteReplace, mode) require.NotEmpty(t, ips) require.Equal(t, "203.0.113.10", ips[0].String()) } func TestWithAddressRewriteRulesConflictingModesWarningAndPrecedence(t *testing.T) { stub := newStubNet(t) logger := &recordingLogger{} factory := &recordingLoggerFactory{logger: logger} agent, err := NewAgentWithOptions( WithNet(stub), WithLoggerFactory(factory), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"203.0.113.10"}, AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, AddressRewriteRule{ External: []string{"198.51.100.50"}, AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, ), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NotEmpty(t, logger.warnings) require.Contains(t, logger.warnings[0], "overlapping address rewrite rule") ips, matched, mode, err := agent.addressRewriteMapper.findExternalIPs(CandidateTypeHost, "10.0.0.1", "") require.NoError(t, err) require.True(t, matched) require.NotEmpty(t, ips) require.Equal(t, "203.0.113.10", ips[0].String()) require.Equal(t, AddressRewriteReplace, mode) } func TestWithAddressRewriteRulesNoFalsePositiveConflicts(t *testing.T) { stub := newStubNet(t) logger := &recordingLogger{} factory := &recordingLoggerFactory{logger: logger} agent, err := NewAgentWithOptions( WithNet(stub), WithLoggerFactory(factory), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"203.0.113.10"}, AsCandidateType: CandidateTypeHost, Networks: []NetworkType{NetworkTypeUDP4}, }, AddressRewriteRule{ External: []string{"2001:db8::10"}, AsCandidateType: CandidateTypeHost, Networks: []NetworkType{NetworkTypeUDP6}, }, ), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"198.51.100.10"}, AsCandidateType: CandidateTypeServerReflexive, }, ), ) assert.NoError(t, err) if agent != nil { t.Cleanup(func() { assert.NoError(t, agent.Close()) }) } assert.Empty(t, logger.warnings) } func TestLegacyAndNewAddressRewriteOrdering(t *testing.T) { stub := newStubNet(t) agent, err := newAgentFromConfig( &AgentConfig{ Net: stub, NAT1To1IPs: []string{"203.0.113.10"}, }, WithAddressRewriteRules( AddressRewriteRule{ External: []string{"198.51.100.5"}, AsCandidateType: CandidateTypeHost, }, ), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 2) extIP := requireFirstExternalIP(t, agent.addressRewriteMapper, CandidateTypeHost, "10.0.0.1") require.Equal(t, "203.0.113.10", extIP.String()) } func TestLegacyNAT1To1TranslationOrder(t *testing.T) { stub := newStubNet(t) agent, err := NewAgent(&AgentConfig{ Net: stub, NAT1To1IPs: []string{ "203.0.113.1/10.0.0.1", "203.0.113.2", }, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 2) firstRule := agent.addressRewriteRules[0] require.Equal(t, "203.0.113.1", firstRule.External[0]) require.Equal(t, "10.0.0.1", firstRule.Local) secondRule := agent.addressRewriteRules[1] require.Equal(t, "203.0.113.2", secondRule.External[0]) require.Empty(t, secondRule.Local) extIP := requireFirstExternalIP(t, agent.addressRewriteMapper, CandidateTypeHost, "10.0.0.2") require.Equal(t, "203.0.113.2", extIP.String()) } func TestLegacyAddressRewriteParityWithRules(t *testing.T) { t.Run("host candidate parity", func(t *testing.T) { legacyStub := newStubNet(t) legacyAgent, err := NewAgent(&AgentConfig{ Net: legacyStub, NAT1To1IPs: []string{ "203.0.113.10", "198.51.100.20/10.0.0.20", }, }) assert.NoError(t, err) if legacyAgent != nil { t.Cleanup(func() { assert.NoError(t, legacyAgent.Close()) }) } modernStub := newStubNet(t) modernAgent, err := NewAgentWithOptions( WithNet(modernStub), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"203.0.113.10"}, AsCandidateType: CandidateTypeHost, }, ), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"198.51.100.20"}, Local: "10.0.0.20", AsCandidateType: CandidateTypeHost, }, ), ) assert.NoError(t, err) if modernAgent != nil { t.Cleanup(func() { assert.NoError(t, modernAgent.Close()) }) } for _, loc := range []string{"10.0.0.20", "10.0.0.21"} { legacyIPs, legacyMatched, _, legacyErr := legacyAgent.addressRewriteMapper.findExternalIPs( CandidateTypeHost, loc, "", ) assert.NoError(t, legacyErr) assert.True(t, legacyMatched) assert.NotEmpty(t, legacyIPs) modernIPs, modernMatched, _, modernErr := modernAgent.addressRewriteMapper.findExternalIPs( CandidateTypeHost, loc, "", ) assert.NoError(t, modernErr) assert.True(t, modernMatched) assert.NotEmpty(t, modernIPs) assert.Equal(t, legacyIPs[0].String(), modernIPs[0].String()) } }) t.Run("srflx candidate parity", func(t *testing.T) { legacyStub := newStubNet(t) legacyAgent, err := NewAgent(&AgentConfig{ Net: legacyStub, NAT1To1IPs: []string{"198.51.100.77"}, NAT1To1IPCandidateType: CandidateTypeServerReflexive, }) assert.NoError(t, err) if legacyAgent != nil { t.Cleanup(func() { assert.NoError(t, legacyAgent.Close()) }) } modernStub := newStubNet(t) modernAgent, err := NewAgentWithOptions( WithNet(modernStub), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"198.51.100.77"}, AsCandidateType: CandidateTypeServerReflexive, }, ), ) assert.NoError(t, err) if modernAgent != nil { t.Cleanup(func() { assert.NoError(t, modernAgent.Close()) }) } legacyIPs, legacyMatched, _, legacyErr := legacyAgent.addressRewriteMapper.findExternalIPs( CandidateTypeServerReflexive, "0.0.0.0", "", ) assert.NoError(t, legacyErr) assert.True(t, legacyMatched) assert.NotEmpty(t, legacyIPs) modernIPs, modernMatched, _, modernErr := modernAgent.addressRewriteMapper.findExternalIPs( CandidateTypeServerReflexive, "0.0.0.0", "", ) assert.NoError(t, modernErr) assert.True(t, modernMatched) assert.NotEmpty(t, modernIPs) assert.Equal(t, legacyIPs[0].String(), modernIPs[0].String()) }) } func TestOverlapWarningPerCandidateType(t *testing.T) { stub := newStubNet(t) logger := &recordingLogger{} factory := &recordingLoggerFactory{logger: logger} agent, err := NewAgentWithOptions( WithNet(stub), WithLoggerFactory(factory), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"203.0.113.10"}, AsCandidateType: CandidateTypeHost, }, ), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"198.51.100.10"}, AsCandidateType: CandidateTypeServerReflexive, }, AddressRewriteRule{ External: []string{"198.51.100.20"}, AsCandidateType: CandidateTypeServerReflexive, }, ), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, logger.warnings, 1) require.Contains(t, logger.warnings[0], "candidate=srflx") } func TestWithNAT1To1IPValidation(t *testing.T) { t.Run("dedupe and trim host IPs", func(t *testing.T) { stub := newStubNet(t) agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules(AddressRewriteRule{ External: []string{" 203.0.113.1 ", "203.0.113.1", "203.0.113.2 "}, AsCandidateType: CandidateTypeHost, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 1) require.Equal(t, []string{"203.0.113.1", "203.0.113.2"}, agent.addressRewriteRules[0].External) }) t.Run("reject hostname entry", func(t *testing.T) { stub := newStubNet(t) agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"example.com"}, AsCandidateType: CandidateTypeHost, }), ) require.Nil(t, agent) require.ErrorIs(t, err, ErrInvalidNAT1To1IPMapping) }) t.Run("reject slash mapping in address rewrite rules", func(t *testing.T) { stub := newStubNet(t) agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"203.0.113.1/10.0.0.1"}, AsCandidateType: CandidateTypeHost, }), ) require.Nil(t, agent) require.ErrorIs(t, err, ErrInvalidNAT1To1IPMapping) }) t.Run("reject invalid rule entry", func(t *testing.T) { stub := newStubNet(t) agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"1.2.3.4", "bad-ip"}, AsCandidateType: CandidateTypeHost, }), ) require.Nil(t, agent) require.ErrorIs(t, err, ErrInvalidNAT1To1IPMapping) }) } func TestWithAddressRewriteRulesIPv6(t *testing.T) { stub := newStubNet(t) agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"2001:db8::2"}, Local: "2001:db8:1::2", AsCandidateType: CandidateTypeHost, Networks: []NetworkType{NetworkTypeUDP6}, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 1) require.Equal(t, []NetworkType{NetworkTypeUDP6}, agent.addressRewriteRules[0].Networks) require.NotNil(t, agent.addressRewriteMapper) mappings := agent.addressRewriteMapper.rulesByCandidateType[CandidateTypeHost] require.Len(t, mappings, 1) require.True(t, mappings[0].ipv6Mapping.valid) _, ok := mappings[0].ipv6Mapping.ipMap["2001:db8:1::2"] require.True(t, ok) for key := range mappings[0].ipv6Mapping.ipMap { t.Logf("stored ipv6 mapping key: %q", key) } locIP := net.ParseIP("2001:db8:1::2") require.NotNil(t, locIP) t.Logf("parsed ipv6 string: %q", locIP.String()) directExt := requireFirstMappingIP(t, &mappings[0].ipv6Mapping, locIP) require.Equal(t, "2001:db8::2", directExt.String()) mapper, err := newAddressRewriteMapper(agent.addressRewriteRules) require.NoError(t, err) extIP := requireFirstExternalIP(t, mapper, CandidateTypeHost, "2001:db8:1::2") require.Equal(t, "2001:db8::2", extIP.String()) _, matched, _, err := mapper.findExternalIPs(CandidateTypeHost, "2001:db8:1::3", "") require.NoError(t, err) require.False(t, matched) } func TestAddressRewriteRulesRejectWithMDNSQueryAndGather(t *testing.T) { agent := &Agent{ candidateTypes: []CandidateType{CandidateTypeHost}, mDNSMode: MulticastDNSModeQueryAndGather, addressRewriteRules: []AddressRewriteRule{ { External: []string{"203.0.113.200"}, AsCandidateType: CandidateTypeHost, }, }, log: logging.NewDefaultLoggerFactory().NewLogger("test"), } err := applyAddressRewriteMapping(agent) assert.ErrorIs(t, err, ErrMulticastDNSWithNAT1To1IPMapping) } func TestAgentAddressRewriteModeIntegration(t *testing.T) { stub := newStubNet(t) t.Run("defaults host replace srflx append", func(t *testing.T) { agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"203.0.113.10"}, AsCandidateType: CandidateTypeHost, }, AddressRewriteRule{ External: []string{"198.51.100.10"}, Local: "0.0.0.0", AsCandidateType: CandidateTypeServerReflexive, }, ), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 2) assert.Equal(t, AddressRewriteReplace, agent.addressRewriteRules[0].Mode) assert.Equal(t, AddressRewriteAppend, agent.addressRewriteRules[1].Mode) }) t.Run("host append honored", func(t *testing.T) { agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"203.0.113.20"}, Local: "10.0.0.5", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, ), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 1) assert.Equal(t, AddressRewriteAppend, agent.addressRewriteRules[0].Mode) }) t.Run("srflx replace honored", func(t *testing.T) { agent, err := NewAgentWithOptions( WithNet(stub), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"198.51.100.50"}, Local: "0.0.0.0", AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteReplace, }, ), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.Len(t, agent.addressRewriteRules, 1) assert.Equal(t, AddressRewriteReplace, agent.addressRewriteRules[0].Mode) }) } func TestAddressRewriteModeOverrides(t *testing.T) { t.Run("host append preserves local candidate", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.99"}, Local: "10.0.0.99", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, }) assert.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logging.NewDefaultLoggerFactory().NewLogger("test"), } local := netip.MustParseAddr("10.0.0.99") mapped, ok := agent.applyHostAddressRewrite(local, []netip.Addr{local}, "") assert.True(t, ok) assert.Equal(t, []netip.Addr{ local, netip.MustParseAddr("203.0.113.99"), }, mapped) }) t.Run("srflx replace overrides default append", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"198.51.100.99"}, Local: "0.0.0.0", AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteReplace, }, }) assert.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logging.NewDefaultLoggerFactory().NewLogger("test"), } assert.True(t, agent.addressRewriteMapper.shouldReplace(CandidateTypeServerReflexive)) ips, matched, mode, err := agent.addressRewriteMapper.findExternalIPs( CandidateTypeServerReflexive, "0.0.0.0", "", ) assert.NoError(t, err) assert.True(t, matched) assert.Equal(t, AddressRewriteReplace, mode) assert.Equal(t, "198.51.100.99", ips[0].String()) }) } func TestAddressRewriteMixedFamilyApplication(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("test") mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.123"}, Local: "2001:db8::123", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, }) assert.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } local := netip.MustParseAddr("2001:db8::123") mapped, ok := agent.applyHostAddressRewrite(local, []netip.Addr{local}, "") assert.True(t, ok) assert.Len(t, mapped, 1) assert.Equal(t, netip.MustParseAddr("203.0.113.123"), mapped[0]) } type recordingLogger struct { warnings []string } func (l *recordingLogger) Trace(string) {} func (l *recordingLogger) Tracef(string, ...any) {} func (l *recordingLogger) Debug(string) {} func (l *recordingLogger) Debugf(string, ...any) {} func (l *recordingLogger) Info(string) {} func (l *recordingLogger) Infof(string, ...any) {} func (l *recordingLogger) Warn(msg string) { l.warnings = append(l.warnings, msg) } func (l *recordingLogger) Warnf(format string, args ...any) { l.warnings = append(l.warnings, fmt.Sprintf(format, args...)) } func (l *recordingLogger) Error(string) {} func (l *recordingLogger) Errorf(string, ...any) {} type recordingLoggerFactory struct { logger *recordingLogger } func (f *recordingLoggerFactory) NewLogger(string) logging.LeveledLogger { return f.logger } func TestAgentConfigNAT1To1IPs(t *testing.T) { testCases := []struct { name string config AgentConfig candidateType CandidateType localIP string expectedIP string }{ { name: "host candidate default type", config: AgentConfig{ NAT1To1IPs: []string{"1.2.3.4"}, }, candidateType: CandidateTypeHost, localIP: "10.0.0.1", expectedIP: "1.2.3.4", }, { name: "srflx candidate explicit type", config: AgentConfig{ NAT1To1IPs: []string{"5.6.7.8"}, NAT1To1IPCandidateType: CandidateTypeServerReflexive, }, candidateType: CandidateTypeServerReflexive, localIP: "0.0.0.0", expectedIP: "5.6.7.8", }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { stub := newStubNet(t) config := tc.config config.Net = stub agent, err := NewAgent(&config) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NotNil(t, agent.addressRewriteMapper) extIP := requireFirstExternalIP(t, agent.addressRewriteMapper, tc.candidateType, tc.localIP) require.Equal(t, tc.expectedIP, extIP.String()) }) } t.Run("deprecated multiple config IPs reject", func(t *testing.T) { stub := newStubNet(t) //nolint:godox // TODO: remove once AgentConfig.NAT1To1IPs is deprecated. agent, err := NewAgent(&AgentConfig{ Net: stub, NAT1To1IPs: []string{"1.2.3.4", "5.6.7.8"}, }) require.ErrorIs(t, err, ErrInvalidNAT1To1IPMapping) require.Nil(t, agent) }) t.Run("legacy config allows slash pair syntax", func(t *testing.T) { stub := newStubNet(t) agent, err := NewAgent(&AgentConfig{ Net: stub, NAT1To1IPs: []string{"203.0.113.20/10.0.0.20"}, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) extIP := requireFirstExternalIP(t, agent.addressRewriteMapper, CandidateTypeHost, "10.0.0.20") require.Equal(t, "203.0.113.20", extIP.String()) }) } var errStubNotImplemented = errors.New("stub not implemented") type stubTCPMux struct{} func (m *stubTCPMux) Close() error { return nil } func (m *stubTCPMux) GetConnByUfrag(string, bool, net.IP) (net.PacketConn, error) { return nil, errStubNotImplemented } func (m *stubTCPMux) RemoveConnByUfrag(string) {} type stubUDPMux struct{} func (m *stubUDPMux) Close() error { return nil } func (m *stubUDPMux) GetConn(string, net.Addr) (net.PacketConn, error) { return nil, errStubNotImplemented } func (m *stubUDPMux) RemoveConnByUfrag(string) {} func (m *stubUDPMux) GetListenAddresses() []net.Addr { return nil } type stubUniversalUDPMux struct { stubUDPMux } func (m *stubUniversalUDPMux) GetXORMappedAddr(net.Addr, time.Duration) (*stun.XORMappedAddress, error) { return nil, errStubNotImplemented } func (m *stubUniversalUDPMux) GetRelayedAddr(net.Addr, time.Duration) (*net.Addr, error) { return nil, errStubNotImplemented } func (m *stubUniversalUDPMux) GetConnForURL(string, string, net.Addr) (net.PacketConn, error) { return nil, errStubNotImplemented } type stubNet struct { t *testing.T udp6ListenCount int } func newStubNet(t *testing.T) *stubNet { t.Helper() return &stubNet{t: t} } func (n *stubNet) ListenPacket(network, address string) (net.PacketConn, error) { return nil, transport.ErrNotSupported } func (n *stubNet) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { if network == "udp6" { n.udp6ListenCount++ } return nil, fmt.Errorf("stub net does not listen on %s", network) //nolint:err113 } func (n *stubNet) ListenTCP(network string, laddr *net.TCPAddr) (transport.TCPListener, error) { return nil, transport.ErrNotSupported } func (n *stubNet) Dial(network, address string) (net.Conn, error) { return nil, transport.ErrNotSupported } func (n *stubNet) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { return nil, transport.ErrNotSupported } func (n *stubNet) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { return nil, transport.ErrNotSupported } func (n *stubNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) { return net.ResolveIPAddr(network, address) } func (n *stubNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { return net.ResolveUDPAddr(network, address) } func (n *stubNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { return net.ResolveTCPAddr(network, address) } func (n *stubNet) Interfaces() ([]*transport.Interface, error) { iface := transport.NewInterface(net.Interface{ Index: 1, MTU: 1500, Name: "stub0", Flags: net.FlagUp, }) iface.AddAddress(&net.IPNet{ IP: net.IPv4(192, 0, 2, 1), Mask: net.CIDRMask(24, 32), }) return []*transport.Interface{iface}, nil } func (n *stubNet) InterfaceByIndex(index int) (*transport.Interface, error) { ifaces, err := n.Interfaces() if err != nil { return nil, err } for _, iface := range ifaces { if iface.Index == index { return iface, nil } } return nil, transport.ErrInterfaceNotFound } func (n *stubNet) InterfaceByName(name string) (*transport.Interface, error) { ifaces, err := n.Interfaces() if err != nil { return nil, err } for _, iface := range ifaces { if iface.Name == name { return iface, nil } } return nil, transport.ErrInterfaceNotFound } func (n *stubNet) CreateDialer(dialer *net.Dialer) transport.Dialer { return nil } ice-4.1.0/agent_stats.go000066400000000000000000000133671511704470200151320ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "context" "time" ) // GetCandidatePairsStats returns a list of candidate pair stats. func (a *Agent) GetCandidatePairsStats() []CandidatePairStats { var res []CandidatePairStats err := a.loop.Run(a.loop, func(_ context.Context) { result := make([]CandidatePairStats, 0, len(a.checklist)) for _, cp := range a.checklist { stat := CandidatePairStats{ Timestamp: time.Now(), LocalCandidateID: cp.Local.ID(), RemoteCandidateID: cp.Remote.ID(), State: cp.state, Nominated: cp.nominated, PacketsSent: cp.PacketsSent(), PacketsReceived: cp.PacketsReceived(), BytesSent: cp.BytesSent(), BytesReceived: cp.BytesReceived(), LastPacketSentTimestamp: cp.LastPacketSentAt(), LastPacketReceivedTimestamp: cp.LastPacketReceivedAt(), FirstRequestTimestamp: cp.FirstRequestSentAt(), LastRequestTimestamp: cp.LastRequestSentAt(), FirstResponseTimestamp: cp.FirstResponseReceivedAt(), LastResponseTimestamp: cp.LastResponseReceivedAt(), FirstRequestReceivedTimestamp: cp.FirstRequestReceivedAt(), LastRequestReceivedTimestamp: cp.LastRequestReceivedAt(), TotalRoundTripTime: cp.TotalRoundTripTime(), CurrentRoundTripTime: cp.CurrentRoundTripTime(), // AvailableOutgoingBitrate float64 // AvailableIncomingBitrate float64 // CircuitBreakerTriggerCount uint32 RequestsReceived: cp.RequestsReceived(), RequestsSent: cp.RequestsSent(), ResponsesReceived: cp.ResponsesReceived(), ResponsesSent: cp.ResponsesSent(), // RetransmissionsReceived uint64 // RetransmissionsSent uint64 // ConsentRequestsSent uint64 // ConsentExpiredTimestamp time.Time } result = append(result, stat) } res = result }) if err != nil { a.log.Errorf("Failed to get candidate pairs stats: %v", err) return []CandidatePairStats{} } return res } // GetSelectedCandidatePairStats returns a candidate pair stats for selected candidate pair. // Returns false if there is no selected pair. func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) { isAvailable := false var res CandidatePairStats err := a.loop.Run(a.loop, func(_ context.Context) { sp := a.getSelectedPair() if sp == nil { return } isAvailable = true res = CandidatePairStats{ Timestamp: time.Now(), LocalCandidateID: sp.Local.ID(), RemoteCandidateID: sp.Remote.ID(), State: sp.state, Nominated: sp.nominated, PacketsSent: sp.PacketsSent(), PacketsReceived: sp.PacketsReceived(), BytesSent: sp.BytesSent(), BytesReceived: sp.BytesReceived(), LastPacketSentTimestamp: sp.LastPacketSentAt(), LastPacketReceivedTimestamp: sp.LastPacketReceivedAt(), // FirstRequestTimestamp time.Time // LastRequestTimestamp time.Time // LastResponseTimestamp time.Time TotalRoundTripTime: sp.TotalRoundTripTime(), CurrentRoundTripTime: sp.CurrentRoundTripTime(), // AvailableOutgoingBitrate float64 // AvailableIncomingBitrate float64 // CircuitBreakerTriggerCount uint32 // RequestsReceived uint64 // RequestsSent uint64 ResponsesReceived: sp.ResponsesReceived(), // ResponsesSent uint64 // RetransmissionsReceived uint64 // RetransmissionsSent uint64 // ConsentRequestsSent uint64 // ConsentExpiredTimestamp time.Time } }) if err != nil { a.log.Errorf("Failed to get selected candidate pair stats: %v", err) return CandidatePairStats{}, false } return res, isAvailable } // GetLocalCandidatesStats returns a list of local candidates stats. func (a *Agent) GetLocalCandidatesStats() []CandidateStats { var res []CandidateStats err := a.loop.Run(a.loop, func(_ context.Context) { result := make([]CandidateStats, 0, len(a.localCandidates)) for networkType, localCandidates := range a.localCandidates { for _, cand := range localCandidates { relayProtocol := "" if cand.Type() == CandidateTypeRelay { if cRelay, ok := cand.(*CandidateRelay); ok { relayProtocol = cRelay.RelayProtocol() } } stat := CandidateStats{ Timestamp: time.Now(), ID: cand.ID(), NetworkType: networkType, IP: cand.Address(), Port: cand.Port(), CandidateType: cand.Type(), Priority: cand.Priority(), // URL string RelayProtocol: relayProtocol, // Deleted bool } result = append(result, stat) } } res = result }) if err != nil { a.log.Errorf("Failed to get candidate pair stats: %v", err) return []CandidateStats{} } return res } // GetRemoteCandidatesStats returns a list of remote candidates stats. func (a *Agent) GetRemoteCandidatesStats() []CandidateStats { var res []CandidateStats err := a.loop.Run(a.loop, func(_ context.Context) { result := make([]CandidateStats, 0, len(a.remoteCandidates)) for networkType, remoteCandidates := range a.remoteCandidates { for _, c := range remoteCandidates { stat := CandidateStats{ Timestamp: time.Now(), ID: c.ID(), NetworkType: networkType, IP: c.Address(), Port: c.Port(), CandidateType: c.Type(), Priority: c.Priority(), // URL string RelayProtocol: "", } result = append(result, stat) } } res = result }) if err != nil { a.log.Errorf("Failed to get candidate pair stats: %v", err) return []CandidateStats{} } return res } ice-4.1.0/agent_test.go000066400000000000000000001764101511704470200147520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "context" "net" "strconv" "sync" "testing" "time" "github.com/pion/ice/v4/internal/fakenet" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/pion/transport/v3/vnet" "github.com/stretchr/testify/require" ) type BadAddr struct{} func (ba *BadAddr) Network() string { return "xxx" } func (ba *BadAddr) String() string { return "yyy" } func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 2).Stop() t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { agent.selector = &controllingSelector{agent: agent, log: agent.log} hostConfig := CandidateHostConfig{ Network: "udp", Address: "192.168.0.2", Port: 777, Component: 1, } local, err := NewCandidateHost(&hostConfig) local.conn = &fakenet.MockPacketConn{} require.NoError(t, err) remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), AttrControlling(agent.tieBreaker), PriorityAttr(local.Priority()), stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) require.NoError(t, err) // nolint: contextcheck agent.handleInbound(msg, local, remote) // Length of remote candidate list must be one now require.Len(t, agent.remoteCandidates, 1) // Length of remote candidate list for a network type must be 1 set := agent.remoteCandidates[local.NetworkType()] require.Len(t, set, 1) c := set[0] require.Equal(t, CandidateTypePeerReflexive, c.Type()) require.Equal(t, "172.17.0.3", c.Address()) require.Equal(t, 999, c.Port()) })) }) t.Run("Bad network type with handleInbound()", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { agent.selector = &controllingSelector{agent: agent, log: agent.log} hostConfig := CandidateHostConfig{ Network: "tcp", Address: "192.168.0.2", Port: 777, Component: 1, } local, err := NewCandidateHost(&hostConfig) require.NoError(t, err) remote := &BadAddr{} // nolint: contextcheck agent.handleInbound(nil, local, remote) require.Len(t, agent.remoteCandidates, 0) })) }) t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { agent.selector = &controllingSelector{agent: agent, log: agent.log} tID := [stun.TransactionIDSize]byte{} copy(tID[:], "ABC") agent.pendingBindingRequests = []bindingRequest{ {time.Now(), tID, &net.UDPAddr{}, false, nil}, } hostConfig := CandidateHostConfig{ Network: "udp", Address: "192.168.0.2", Port: 777, Component: 1, } local, err := NewCandidateHost(&hostConfig) local.conn = &fakenet.MockPacketConn{} require.NoError(t, err) remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID), stun.NewShortTermIntegrity(agent.remotePwd), stun.Fingerprint, ) require.NoError(t, err) // nolint: contextcheck agent.handleInbound(msg, local, remote) require.Len(t, agent.remoteCandidates, 0) })) }) } // Assert that Agent on startup sends message, and doesn't wait for connectivityTicker to fire // https://github.com/pion/ice/issues/15 func TestConnectivityOnStartup(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() // Create a network with two interfaces wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) require.NoError(t, err) net0, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.1"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net0)) net1, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.2"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net1)) require.NoError(t, wan.Start()) aNotifier, aConnected := onConnected() bNotifier, bConnected := onConnected() KeepaliveInterval := time.Hour cfg0 := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, Net: net0, KeepaliveInterval: &KeepaliveInterval, CheckInterval: &KeepaliveInterval, } aAgent, err := NewAgent(cfg0) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) cfg1 := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, Net: net1, KeepaliveInterval: &KeepaliveInterval, CheckInterval: &KeepaliveInterval, } bAgent, err := NewAgent(cfg1) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) func(aAgent, bAgent *Agent) (*Conn, *Conn) { // Manual signaling aUfrag, aPwd, err := aAgent.GetLocalUserCredentials() require.NoError(t, err) bUfrag, bPwd, err := bAgent.GetLocalUserCredentials() require.NoError(t, err) gatherAndExchangeCandidates(t, aAgent, bAgent) accepted := make(chan struct{}) accepting := make(chan struct{}) var aConn *Conn origHdlr := aAgent.onConnectionStateChangeHdlr.Load() if origHdlr != nil { defer require.NoError(t, aAgent.OnConnectionStateChange(origHdlr.(func(ConnectionState)))) //nolint:forcetypeassert } require.NoError(t, aAgent.OnConnectionStateChange(func(s ConnectionState) { if s == ConnectionStateChecking { close(accepting) } if origHdlr != nil { origHdlr.(func(ConnectionState))(s) //nolint:forcetypeassert } })) go func() { var acceptErr error aConn, acceptErr = aAgent.Accept(context.TODO(), bUfrag, bPwd) require.NoError(t, acceptErr) close(accepted) }() <-accepting bConn, err := bAgent.Dial(context.TODO(), aUfrag, aPwd) require.NoError(t, err) // Ensure accepted <-accepted return aConn, bConn }(aAgent, bAgent) // Ensure pair selected // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair <-aConnected <-bConnected require.NoError(t, wan.Stop()) } func TestConnectivityLite(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() stunServerURL := &stun.URI{ Scheme: SchemeTypeSTUN, Host: "1.2.3.4", Port: 3478, Proto: stun.ProtoTypeUDP, } natType := &vnet.NATType{ MappingBehavior: vnet.EndpointIndependent, FilteringBehavior: vnet.EndpointIndependent, } vent, err := buildVNet(natType, natType) require.NoError(t, err, "should succeed") defer vent.close() aNotifier, aConnected := onConnected() bNotifier, bConnected := onConnected() cfg0 := &AgentConfig{ Urls: []*stun.URI{stunServerURL}, NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, Net: vent.net0, } aAgent, err := NewAgent(cfg0) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) cfg1 := &AgentConfig{ Urls: []*stun.URI{}, Lite: true, CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, Net: vent.net1, } bAgent, err := NewAgent(cfg1) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) connectWithVNet(t, aAgent, bAgent) // Ensure pair selected // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair <-aConnected <-bConnected } func TestInboundValidity(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() buildMsg := func(class stun.MessageClass, username, key string) *stun.Message { msg, err := stun.Build(stun.NewType(stun.MethodBinding, class), stun.TransactionID, stun.NewUsername(username), stun.NewShortTermIntegrity(key), stun.Fingerprint, ) require.NoError(t, err) return msg } remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} hostConfig := CandidateHostConfig{ Network: "udp", Address: "192.168.0.2", Port: 777, Component: 1, } local, err := NewCandidateHost(&hostConfig) local.conn = &fakenet.MockPacketConn{} require.NoError(t, err) t.Run("Invalid Binding requests should be discarded", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() agent.handleInbound(buildMsg(stun.ClassRequest, "invalid", agent.localPwd), local, remote) require.Len(t, agent.remoteCandidates, 0) agent.handleInbound(buildMsg(stun.ClassRequest, agent.localUfrag+":"+agent.remoteUfrag, "Invalid"), local, remote) require.Len(t, agent.remoteCandidates, 0) }) t.Run("Invalid Binding success responses should be discarded", func(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, a.Close()) }() a.handleInbound(buildMsg(stun.ClassSuccessResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote) require.Len(t, a.remoteCandidates, 0) }) t.Run("Discard non-binding messages", func(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, a.Close()) }() a.handleInbound(buildMsg(stun.ClassErrorResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote) require.Len(t, a.remoteCandidates, 0) }) t.Run("Valid bind request", func(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, a.Close()) }() err = a.loop.Run(a.loop, func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} // nolint: contextcheck a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote) require.Len(t, a.remoteCandidates, 1) }) require.NoError(t, err) }) t.Run("Valid bind without fingerprint", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { agent.selector = &controllingSelector{agent: agent, log: agent.log} msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), stun.NewShortTermIntegrity(agent.localPwd), ) require.NoError(t, err) // nolint: contextcheck agent.handleInbound(msg, local, remote) require.Len(t, agent.remoteCandidates, 1) })) }) t.Run("Success with invalid TransactionID", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() hostConfig := CandidateHostConfig{ Network: "udp", Address: "192.168.0.2", Port: 777, Component: 1, } local, err := NewCandidateHost(&hostConfig) local.conn = &fakenet.MockPacketConn{} require.NoError(t, err) remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} tID := [stun.TransactionIDSize]byte{} copy(tID[:], "ABC") msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID), stun.NewShortTermIntegrity(agent.remotePwd), stun.Fingerprint, ) require.NoError(t, err) agent.handleInbound(msg, local, remote) require.Len(t, agent.remoteCandidates, 0) }) } func TestInvalidAgentStarts(t *testing.T) { defer test.CheckRoutines(t)() agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) defer cancel() _, err = agent.Dial(ctx, "", "bar") require.ErrorIs(t, ErrRemoteUfragEmpty, err) _, err = agent.Dial(ctx, "foo", "") require.ErrorIs(t, ErrRemotePwdEmpty, err) _, err = agent.Dial(ctx, "foo", "bar") require.ErrorIs(t, ErrCanceledByCaller, err) _, err = agent.Dial(ctx, "foo", "bar") require.ErrorIs(t, ErrMultipleStart, err) } // Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages. func TestConnectionStateCallback(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() disconnectedDuration := time.Second failedDuration := time.Second KeepaliveInterval := time.Duration(0) cfg := &AgentConfig{ Urls: []*stun.URI{}, NetworkTypes: supportedNetworkTypes(), DisconnectedTimeout: &disconnectedDuration, FailedTimeout: &failedDuration, KeepaliveInterval: &KeepaliveInterval, InterfaceFilter: problematicNetworkInterfaces, } isClosed := make(chan any) aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { select { case <-isClosed: return default: } require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { select { case <-isClosed: return default: } require.NoError(t, bAgent.Close()) }() isChecking := make(chan any) isConnected := make(chan any) isDisconnected := make(chan any) isFailed := make(chan any) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { switch c { case ConnectionStateChecking: close(isChecking) case ConnectionStateConnected: close(isConnected) case ConnectionStateDisconnected: close(isDisconnected) case ConnectionStateFailed: close(isFailed) case ConnectionStateClosed: close(isClosed) default: } }) require.NoError(t, err) connect(t, aAgent, bAgent) <-isChecking <-isConnected <-isDisconnected <-isFailed require.NoError(t, aAgent.Close()) require.NoError(t, bAgent.Close()) <-isClosed } func TestInvalidGather(t *testing.T) { t.Run("Gather with no OnCandidate should error", func(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, a.Close()) }() err = a.GatherCandidates() require.ErrorIs(t, ErrNoOnCandidateHandler, err) }) } func TestCandidatePairsStats(t *testing.T) { //nolint:cyclop,gocyclo defer test.CheckRoutines(t)() // Avoid deadlocks? defer test.TimeOut(1 * time.Second).Stop() agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() hostConfig := &CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 19216, Component: 1, } hostLocal, err := NewCandidateHost(hostConfig) require.NoError(t, err) relayConfig := &CandidateRelayConfig{ Network: "udp", Address: "1.2.3.4", Port: 2340, Component: 1, RelAddr: "4.3.2.1", RelPort: 43210, } relayRemote, err := NewCandidateRelay(relayConfig) require.NoError(t, err) srflxConfig := &CandidateServerReflexiveConfig{ Network: "udp", Address: "10.10.10.2", Port: 19218, Component: 1, RelAddr: "4.3.2.1", RelPort: 43212, } srflxRemote, err := NewCandidateServerReflexive(srflxConfig) require.NoError(t, err) prflxConfig := &CandidatePeerReflexiveConfig{ Network: "udp", Address: "10.10.10.2", Port: 19217, Component: 1, RelAddr: "4.3.2.1", RelPort: 43211, } prflxRemote, err := NewCandidatePeerReflexive(prflxConfig) require.NoError(t, err) hostConfig = &CandidateHostConfig{ Network: "udp", Address: "1.2.3.5", Port: 12350, Component: 1, } hostRemote, err := NewCandidateHost(hostConfig) require.NoError(t, err) for _, remote := range []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} { p := agent.findPair(hostLocal, remote) if p == nil { p = agent.addPair(hostLocal, remote) } p.UpdateRequestReceived() p.UpdateRequestSent() p.UpdateResponseSent() p.UpdateRoundTripTime(time.Second) p.UpdatePacketSent(100) p.UpdatePacketReceived(200) } p := agent.findPair(hostLocal, prflxRemote) p.state = CandidatePairStateFailed for i := 1; i < 10; i++ { p.UpdateRoundTripTime(time.Duration(i+1) * time.Second) } stats := agent.GetCandidatePairsStats() require.Len(t, stats, 4) var relayPairStat, srflxPairStat, prflxPairStat, hostPairStat CandidatePairStats for _, cps := range stats { require.Equal(t, cps.LocalCandidateID, hostLocal.ID()) switch cps.RemoteCandidateID { case relayRemote.ID(): relayPairStat = cps case srflxRemote.ID(): srflxPairStat = cps case prflxRemote.ID(): prflxPairStat = cps case hostRemote.ID(): hostPairStat = cps default: t.Fatal("invalid remote candidate ID") //nolint } require.False(t, cps.FirstRequestTimestamp.IsZero()) require.False(t, cps.LastRequestTimestamp.IsZero()) require.False(t, cps.FirstResponseTimestamp.IsZero()) require.False(t, cps.LastResponseTimestamp.IsZero()) require.False(t, cps.FirstRequestReceivedTimestamp.IsZero()) require.False(t, cps.LastRequestReceivedTimestamp.IsZero()) require.Equal(t, uint32(1), cps.PacketsSent) require.Equal(t, uint32(1), cps.PacketsReceived) require.Equal(t, uint64(100), cps.BytesSent) require.Equal(t, uint64(200), cps.BytesReceived) require.False(t, cps.LastPacketSentTimestamp.IsZero()) require.False(t, cps.LastPacketReceivedTimestamp.IsZero()) } require.Equal(t, relayPairStat.RemoteCandidateID, relayRemote.ID()) require.Equal(t, srflxPairStat.RemoteCandidateID, srflxRemote.ID()) require.Equal(t, prflxPairStat.RemoteCandidateID, prflxRemote.ID()) require.Equal(t, hostPairStat.RemoteCandidateID, hostRemote.ID()) require.Equal(t, prflxPairStat.State, CandidatePairStateFailed) require.Equal(t, float64(10), prflxPairStat.CurrentRoundTripTime) require.Equal(t, float64(55), prflxPairStat.TotalRoundTripTime) require.Equal(t, uint64(10), prflxPairStat.ResponsesReceived) } func TestSelectedCandidatePairStats(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() // Avoid deadlocks? defer test.TimeOut(1 * time.Second).Stop() agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() hostConfig := &CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 19216, Component: 1, } hostLocal, err := NewCandidateHost(hostConfig) require.NoError(t, err) srflxConfig := &CandidateServerReflexiveConfig{ Network: "udp", Address: "10.10.10.2", Port: 19218, Component: 1, RelAddr: "4.3.2.1", RelPort: 43212, } srflxRemote, err := NewCandidateServerReflexive(srflxConfig) require.NoError(t, err) // no selected pair, should return not available _, ok := agent.GetSelectedCandidatePairStats() require.False(t, ok) // add pair and populate some RTT stats candidatePair := agent.findPair(hostLocal, srflxRemote) if candidatePair == nil { agent.addPair(hostLocal, srflxRemote) candidatePair = agent.findPair(hostLocal, srflxRemote) } for i := 0; i < 10; i++ { candidatePair.UpdateRoundTripTime(time.Duration(i+1) * time.Second) } candidatePair.UpdatePacketSent(150) candidatePair.UpdatePacketReceived(250) // set the pair as selected agent.setSelectedPair(candidatePair) stats, ok := agent.GetSelectedCandidatePairStats() require.True(t, ok) require.Equal(t, stats.LocalCandidateID, hostLocal.ID()) require.Equal(t, stats.RemoteCandidateID, srflxRemote.ID()) require.Equal(t, float64(10), stats.CurrentRoundTripTime) require.Equal(t, float64(55), stats.TotalRoundTripTime) require.Equal(t, uint64(10), stats.ResponsesReceived) require.Equal(t, uint32(1), stats.PacketsSent) require.Equal(t, uint32(1), stats.PacketsReceived) require.Equal(t, uint64(150), stats.BytesSent) require.Equal(t, uint64(250), stats.BytesReceived) require.False(t, stats.LastPacketSentTimestamp.IsZero()) require.False(t, stats.LastPacketReceivedTimestamp.IsZero()) } func TestLocalCandidateStats(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() // Avoid deadlocks? defer test.TimeOut(1 * time.Second).Stop() agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() hostConfig := &CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 19216, Component: 1, } hostLocal, err := NewCandidateHost(hostConfig) require.NoError(t, err) srflxConfig := &CandidateServerReflexiveConfig{ Network: "udp", Address: "192.168.1.1", Port: 19217, Component: 1, RelAddr: "4.3.2.1", RelPort: 43212, } srflxLocal, err := NewCandidateServerReflexive(srflxConfig) require.NoError(t, err) agent.localCandidates[NetworkTypeUDP4] = []Candidate{hostLocal, srflxLocal} localStats := agent.GetLocalCandidatesStats() require.Len(t, localStats, 2) var hostLocalStat, srflxLocalStat CandidateStats for _, stats := range localStats { var candidate Candidate switch stats.ID { case hostLocal.ID(): hostLocalStat = stats candidate = hostLocal case srflxLocal.ID(): srflxLocalStat = stats candidate = srflxLocal default: t.Fatal("invalid local candidate ID") // nolint } require.Equal(t, stats.CandidateType, candidate.Type()) require.Equal(t, stats.Priority, candidate.Priority()) require.Equal(t, stats.IP, candidate.Address()) } require.Equal(t, hostLocalStat.ID, hostLocal.ID()) require.Equal(t, srflxLocalStat.ID, srflxLocal.ID()) } func TestRemoteCandidateStats(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() // Avoid deadlocks? defer test.TimeOut(1 * time.Second).Stop() agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() relayConfig := &CandidateRelayConfig{ Network: "udp", Address: "1.2.3.4", Port: 12340, Component: 1, RelAddr: "4.3.2.1", RelPort: 43210, } relayRemote, err := NewCandidateRelay(relayConfig) require.NoError(t, err) srflxConfig := &CandidateServerReflexiveConfig{ Network: "udp", Address: "10.10.10.2", Port: 19218, Component: 1, RelAddr: "4.3.2.1", RelPort: 43212, } srflxRemote, err := NewCandidateServerReflexive(srflxConfig) require.NoError(t, err) prflxConfig := &CandidatePeerReflexiveConfig{ Network: "udp", Address: "10.10.10.2", Port: 19217, Component: 1, RelAddr: "4.3.2.1", RelPort: 43211, } prflxRemote, err := NewCandidatePeerReflexive(prflxConfig) require.NoError(t, err) hostConfig := &CandidateHostConfig{ Network: "udp", Address: "1.2.3.5", Port: 12350, Component: 1, } hostRemote, err := NewCandidateHost(hostConfig) require.NoError(t, err) agent.remoteCandidates[NetworkTypeUDP4] = []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} remoteStats := agent.GetRemoteCandidatesStats() require.Len(t, remoteStats, 4) var relayRemoteStat, srflxRemoteStat, prflxRemoteStat, hostRemoteStat CandidateStats for _, stats := range remoteStats { var candidate Candidate switch stats.ID { case relayRemote.ID(): relayRemoteStat = stats candidate = relayRemote case srflxRemote.ID(): srflxRemoteStat = stats candidate = srflxRemote case prflxRemote.ID(): prflxRemoteStat = stats candidate = prflxRemote case hostRemote.ID(): hostRemoteStat = stats candidate = hostRemote default: t.Fatal("invalid remote candidate ID") // nolint } require.Equal(t, stats.CandidateType, candidate.Type()) require.Equal(t, stats.Priority, candidate.Priority()) require.Equal(t, stats.IP, candidate.Address()) } require.Equal(t, relayRemoteStat.ID, relayRemote.ID()) require.Equal(t, srflxRemoteStat.ID, srflxRemote.ID()) require.Equal(t, prflxRemoteStat.ID, prflxRemote.ID()) require.Equal(t, hostRemoteStat.ID, hostRemote.ID()) } func TestInitExtIPMapping(t *testing.T) { defer test.CheckRoutines(t)() // agent.addressRewriteMapper should be nil by default agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) require.Nil(t, agent.addressRewriteMapper) require.NoError(t, agent.Close()) // a.addressRewriteMapper should be nil when NAT1To1IPs is a non-nil empty array agent, err = NewAgent(&AgentConfig{ NAT1To1IPs: []string{}, NAT1To1IPCandidateType: CandidateTypeHost, }) require.NoError(t, err) require.Nil(t, agent.addressRewriteMapper) require.NoError(t, agent.Close()) // NewAgent should return an error when 1:1 NAT for host candidate is enabled // but the candidate type does not appear in the CandidateTypes. _, err = NewAgent(&AgentConfig{ NAT1To1IPs: []string{"1.2.3.4"}, NAT1To1IPCandidateType: CandidateTypeHost, CandidateTypes: []CandidateType{CandidateTypeRelay}, }) require.ErrorIs(t, ErrIneffectiveNAT1To1IPMappingHost, err) // NewAgent should return an error when 1:1 NAT for srflx candidate is enabled // but the candidate type does not appear in the CandidateTypes. _, err = NewAgent(&AgentConfig{ NAT1To1IPs: []string{"1.2.3.4"}, NAT1To1IPCandidateType: CandidateTypeServerReflexive, CandidateTypes: []CandidateType{CandidateTypeRelay}, }) require.ErrorIs(t, ErrIneffectiveNAT1To1IPMappingSrflx, err) // NewAgent should return an error when 1:1 NAT for host candidate is enabled // along with mDNS with MulticastDNSModeQueryAndGather _, err = NewAgent(&AgentConfig{ NAT1To1IPs: []string{"1.2.3.4"}, NAT1To1IPCandidateType: CandidateTypeHost, MulticastDNSMode: MulticastDNSModeQueryAndGather, }) require.ErrorIs(t, ErrMulticastDNSWithNAT1To1IPMapping, err) // NewAgent should return if newAddressRewriteMapper() returns an error. _, err = NewAgent(&AgentConfig{ NAT1To1IPs: []string{"bad.2.3.4"}, // Bad IP NAT1To1IPCandidateType: CandidateTypeHost, }) require.ErrorIs(t, ErrInvalidNAT1To1IPMapping, err) } func TestBindingRequestTimeout(t *testing.T) { defer test.CheckRoutines(t)() const expectedRemovalCount = 2 agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() now := time.Now() agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{ timestamp: now, // Valid }) agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{ timestamp: now.Add(-3900 * time.Millisecond), // Valid }) agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{ timestamp: now.Add(-4100 * time.Millisecond), // Invalid }) agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{ timestamp: now.Add(-75 * time.Hour), // Invalid }) agent.invalidatePendingBindingRequests(now) require.Equal( t, expectedRemovalCount, len(agent.pendingBindingRequests), "Binding invalidation due to timeout did not remove the correct number of binding requests", ) } // TestAgentCredentials checks if local username fragments and passwords (if set) meet RFC standard // and ensure it's backwards compatible with previous versions of the pion/ice. func TestAgentCredentials(t *testing.T) { defer test.CheckRoutines(t)() // Make sure to pass Travis check by disabling the logs log := logging.NewDefaultLoggerFactory() log.DefaultLogLevel = logging.LogLevelDisabled // Agent should not require any of the usernames and password to be set // If set, they should follow the default 16/128 bits random number generator strategy agent, err := NewAgent(&AgentConfig{LoggerFactory: log}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.GreaterOrEqual(t, len([]rune(agent.localUfrag))*8, 24) require.GreaterOrEqual(t, len([]rune(agent.localPwd))*8, 128) // Should honor RFC standards // Local values MUST be unguessable, with at least 128 bits of // random number generator output used to generate the password, and // at least 24 bits of output to generate the username fragment. _, err = NewAgent(&AgentConfig{LocalUfrag: "xx", LoggerFactory: log}) require.EqualError(t, err, ErrLocalUfragInsufficientBits.Error()) _, err = NewAgent(&AgentConfig{LocalPwd: "xxxxxx", LoggerFactory: log}) require.EqualError(t, err, ErrLocalPwdInsufficientBits.Error()) } // Assert that Agent on Failure deletes all existing candidates // User can then do an ICE Restart to bring agent back. func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() oneSecond := time.Second KeepaliveInterval := time.Duration(0) cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, KeepaliveInterval: &KeepaliveInterval, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() isFailed := make(chan any) require.NoError(t, aAgent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateFailed { close(isFailed) } })) connect(t, aAgent, bAgent) <-isFailed done := make(chan struct{}) require.NoError(t, aAgent.loop.Run(context.Background(), func(context.Context) { require.Equal(t, len(aAgent.remoteCandidates), 0) require.Equal(t, len(aAgent.localCandidates), 0) close(done) })) <-done } // Assert that the ICE Agent can go directly from Connecting -> Failed on both sides. func TestConnectionStateConnectingToFailed(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() oneSecond := time.Second KeepaliveInterval := time.Duration(0) cfg := &AgentConfig{ DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, KeepaliveInterval: &KeepaliveInterval, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() var isFailed sync.WaitGroup var isChecking sync.WaitGroup isFailed.Add(2) isChecking.Add(2) connectionStateCheck := func(c ConnectionState) { switch c { case ConnectionStateFailed: isFailed.Done() case ConnectionStateChecking: isChecking.Done() case ConnectionStateCompleted: t.Errorf("Unexpected ConnectionState: %v", c) //nolint default: } } require.NoError(t, aAgent.OnConnectionStateChange(connectionStateCheck)) require.NoError(t, bAgent.OnConnectionStateChange(connectionStateCheck)) go func() { _, err := aAgent.Accept(context.TODO(), "InvalidFrag", "InvalidPwd") require.Error(t, err) }() go func() { _, err := bAgent.Dial(context.TODO(), "InvalidFrag", "InvalidPwd") require.Error(t, err) }() isChecking.Wait() isFailed.Wait() } func TestAgentRestart(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() oneSecond := time.Second t.Run("Restart During Gather", func(t *testing.T) { connA, connB := pipe(t, &AgentConfig{ DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) defer closePipe(t, connA, connB) ctx, cancel := context.WithCancel(context.Background()) require.NoError(t, connB.agent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateFailed || c == ConnectionStateDisconnected { cancel() } })) connA.agent.gatheringState = GatheringStateGathering require.NoError(t, connA.agent.Restart("", "")) <-ctx.Done() }) t.Run("Restart When Closed", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) require.NoError(t, agent.Close()) require.Equal(t, ErrClosed, agent.Restart("", "")) }) t.Run("Restart One Side", func(t *testing.T) { connA, connB := pipe(t, &AgentConfig{ DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) defer closePipe(t, connA, connB) ctx, cancel := context.WithCancel(context.Background()) require.NoError(t, connB.agent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateFailed || c == ConnectionStateDisconnected { cancel() } })) require.NoError(t, connA.agent.Restart("", "")) <-ctx.Done() }) t.Run("Restart Both Sides", func(t *testing.T) { // Get all addresses of candidates concatenated generateCandidateAddressStrings := func(candidates []Candidate, err error) (out string) { require.NoError(t, err) for _, c := range candidates { out += c.Address() + ":" out += strconv.Itoa(c.Port()) } return } // Store the original candidates, confirm that after we reconnect we have new pairs connA, connB := pipe(t, &AgentConfig{ DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) defer closePipe(t, connA, connB) connAFirstCandidates := generateCandidateAddressStrings(connA.agent.GetLocalCandidates()) connBFirstCandidates := generateCandidateAddressStrings(connB.agent.GetLocalCandidates()) aNotifier, aConnected := onConnected() require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier)) bNotifier, bConnected := onConnected() require.NoError(t, connB.agent.OnConnectionStateChange(bNotifier)) // Restart and Re-Signal require.NoError(t, connA.agent.Restart("", "")) require.NoError(t, connB.agent.Restart("", "")) // Exchange Candidates and Credentials ufrag, pwd, err := connB.agent.GetLocalUserCredentials() require.NoError(t, err) require.NoError(t, connA.agent.SetRemoteCredentials(ufrag, pwd)) ufrag, pwd, err = connA.agent.GetLocalUserCredentials() require.NoError(t, err) require.NoError(t, connB.agent.SetRemoteCredentials(ufrag, pwd)) gatherAndExchangeCandidates(t, connA.agent, connB.agent) // Wait until both have gone back to connected <-aConnected <-bConnected // Assert that we have new candidates each time require.NotEqual(t, connAFirstCandidates, generateCandidateAddressStrings(connA.agent.GetLocalCandidates())) require.NotEqual(t, connBFirstCandidates, generateCandidateAddressStrings(connB.agent.GetLocalCandidates())) }) } func TestGetRemoteCredentials(t *testing.T) { var config AgentConfig agent, err := NewAgent(&config) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() agent.remoteUfrag = "remoteUfrag" agent.remotePwd = "remotePwd" actualUfrag, actualPwd, err := agent.GetRemoteUserCredentials() require.NoError(t, err) require.Equal(t, actualUfrag, agent.remoteUfrag) require.Equal(t, actualPwd, agent.remotePwd) } func TestGetRemoteCandidates(t *testing.T) { var config AgentConfig agent, err := NewAgent(&config) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() expectedCandidates := []Candidate{} for i := 0; i < 5; i++ { cfg := CandidateHostConfig{ Network: "udp", Address: "192.168.0.2", Port: 1000 + i, Component: 1, } cand, errCand := NewCandidateHost(&cfg) require.NoError(t, errCand) expectedCandidates = append(expectedCandidates, cand) agent.addRemoteCandidate(cand) } actualCandidates, err := agent.GetRemoteCandidates() require.NoError(t, err) require.ElementsMatch(t, expectedCandidates, actualCandidates) } func TestGetLocalCandidates(t *testing.T) { var config AgentConfig agent, err := NewAgent(&config) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() dummyConn := &net.UDPConn{} expectedCandidates := []Candidate{} for i := 0; i < 5; i++ { cfg := CandidateHostConfig{ Network: "udp", Address: "192.168.0.2", Port: 1000 + i, Component: 1, } cand, errCand := NewCandidateHost(&cfg) require.NoError(t, errCand) expectedCandidates = append(expectedCandidates, cand) err = agent.addCandidate(context.Background(), cand, dummyConn) require.NoError(t, err) } actualCandidates, err := agent.GetLocalCandidates() require.NoError(t, err) require.ElementsMatch(t, expectedCandidates, actualCandidates) } func TestCloseInConnectionStateCallback(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() disconnectedDuration := time.Second failedDuration := time.Second KeepaliveInterval := time.Duration(0) CheckInterval := 500 * time.Millisecond cfg := &AgentConfig{ Urls: []*stun.URI{}, NetworkTypes: supportedNetworkTypes(), DisconnectedTimeout: &disconnectedDuration, FailedTimeout: &failedDuration, KeepaliveInterval: &KeepaliveInterval, CheckInterval: &CheckInterval, } aAgent, err := NewAgent(cfg) require.NoError(t, err) var aAgentClosed bool defer func() { if aAgentClosed { return } require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() isClosed := make(chan any) isConnected := make(chan any) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { switch c { case ConnectionStateConnected: <-isConnected require.NoError(t, aAgent.Close()) aAgentClosed = true case ConnectionStateClosed: close(isClosed) default: } }) require.NoError(t, err) connect(t, aAgent, bAgent) close(isConnected) <-isClosed } func TestRunTaskInConnectionStateCallback(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() oneSecond := time.Second KeepaliveInterval := time.Duration(0) CheckInterval := 50 * time.Millisecond cfg := &AgentConfig{ Urls: []*stun.URI{}, NetworkTypes: supportedNetworkTypes(), DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, KeepaliveInterval: &KeepaliveInterval, CheckInterval: &CheckInterval, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() isComplete := make(chan any) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateConnected { _, _, errCred := aAgent.GetLocalUserCredentials() require.NoError(t, errCred) require.NoError(t, aAgent.Restart("", "")) close(isComplete) } }) require.NoError(t, err) connect(t, aAgent, bAgent) <-isComplete } func TestRunTaskInSelectedCandidatePairChangeCallback(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 5).Stop() oneSecond := time.Second KeepaliveInterval := time.Duration(0) CheckInterval := 50 * time.Millisecond cfg := &AgentConfig{ Urls: []*stun.URI{}, NetworkTypes: supportedNetworkTypes(), DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, KeepaliveInterval: &KeepaliveInterval, CheckInterval: &CheckInterval, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() isComplete := make(chan any) isTested := make(chan any) err = aAgent.OnSelectedCandidatePairChange(func(Candidate, Candidate) { go func() { _, _, errCred := aAgent.GetLocalUserCredentials() require.NoError(t, errCred) close(isTested) }() }) require.NoError(t, err) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateConnected { close(isComplete) } }) require.NoError(t, err) connect(t, aAgent, bAgent) <-isComplete <-isTested } // Assert that a Lite agent goes to disconnected and failed. func TestLiteLifecycle(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() aNotifier, aConnected := onConnected() aAgent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, }) require.NoError(t, err) var aClosed bool defer func() { if aClosed { return } require.NoError(t, aAgent.Close()) }() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) disconnectedDuration := time.Second failedDuration := time.Second KeepaliveInterval := time.Duration(0) CheckInterval := 500 * time.Millisecond bAgent, err := NewAgent(&AgentConfig{ Lite: true, CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, DisconnectedTimeout: &disconnectedDuration, FailedTimeout: &failedDuration, KeepaliveInterval: &KeepaliveInterval, CheckInterval: &CheckInterval, }) require.NoError(t, err) var bClosed bool defer func() { if bClosed { return } require.NoError(t, bAgent.Close()) }() bConnected := make(chan any) bDisconnected := make(chan any) bFailed := make(chan any) require.NoError(t, bAgent.OnConnectionStateChange(func(c ConnectionState) { switch c { case ConnectionStateConnected: close(bConnected) case ConnectionStateDisconnected: close(bDisconnected) case ConnectionStateFailed: close(bFailed) default: } })) connectWithVNet(t, bAgent, aAgent) <-aConnected <-bConnected require.NoError(t, aAgent.Close()) aClosed = true <-bDisconnected <-bFailed require.NoError(t, bAgent.Close()) bClosed = true } func TestValidateSelectedPairTransitions(t *testing.T) { agent := &Agent{ disconnectedTimeout: time.Second, failedTimeout: time.Second, connectionState: ConnectionStateConnected, connectionStateNotifier: &handlerNotifier{ connectionStateFunc: func(ConnectionState) {}, done: make(chan struct{}), }, log: logging.NewDefaultLoggerFactory().NewLogger("test"), } local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "1.1.1.1", Port: 1000, Component: ComponentRTP, }) require.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "2.2.2.2", Port: 2000, Component: ComponentRTP, }) require.NoError(t, err) remote.setLastReceived(time.Now().Add(-3 * time.Second)) agent.selectedPair.Store(newCandidatePair(local, remote, true)) require.True(t, agent.validateSelectedPair()) require.Equal(t, ConnectionStateDisconnected, agent.connectionState) require.True(t, agent.validateSelectedPair()) require.Equal(t, ConnectionStateFailed, agent.connectionState) } func TestNilCandidate(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) require.NoError(t, a.AddRemoteCandidate(nil)) require.NoError(t, a.Close()) } func TestNilCandidatePair(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, a.Close()) }() a.setSelectedPair(nil) } func TestGetSelectedCandidatePair(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) require.NoError(t, err) net, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.1"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net)) require.NoError(t, wan.Start()) cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), Net: net, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() aAgentPair, err := aAgent.GetSelectedCandidatePair() require.NoError(t, err) require.Nil(t, aAgentPair) bAgentPair, err := bAgent.GetSelectedCandidatePair() require.NoError(t, err) require.Nil(t, bAgentPair) connect(t, aAgent, bAgent) aAgentPair, err = aAgent.GetSelectedCandidatePair() require.NoError(t, err) require.NotNil(t, aAgentPair) bAgentPair, err = bAgent.GetSelectedCandidatePair() require.NoError(t, err) require.NotNil(t, bAgentPair) require.True(t, bAgentPair.Local.Equal(aAgentPair.Remote)) require.True(t, bAgentPair.Remote.Equal(aAgentPair.Local)) require.NoError(t, wan.Stop()) } func TestAcceptAggressiveNomination(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() // Create a network with two interfaces wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) require.NoError(t, err) net0, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.1"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net0)) net1, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.2", "192.168.0.3", "192.168.0.4"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net1)) require.NoError(t, wan.Start()) testCases := []struct { name string isLite bool enableUseCandidateCheckPriority bool useHigherPriority bool isExpectedToSwitch bool }{ {"should accept higher priority - full agent", false, false, true, true}, {"should not accept lower priority - full agent", false, false, false, false}, {"should accept higher priority - no use-candidate priority check - lite agent", true, false, true, true}, {"should accept lower priority - no use-candidate priority check - lite agent", true, false, false, true}, {"should accept higher priority - use-candidate priority check - lite agent", true, true, true, true}, {"should not accept lower priority - use-candidate priority check - lite agent", true, true, false, false}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { aNotifier, aConnected := onConnected() bNotifier, bConnected := onConnected() KeepaliveInterval := time.Hour cfg0 := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, MulticastDNSMode: MulticastDNSModeDisabled, Net: net0, KeepaliveInterval: &KeepaliveInterval, CheckInterval: &KeepaliveInterval, Lite: tc.isLite, EnableUseCandidateCheckPriority: tc.enableUseCandidateCheckPriority, } if tc.isLite { cfg0.CandidateTypes = []CandidateType{CandidateTypeHost} } var aAgent, bAgent *Agent aAgent, err = NewAgent(cfg0) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) cfg1 := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, MulticastDNSMode: MulticastDNSModeDisabled, Net: net1, KeepaliveInterval: &KeepaliveInterval, CheckInterval: &KeepaliveInterval, } bAgent, err = NewAgent(cfg1) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) connect(t, aAgent, bAgent) // Ensure pair selected // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair <-aConnected <-bConnected // Send new USE-CANDIDATE message with priority to update the selected pair buildMsg := func(class stun.MessageClass, username, key string, priority uint32) *stun.Message { msg, err1 := stun.Build(stun.NewType(stun.MethodBinding, class), stun.TransactionID, stun.NewUsername(username), stun.NewShortTermIntegrity(key), UseCandidate(), PriorityAttr(priority), stun.Fingerprint, ) require.NoError(t, err1) return msg } selectedCh := make(chan Candidate, 1) var expectNewSelectedCandidate Candidate err = aAgent.OnSelectedCandidatePairChange(func(_, remote Candidate) { selectedCh <- remote }) require.NoError(t, err) var bcandidates []Candidate bcandidates, err = bAgent.GetLocalCandidates() require.NoError(t, err) for _, cand := range bcandidates { if cand != bAgent.getSelectedPair().Local { //nolint:nestif if expectNewSelectedCandidate == nil { expected_change_priority: for _, candidates := range aAgent.remoteCandidates { for _, candidate := range candidates { if candidate.Equal(cand) { if tc.useHigherPriority { candidate.(*CandidateHost).priorityOverride += 1000 //nolint:forcetypeassert } else { candidate.(*CandidateHost).priorityOverride -= 1000 //nolint:forcetypeassert } break expected_change_priority } } } if tc.isExpectedToSwitch { expectNewSelectedCandidate = cand } else { expectNewSelectedCandidate = aAgent.getSelectedPair().Remote } } else { // a smaller change for other candidates other the new expected one change_priority: for _, candidates := range aAgent.remoteCandidates { for _, candidate := range candidates { if candidate.Equal(cand) { if tc.useHigherPriority { candidate.(*CandidateHost).priorityOverride += 500 //nolint:forcetypeassert } else { candidate.(*CandidateHost).priorityOverride -= 500 //nolint:forcetypeassert } break change_priority } } } } _, err = cand.writeTo( buildMsg( stun.ClassRequest, aAgent.localUfrag+":"+aAgent.remoteUfrag, aAgent.localPwd, cand.Priority(), ).Raw, bAgent.getSelectedPair().Remote, ) require.NoError(t, err) } } // Wait until either we observe the expected switch or the timeout elapses, // Ugly but makes the tests less flaky, especially on Windows. timeout := 3 * time.Second deadline := time.Now().Add(timeout) observedExpected := false waitLoop: for time.Now().Before(deadline) { select { case selected := <-selectedCh: if tc.isExpectedToSwitch { if selected.Equal(expectNewSelectedCandidate) { observedExpected = true break waitLoop } } default: time.Sleep(10 * time.Millisecond) } } if tc.isExpectedToSwitch { if !observedExpected { // Verify the agent's final selected pair if we didn't observe the event directly. require.True(t, aAgent.getSelectedPair().Remote.Equal(expectNewSelectedCandidate)) } } else { // Ensure no switch happened by checking the agent's final selected pair. require.True(t, aAgent.getSelectedPair().Remote.Equal(expectNewSelectedCandidate)) } }) } require.NoError(t, wan.Stop()) } // Close can deadlock but GracefulClose must not. func TestAgentGracefulCloseDeadlock(t *testing.T) { defer test.CheckRoutinesStrict(t)() defer test.TimeOut(time.Second * 5).Stop() config := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), } aAgent, err := NewAgent(config) require.NoError(t, err) var aAgentClosed bool defer func() { if aAgentClosed { return } require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgent(config) require.NoError(t, err) var bAgentClosed bool defer func() { if bAgentClosed { return } require.NoError(t, bAgent.Close()) }() var connected, closeNow, closed sync.WaitGroup connected.Add(2) closeNow.Add(1) closed.Add(2) closeHdlr := func(agent *Agent, agentClosed *bool) { require.NoError(t, agent.OnConnectionStateChange(func(cs ConnectionState) { if cs == ConnectionStateConnected { connected.Done() closeNow.Wait() go func() { require.NoError(t, agent.GracefulClose()) *agentClosed = true closed.Done() }() } })) } closeHdlr(aAgent, &aAgentClosed) closeHdlr(bAgent, &bAgentClosed) t.Log("connecting agents") _, _ = connect(t, aAgent, bAgent) t.Log("waiting for them to confirm connection in callback") connected.Wait() t.Log("tell them to close themselves in the same callback and wait") closeNow.Done() closed.Wait() } func TestSetCandidatesUfrag(t *testing.T) { var config AgentConfig agent, err := NewAgent(&config) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() dummyConn := &net.UDPConn{} for i := 0; i < 5; i++ { cfg := CandidateHostConfig{ Network: "udp", Address: "192.168.0.2", Port: 1000 + i, Component: 1, } cand, errCand := NewCandidateHost(&cfg) require.NoError(t, errCand) err = agent.addCandidate(context.Background(), cand, dummyConn) require.NoError(t, err) } actualCandidates, err := agent.GetLocalCandidates() require.NoError(t, err) for _, candidate := range actualCandidates { ext, ok := candidate.GetExtension("ufrag") require.True(t, ok) require.Equal(t, agent.localUfrag, ext.Value) } } func TestAlwaysSentKeepAlive(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() // Avoid deadlocks? defer test.TimeOut(1 * time.Second).Stop() agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() log := logging.NewDefaultLoggerFactory().NewLogger("agent") agent.selector = &controllingSelector{agent: agent, log: log} pair := makeCandidatePair(t) s, ok := pair.Local.(*CandidateHost) require.True(t, ok) s.conn = &fakenet.MockPacketConn{} agent.setSelectedPair(pair) pair.Remote.seen(false) lastSent := pair.Local.LastSent() agent.checkKeepalive() newLastSent := pair.Local.LastSent() require.NotEqual(t, lastSent, newLastSent) lastSent = newLastSent // Wait for enough time to pass so there is difference in sent time of local candidate. require.Eventually(t, func() bool { agent.checkKeepalive() newLastSent = pair.Local.LastSent() return !lastSent.Equal(newLastSent) }, 1*time.Second, 50*time.Millisecond) } func TestRoleConflict(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() runTest := func(t *testing.T, doDial bool) { t.Helper() cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, InterfaceFilter: problematicNetworkInterfaces, } aAgent, err := NewAgent(cfg) require.NoError(t, err) bAgent, err := NewAgent(cfg) require.NoError(t, err) isConnected := make(chan any) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateConnected { close(isConnected) } }) require.NoError(t, err) gatherAndExchangeCandidates(t, aAgent, bAgent) go func() { ufrag, pwd, routineErr := bAgent.GetLocalUserCredentials() require.NoError(t, routineErr) if doDial { _, routineErr = aAgent.Dial(context.TODO(), ufrag, pwd) } else { _, routineErr = aAgent.Accept(context.TODO(), ufrag, pwd) } require.NoError(t, routineErr) }() ufrag, pwd, err := aAgent.GetLocalUserCredentials() require.NoError(t, err) if doDial { _, err = bAgent.Dial(context.TODO(), ufrag, pwd) } else { _, err = bAgent.Accept(context.TODO(), ufrag, pwd) } require.NoError(t, err) <-isConnected require.NoError(t, aAgent.Close()) require.NoError(t, bAgent.Close()) } t.Run("Controlling", func(t *testing.T) { runTest(t, true) }) t.Run("Controlled", func(t *testing.T) { runTest(t, false) }) } func TestDefaultCandidateTypes(t *testing.T) { expected := []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay} first := defaultCandidateTypes() require.Equal(t, expected, first) first[0] = CandidateTypeRelay second := defaultCandidateTypes() require.Equal(t, expected, second) } func TestDefaultRelayAcceptanceMinWaitFor(t *testing.T) { t.Run("relay only defaults to zero wait", func(t *testing.T) { wait := defaultRelayAcceptanceMinWaitFor([]CandidateType{CandidateTypeRelay}) require.Equal(t, defaultRelayOnlyAcceptanceMinWait, wait) }) t.Run("empty candidate types uses general relay wait", func(t *testing.T) { wait := defaultRelayAcceptanceMinWaitFor(nil) require.Equal(t, defaultRelayAcceptanceMinWait, wait) }) t.Run("mixed candidate types uses general relay wait", func(t *testing.T) { wait := defaultRelayAcceptanceMinWaitFor([]CandidateType{CandidateTypeHost, CandidateTypeRelay}) require.Equal(t, defaultRelayAcceptanceMinWait, wait) }) } func TestAgentConfig_initWithDefaults_UsesProvidedValues(t *testing.T) { valMaxBindingReq := uint16(0) valSrflxWait := 111 * time.Millisecond valPrflxWait := 222 * time.Millisecond valRelayWait := 3 * time.Second valStunTimeout := 4 * time.Second cfg := &AgentConfig{ MaxBindingRequests: &valMaxBindingReq, SrflxAcceptanceMinWait: &valSrflxWait, PrflxAcceptanceMinWait: &valPrflxWait, RelayAcceptanceMinWait: &valRelayWait, STUNGatherTimeout: &valStunTimeout, } var a Agent cfg.initWithDefaults(&a) require.Equal(t, valMaxBindingReq, a.maxBindingRequests, "expected override for MaxBindingRequests") require.Equal(t, valSrflxWait, a.srflxAcceptanceMinWait, "expected override for SrflxAcceptanceMinWait") require.Equal(t, valPrflxWait, a.prflxAcceptanceMinWait, "expected override for PrflxAcceptanceMinWait") require.Equal(t, valRelayWait, a.relayAcceptanceMinWait, "expected override for RelayAcceptanceMinWait") require.Equal(t, valStunTimeout, a.stunGatherTimeout, "expected override for STUNGatherTimeout") } // TestAutomaticRenominationWithVNet tests automatic renomination with simple vnet setup. // This is a simplified test that verifies the renomination mechanism triggers correctly. func TestAutomaticRenominationWithVNet(t *testing.T) { defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() // Create simple vnet with two agents on same network (no NAT) wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: loggerFactory, }) require.NoError(t, err) net0, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.1"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net0)) net1, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.2"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net1)) require.NoError(t, wan.Start()) defer wan.Stop() //nolint:errcheck // Create agents with automatic renomination keepaliveInterval := 100 * time.Millisecond checkInterval := 50 * time.Millisecond renominationInterval := 200 * time.Millisecond agent1, err := newAgentFromConfig(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, MulticastDNSMode: MulticastDNSModeDisabled, Net: net0, KeepaliveInterval: &keepaliveInterval, CheckInterval: &checkInterval, }, WithRenomination(DefaultNominationValueGenerator()), WithAutomaticRenomination(renominationInterval), ) require.NoError(t, err) defer agent1.Close() //nolint:errcheck agent2, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, MulticastDNSMode: MulticastDNSModeDisabled, Net: net1, KeepaliveInterval: &keepaliveInterval, CheckInterval: &checkInterval, }) require.NoError(t, err) defer agent2.Close() //nolint:errcheck agent2.enableRenomination = true agent2.nominationValueGenerator = DefaultNominationValueGenerator() // Connect the agents using the existing helper conn1, conn2 := connectWithVNet(t, agent1, agent2) // Verify connection works testData := []byte("test data") _, err = conn1.Write(testData) require.NoError(t, err) buf := make([]byte, len(testData)) _, err = conn2.Read(buf) require.NoError(t, err) require.Equal(t, testData, buf) } // TestAutomaticRenominationRTTImprovement tests that automatic renomination // triggers when RTT significantly improves. func TestAutomaticRenominationRTTImprovement(t *testing.T) { defer test.CheckRoutines(t)() // This test verifies the RTT-based renomination logic agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer agent.Close() //nolint:errcheck // Create two pairs with different RTTs localHost1, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 10000, Component: 1, }) require.NoError(t, err) localHost2, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.3", // Different address Port: 10001, Component: 1, }) require.NoError(t, err) remoteHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.2", Port: 20000, Component: 1, }) require.NoError(t, err) // Current pair with high RTT currentPair := newCandidatePair(localHost1, remoteHost, true) currentPair.state = CandidatePairStateSucceeded currentPair.UpdateRoundTripTime(100 * time.Millisecond) // Candidate pair with significantly better RTT (>10ms improvement) betterPair := newCandidatePair(localHost2, remoteHost, true) betterPair.state = CandidatePairStateSucceeded betterPair.UpdateRoundTripTime(50 * time.Millisecond) // 50ms improvement // Should trigger renomination due to RTT improvement shouldRenominate := agent.shouldRenominate(currentPair, betterPair) require.True(t, shouldRenominate, "Should renominate for >10ms RTT improvement") // Test with small RTT improvement (<10ms) slightlyBetterPair := newCandidatePair(localHost2, remoteHost, true) slightlyBetterPair.state = CandidatePairStateSucceeded slightlyBetterPair.UpdateRoundTripTime(95 * time.Millisecond) // Only 5ms improvement shouldRenominate = agent.shouldRenominate(currentPair, slightlyBetterPair) require.False(t, shouldRenominate, "Should not renominate for <10ms RTT improvement") } // TestAutomaticRenominationRelayToDirect tests that automatic renomination // always prefers direct connections over relay connections. func TestAutomaticRenominationRelayToDirect(t *testing.T) { defer test.CheckRoutines(t)() agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer agent.Close() //nolint:errcheck // Create relay pair localRelay, err := NewCandidateRelay(&CandidateRelayConfig{ Network: "udp", Address: "10.0.0.1", Port: 30000, Component: 1, RelAddr: "192.168.1.1", RelPort: 10000, }) require.NoError(t, err) remoteRelay, err := NewCandidateRelay(&CandidateRelayConfig{ Network: "udp", Address: "10.0.0.2", Port: 40000, Component: 1, RelAddr: "192.168.1.2", RelPort: 20000, }) require.NoError(t, err) relayPair := newCandidatePair(localRelay, remoteRelay, true) relayPair.state = CandidatePairStateSucceeded relayPair.UpdateRoundTripTime(50 * time.Millisecond) // Create host pair with similar RTT localHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 10000, Component: 1, }) require.NoError(t, err) remoteHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.2", Port: 20000, Component: 1, }) require.NoError(t, err) hostPair := newCandidatePair(localHost, remoteHost, true) hostPair.state = CandidatePairStateSucceeded hostPair.UpdateRoundTripTime(45 * time.Millisecond) // Similar RTT // Should always prefer direct over relay shouldRenominate := agent.shouldRenominate(relayPair, hostPair) require.True(t, shouldRenominate, "Should always renominate from relay to direct connection") } ice-4.1.0/agent_udpmux_test.go000066400000000000000000000065641511704470200163560ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "net" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) // newMuxForAddr creates a UDPMuxDefault with the correct socket family for the given address. // This fixes Windows dual-stack issues where IPv6 sockets don't receive IPv4 traffic by default. func newMuxForAddr(t *testing.T, addr *net.UDPAddr, loggerFactory logging.LoggerFactory) *UDPMuxDefault { t.Helper() var ( network string laddr *net.UDPAddr ) switch { case addr.IP == nil || addr.IP.IsUnspecified(): network = "udp4" laddr = &net.UDPAddr{IP: net.IPv4zero, Port: addr.Port} case addr.IP.To4() != nil: network = "udp4" laddr = &net.UDPAddr{IP: net.IPv4zero, Port: addr.Port} default: network = "udp6" laddr = &net.UDPAddr{IP: net.IPv6unspecified, Port: addr.Port} } pc, err := net.ListenUDP(network, laddr) require.NoError(t, err) t.Cleanup(func() { _ = pc.Close() }) return NewUDPMuxDefault(UDPMuxParams{ Logger: loggerFactory.NewLogger("ice"), UDPConn: pc, }) } // TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux. func TestMuxAgent(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() const muxPort = 7686 caseAddrs := map[string]*net.UDPAddr{ "unspecified": {Port: muxPort}, "ipv4Loopback": {IP: net.IPv4(127, 0, 0, 1), Port: muxPort}, } for subTest, addr := range caseAddrs { muxAddr := addr t.Run(subTest, func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() udpMux := newMuxForAddr(t, muxAddr, loggerFactory) muxedA, err := NewAgent(&AgentConfig{ UDPMux: udpMux, CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: []NetworkType{ NetworkTypeUDP4, }, IncludeLoopback: addr.IP.IsLoopback(), }) require.NoError(t, err) var muxedAClosed bool defer func() { if muxedAClosed { return } require.NoError(t, muxedA.Close()) }() agent, err := NewAgent(&AgentConfig{ CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: supportedNetworkTypes(), }) require.NoError(t, err) var aClosed bool defer func() { if aClosed { return } require.NoError(t, agent.Close()) }() conn, muxedConn := connect(t, agent, muxedA) pair := muxedA.getSelectedPair() require.NotNil(t, pair) require.Equal(t, muxPort, pair.Local.Port()) // Send a packet to Mux data := []byte("hello world") _, err = conn.Write(data) require.NoError(t, err) buf := make([]byte, 1024) n, err := muxedConn.Read(buf) require.NoError(t, err) require.Equal(t, data, buf[:n]) // Send a packet from Mux _, err = muxedConn.Write(data) require.NoError(t, err) n, err = conn.Read(buf) require.NoError(t, err) require.Equal(t, data, buf[:n]) // Close it down require.NoError(t, conn.Close()) aClosed = true require.NoError(t, muxedConn.Close()) muxedAClosed = true require.NoError(t, udpMux.Close()) // Expect error when reading from closed mux _, err = muxedConn.Read(data) require.Error(t, err) // Expect error when writing to closed mux _, err = muxedConn.Write(data) require.Error(t, err) }) } } ice-4.1.0/candidate.go000066400000000000000000000054751511704470200145330ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "context" "net" "time" ) const ( receiveMTU = 8192 defaultLocalPreference = 65535 // ComponentRTP indicates that the candidate is used for RTP. ComponentRTP uint16 = 1 // ComponentRTCP indicates that the candidate is used for RTCP. ComponentRTCP ) // Candidate represents an ICE candidate. type Candidate interface { // An arbitrary string used in the freezing algorithm to // group similar candidates. It is the same for two candidates that // have the same type, base IP address, protocol (UDP, TCP, etc.), // and STUN or TURN server. Foundation() string // ID is a unique identifier for just this candidate // Unlike the foundation this is different for each candidate ID() string // A component is a piece of a data stream. // An example is one for RTP, and one for RTCP Component() uint16 SetComponent(uint16) // The last time this candidate received traffic LastReceived() time.Time // The last time this candidate sent traffic LastSent() time.Time NetworkType() NetworkType Address() string Port() int Priority() uint32 // A transport address related to a // candidate, which is useful for diagnostics and other purposes RelatedAddress() *CandidateRelatedAddress // Extensions returns a copy of all extension attributes associated with the ICECandidate. // In the order of insertion, *(key value). // Extension attributes are defined in RFC 5245, Section 15.1: // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 //. Extensions() []CandidateExtension // GetExtension returns the value of the extension attribute associated with the ICECandidate. // Extension attributes are defined in RFC 5245, Section 15.1: // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 //. GetExtension(key string) (value CandidateExtension, ok bool) // AddExtension adds an extension attribute to the ICECandidate. // If an extension with the same key already exists, it will be overwritten. // Extension attributes are defined in RFC 5245, Section 15.1: AddExtension(extension CandidateExtension) error // RemoveExtension removes an extension attribute from the ICECandidate. // Extension attributes are defined in RFC 5245, Section 15.1: RemoveExtension(key string) (ok bool) String() string Type() CandidateType TCPType() TCPType Equal(other Candidate) bool // DeepEqual same as Equal, But it also compares the candidate extensions. DeepEqual(other Candidate) bool Marshal() string addr() net.Addr filterForLocationTracking() bool agent() *Agent context() context.Context close() error copy() (Candidate, error) seen(outbound bool) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) writeTo(raw []byte, dst Candidate) (int, error) } ice-4.1.0/candidate_base.go000066400000000000000000000641221511704470200155170ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "context" "errors" "fmt" "hash/crc32" "io" "net" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/pion/stun/v3" ) type candidateBase struct { id string networkType NetworkType candidateType CandidateType component uint16 address string port int relatedAddress *CandidateRelatedAddress tcpType TCPType resolvedAddr net.Addr lastSent atomic.Value lastReceived atomic.Value conn net.PacketConn currAgent *Agent closeCh chan struct{} closedCh chan struct{} foundationOverride string priorityOverride uint32 remoteCandidateCaches map[AddrPort]Candidate isLocationTracked bool extensions []CandidateExtension } // Done implements context.Context. func (c *candidateBase) Done() <-chan struct{} { return c.closeCh } // Err implements context.Context. func (c *candidateBase) Err() error { select { case <-c.closedCh: return ErrRunCanceled default: return nil } } // Deadline implements context.Context. func (c *candidateBase) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false } // Value implements context.Context. func (c *candidateBase) Value(any) any { return nil } // ID returns Candidate ID. func (c *candidateBase) ID() string { return c.id } func (c *candidateBase) Foundation() string { if c.foundationOverride != "" { return c.foundationOverride } return fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(c.Type().String()+c.address+c.networkType.String()))) } // Address returns Candidate Address. func (c *candidateBase) Address() string { return c.address } // Port returns Candidate Port. func (c *candidateBase) Port() int { return c.port } // Type returns candidate type. func (c *candidateBase) Type() CandidateType { return c.candidateType } // NetworkType returns candidate NetworkType. func (c *candidateBase) NetworkType() NetworkType { return c.networkType } // Component returns candidate component. func (c *candidateBase) Component() uint16 { return c.component } func (c *candidateBase) SetComponent(component uint16) { c.component = component } // LocalPreference returns the local preference for this candidate. func (c *candidateBase) LocalPreference() uint16 { //nolint:cyclop if c.NetworkType().IsTCP() { // RFC 6544, section 4.2 // // In Section 4.1.2.1 of [RFC5245], a recommended formula for UDP ICE // candidate prioritization is defined. For TCP candidates, the same // formula and candidate type preferences SHOULD be used, and the // RECOMMENDED type preferences for the new candidate types defined in // this document (see Section 5) are 105 for NAT-assisted candidates and // 75 for UDP-tunneled candidates. // // (...) // // With TCP candidates, the local preference part of the recommended // priority formula is updated to also include the directionality // (active, passive, or simultaneous-open) of the TCP connection. The // RECOMMENDED local preference is then defined as: // // local preference = (2^13) * direction-pref + other-pref // // The direction-pref MUST be between 0 and 7 (both inclusive), with 7 // being the most preferred. The other-pref MUST be between 0 and 8191 // (both inclusive), with 8191 being the most preferred. It is // RECOMMENDED that the host, UDP-tunneled, and relayed TCP candidates // have the direction-pref assigned as follows: 6 for active, 4 for // passive, and 2 for S-O. For the NAT-assisted and server reflexive // candidates, the RECOMMENDED values are: 6 for S-O, 4 for active, and // 2 for passive. // // (...) // // If any two candidates have the same type-preference and direction- // pref, they MUST have a unique other-pref. With this specification, // this usually only happens with multi-homed hosts, in which case // other-pref is the preference for the particular IP address from which // the candidate was obtained. When there is only a single IP address, // this value SHOULD be set to the maximum allowed value (8191). var otherPref uint16 = 8191 directionPref := func() uint16 { switch c.Type() { case CandidateTypeHost, CandidateTypeRelay: switch c.tcpType { case TCPTypeActive: return 6 case TCPTypePassive: return 4 case TCPTypeSimultaneousOpen: return 2 case TCPTypeUnspecified: return 0 } case CandidateTypePeerReflexive, CandidateTypeServerReflexive: switch c.tcpType { case TCPTypeSimultaneousOpen: return 6 case TCPTypeActive: return 4 case TCPTypePassive: return 2 case TCPTypeUnspecified: return 0 } case CandidateTypeUnspecified: return 0 } return 0 }() return (1<<13)*directionPref + otherPref } return defaultLocalPreference } // RelatedAddress returns *CandidateRelatedAddress. func (c *candidateBase) RelatedAddress() *CandidateRelatedAddress { return c.relatedAddress } func (c *candidateBase) TCPType() TCPType { return c.tcpType } // start runs the candidate using the provided connection. func (c *candidateBase) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) { if c.conn != nil { c.agent().log.Warn("Can't start already started candidateBase") return } c.currAgent = a c.conn = conn c.closeCh = make(chan struct{}) c.closedCh = make(chan struct{}) go c.recvLoop(initializedCh) } var bufferPool = sync.Pool{ // nolint:gochecknoglobals New: func() any { return make([]byte, receiveMTU) }, } func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) { agent := c.agent() defer close(c.closedCh) select { case <-initializedCh: case <-c.closeCh: return } bufferPoolBuffer := bufferPool.Get() defer bufferPool.Put(bufferPoolBuffer) buf, ok := bufferPoolBuffer.([]byte) if !ok { return } for { n, srcAddr, err := c.conn.ReadFrom(buf) if err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { agent.log.Warnf("Failed to read from candidate %s: %v", c, err) } return } c.handleInboundPacket(buf[:n], srcAddr) } } func (c *candidateBase) validateSTUNTrafficCache(addr net.Addr) bool { if candidate, ok := c.remoteCandidateCaches[toAddrPort(addr)]; ok { candidate.seen(false) return true } return false } func (c *candidateBase) addRemoteCandidateCache(candidate Candidate, srcAddr net.Addr) { if c.validateSTUNTrafficCache(srcAddr) { return } c.remoteCandidateCaches[toAddrPort(srcAddr)] = candidate } func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) { agent := c.agent() if stun.IsMessage(buf) { msg := &stun.Message{ Raw: make([]byte, len(buf)), } // Explicitly copy raw buffer so Message can own the memory. copy(msg.Raw, buf) if err := msg.Decode(); err != nil { agent.log.Warnf("Failed to handle decode ICE from %s to %s: %v", c.addr(), srcAddr, err) return } if err := agent.loop.Run(c, func(_ context.Context) { // nolint: contextcheck agent.handleInbound(msg, c, srcAddr) }); err != nil { agent.log.Warnf("Failed to handle message: %v", err) } return } if !c.validateSTUNTrafficCache(srcAddr) { remoteCandidate, valid := agent.validateNonSTUNTraffic(c, srcAddr) //nolint:contextcheck if !valid { agent.log.Warnf("Discarded message from %s, not a valid remote candidate", c.addr()) return } c.addRemoteCandidateCache(remoteCandidate, srcAddr) } // Note: This will return packetio.ErrFull if the buffer ever manages to fill up. n, err := agent.buf.Write(buf) if err != nil { agent.log.Warnf("Failed to write packet: %s", err) return } // Add received application bytes to the currently selected candidate pair. if n > 0 { if sp := agent.getSelectedPair(); sp != nil { sp.UpdatePacketReceived(n) } } } // close stops the recvLoop. func (c *candidateBase) close() error { // If conn has never been started will be nil if c.Done() == nil { return nil } // Assert that conn has not already been closed select { case <-c.Done(): return nil default: } var firstErr error // Unblock recvLoop close(c.closeCh) if err := c.conn.SetDeadline(time.Now()); err != nil { firstErr = err } // Close the conn if err := c.conn.Close(); err != nil && firstErr == nil { firstErr = err } if firstErr != nil { return firstErr } // Wait until the recvLoop is closed <-c.closedCh return nil } func (c *candidateBase) writeTo(raw []byte, dst Candidate) (int, error) { n, err := c.conn.WriteTo(raw, dst.addr()) if err != nil { // If the connection is closed, we should return the error if errors.Is(err, io.ErrClosedPipe) { return n, err } c.agent().log.Infof("Failed to send packet: %v", err) return n, nil } c.seen(true) return n, nil } // TypePreference returns the type preference for this candidate. func (c *candidateBase) TypePreference() uint16 { pref := c.Type().Preference() if pref == 0 { return 0 } if c.NetworkType().IsTCP() { var tcpPriorityOffset uint16 = defaultTCPPriorityOffset if c.agent() != nil { tcpPriorityOffset = c.agent().tcpPriorityOffset } pref -= tcpPriorityOffset } return pref } // Priority computes the priority for this ICE Candidate // See: https://www.rfc-editor.org/rfc/rfc8445#section-5.1.2.1 func (c *candidateBase) Priority() uint32 { if c.priorityOverride != 0 { return c.priorityOverride } // The local preference MUST be an integer from 0 (lowest preference) to // 65535 (highest preference) inclusive. When there is only a single IP // address, this value SHOULD be set to 65535. If there are multiple // candidates for a particular component for a particular data stream // that have the same type, the local preference MUST be unique for each // one. return (1<<24)*uint32(c.TypePreference()) + (1<<8)*uint32(c.LocalPreference()) + (1<<0)*uint32(256-c.Component()) } // Equal is used to compare two candidateBases. func (c *candidateBase) Equal(other Candidate) bool { if c.addr() != other.addr() { if c.addr() == nil || other.addr() == nil { return false } if !addrEqual(c.addr(), other.addr()) { return false } } return c.NetworkType() == other.NetworkType() && c.Type() == other.Type() && c.Address() == other.Address() && c.Port() == other.Port() && c.TCPType() == other.TCPType() && c.RelatedAddress().Equal(other.RelatedAddress()) } // DeepEqual is same as Equal but also compares the extensions. func (c *candidateBase) DeepEqual(other Candidate) bool { return c.Equal(other) && c.extensionsEqual(other.Extensions()) } // String makes the candidateBase printable. func (c *candidateBase) String() string { return fmt.Sprintf( "%s %s %s%s (resolved: %v)", c.NetworkType(), c.Type(), net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())), c.relatedAddress, c.resolvedAddr, ) } // LastReceived returns a time.Time indicating the last time // this candidate was received. func (c *candidateBase) LastReceived() time.Time { if lastReceived, ok := c.lastReceived.Load().(time.Time); ok { return lastReceived } return time.Time{} } func (c *candidateBase) setLastReceived(t time.Time) { c.lastReceived.Store(t) } // LastSent returns a time.Time indicating the last time // this candidate was sent. func (c *candidateBase) LastSent() time.Time { if lastSent, ok := c.lastSent.Load().(time.Time); ok { return lastSent } return time.Time{} } func (c *candidateBase) setLastSent(t time.Time) { c.lastSent.Store(t) } func (c *candidateBase) seen(outbound bool) { if outbound { c.setLastSent(time.Now()) } else { c.setLastReceived(time.Now()) } } func (c *candidateBase) addr() net.Addr { return c.resolvedAddr } func (c *candidateBase) filterForLocationTracking() bool { return c.isLocationTracked } func (c *candidateBase) agent() *Agent { return c.currAgent } func (c *candidateBase) context() context.Context { return c } func (c *candidateBase) copy() (Candidate, error) { return UnmarshalCandidate(c.Marshal()) } func removeZoneIDFromAddress(addr string) string { if i := strings.Index(addr, "%"); i != -1 { return addr[:i] } return addr } // Marshal returns the string representation of the ICECandidate. func (c *candidateBase) Marshal() string { val := c.Foundation() if val == " " { val = "" } val = fmt.Sprintf("%s %d %s %d %s %d typ %s", val, c.Component(), c.NetworkType().NetworkShort(), c.Priority(), removeZoneIDFromAddress(c.Address()), c.Port(), c.Type()) if r := c.RelatedAddress(); r != nil && r.Address != "" && r.Port != 0 { val = fmt.Sprintf("%s raddr %s rport %d", val, r.Address, r.Port) } extensions := c.marshalExtensions() if extensions != "" { val = fmt.Sprintf("%s %s", val, extensions) } return val } // CandidateExtension represents a single candidate extension // as defined in https://tools.ietf.org/html/rfc5245#section-15.1 // . type CandidateExtension struct { Key string Value string } func (c *candidateBase) Extensions() []CandidateExtension { tcpType := c.TCPType() hasTCPType := 0 if tcpType != TCPTypeUnspecified { hasTCPType = 1 } extensions := make([]CandidateExtension, len(c.extensions)+hasTCPType) // We store the TCPType in c.tcpType, but we need to return it as an extension. if hasTCPType == 1 { extensions[0] = CandidateExtension{ Key: "tcptype", Value: tcpType.String(), } } copy(extensions[hasTCPType:], c.extensions) return extensions } // Get returns the value of the given key if it exists. func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) { extension := CandidateExtension{Key: key} for i := range c.extensions { if c.extensions[i].Key == key { extension.Value = c.extensions[i].Value return extension, true } } // TCPType was manually set. if key == "tcptype" && c.TCPType() != TCPTypeUnspecified { //nolint:goconst extension.Value = c.TCPType().String() return extension, true } return extension, false } func (c *candidateBase) AddExtension(ext CandidateExtension) error { if ext.Key == "tcptype" { tcpType := NewTCPType(ext.Value) if tcpType == TCPTypeUnspecified { return fmt.Errorf("%w: invalid or unsupported TCPtype %s", errParseTCPType, ext.Value) } c.tcpType = tcpType return nil } if ext.Key == "" { return fmt.Errorf("%w: key is empty", errParseExtension) } // per spec, Extensions aren't explicitly unique, we only set the first one. // If the exteion is set multiple times. for i := range c.extensions { if c.extensions[i].Key == ext.Key { c.extensions[i] = ext return nil } } c.extensions = append(c.extensions, ext) return nil } func (c *candidateBase) RemoveExtension(key string) (ok bool) { if key == "tcptype" { c.tcpType = TCPTypeUnspecified ok = true } for i := range c.extensions { if c.extensions[i].Key == key { c.extensions = append(c.extensions[:i], c.extensions[i+1:]...) ok = true break } } return ok } // marshalExtensions returns the string representation of the candidate extensions. func (c *candidateBase) marshalExtensions() string { value := "" exts := c.Extensions() for i := range exts { if value != "" { value += " " } value += exts[i].Key + " " + exts[i].Value } return value } // Equal returns true if the candidate extensions are equal. func (c *candidateBase) extensionsEqual(other []CandidateExtension) bool { freq1 := make(map[CandidateExtension]int) freq2 := make(map[CandidateExtension]int) if len(c.extensions) != len(other) { return false } if len(c.extensions) == 0 { return true } if len(c.extensions) == 1 { return c.extensions[0] == other[0] } for i := range c.extensions { freq1[c.extensions[i]]++ freq2[other[i]]++ } for k, v := range freq1 { if freq2[k] != v { return false } } return true } func (c *candidateBase) setExtensions(extensions []CandidateExtension) { c.extensions = extensions } // UnmarshalCandidate Parses a candidate from a string // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 func UnmarshalCandidate(raw string) (Candidate, error) { //nolint:cyclop // Handle candidates with the "candidate:" prefix as defined in RFC 5245 section 15.1. raw = strings.TrimPrefix(raw, "candidate:") pos := 0 // foundation ( 1*32ice-char ) But we allow for empty foundation, foundation, pos, err := readCandidateCharToken(raw, pos, 32) if err != nil { return nil, fmt.Errorf("%w: %v in %s", errParseFoundation, err, raw) //nolint:errorlint // we wrap the error } // Empty foundation, not RFC 8445 compliant but seen in the wild if foundation == "" { foundation = " " } if pos >= len(raw) { return nil, fmt.Errorf("%w: expected component in %s", errAttributeTooShortICECandidate, raw) } // component-id ( 1*5DIGIT ) component, pos, err := readCandidateDigitToken(raw, pos, 5) if err != nil { return nil, fmt.Errorf("%w: %v in %s", errParseComponent, err, raw) //nolint:errorlint // we wrap the error } if pos >= len(raw) { return nil, fmt.Errorf("%w: expected transport in %s", errAttributeTooShortICECandidate, raw) } // transport ( "UDP" / transport-extension ; from RFC 3261 ) SP protocol, pos := readCandidateStringToken(raw, pos) if pos >= len(raw) { return nil, fmt.Errorf("%w: expected priority in %s", errAttributeTooShortICECandidate, raw) } // priority ( 1*10DIGIT ) SP priority, pos, err := readCandidateDigitToken(raw, pos, 10) if err != nil { return nil, fmt.Errorf("%w: %v in %s", errParsePriority, err, raw) //nolint:errorlint // we wrap the error } if pos >= len(raw) { return nil, fmt.Errorf("%w: expected address in %s", errAttributeTooShortICECandidate, raw) } // connection-address SP ;from RFC 4566 address, pos := readCandidateStringToken(raw, pos) // Remove IPv6 ZoneID: https://github.com/pion/ice/pull/704 address = removeZoneIDFromAddress(address) if pos >= len(raw) { return nil, fmt.Errorf("%w: expected port in %s", errAttributeTooShortICECandidate, raw) } // port from RFC 4566 port, pos, err := readCandidatePort(raw, pos) if err != nil { return nil, fmt.Errorf("%w: %v in %s", errParsePort, err, raw) //nolint:errorlint // we wrap the error } // "typ" SP typeKey, pos := readCandidateStringToken(raw, pos) if typeKey != "typ" { return nil, fmt.Errorf("%w (%s)", ErrUnknownCandidateTyp, typeKey) } if pos >= len(raw) { return nil, fmt.Errorf("%w: expected candidate type in %s", errAttributeTooShortICECandidate, raw) } // SP cand-type ("host" / "srflx" / "prflx" / "relay") typ, pos := readCandidateStringToken(raw, pos) raddr, rport, pos, err := tryReadRelativeAddrs(raw, pos) if err != nil { return nil, err } tcpType := TCPTypeUnspecified var extensions []CandidateExtension var tcpTypeRaw string if pos < len(raw) { extensions, tcpTypeRaw, err = unmarshalCandidateExtensions(raw[pos:]) if err != nil { return nil, fmt.Errorf("%w: %v", errParseExtension, err) //nolint:errorlint // we wrap the error } if tcpTypeRaw != "" { tcpType = NewTCPType(tcpTypeRaw) if tcpType == TCPTypeUnspecified { return nil, fmt.Errorf("%w: invalid or unsupported TCPtype %s", errParseTCPType, tcpTypeRaw) } } } // this code is ugly because we can't break backwards compatibility // with the old way of parsing candidates switch typ { case "host": candidate, err := NewCandidateHost(&CandidateHostConfig{ "", protocol, address, port, uint16(component), //nolint:gosec // G115 no overflow we read 5 digits uint32(priority), //nolint:gosec // G115 no overflow we read 5 digits foundation, tcpType, false, }) if err != nil { return nil, err } candidate.setExtensions(extensions) return candidate, nil case "srflx": candidate, err := NewCandidateServerReflexive(&CandidateServerReflexiveConfig{ "", protocol, address, port, uint16(component), //nolint:gosec // G115 no overflow we read 5 digits uint32(priority), //nolint:gosec // G115 no overflow we read 5 digits foundation, raddr, rport, }) if err != nil { return nil, err } candidate.setExtensions(extensions) return candidate, nil case "prflx": candidate, err := NewCandidatePeerReflexive(&CandidatePeerReflexiveConfig{ "", protocol, address, port, uint16(component), //nolint:gosec // G115 no overflow we read 5 digits uint32(priority), //nolint:gosec // G115 no overflow we read 5 digits foundation, raddr, rport, }) if err != nil { return nil, err } candidate.setExtensions(extensions) return candidate, nil case "relay": candidate, err := NewCandidateRelay(&CandidateRelayConfig{ "", protocol, address, port, uint16(component), //nolint:gosec // G115 no overflow we read 5 digits uint32(priority), //nolint:gosec // G115 no overflow we read 5 digits foundation, raddr, rport, "", nil, }) if err != nil { return nil, err } candidate.setExtensions(extensions) return candidate, nil default: return nil, fmt.Errorf("%w (%s)", ErrUnknownCandidateTyp, typ) } } // Read an ice-char token from the raw string // ice-char = ALPHA / DIGIT / "+" / "/" // stop reading when a space is encountered or the end of the string. func readCandidateCharToken(raw string, start int, limit int) (string, int, error) { //nolint:cyclop for i, char := range raw[start:] { if char == 0x20 { // SP return raw[start : start+i], start + i + 1, nil } if i == limit { //nolint: err113 // handled by caller return "", 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) } if (char < 'A' || char > 'Z') && (char < 'a' || char > 'z') && (char < '0' || char > '9') && char != '+' && char != '/' { return "", 0, fmt.Errorf("invalid ice-char token: %c", char) //nolint: err113 // handled by caller } } return raw[start:], len(raw), nil } // Read an ice string token from the raw string until a space is encountered // Or the end of the string, we imply that ice string are UTF-8 encoded. func readCandidateStringToken(raw string, start int) (string, int) { for i, char := range raw[start:] { if char == 0x20 { // SP return raw[start : start+i], start + i + 1 } } return raw[start:], len(raw) } // Read a digit token from the raw string // stop reading when a space is encountered or the end of the string. func readCandidateDigitToken(raw string, start, limit int) (int, int, error) { var val int for i, char := range raw[start:] { if char == 0x20 { // SP return val, start + i + 1, nil } if i == limit { //nolint: err113 // handled by caller return 0, 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) } if char < '0' || char > '9' { return 0, 0, fmt.Errorf("invalid digit token: %c", char) //nolint: err113 // handled by caller } val = val*10 + int(char-'0') } return val, len(raw), nil } // Read and validate RFC 4566 port from the raw string. func readCandidatePort(raw string, start int) (int, int, error) { port, pos, err := readCandidateDigitToken(raw, start, 5) if err != nil { return 0, 0, err } if port > 65535 { return 0, 0, fmt.Errorf("invalid RFC 4566 port %d", port) //nolint: err113 // handled by caller } return port, pos, nil } // Read a byte-string token from the raw string // As defined in RFC 4566 1*(%x01-09/%x0B-0C/%x0E-FF) ;any byte except NUL, CR, or LF // we imply that extensions byte-string are UTF-8 encoded. func readCandidateByteString(raw string, start int) (string, int, error) { for i, char := range raw[start:] { if char == 0x20 { // SP return raw[start : start+i], start + i + 1, nil } // 1*(%x01-09/%x0B-0C/%x0E-FF) if (char < 0x01 || char > 0x09) && (char < 0x0B || char > 0x0C) && (char < 0x0E || char > 0xFF) { return "", 0, fmt.Errorf("invalid byte-string character: %c", char) //nolint: err113 // handled by caller } } return raw[start:], len(raw), nil } // Read and validate raddr and rport from the raw string // [SP rel-addr] [SP rel-port] // defined in https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 // . func tryReadRelativeAddrs(raw string, start int) (raddr string, rport, pos int, err error) { key, pos := readCandidateStringToken(raw, start) if key != "raddr" { return "", 0, start, nil } if pos >= len(raw) { return "", 0, 0, fmt.Errorf("%w: expected raddr value in %s", errParseRelatedAddr, raw) } raddr, pos = readCandidateStringToken(raw, pos) if pos >= len(raw) { return "", 0, 0, fmt.Errorf("%w: expected rport in %s", errParseRelatedAddr, raw) } key, pos = readCandidateStringToken(raw, pos) if key != "rport" { return "", 0, 0, fmt.Errorf("%w: expected rport in %s", errParseRelatedAddr, raw) } if pos >= len(raw) { return "", 0, 0, fmt.Errorf("%w: expected rport value in %s", errParseRelatedAddr, raw) } rport, pos, err = readCandidatePort(raw, pos) if err != nil { return "", 0, 0, fmt.Errorf("%w: %v", errParseRelatedAddr, err) //nolint:errorlint // we wrap the error } return raddr, rport, pos, nil } // UnmarshalCandidateExtensions parses the candidate extensions from the raw string. // *(SP extension-att-name SP extension-att-value) // Where extension-att-name, and extension-att-value are byte-strings // as defined in https://tools.ietf.org/html/rfc5245#section-15.1 func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension, rawTCPTypeRaw string, err error) { extensions = make([]CandidateExtension, 0) if raw == "" { return extensions, "", nil } if raw[0] == 0x20 { // SP return extensions, "", fmt.Errorf("%w: unexpected space %s", errParseExtension, raw) } for i := 0; i < len(raw); { key, next, err := readCandidateByteString(raw, i) if err != nil { return extensions, "", fmt.Errorf( "%w: failed to read key %v", errParseExtension, err, //nolint: errorlint // we wrap the error ) } i = next // while not spec-compliant, we allow for empty values, as seen in the wild var value string if i < len(raw) { value, next, err = readCandidateByteString(raw, i) if err != nil { return extensions, "", fmt.Errorf( "%w: failed to read value %v", errParseExtension, err, //nolint: errorlint // we are wrapping the error ) } i = next } if key == "tcptype" { rawTCPTypeRaw = value continue } extensions = append(extensions, CandidateExtension{key, value}) } return extensions, rawTCPTypeRaw, nil } ice-4.1.0/candidate_host.go000066400000000000000000000040121511704470200155520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net/netip" "strings" ) // CandidateHost is a candidate of type host. type CandidateHost struct { candidateBase network string } // CandidateHostConfig is the config required to create a new CandidateHost. type CandidateHostConfig struct { CandidateID string Network string Address string Port int Component uint16 Priority uint32 Foundation string TCPType TCPType IsLocationTracked bool } // NewCandidateHost creates a new host candidate. func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) { candidateID := config.CandidateID if candidateID == "" { candidateID = globalCandidateIDGenerator.Generate() } candidateHost := &CandidateHost{ candidateBase: candidateBase{ id: candidateID, address: config.Address, candidateType: CandidateTypeHost, component: config.Component, port: config.Port, tcpType: config.TCPType, foundationOverride: config.Foundation, priorityOverride: config.Priority, remoteCandidateCaches: map[AddrPort]Candidate{}, isLocationTracked: config.IsLocationTracked, }, network: config.Network, } if !strings.HasSuffix(config.Address, ".local") { ipAddr, err := netip.ParseAddr(config.Address) if err != nil { return nil, err } if err := candidateHost.setIPAddr(ipAddr); err != nil { return nil, err } } else { // Until mDNS candidate is resolved assume it is UDPv4 candidateHost.candidateBase.networkType = NetworkTypeUDP4 } return candidateHost, nil } func (c *CandidateHost) setIPAddr(addr netip.Addr) error { networkType, err := determineNetworkType(c.network, addr) if err != nil { return err } c.candidateBase.networkType = networkType c.candidateBase.resolvedAddr = createAddr(networkType, addr, c.port) return nil } ice-4.1.0/candidate_peer_reflexive.go000066400000000000000000000032321511704470200176040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package ice ... // //nolint:dupl package ice import ( "net/netip" ) // CandidatePeerReflexive ... type CandidatePeerReflexive struct { candidateBase } // CandidatePeerReflexiveConfig is the config required to create a new CandidatePeerReflexive. type CandidatePeerReflexiveConfig struct { CandidateID string Network string Address string Port int Component uint16 Priority uint32 Foundation string RelAddr string RelPort int } // NewCandidatePeerReflexive creates a new peer reflective candidate. func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*CandidatePeerReflexive, error) { ipAddr, err := netip.ParseAddr(config.Address) if err != nil { return nil, err } networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } candidateID := config.CandidateID if candidateID == "" { candidateID = globalCandidateIDGenerator.Generate() } return &CandidatePeerReflexive{ candidateBase: candidateBase{ id: candidateID, networkType: networkType, candidateType: CandidateTypePeerReflexive, address: config.Address, port: config.Port, resolvedAddr: createAddr(networkType, ipAddr, config.Port), component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, relatedAddress: &CandidateRelatedAddress{ Address: config.RelAddr, Port: config.RelPort, }, remoteCandidateCaches: map[AddrPort]Candidate{}, }, }, nil } ice-4.1.0/candidate_relay.go000066400000000000000000000054441511704470200157230ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "net/netip" ) // CandidateRelay ... type CandidateRelay struct { candidateBase relayProtocol string onClose func() error } // CandidateRelayConfig is the config required to create a new CandidateRelay. type CandidateRelayConfig struct { CandidateID string Network string Address string Port int Component uint16 Priority uint32 Foundation string RelAddr string RelPort int RelayProtocol string OnClose func() error } // NewCandidateRelay creates a new relay candidate. func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) { candidateID := config.CandidateID if candidateID == "" { candidateID = globalCandidateIDGenerator.Generate() } ipAddr, err := netip.ParseAddr(config.Address) if err != nil { return nil, err } networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } return &CandidateRelay{ candidateBase: candidateBase{ id: candidateID, networkType: networkType, candidateType: CandidateTypeRelay, address: config.Address, port: config.Port, resolvedAddr: &net.UDPAddr{ IP: ipAddr.AsSlice(), Port: config.Port, Zone: ipAddr.Zone(), }, component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, relatedAddress: &CandidateRelatedAddress{ Address: config.RelAddr, Port: config.RelPort, }, remoteCandidateCaches: map[AddrPort]Candidate{}, }, relayProtocol: config.RelayProtocol, onClose: config.OnClose, }, nil } // LocalPreference returns the local preference for this candidate. func (c *CandidateRelay) LocalPreference() uint16 { // These preference values come from libwebrtc // https://github.com/mozilla/libwebrtc/blob/1389c76d9c79839a2ca069df1db48aa3f2e6a1ac/p2p/base/turn_port.cc#L61 var relayPreference uint16 switch c.relayProtocol { case relayProtocolTLS, relayProtocolDTLS: relayPreference = 2 case tcp: relayPreference = 1 default: relayPreference = 0 } return c.candidateBase.LocalPreference() + relayPreference } // RelayProtocol returns the protocol used between the endpoint and the relay server. func (c *CandidateRelay) RelayProtocol() string { return c.relayProtocol } func (c *CandidateRelay) close() error { err := c.candidateBase.close() if c.onClose != nil { err = c.onClose() c.onClose = nil } return err } func (c *CandidateRelay) copy() (Candidate, error) { cc, err := c.candidateBase.copy() if err != nil { return nil, err } if ccr, ok := cc.(*CandidateRelay); ok { ccr.relayProtocol = c.relayProtocol } return cc, nil } ice-4.1.0/candidate_relay_test.go000066400000000000000000000037611511704470200167620ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "net" "strconv" "testing" "time" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/pion/turn/v4" "github.com/stretchr/testify/require" ) func optimisticAuthHandler(string, string, net.Addr) (key []byte, ok bool) { return turn.GenerateAuthKey("username", "pion.ly", "password"), true } func TestRelayOnlyConnection(t *testing.T) { // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 30).Stop() defer test.CheckRoutines(t)() serverPort := randomPort(t) serverListener, err := net.ListenPacket("udp", localhostIPStr+":"+strconv.Itoa(serverPort)) // nolint: noctx require.NoError(t, err) server, err := turn.NewServer(turn.ServerConfig{ Realm: "pion.ly", AuthHandler: optimisticAuthHandler, PacketConnConfigs: []turn.PacketConnConfig{ { PacketConn: serverListener, RelayAddressGenerator: &turn.RelayAddressGeneratorNone{Address: localhostIPStr + ""}, }, }, }) require.NoError(t, err) defer func() { require.NoError(t, server.Close()) }() cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), Urls: []*stun.URI{ { Scheme: stun.SchemeTypeTURN, Host: localhostIPStr + "", Username: "username", Password: "password", Port: serverPort, Proto: stun.ProtoTypeUDP, }, }, CandidateTypes: []CandidateType{CandidateTypeRelay}, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() aNotifier, aConnected := onConnected() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() bNotifier, bConnected := onConnected() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) connect(t, aAgent, bAgent) <-aConnected <-bConnected } ice-4.1.0/candidate_server_reflexive.go000066400000000000000000000032501511704470200201570ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "net/netip" ) // CandidateServerReflexive ... type CandidateServerReflexive struct { candidateBase } // CandidateServerReflexiveConfig is the config required to create a new CandidateServerReflexive. type CandidateServerReflexiveConfig struct { CandidateID string Network string Address string Port int Component uint16 Priority uint32 Foundation string RelAddr string RelPort int } // NewCandidateServerReflexive creates a new server reflective candidate. func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*CandidateServerReflexive, error) { ipAddr, err := netip.ParseAddr(config.Address) if err != nil { return nil, err } networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } candidateID := config.CandidateID if candidateID == "" { candidateID = globalCandidateIDGenerator.Generate() } return &CandidateServerReflexive{ candidateBase: candidateBase{ id: candidateID, networkType: networkType, candidateType: CandidateTypeServerReflexive, address: config.Address, port: config.Port, resolvedAddr: &net.UDPAddr{ IP: ipAddr.AsSlice(), Port: config.Port, Zone: ipAddr.Zone(), }, component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, relatedAddress: &CandidateRelatedAddress{ Address: config.RelAddr, Port: config.RelPort, }, remoteCandidateCaches: map[AddrPort]Candidate{}, }, }, nil } ice-4.1.0/candidate_server_reflexive_test.go000066400000000000000000000034011511704470200212140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "net" "strconv" "testing" "time" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/pion/turn/v4" "github.com/stretchr/testify/require" ) func TestServerReflexiveOnlyConnection(t *testing.T) { defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 30).Stop() serverPort := randomPort(t) serverListener, err := net.ListenPacket("udp4", "127.0.0.1:"+strconv.Itoa(serverPort)) // nolint: noctx require.NoError(t, err) server, err := turn.NewServer(turn.ServerConfig{ Realm: "pion.ly", AuthHandler: optimisticAuthHandler, PacketConnConfigs: []turn.PacketConnConfig{ { PacketConn: serverListener, RelayAddressGenerator: &turn.RelayAddressGeneratorNone{Address: "127.0.0.1"}, }, }, }) require.NoError(t, err) defer func() { require.NoError(t, server.Close()) }() cfg := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, Urls: []*stun.URI{ { Scheme: SchemeTypeSTUN, Host: "127.0.0.1", Port: serverPort, }, }, CandidateTypes: []CandidateType{CandidateTypeServerReflexive}, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() aNotifier, aConnected := onConnected() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() bNotifier, bConnected := onConnected() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) connect(t, aAgent, bAgent) <-aConnected <-bConnected } ice-4.1.0/candidate_test.go000066400000000000000000001141131511704470200155600ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "strconv" "strings" "testing" "time" "github.com/pion/logging" "github.com/stretchr/testify/require" ) const localhostIPStr = "127.0.0.1" func TestCandidateTypePreference(t *testing.T) { req := require.New(t) hostDefaultPreference := uint16(126) prflxDefaultPreference := uint16(110) srflxDefaultPreference := uint16(100) relayDefaultPreference := uint16(0) tcpOffsets := []uint16{0, 10} for _, tcpOffset := range tcpOffsets { agent := &Agent{ tcpPriorityOffset: tcpOffset, } for _, networkType := range supportedNetworkTypes() { hostCandidate := candidateBase{ candidateType: CandidateTypeHost, networkType: networkType, currAgent: agent, } prflxCandidate := candidateBase{ candidateType: CandidateTypePeerReflexive, networkType: networkType, currAgent: agent, } srflxCandidate := candidateBase{ candidateType: CandidateTypeServerReflexive, networkType: networkType, currAgent: agent, } relayCandidate := candidateBase{ candidateType: CandidateTypeRelay, networkType: networkType, currAgent: agent, } if networkType.IsTCP() { req.Equal(hostDefaultPreference-tcpOffset, hostCandidate.TypePreference()) req.Equal(prflxDefaultPreference-tcpOffset, prflxCandidate.TypePreference()) req.Equal(srflxDefaultPreference-tcpOffset, srflxCandidate.TypePreference()) } else { req.Equal(hostDefaultPreference, hostCandidate.TypePreference()) req.Equal(prflxDefaultPreference, prflxCandidate.TypePreference()) req.Equal(srflxDefaultPreference, srflxCandidate.TypePreference()) } req.Equal(relayDefaultPreference, relayCandidate.TypePreference()) } } } func TestCandidatePriority(t *testing.T) { for _, test := range []struct { Candidate Candidate WantPriority uint32 }{ { Candidate: &CandidateHost{ candidateBase: candidateBase{ candidateType: CandidateTypeHost, component: ComponentRTP, }, }, WantPriority: 2130706431, }, { Candidate: &CandidateHost{ candidateBase: candidateBase{ candidateType: CandidateTypeHost, component: ComponentRTP, networkType: NetworkTypeTCP4, tcpType: TCPTypeActive, }, }, WantPriority: 1675624447, }, { Candidate: &CandidateHost{ candidateBase: candidateBase{ candidateType: CandidateTypeHost, component: ComponentRTP, networkType: NetworkTypeTCP4, tcpType: TCPTypePassive, }, }, WantPriority: 1671430143, }, { Candidate: &CandidateHost{ candidateBase: candidateBase{ candidateType: CandidateTypeHost, component: ComponentRTP, networkType: NetworkTypeTCP4, tcpType: TCPTypeSimultaneousOpen, }, }, WantPriority: 1667235839, }, { Candidate: &CandidatePeerReflexive{ candidateBase: candidateBase{ candidateType: CandidateTypePeerReflexive, component: ComponentRTP, }, }, WantPriority: 1862270975, }, { Candidate: &CandidatePeerReflexive{ candidateBase: candidateBase{ candidateType: CandidateTypePeerReflexive, component: ComponentRTP, networkType: NetworkTypeTCP6, tcpType: TCPTypeSimultaneousOpen, }, }, WantPriority: 1407188991, }, { Candidate: &CandidatePeerReflexive{ candidateBase: candidateBase{ candidateType: CandidateTypePeerReflexive, component: ComponentRTP, networkType: NetworkTypeTCP6, tcpType: TCPTypeActive, }, }, WantPriority: 1402994687, }, { Candidate: &CandidatePeerReflexive{ candidateBase: candidateBase{ candidateType: CandidateTypePeerReflexive, component: ComponentRTP, networkType: NetworkTypeTCP6, tcpType: TCPTypePassive, }, }, WantPriority: 1398800383, }, { Candidate: &CandidateServerReflexive{ candidateBase: candidateBase{ candidateType: CandidateTypeServerReflexive, component: ComponentRTP, }, }, WantPriority: 1694498815, }, { Candidate: &CandidateRelay{ candidateBase: candidateBase{ candidateType: CandidateTypeRelay, component: ComponentRTP, }, }, WantPriority: 16777215, }, } { require.Equal(t, test.Candidate.Priority(), test.WantPriority) } } func TestCandidateLastSent(t *testing.T) { candidate := candidateBase{} require.Equal(t, candidate.LastSent(), time.Time{}) now := time.Now() candidate.setLastSent(now) require.Equal(t, candidate.LastSent(), now) } func TestCandidateLastReceived(t *testing.T) { candidate := candidateBase{} require.Equal(t, candidate.LastReceived(), time.Time{}) now := time.Now() candidate.setLastReceived(now) require.Equal(t, candidate.LastReceived(), now) } func TestCandidateFoundation(t *testing.T) { // All fields are the same require.Equal(t, (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP4, address: "A", }).Foundation(), (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP4, address: "A", }).Foundation()) // Different Address require.NotEqual(t, (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP4, address: "A", }).Foundation(), (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP4, address: "B", }).Foundation()) // Different networkType require.NotEqual(t, (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP4, address: "A", }).Foundation(), (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP6, address: "A", }).Foundation()) // Different candidateType require.NotEqual(t, (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP4, address: "A", }).Foundation(), (&candidateBase{ candidateType: CandidateTypePeerReflexive, networkType: NetworkTypeUDP4, address: "A", }).Foundation()) // Port has no effect require.Equal(t, (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP4, address: "A", port: 8080, }).Foundation(), (&candidateBase{ candidateType: CandidateTypeHost, networkType: NetworkTypeUDP4, address: "A", port: 80, }).Foundation()) } func mustCandidateHost(t *testing.T, conf *CandidateHostConfig) Candidate { t.Helper() cand, err := NewCandidateHost(conf) require.NoError(t, err) return cand } func mustCandidateHostWithExtensions( t *testing.T, conf *CandidateHostConfig, extensions []CandidateExtension, ) Candidate { t.Helper() cand, err := NewCandidateHost(conf) require.NoError(t, err) cand.setExtensions(extensions) return cand } func mustCandidateRelay(t *testing.T, conf *CandidateRelayConfig) Candidate { t.Helper() cand, err := NewCandidateRelay(conf) require.NoError(t, err) return cand } func mustCandidateRelayWithExtensions( t *testing.T, conf *CandidateRelayConfig, extensions []CandidateExtension, ) Candidate { t.Helper() cand, err := NewCandidateRelay(conf) require.NoError(t, err) cand.setExtensions(extensions) return cand } func mustCandidateServerReflexive(t *testing.T, conf *CandidateServerReflexiveConfig) Candidate { t.Helper() cand, err := NewCandidateServerReflexive(conf) require.NoError(t, err) return cand } func mustCandidateServerReflexiveWithExtensions( t *testing.T, conf *CandidateServerReflexiveConfig, extensions []CandidateExtension, ) Candidate { t.Helper() cand, err := NewCandidateServerReflexive(conf) require.NoError(t, err) cand.setExtensions(extensions) return cand } func mustCandidatePeerReflexiveWithExtensions( t *testing.T, conf *CandidatePeerReflexiveConfig, extensions []CandidateExtension, ) Candidate { t.Helper() cand, err := NewCandidatePeerReflexive(conf) require.NoError(t, err) cand.setExtensions(extensions) return cand } func TestCandidateMarshal(t *testing.T) { for idx, test := range []struct { candidate Candidate marshaled string expectError bool }{ { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP6.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }), "750 1 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host", false, }, { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "10.0.75.1", Port: 53634, }), "4273957277 1 udp 2130706431 10.0.75.1 53634 typ host", false, }, { mustCandidateServerReflexive(t, &CandidateServerReflexiveConfig{ Network: NetworkTypeUDP4.String(), Address: "191.228.238.68", Port: 53991, RelAddr: "192.168.0.274", RelPort: 53991, }), "647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991", false, }, { mustCandidatePeerReflexiveWithExtensions( t, &CandidatePeerReflexiveConfig{ Network: NetworkTypeTCP4.String(), Address: "192.0.2.15", Port: 50000, RelAddr: "10.0.0.1", RelPort: 12345, }, []CandidateExtension{ {"generation", "0"}, {"network-id", "2"}, {"network-cost", "10"}, }, ), //nolint: lll "4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10", false, }, { mustCandidateRelay(t, &CandidateRelayConfig{ Network: NetworkTypeUDP4.String(), Address: "50.0.0.1", Port: 5000, RelAddr: "192.168.0.1", RelPort: 5001, }), "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001", false, }, { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "192.168.0.196", Port: 0, TCPType: TCPTypeActive, }), "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype active", false, }, { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local", Port: 60542, }), "1380287402 1 udp 2130706431 e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local 60542 typ host", false, }, // Missing Foundation { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: localhostIPStr, Port: 80, Priority: 500, Foundation: " ", }), " 1 udp 500 " + localhostIPStr + " 80 typ host", false, }, // Missing Foundation { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: localhostIPStr, Port: 80, Priority: 500, Foundation: " ", }), "candidate: 1 udp 500 " + localhostIPStr + " 80 typ host", false, }, { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: localhostIPStr, Port: 80, Priority: 500, Foundation: "+/3713fhi", }), "+/3713fhi 1 udp 500 " + localhostIPStr + " 80 typ host", false, }, { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "172.28.142.173", Port: 7686, Priority: 1671430143, Foundation: "+/3713fhi", }), "3359356140 1 tcp 1671430143 172.28.142.173 7686 typ host", false, }, { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "172.28.142.173", Port: 7686, Priority: 1671430143, Foundation: "+/3713fhi", }), "candidate:3359356140 1 tcp 1671430143 172.28.142.173 7686 typ host", false, }, // Invalid candidates {nil, "", true}, {nil, "1938809241", true}, {nil, "1986380506 99999999 udp 2122063615 10.0.75.1 53634 typ host generation 0 network-id 2", true}, {nil, "1986380506 1 udp 99999999999 10.0.75.1 53634 typ host", true}, //nolint: lll {nil, "4207374051 1 udp 1685790463 191.228.238.68 99999999 typ srflx raddr 192.168.0.278 rport 53991 generation 0 network-id 3", true}, {nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr", true}, //nolint: lll {nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr 192.168.0.278 rport 99999999 generation 0 network-id 3", true}, {nil, "4207374051 INVALID udp 2130706431 10.0.75.1 53634 typ host", true}, {nil, "4207374051 1 udp INVALID 10.0.75.1 53634 typ host", true}, {nil, "4207374051 INVALID udp 2130706431 10.0.75.1 INVALID typ host", true}, {nil, "4207374051 1 udp 2130706431 10.0.75.1 53634 typ INVALID", true}, {nil, "4207374051 1 INVALID 2130706431 10.0.75.1 53634 typ host", true}, {nil, "4207374051 1 INVALID 2130706431 10.0.75.1 53634 typ", true}, {nil, "4207374051 1 INVALID 2130706431 10.0.75.1 53634", true}, {nil, "848194626 1 udp 16777215 50.0.0.^^1 5000 typ relay raddr 192.168.0.1 rport 5001", true}, {nil, "4207374052 1 tcp 1685790463 192.0#.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 rport 5001", true}, {nil, "647372371 1 udp 1694498815 191.228.2@338.68 53991 typ srflx raddr 192.168.0.274 rport 53991", true}, // invalid foundion; longer than 32 characters {nil, "111111111111111111111111111111111 1 udp 500 " + localhostIPStr + " 80 typ host", true}, // Invalid ice-char {nil, "3$3 1 udp 500 " + localhostIPStr + " 80 typ host", true}, // invalid component; longer than 5 digits {nil, "4207374051 123456 udp 500 " + localhostIPStr + " 0 typ host", true}, // invalid priority; longer than 10 digits {nil, "4207374051 99999 udp 12345678910 " + localhostIPStr + " 99999 typ host", true}, // invalid port; {nil, "4207374051 99999 udp 500 " + localhostIPStr + " 65536 typ host", true}, {nil, "4207374051 99999 udp 500 " + localhostIPStr + " 999999 typ host", true}, {nil, "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 999999", true}, // bad byte-string in extension value {nil, "750 1 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host ext valu\nu", true}, {nil, "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 654 ext valu\nu", true}, {nil, "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 654 ext valu\000e", true}, // bad byte-string in extension key {nil, "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 654 ext\r value", true}, // invalid tcptype {nil, "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype INVALID", true}, // expect rport after raddr {nil, "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 extension 322", true}, {nil, "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport", true}, {nil, "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1", true}, {nil, "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr", true}, {nil, "4207374051 99999 udp 500 " + localhostIPStr + " 80 typ", true}, {nil, "4207374051 99999 udp 500 " + localhostIPStr + " 80", true}, {nil, "4207374051 99999 udp 500 " + localhostIPStr, true}, {nil, "4207374051 99999 udp 500 ", true}, {nil, "4207374051 99999 udp", true}, {nil, "4207374051 99999", true}, {nil, "4207374051", true}, } { t.Run(strconv.Itoa(idx), func(t *testing.T) { actualCandidate, err := UnmarshalCandidate(test.marshaled) if test.expectError { require.Error(t, err, "expected error", test.marshaled) return } require.NoError(t, err) require.Truef( t, test.candidate.Equal(actualCandidate), "%s != %s", test.candidate.String(), actualCandidate.String(), ) if strings.HasPrefix(test.marshaled, "candidate:") { require.Equal(t, test.marshaled[len("candidate:"):], actualCandidate.Marshal()) } else { require.Equal(t, test.marshaled, actualCandidate.Marshal()) } }) } } func TestCandidateWriteTo(t *testing.T) { listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 0, }) require.NoError(t, err, "error creating test TCP listener") conn, err := net.DialTCP("tcp", nil, listener.Addr().(*net.TCPAddr)) // nolint require.NoError(t, err, "error dialing test TCP connection") loggerFactory := logging.NewDefaultLoggerFactory() packetConn := newTCPPacketConn(tcpPacketParams{ ReadBuffer: 2048, Logger: loggerFactory.NewLogger("tcp-packet-conn"), }) err = packetConn.AddConn(conn, nil) require.NoError(t, err, "error adding test TCP connection to packet connection") c1 := &candidateBase{ conn: packetConn, currAgent: &Agent{ log: loggerFactory.NewLogger("agent"), }, } c2 := &candidateBase{ resolvedAddr: listener.Addr(), } _, err = c1.writeTo([]byte("test"), c2) require.NoError(t, err, "writing to open conn") err = packetConn.Close() require.NoError(t, err, "error closing test TCP connection") _, err = c1.writeTo([]byte("test"), c2) require.Error(t, err, "writing to closed conn") } func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) { candidateWithZoneID := mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP6.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a%Local Connection", Port: 53987, Priority: 500, Foundation: "750", }) candidateStr := "750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host" require.Equal(t, candidateStr, candidateWithZoneID.Marshal()) candidate := mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP6.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) candidateWithZoneIDStr := "750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a%eth0 53987 typ host" candidate2, err := UnmarshalCandidate(candidateWithZoneIDStr) require.NoError(t, err) require.Truef(t, candidate.Equal(candidate2), "%s != %s", candidate.String(), candidate2.String()) candidateWithZoneIDStr2 := "750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a%eth0%eth1 53987 typ host" candidate2, err = UnmarshalCandidate(candidateWithZoneIDStr2) require.NoError(t, err) require.Truef(t, candidate.Equal(candidate2), "%s != %s", candidate.String(), candidate2.String()) } func TestCandidateExtensionsMarshal(t *testing.T) { testCases := []struct { Extensions []CandidateExtension candidate string }{ { []CandidateExtension{ {"generation", "0"}, {"ufrag", "QNvE"}, {"network-id", "4"}, }, //nolint: lll "1299692247 1 udp 2122134271 fdc8:cc8:c835:e400:343c:feb:32c8:17b9 58240 typ host generation 0 ufrag QNvE network-id 4", }, { []CandidateExtension{ {"generation", "1"}, {"network-id", "2"}, {"network-cost", "50"}, }, //nolint:lll "647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991 generation 1 network-id 2 network-cost 50", }, { []CandidateExtension{ {"generation", "0"}, {"network-id", "2"}, {"network-cost", "10"}, }, //nolint:lll "4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10", }, { []CandidateExtension{ {"generation", "0"}, {"network-id", "1"}, {"network-cost", "20"}, {"ufrag", "frag42abcdef"}, {"password", "abc123exp123"}, }, //nolint: lll "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001 generation 0 network-id 1 network-cost 20 ufrag frag42abcdef password abc123exp123", }, { []CandidateExtension{ {"tcptype", "active"}, {"generation", "0"}, }, "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype active generation 0", }, { []CandidateExtension{ {"tcptype", "active"}, {"generation", "0"}, }, "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype active generation 0", }, { []CandidateExtension{}, "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host", }, { []CandidateExtension{ {"tcptype", "active"}, {"empty-value-1", ""}, {"empty-value-2", ""}, }, "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype active empty-value-1 empty-value-2", }, { []CandidateExtension{ {"tcptype", "active"}, {"empty-value-1", ""}, {"empty-value-2", ""}, }, "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype active empty-value-1 empty-value-2 ", }, } for _, tc := range testCases { candidate, err := UnmarshalCandidate(tc.candidate) require.NoError(t, err) require.Equal(t, tc.Extensions, candidate.Extensions(), "Extensions should be equal", tc.candidate) valueStr := candidate.Marshal() candidate2, err := UnmarshalCandidate(valueStr) require.NoError(t, err) require.Equal(t, tc.Extensions, candidate2.Extensions(), "Marshal() should preserve extensions") } } func TestCandidateExtensionsDeepEqual(t *testing.T) { noExt, err := UnmarshalCandidate("750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host") require.NoError(t, err) generation := "0" ufrag := "QNvE" networkID := "4" extensions := []CandidateExtension{ {"generation", generation}, {"ufrag", ufrag}, {"network-id", networkID}, } candidate, err := UnmarshalCandidate( "750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host generation " + generation + " ufrag " + ufrag + " network-id " + networkID, ) require.NoError(t, err) testCases := []struct { a Candidate b Candidate equal bool }{ { mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }), noExt, true, }, { mustCandidateHostWithExtensions( t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }, []CandidateExtension{}, ), noExt, true, }, { mustCandidateHostWithExtensions( t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }, extensions, ), candidate, true, }, { mustCandidateRelayWithExtensions( t, &CandidateRelayConfig{ Network: NetworkTypeUDP4.String(), Address: "10.0.0.10", Port: 5000, RelAddr: "10.0.0.2", RelPort: 5001, }, []CandidateExtension{ {"generation", "0"}, {"network-id", "1"}, }, ), mustCandidateRelayWithExtensions( t, &CandidateRelayConfig{ Network: NetworkTypeUDP4.String(), Address: "10.0.0.10", Port: 5000, RelAddr: "10.0.0.2", RelPort: 5001, }, []CandidateExtension{ {"network-id", "1"}, {"generation", "0"}, }, ), true, }, { mustCandidatePeerReflexiveWithExtensions( t, &CandidatePeerReflexiveConfig{ Network: NetworkTypeTCP4.String(), Address: "192.0.2.15", Port: 50000, RelAddr: "10.0.0.1", RelPort: 12345, }, []CandidateExtension{ {"generation", "0"}, {"network-id", "2"}, {"network-cost", "10"}, }, ), mustCandidatePeerReflexiveWithExtensions( t, &CandidatePeerReflexiveConfig{ Network: NetworkTypeTCP4.String(), Address: "192.0.2.15", Port: 50000, RelAddr: "10.0.0.1", RelPort: 12345, }, []CandidateExtension{ {"generation", "0"}, {"network-id", "2"}, {"network-cost", "10"}, }, ), true, }, { mustCandidateServerReflexiveWithExtensions( t, &CandidateServerReflexiveConfig{ Network: NetworkTypeUDP4.String(), Address: "191.228.238.68", Port: 53991, RelAddr: "192.168.0.274", RelPort: 53991, }, []CandidateExtension{ {"generation", "0"}, {"network-id", "2"}, {"network-cost", "10"}, }, ), mustCandidateServerReflexiveWithExtensions( t, &CandidateServerReflexiveConfig{ Network: NetworkTypeUDP4.String(), Address: "191.228.238.68", Port: 53991, RelAddr: "192.168.0.274", RelPort: 53991, }, []CandidateExtension{ {"generation", "0"}, {"network-id", "2"}, {"network-cost", "10"}, }, ), true, }, { mustCandidateHostWithExtensions( t, &CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }, []CandidateExtension{ {"generation", "5"}, {"ufrag", ufrag}, {"network-id", networkID}, }, ), candidate, false, }, { mustCandidateHostWithExtensions( t, &CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "192.168.0.196", Port: 0, Priority: 2128609279, Foundation: "1052353102", TCPType: TCPTypeActive, }, []CandidateExtension{ {"tcptype", TCPTypeActive.String()}, {"generation", "0"}, }, ), mustCandidateHostWithExtensions( t, &CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "192.168.0.197", Port: 0, Priority: 2128609279, Foundation: "1052353102", TCPType: TCPTypeActive, }, []CandidateExtension{ {"tcptype", TCPTypeActive.String()}, {"generation", "0"}, }, ), false, }, } for _, tc := range testCases { require.Equal(t, tc.a.DeepEqual(tc.b), tc.equal, "a: %s, b: %s", tc.a.Marshal(), tc.b.Marshal()) } } func TestUnmarshalCandidateExtensions(t *testing.T) { testCases := []struct { name string value string expected []CandidateExtension fail bool }{ { name: "empty string", value: "", expected: []CandidateExtension{}, fail: false, }, { name: "valid extension string", value: "a b c d", expected: []CandidateExtension{{"a", "b"}, {"c", "d"}}, fail: false, }, { name: "valid extension string", value: "a b empty c d", expected: []CandidateExtension{ {"a", "b"}, {"empty", ""}, {"c", "d"}, }, fail: false, }, { name: "invalid extension", value: " a b d", expected: []CandidateExtension{{"", "a"}, {"b", "d"}}, fail: true, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { req := require.New(t) actual, _, err := unmarshalCandidateExtensions(testCase.value) if testCase.fail { req.Error(err) } else { req.NoError(err) req.EqualValuesf( testCase.expected, actual, "UnmarshalCandidateExtensions() did not return the expected value %v", testCase.value, ) } }) } } func TestCandidateGetExtension(t *testing.T) { t.Run("Get extension", func(t *testing.T) { extensions := []CandidateExtension{ {"a", "b"}, {"c", "d"}, } candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) candidate.setExtensions(extensions) value, ok := candidate.GetExtension("c") require.True(t, ok) require.Equal(t, "c", value.Key) require.Equal(t, "d", value.Value) value, ok = candidate.GetExtension("a") require.True(t, ok) require.Equal(t, "a", value.Key) require.Equal(t, "b", value.Value) value, ok = candidate.GetExtension("b") require.False(t, ok) require.Equal(t, "b", value.Key) require.Equal(t, "", value.Value) }) // This is undefined behavior in the spec; extension-att-name is not unique // but it implied that it's unique in the implementation t.Run("Extension with multiple values", func(t *testing.T) { extensions := []CandidateExtension{ {"a", "1"}, {"a", "2"}, } candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) candidate.setExtensions(extensions) value, ok := candidate.GetExtension("a") require.True(t, ok) require.Equal(t, "a", value.Key) require.Equal(t, "1", value.Value) }) t.Run("TCPType extension", func(t *testing.T) { extensions := []CandidateExtension{ {"tcptype", "passive"}, } candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", TCPType: TCPTypeActive, }) require.NoError(t, err) tcpType, ok := candidate.GetExtension("tcptype") require.True(t, ok) require.Equal(t, "tcptype", tcpType.Key) require.Equal(t, TCPTypeActive.String(), tcpType.Value) candidate.setExtensions(extensions) tcpType, ok = candidate.GetExtension("tcptype") require.True(t, ok) require.Equal(t, "tcptype", tcpType.Key) require.Equal(t, "passive", tcpType.Value) candidate2, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) tcpType, ok = candidate2.GetExtension("tcptype") require.False(t, ok) require.Equal(t, "tcptype", tcpType.Key) require.Equal(t, "", tcpType.Value) }) } func TestBaseCandidateMarshalExtensions(t *testing.T) { t.Run("Marshal extension", func(t *testing.T) { extensions := []CandidateExtension{ {"generation", "0"}, {"ValuE", "KeE"}, {"empty", ""}, {"another", "value"}, } candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) candidate.setExtensions(extensions) value := candidate.marshalExtensions() require.Equal(t, "generation 0 ValuE KeE empty another value", value) }) t.Run("Marshal Empty", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) value := candidate.marshalExtensions() require.Equal(t, "", value) }) t.Run("Marshal TCPType no extension", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", TCPType: TCPTypeActive, }) require.NoError(t, err) value := candidate.marshalExtensions() require.Equal(t, "tcptype active", value) }) } func TestBaseCandidateExtensionsEqual(t *testing.T) { testCases := []struct { name string extensions1 []CandidateExtension extensions2 []CandidateExtension expected bool }{ { name: "Empty extensions", extensions1: []CandidateExtension{}, extensions2: []CandidateExtension{}, expected: true, }, { name: "Single value extensions", extensions1: []CandidateExtension{{"a", "b"}}, extensions2: []CandidateExtension{{"a", "b"}}, expected: true, }, { name: "multiple value extensions", extensions1: []CandidateExtension{ {"a", "b"}, {"c", "d"}, }, extensions2: []CandidateExtension{ {"a", "b"}, {"c", "d"}, }, expected: true, }, { name: "unsorted extensions", extensions1: []CandidateExtension{ {"c", "d"}, {"a", "b"}, }, extensions2: []CandidateExtension{ {"a", "b"}, {"c", "d"}, }, expected: true, }, { name: "different values", extensions1: []CandidateExtension{ {"a", "b"}, {"c", "d"}, }, extensions2: []CandidateExtension{ {"a", "b"}, {"c", "e"}, }, expected: false, }, { name: "different size", extensions1: []CandidateExtension{ {"a", "b"}, {"c", "d"}, }, extensions2: []CandidateExtension{ {"a", "b"}, }, expected: false, }, { name: "different keys", extensions1: []CandidateExtension{ {"a", "b"}, {"c", "d"}, }, extensions2: []CandidateExtension{ {"a", "b"}, {"e", "d"}, }, expected: false, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { cand, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) cand.setExtensions(testCase.extensions1) require.Equal(t, testCase.expected, cand.extensionsEqual(testCase.extensions2)) }) } } func TestCandidateAddExtension(t *testing.T) { t.Run("Add extension", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) extensions := candidate.Extensions() require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions) }) t.Run("Add extension with existing key", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "d"})) extensions := candidate.Extensions() require.Equal(t, []CandidateExtension{{"a", "d"}}, extensions) }) t.Run("Keep tcptype extension", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", TCPType: TCPTypeActive, }) require.NoError(t, err) ext, ok := candidate.GetExtension("tcptype") require.True(t, ok) require.Equal(t, ext, CandidateExtension{"tcptype", "active"}) require.Equal(t, candidate.Extensions(), []CandidateExtension{{"tcptype", "active"}}) require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) ext, ok = candidate.GetExtension("tcptype") require.True(t, ok) require.Equal(t, ext, CandidateExtension{"tcptype", "active"}) require.Equal(t, candidate.Extensions(), []CandidateExtension{{"tcptype", "active"}, {"a", "b"}}) }) t.Run("TcpType change extension", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "active"})) extensions := candidate.Extensions() require.Equal(t, []CandidateExtension{{"tcptype", "active"}}, extensions) require.Equal(t, TCPTypeActive, candidate.TCPType()) require.Error(t, candidate.AddExtension(CandidateExtension{"tcptype", "INVALID"})) }) t.Run("Add empty extension", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) require.Error(t, candidate.AddExtension(CandidateExtension{"", ""})) require.NoError(t, candidate.AddExtension(CandidateExtension{"a", ""})) extensions := candidate.Extensions() require.Equal(t, []CandidateExtension{{"a", ""}}, extensions) }) } func TestCandidateRemoveExtension(t *testing.T) { t.Run("Remove extension", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) require.True(t, candidate.RemoveExtension("a")) extensions := candidate.Extensions() require.Equal(t, []CandidateExtension{{"c", "d"}}, extensions) }) t.Run("Remove extension that does not exist", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeUDP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", }) require.NoError(t, err) require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) require.False(t, candidate.RemoveExtension("b")) extensions := candidate.Extensions() require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions) }) t.Run("Remove tcptype extension", func(t *testing.T) { candidate, err := NewCandidateHost(&CandidateHostConfig{ Network: NetworkTypeTCP4.String(), Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Port: 53987, Priority: 500, Foundation: "750", TCPType: TCPTypeActive, }) require.NoError(t, err) // tcptype extension should be removed, even if it's not in the extensions list (Not Parsed) require.True(t, candidate.RemoveExtension("tcptype")) require.Equal(t, TCPTypeUnspecified, candidate.TCPType()) require.Empty(t, candidate.Extensions()) require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "passive"})) require.True(t, candidate.RemoveExtension("tcptype")) require.Equal(t, TCPTypeUnspecified, candidate.TCPType()) require.Empty(t, candidate.Extensions()) }) } ice-4.1.0/candidatepair.go000066400000000000000000000231401511704470200153740ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" "sync/atomic" "time" "github.com/pion/stun/v3" ) func newCandidatePair(local, remote Candidate, controlling bool) *CandidatePair { return &CandidatePair{ iceRoleControlling: controlling, Remote: remote, Local: local, state: CandidatePairStateWaiting, } } // CandidatePair is a combination of a local and remote candidate. type CandidatePair struct { iceRoleControlling bool Remote Candidate Local Candidate bindingRequestCount uint16 state CandidatePairState nominated bool nominateOnBindingSuccess bool // stats currentRoundTripTime int64 // in ns totalRoundTripTime int64 // in ns packetsSent uint32 packetsReceived uint32 bytesSent uint64 bytesReceived uint64 lastPacketSentAt atomic.Value // time.Time lastPacketReceivedAt atomic.Value // time.Time requestsReceived uint64 requestsSent uint64 responsesReceived uint64 responsesSent uint64 firstRequestSentAt atomic.Value // time.Time lastRequestSentAt atomic.Value // time.Time firstResponseReceivedAt atomic.Value // time.Time lastResponseReceivedAt atomic.Value // time.Time firstRequestReceivedAt atomic.Value // time.Time lastRequestReceivedAt atomic.Value // time.Time } func (p *CandidatePair) String() string { if p == nil { return "" } return fmt.Sprintf( "prio %d (local, prio %d) %s <-> %s (remote, prio %d), state: %s, nominated: %v, nominateOnBindingSuccess: %v", p.priority(), p.Local.Priority(), p.Local, p.Remote, p.Remote.Priority(), p.state, p.nominated, p.nominateOnBindingSuccess, ) } func (p *CandidatePair) equal(other *CandidatePair) bool { if p == nil && other == nil { return true } if p == nil || other == nil { return false } return p.Local.Equal(other.Local) && p.Remote.Equal(other.Remote) } // RFC 5245 - 5.7.2. Computing Pair Priority and Ordering Pairs // Let G be the priority for the candidate provided by the controlling // agent. Let D be the priority for the candidate provided by the // controlled agent. // pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0). func (p *CandidatePair) priority() uint64 { var g, d uint32 //nolint:varnamelen // clearer to use g and d here if p.iceRoleControlling { g = p.Local.Priority() d = p.Remote.Priority() } else { g = p.Remote.Priority() d = p.Local.Priority() } // Just implement these here rather // than fooling around with the math package localMin := func(x, y uint32) uint64 { if x < y { return uint64(x) } return uint64(y) } localMax := func(x, y uint32) uint64 { if x > y { return uint64(x) } return uint64(y) } cmp := func(x, y uint32) uint64 { if x > y { return uint64(1) } return uint64(0) } // 1<<32 overflows uint32; and if both g && d are // maxUint32, this result would overflow uint64 return (1<<32-1)*localMin(g, d) + 2*localMax(g, d) + cmp(g, d) } func (p *CandidatePair) Write(b []byte) (int, error) { return p.Local.writeTo(b, p.Remote) } func (a *Agent) sendSTUN(msg *stun.Message, local, remote Candidate) { _, err := local.writeTo(msg.Raw, remote) if err != nil { a.log.Tracef("Failed to send STUN message: %s", err) } } // UpdateRoundTripTime sets the current round time of this pair and // accumulates total round trip time and responses received. func (p *CandidatePair) UpdateRoundTripTime(rtt time.Duration) { rttNs := rtt.Nanoseconds() atomic.StoreInt64(&p.currentRoundTripTime, rttNs) atomic.AddInt64(&p.totalRoundTripTime, rttNs) atomic.AddUint64(&p.responsesReceived, 1) now := time.Now() p.firstResponseReceivedAt.CompareAndSwap(nil, now) p.lastResponseReceivedAt.Store(now) } // CurrentRoundTripTime returns the current round trip time in seconds // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-currentroundtriptime func (p *CandidatePair) CurrentRoundTripTime() float64 { return time.Duration(atomic.LoadInt64(&p.currentRoundTripTime)).Seconds() } // TotalRoundTripTime returns the current round trip time in seconds // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-totalroundtriptime func (p *CandidatePair) TotalRoundTripTime() float64 { return time.Duration(atomic.LoadInt64(&p.totalRoundTripTime)).Seconds() } // RequestsReceived returns the total number of connectivity checks received // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-requestsreceived func (p *CandidatePair) RequestsReceived() uint64 { return atomic.LoadUint64(&p.requestsReceived) } // RequestsSent returns the total number of connectivity checks sent // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-requestssent func (p *CandidatePair) RequestsSent() uint64 { return atomic.LoadUint64(&p.requestsSent) } // ResponsesReceived returns the total number of connectivity responses received // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-responsesreceived func (p *CandidatePair) ResponsesReceived() uint64 { return atomic.LoadUint64(&p.responsesReceived) } // ResponsesSent returns the total number of connectivity responses sent // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-responsessent func (p *CandidatePair) ResponsesSent() uint64 { return atomic.LoadUint64(&p.responsesSent) } // PacketsSent returns total application (non-STUN) packets sent on this pair. func (p *CandidatePair) PacketsSent() uint32 { return atomic.LoadUint32(&p.packetsSent) } // PacketsReceived returns total application (non-STUN) packets received on this pair. func (p *CandidatePair) PacketsReceived() uint32 { return atomic.LoadUint32(&p.packetsReceived) } // BytesSent returns total application bytes sent on this pair. func (p *CandidatePair) BytesSent() uint64 { return atomic.LoadUint64(&p.bytesSent) } // BytesReceived returns total application bytes received on this pair. func (p *CandidatePair) BytesReceived() uint64 { return atomic.LoadUint64(&p.bytesReceived) } // LastPacketSentAt returns the timestamp of the last application packet sent. func (p *CandidatePair) LastPacketSentAt() time.Time { if v, ok := p.lastPacketSentAt.Load().(time.Time); ok { return v } return time.Time{} } // LastPacketReceivedAt returns the timestamp of the last application packet received. func (p *CandidatePair) LastPacketReceivedAt() time.Time { if v, ok := p.lastPacketReceivedAt.Load().(time.Time); ok { return v } return time.Time{} } // UpdatePacketSent increments packet/byte counters and updates timestamp for a sent application packet. func (p *CandidatePair) UpdatePacketSent(n int) { if n <= 0 { return } atomic.AddUint32(&p.packetsSent, 1) atomic.AddUint64(&p.bytesSent, uint64(n)) // #nosec G115 -- n > 0 validated above p.lastPacketSentAt.Store(time.Now()) } // UpdatePacketReceived increments packet/byte counters and updates timestamp for a received application packet. func (p *CandidatePair) UpdatePacketReceived(n int) { if n <= 0 { return } atomic.AddUint32(&p.packetsReceived, 1) atomic.AddUint64(&p.bytesReceived, uint64(n)) // #nosec G115 -- n > 0 validated above p.lastPacketReceivedAt.Store(time.Now()) } // FirstRequestSentAt returns the timestamp of the first connectivity check sent. func (p *CandidatePair) FirstRequestSentAt() time.Time { if v, ok := p.firstRequestSentAt.Load().(time.Time); ok { return v } return time.Time{} } // LastRequestSentAt returns the timestamp of the last connectivity check sent. func (p *CandidatePair) LastRequestSentAt() time.Time { if v, ok := p.lastRequestSentAt.Load().(time.Time); ok { return v } return time.Time{} } // Deprecated: use FirstResponseReceivedAt // FirstReponseReceivedAt returns the timestamp of the first connectivity response received. func (p *CandidatePair) FirstReponseReceivedAt() time.Time { return p.FirstResponseReceivedAt() } // FirstResponseReceivedAt returns the timestamp of the first connectivity response received. func (p *CandidatePair) FirstResponseReceivedAt() time.Time { if v, ok := p.firstResponseReceivedAt.Load().(time.Time); ok { return v } return time.Time{} } // LastResponseReceivedAt returns the timestamp of the last connectivity response received. func (p *CandidatePair) LastResponseReceivedAt() time.Time { if v, ok := p.lastResponseReceivedAt.Load().(time.Time); ok { return v } return time.Time{} } // FirstRequestReceivedAt returns the timestamp of the first connectivity check received. func (p *CandidatePair) FirstRequestReceivedAt() time.Time { if v, ok := p.firstRequestReceivedAt.Load().(time.Time); ok { return v } return time.Time{} } // LastRequestReceivedAt returns the timestamp of the last connectivity check received. func (p *CandidatePair) LastRequestReceivedAt() time.Time { if v, ok := p.lastRequestReceivedAt.Load().(time.Time); ok { return v } return time.Time{} } // UpdateRequestSent increments the number of requests sent and updates the timestamp. func (p *CandidatePair) UpdateRequestSent() { atomic.AddUint64(&p.requestsSent, 1) now := time.Now() p.firstRequestSentAt.CompareAndSwap(nil, now) p.lastRequestSentAt.Store(now) } // UpdateResponseSent increments the number of responses sent. func (p *CandidatePair) UpdateResponseSent() { atomic.AddUint64(&p.responsesSent, 1) } // UpdateRequestReceived increments the number of requests received and updates the timestamp. func (p *CandidatePair) UpdateRequestReceived() { atomic.AddUint64(&p.requestsReceived, 1) now := time.Now() p.firstRequestReceivedAt.CompareAndSwap(nil, now) p.lastRequestReceivedAt.Store(now) } ice-4.1.0/candidatepair_state.go000066400000000000000000000022311511704470200165720ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice // CandidatePairState represent the ICE candidate pair state. type CandidatePairState int const ( // CandidatePairStateWaiting means a check has not been performed for // this pair. CandidatePairStateWaiting CandidatePairState = iota + 1 // CandidatePairStateInProgress means a check has been sent for this pair, // but the transaction is in progress. CandidatePairStateInProgress // CandidatePairStateFailed means a check for this pair was already done // and failed, either never producing any response or producing an unrecoverable // failure response. CandidatePairStateFailed // CandidatePairStateSucceeded means a check for this pair was already // done and produced a successful result. CandidatePairStateSucceeded ) func (c CandidatePairState) String() string { switch c { case CandidatePairStateWaiting: return "waiting" case CandidatePairStateInProgress: return "in-progress" case CandidatePairStateFailed: return "failed" case CandidatePairStateSucceeded: return "succeeded" } return "Unknown candidate pair state" } ice-4.1.0/candidatepair_test.go000066400000000000000000000103701511704470200164340ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "github.com/stretchr/testify/require" ) func hostCandidate() *CandidateHost { return &CandidateHost{ candidateBase: candidateBase{ candidateType: CandidateTypeHost, component: ComponentRTP, }, } } func prflxCandidate() *CandidatePeerReflexive { return &CandidatePeerReflexive{ candidateBase: candidateBase{ candidateType: CandidateTypePeerReflexive, component: ComponentRTP, }, } } func srflxCandidate() *CandidateServerReflexive { return &CandidateServerReflexive{ candidateBase: candidateBase{ candidateType: CandidateTypeServerReflexive, component: ComponentRTP, }, } } func relayCandidate() *CandidateRelay { return &CandidateRelay{ candidateBase: candidateBase{ candidateType: CandidateTypeRelay, component: ComponentRTP, }, } } func TestCandidatePairPriority(t *testing.T) { for _, test := range []struct { Pair *CandidatePair WantPriority uint64 }{ { Pair: newCandidatePair( hostCandidate(), hostCandidate(), false, ), WantPriority: 9151314440652587007, }, { Pair: newCandidatePair( hostCandidate(), hostCandidate(), true, ), WantPriority: 9151314440652587007, }, { Pair: newCandidatePair( hostCandidate(), prflxCandidate(), true, ), WantPriority: 7998392936314175488, }, { Pair: newCandidatePair( hostCandidate(), prflxCandidate(), false, ), WantPriority: 7998392936314175487, }, { Pair: newCandidatePair( hostCandidate(), srflxCandidate(), true, ), WantPriority: 7277816996102668288, }, { Pair: newCandidatePair( hostCandidate(), srflxCandidate(), false, ), WantPriority: 7277816996102668287, }, { Pair: newCandidatePair( hostCandidate(), relayCandidate(), true, ), WantPriority: 72057593987596288, }, { Pair: newCandidatePair( hostCandidate(), relayCandidate(), false, ), WantPriority: 72057593987596287, }, } { require.Equal(t, test.Pair.priority(), test.WantPriority) } } func TestCandidatePairEquality(t *testing.T) { pairA := newCandidatePair(hostCandidate(), srflxCandidate(), true) pairB := newCandidatePair(hostCandidate(), srflxCandidate(), false) require.True(t, pairA.equal(pairB)) } func TestNilCandidatePairString(t *testing.T) { var nilCandidatePair *CandidatePair require.Equal(t, nilCandidatePair.String(), "") } func TestCandidatePairState_String(t *testing.T) { tests := []struct { name string in CandidatePairState want string }{ {"waiting", CandidatePairStateWaiting, "waiting"}, {"in-progress", CandidatePairStateInProgress, "in-progress"}, {"failed", CandidatePairStateFailed, "failed"}, {"succeeded", CandidatePairStateSucceeded, "succeeded"}, {"unknown", CandidatePairState(255), "Unknown candidate pair state"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.Equal(t, tt.want, tt.in.String()) }) } } func TestCandidatePairEqual_NilCases(t *testing.T) { // both nil -> true var a *CandidatePair var b *CandidatePair require.True(t, a.equal(b), "both nil pairs should be equal") // left non-nil, right nil -> false a = newCandidatePair(hostCandidate(), srflxCandidate(), true) require.False(t, a.equal(nil), "non-nil vs nil should be false") // left nil, right non-nil -> false require.False(t, (*CandidatePair)(nil).equal(a), "nil vs non-nil should be false") } func TestCandidatePair_TimeGetters_DefaultZero(t *testing.T) { p := newCandidatePair(hostCandidate(), srflxCandidate(), true) require.True(t, p.FirstRequestSentAt().IsZero(), "FirstRequestSentAt should be zero by default") require.True(t, p.LastRequestSentAt().IsZero(), "LastRequestSentAt should be zero by default") require.True(t, p.FirstReponseReceivedAt().IsZero(), "FirstReponseReceivedAt should be zero by default") require.True(t, p.LastResponseReceivedAt().IsZero(), "LastResponseReceivedAt should be zero by default") require.True(t, p.FirstRequestReceivedAt().IsZero(), "FirstRequestReceivedAt should be zero by default") require.True(t, p.LastRequestReceivedAt().IsZero(), "LastRequestReceivedAt should be zero by default") } ice-4.1.0/candidaterelatedaddress.go000066400000000000000000000015231511704470200174300ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import "fmt" // CandidateRelatedAddress convey transport addresses related to the // candidate, useful for diagnostics and other purposes. type CandidateRelatedAddress struct { Address string Port int } // String makes CandidateRelatedAddress printable. func (c *CandidateRelatedAddress) String() string { if c == nil { return "" } return fmt.Sprintf(" related %s:%d", c.Address, c.Port) } // Equal allows comparing two CandidateRelatedAddresses. // The CandidateRelatedAddress are allowed to be nil. func (c *CandidateRelatedAddress) Equal(other *CandidateRelatedAddress) bool { if c == nil && other == nil { return true } return c != nil && other != nil && c.Address == other.Address && c.Port == other.Port } ice-4.1.0/candidatetype.go000066400000000000000000000030121511704470200154160ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import "slices" // CandidateType represents the type of candidate. type CandidateType byte // CandidateType enum. const ( CandidateTypeUnspecified CandidateType = iota CandidateTypeHost CandidateTypeServerReflexive CandidateTypePeerReflexive CandidateTypeRelay ) // String makes CandidateType printable. func (c CandidateType) String() string { switch c { case CandidateTypeHost: return "host" case CandidateTypeServerReflexive: return "srflx" case CandidateTypePeerReflexive: return "prflx" case CandidateTypeRelay: return "relay" case CandidateTypeUnspecified: return "Unknown candidate type" } return "Unknown candidate type" } // Preference returns the preference weight of a CandidateType // // 4.1.2.2. Guidelines for Choosing Type and Local Preferences // The RECOMMENDED values are 126 for host candidates, 100 // for server reflexive candidates, 110 for peer reflexive candidates, // and 0 for relayed candidates. func (c CandidateType) Preference() uint16 { switch c { case CandidateTypeHost: return 126 case CandidateTypePeerReflexive: return 110 case CandidateTypeServerReflexive: return 100 case CandidateTypeRelay, CandidateTypeUnspecified: return 0 } return 0 } func containsCandidateType(candidateType CandidateType, candidateTypeList []CandidateType) bool { if candidateTypeList == nil { return false } return slices.Contains(candidateTypeList, candidateType) } ice-4.1.0/candidatetype_test.go000066400000000000000000000020771511704470200164670ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "github.com/stretchr/testify/require" ) func TestCandidateType_String_KnownCases(t *testing.T) { cases := map[CandidateType]string{ CandidateTypeHost: "host", CandidateTypeServerReflexive: "srflx", CandidateTypePeerReflexive: "prflx", CandidateTypeRelay: "relay", CandidateTypeUnspecified: "Unknown candidate type", } for ct, want := range cases { require.Equal(t, want, ct.String(), "unexpected string for %v", ct) } } func TestCandidateType_String_Default(t *testing.T) { const outOfBounds CandidateType = 255 require.Equal(t, "Unknown candidate type", outOfBounds.String()) } func TestCandidateType_Preference_DefaultCase(t *testing.T) { const outOfBounds CandidateType = 255 require.Equal(t, uint16(0), outOfBounds.Preference()) } func TestContainsCandidateType_NilSlice(t *testing.T) { var list []CandidateType // nil slice require.False(t, containsCandidateType(CandidateTypeHost, list)) } ice-4.1.0/codecov.yml000066400000000000000000000007151511704470200144250ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT coverage: status: project: default: # Allow decreasing 2% of total coverage to avoid noise. threshold: 2% patch: default: target: 70% only_pulls: true ignore: - "examples/*" - "examples/**/*" ice-4.1.0/connectivity_vnet_test.go000066400000000000000000000503511511704470200174210ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "context" "fmt" "net" "sync/atomic" "testing" "time" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/pion/transport/v3/vnet" "github.com/pion/turn/v4" "github.com/stretchr/testify/require" ) const ( vnetGlobalIPA = "27.1.1.1" vnetLocalIPA = "192.168.0.1" vnetLocalSubnetMaskA = "24" vnetGlobalIPB = "28.1.1.1" vnetLocalIPB = "10.2.0.1" vnetLocalSubnetMaskB = "24" vnetSTUNServerIP = "1.2.3.4" vnetSTUNServerPort = 3478 ) type virtualNet struct { wan *vnet.Router net0 *vnet.Net net1 *vnet.Net server *turn.Server } func (v *virtualNet) close() { v.server.Close() //nolint:errcheck,gosec v.wan.Stop() //nolint:errcheck,gosec } func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) { //nolint:cyclop loggerFactory := logging.NewDefaultLoggerFactory() // WAN wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: loggerFactory, }) if err != nil { return nil, err } wanNet, err := vnet.NewNet(&vnet.NetConfig{ StaticIP: vnetSTUNServerIP, // Will be assigned to eth0 }) if err != nil { return nil, err } err = wan.AddNet(wanNet) if err != nil { return nil, err } // LAN 0 lan0, err := vnet.NewRouter(&vnet.RouterConfig{ StaticIPs: func() []string { if natType0.Mode == vnet.NATModeNAT1To1 { return []string{ vnetGlobalIPA + "/" + vnetLocalIPA, } } return []string{ vnetGlobalIPA, } }(), CIDR: vnetLocalIPA + "/" + vnetLocalSubnetMaskA, NATType: natType0, LoggerFactory: loggerFactory, }) if err != nil { return nil, err } net0, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{vnetLocalIPA}, }) if err != nil { return nil, err } err = lan0.AddNet(net0) if err != nil { return nil, err } err = wan.AddRouter(lan0) if err != nil { return nil, err } // LAN 1 lan1, err := vnet.NewRouter(&vnet.RouterConfig{ StaticIPs: func() []string { if natType1.Mode == vnet.NATModeNAT1To1 { return []string{ vnetGlobalIPB + "/" + vnetLocalIPB, } } return []string{ vnetGlobalIPB, } }(), CIDR: vnetLocalIPB + "/" + vnetLocalSubnetMaskB, NATType: natType1, LoggerFactory: loggerFactory, }) if err != nil { return nil, err } net1, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{vnetLocalIPB}, }) if err != nil { return nil, err } err = lan1.AddNet(net1) if err != nil { return nil, err } err = wan.AddRouter(lan1) if err != nil { return nil, err } // Start routers err = wan.Start() if err != nil { return nil, err } server, err := addVNetSTUN(wanNet, loggerFactory) if err != nil { return nil, err } return &virtualNet{ wan: wan, net0: net0, net1: net1, server: server, }, nil } func addVNetSTUN(wanNet *vnet.Net, loggerFactory logging.LoggerFactory) (*turn.Server, error) { // Run TURN(STUN) server credMap := map[string]string{} credMap["user"] = "pass" wanNetPacketConn, err := wanNet.ListenPacket("udp", fmt.Sprintf("%s:%d", vnetSTUNServerIP, vnetSTUNServerPort)) if err != nil { return nil, err } server, err := turn.NewServer(turn.ServerConfig{ AuthHandler: func(username, realm string, _ net.Addr) (key []byte, ok bool) { if pw, ok := credMap[username]; ok { return turn.GenerateAuthKey(username, realm, pw), true } return nil, false }, PacketConnConfigs: []turn.PacketConnConfig{ { PacketConn: wanNetPacketConn, RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ RelayAddress: net.ParseIP(vnetSTUNServerIP), Address: "0.0.0.0", Net: wanNet, }, }, }, Realm: "pion.ly", LoggerFactory: loggerFactory, }) if err != nil { return nil, err } return server, err } func connectWithVNet(t *testing.T, aAgent, bAgent *Agent) (*Conn, *Conn) { t.Helper() // Manual signaling aUfrag, aPwd, err := aAgent.GetLocalUserCredentials() require.NoError(t, err) bUfrag, bPwd, err := bAgent.GetLocalUserCredentials() require.NoError(t, err) gatherAndExchangeCandidates(t, aAgent, bAgent) accepted := make(chan struct{}) var aConn *Conn go func() { var acceptErr error aConn, acceptErr = aAgent.Accept(context.TODO(), bUfrag, bPwd) require.NoError(t, acceptErr) close(accepted) }() bConn, err := bAgent.Dial(context.TODO(), aUfrag, aPwd) require.NoError(t, err) // Ensure accepted <-accepted return aConn, bConn } type agentTestConfig struct { urls []*stun.URI nat1To1IPCandidateType CandidateType } func pipeWithVNet(t *testing.T, vnet *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) { t.Helper() aNotifier, aConnected := onConnected() bNotifier, bConnected := onConnected() var nat1To1IPs []string if a0TestConfig.nat1To1IPCandidateType != CandidateTypeUnspecified { nat1To1IPs = []string{ vnetGlobalIPA, } } cfg0 := &AgentConfig{ Urls: a0TestConfig.urls, NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, NAT1To1IPs: nat1To1IPs, NAT1To1IPCandidateType: a0TestConfig.nat1To1IPCandidateType, Net: vnet.net0, } aAgent, err := NewAgent(cfg0) require.NoError(t, err) require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) if a1TestConfig.nat1To1IPCandidateType != CandidateTypeUnspecified { nat1To1IPs = []string{ vnetGlobalIPB, } } cfg1 := &AgentConfig{ Urls: a1TestConfig.urls, NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, NAT1To1IPs: nat1To1IPs, NAT1To1IPCandidateType: a1TestConfig.nat1To1IPCandidateType, Net: vnet.net1, } bAgent, err := NewAgent(cfg1) require.NoError(t, err) require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) aConn, bConn := connectWithVNet(t, aAgent, bAgent) // Ensure pair selected // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair <-aConnected <-bConnected return aConn, bConn } func pipeWithVNetUsingOptions(t *testing.T, opts0, opts1 []AgentOption) (*Conn, *Conn) { t.Helper() aNotifier, aConnected := onConnected() bNotifier, bConnected := onConnected() aAgent, err := NewAgentWithOptions(opts0...) require.NoError(t, err) if err = aAgent.OnConnectionStateChange(aNotifier); err != nil { require.NoError(t, err) } bAgent, err := NewAgentWithOptions(opts1...) require.NoError(t, err) if err = bAgent.OnConnectionStateChange(bNotifier); err != nil { require.NoError(t, err) } t.Cleanup(func() { require.NoError(t, aAgent.Close(), "failed to close agent0") require.NoError(t, bAgent.Close(), "failed to close agent1") }) aConn, bConn := connectWithVNet(t, aAgent, bAgent) <-aConnected <-bConnected return aConn, bConn } func closePipe(t *testing.T, ca *Conn, cb *Conn) { t.Helper() require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) } func TestConnectivityVNet(t *testing.T) { defer test.CheckRoutines(t)() stunServerURL := &stun.URI{ Scheme: stun.SchemeTypeSTUN, Host: vnetSTUNServerIP, Port: vnetSTUNServerPort, Proto: stun.ProtoTypeUDP, } turnServerURL := &stun.URI{ Scheme: stun.SchemeTypeTURN, Host: vnetSTUNServerIP, Port: vnetSTUNServerPort, Username: "user", Password: "pass", Proto: stun.ProtoTypeUDP, } t.Run("Full-cone NATs on both ends", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") // buildVNet with a Full-cone NATs both LANs natType := &vnet.NATType{ MappingBehavior: vnet.EndpointIndependent, FilteringBehavior: vnet.EndpointIndependent, } vnet, err := buildVNet(natType, natType) require.NoError(t, err, "should succeed") defer vnet.close() log.Debug("Connecting...") a0TestConfig := &agentTestConfig{ urls: []*stun.URI{ stunServerURL, }, } a1TestConfig := &agentTestConfig{ urls: []*stun.URI{ stunServerURL, }, } ca, cb := pipeWithVNet(t, vnet, a0TestConfig, a1TestConfig) time.Sleep(1 * time.Second) log.Debug("Closing...") closePipe(t, ca, cb) }) t.Run("Symmetric NATs on both ends", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") // buildVNet with a Symmetric NATs for both LANs natType := &vnet.NATType{ MappingBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent, } vnet, err := buildVNet(natType, natType) require.NoError(t, err, "should succeed") defer vnet.close() log.Debug("Connecting...") a0TestConfig := &agentTestConfig{ urls: []*stun.URI{ stunServerURL, turnServerURL, }, } a1TestConfig := &agentTestConfig{ urls: []*stun.URI{ stunServerURL, }, } ca, cb := pipeWithVNet(t, vnet, a0TestConfig, a1TestConfig) log.Debug("Closing...") closePipe(t, ca, cb) }) t.Run("1:1 NAT with host candidate vs Symmetric NATs", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") // Agent0 is behind 1:1 NAT natType0 := &vnet.NATType{ Mode: vnet.NATModeNAT1To1, } // Agent1 is behind a symmetric NAT natType1 := &vnet.NATType{ MappingBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent, } vnet, err := buildVNet(natType0, natType1) require.NoError(t, err, "should succeed") defer vnet.close() log.Debug("Connecting...") a0TestConfig := &agentTestConfig{ urls: []*stun.URI{}, nat1To1IPCandidateType: CandidateTypeHost, // Use 1:1 NAT IP as a host candidate } a1TestConfig := &agentTestConfig{ urls: []*stun.URI{}, } ca, cb := pipeWithVNet(t, vnet, a0TestConfig, a1TestConfig) log.Debug("Closing...") closePipe(t, ca, cb) }) t.Run("1:1 NAT with srflx candidate vs Symmetric NATs", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") // Agent0 is behind 1:1 NAT natType0 := &vnet.NATType{ Mode: vnet.NATModeNAT1To1, } // Agent1 is behind a symmetric NAT natType1 := &vnet.NATType{ MappingBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent, } vnet, err := buildVNet(natType0, natType1) require.NoError(t, err, "should succeed") defer vnet.close() log.Debug("Connecting...") a0TestConfig := &agentTestConfig{ urls: []*stun.URI{}, nat1To1IPCandidateType: CandidateTypeServerReflexive, // Use 1:1 NAT IP as a srflx candidate } a1TestConfig := &agentTestConfig{ urls: []*stun.URI{}, } ca, cb := pipeWithVNet(t, vnet, a0TestConfig, a1TestConfig) log.Debug("Closing...") closePipe(t, ca, cb) }) } func TestConnectivityVNetWithAddressRewriteRuleOptions(t *testing.T) { defer test.CheckRoutines(t)() t.Run("host candidate mapping with options", func(t *testing.T) { natType0 := &vnet.NATType{Mode: vnet.NATModeNAT1To1} natType1 := &vnet.NATType{ MappingBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent, } vnet, err := buildVNet(natType0, natType1) require.NoError(t, err) defer vnet.close() agent0Opts := []AgentOption{ WithNet(vnet.net0), WithNetworkTypes(supportedNetworkTypes()), WithMulticastDNSMode(MulticastDNSModeDisabled), WithAddressRewriteRules(AddressRewriteRule{ External: []string{vnetGlobalIPA}, AsCandidateType: CandidateTypeHost, }), } agent1Opts := []AgentOption{ WithNet(vnet.net1), WithNetworkTypes(supportedNetworkTypes()), WithMulticastDNSMode(MulticastDNSModeDisabled), } ca, cb := pipeWithVNetUsingOptions(t, agent0Opts, agent1Opts) closePipe(t, ca, cb) }) t.Run("srflx candidate mapping with options", func(t *testing.T) { natType0 := &vnet.NATType{Mode: vnet.NATModeNAT1To1} natType1 := &vnet.NATType{ MappingBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent, } vnet, err := buildVNet(natType0, natType1) require.NoError(t, err) defer vnet.close() agent0Opts := []AgentOption{ WithNet(vnet.net0), WithNetworkTypes(supportedNetworkTypes()), WithMulticastDNSMode(MulticastDNSModeDisabled), WithAddressRewriteRules(AddressRewriteRule{ External: []string{vnetGlobalIPA}, AsCandidateType: CandidateTypeServerReflexive, }), } agent1Opts := []AgentOption{ WithNet(vnet.net1), WithNetworkTypes(supportedNetworkTypes()), WithMulticastDNSMode(MulticastDNSModeDisabled), } ca, cb := pipeWithVNetUsingOptions(t, agent0Opts, agent1Opts) closePipe(t, ca, cb) }) } func TestConnectivityVNetNAT1To1SharedFoundation(t *testing.T) { defer test.CheckRoutines(t)() natType0 := &vnet.NATType{Mode: vnet.NATModeNAT1To1} natType1 := &vnet.NATType{} vnet, err := buildVNet(natType0, natType1) require.NoError(t, err) defer vnet.close() agent, err := NewAgentWithOptions( WithNet(vnet.net0), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithMulticastDNSMode(MulticastDNSModeDisabled), WithAddressRewriteRules( AddressRewriteRule{ External: []string{vnetGlobalIPA}, AsCandidateType: CandidateTypeHost, }, AddressRewriteRule{ External: []string{vnetGlobalIPA}, AsCandidateType: CandidateTypeServerReflexive, }, ), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.OnCandidate(func(Candidate) {})) require.NoError(t, agent.GatherCandidates()) require.Eventually(t, func() bool { state, stateErr := agent.GetGatheringState() require.NoError(t, stateErr) return state == GatheringStateComplete }, time.Second, 10*time.Millisecond) foundationSeen := make(map[string]struct{}) typeCount := make(map[CandidateType]int) for _, candidates := range agent.localCandidates { for _, cand := range candidates { if cand.Address() != vnetGlobalIPA { continue } if cand.Type() != CandidateTypeHost && cand.Type() != CandidateTypeServerReflexive { continue } foundation := cand.Foundation() _, dup := foundationSeen[foundation] require.Falsef(t, dup, "duplicate foundation %s for %s candidate", foundation, cand.Type()) foundationSeen[foundation] = struct{}{} typeCount[cand.Type()]++ } } require.Equal(t, 1, typeCount[CandidateTypeHost], "expected exactly one host candidate for %s", vnetGlobalIPA) require.Equal( t, 1, typeCount[CandidateTypeServerReflexive], "expected exactly one srflx candidate for %s", vnetGlobalIPA, ) } // TestDisconnectedToConnected requires that an agent can go to disconnected, // and then return to connected successfully. func TestDisconnectedToConnected(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() loggerFactory := logging.NewDefaultLoggerFactory() // Create a network with two interfaces wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: loggerFactory, }) require.NoError(t, err) var dropAllData uint64 wan.AddChunkFilter(func(vnet.Chunk) bool { return atomic.LoadUint64(&dropAllData) != 1 }) net0, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.1"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net0)) net1, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.2"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net1)) require.NoError(t, wan.Start()) disconnectTimeout := time.Second keepaliveInterval := time.Millisecond * 20 // Create two agents and connect them controllingAgent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, Net: net0, DisconnectedTimeout: &disconnectTimeout, KeepaliveInterval: &keepaliveInterval, CheckInterval: &keepaliveInterval, }) require.NoError(t, err) defer func() { require.NoError(t, controllingAgent.Close()) }() controlledAgent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, Net: net1, DisconnectedTimeout: &disconnectTimeout, KeepaliveInterval: &keepaliveInterval, CheckInterval: &keepaliveInterval, }) require.NoError(t, err) defer func() { require.NoError(t, controlledAgent.Close()) }() controllingStateChanges := make(chan ConnectionState, 100) require.NoError(t, controllingAgent.OnConnectionStateChange(func(c ConnectionState) { controllingStateChanges <- c })) controlledStateChanges := make(chan ConnectionState, 100) require.NoError(t, controlledAgent.OnConnectionStateChange(func(c ConnectionState) { controlledStateChanges <- c })) connectWithVNet(t, controllingAgent, controlledAgent) blockUntilStateSeen := func(expectedState ConnectionState, stateQueue chan ConnectionState) { for s := range stateQueue { if s == expectedState { return } } } // Assert we have gone to connected blockUntilStateSeen(ConnectionStateConnected, controllingStateChanges) blockUntilStateSeen(ConnectionStateConnected, controlledStateChanges) // Drop all packets, and block until we have gone to disconnected atomic.StoreUint64(&dropAllData, 1) blockUntilStateSeen(ConnectionStateDisconnected, controllingStateChanges) blockUntilStateSeen(ConnectionStateDisconnected, controlledStateChanges) // Allow all packets through again, block until we have gone to connected atomic.StoreUint64(&dropAllData, 0) blockUntilStateSeen(ConnectionStateConnected, controllingStateChanges) blockUntilStateSeen(ConnectionStateConnected, controlledStateChanges) require.NoError(t, wan.Stop()) } // Agent.Write should use the best valid pair if a selected pair is not yet available. func TestWriteUseValidPair(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() loggerFactory := logging.NewDefaultLoggerFactory() // Create a network with two interfaces wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: loggerFactory, }) require.NoError(t, err) wan.AddChunkFilter(func(c vnet.Chunk) bool { if stun.IsMessage(c.UserData()) { m := &stun.Message{ Raw: c.UserData(), } if decErr := m.Decode(); decErr != nil { return false } else if m.Contains(stun.AttrUseCandidate) { return false } } return true }) net0, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.1"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net0)) net1, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"192.168.0.2"}, }) require.NoError(t, err) require.NoError(t, wan.AddNet(net1)) require.NoError(t, wan.Start()) // Create two agents and connect them controllingAgent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, Net: net0, }) require.NoError(t, err) defer func() { require.NoError(t, controllingAgent.Close()) }() controlledAgent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, Net: net1, }) require.NoError(t, err) defer func() { require.NoError(t, controlledAgent.Close()) }() gatherAndExchangeCandidates(t, controllingAgent, controlledAgent) controllingUfrag, controllingPwd, err := controllingAgent.GetLocalUserCredentials() require.NoError(t, err) controlledUfrag, controlledPwd, err := controlledAgent.GetLocalUserCredentials() require.NoError(t, err) require.NoError(t, controllingAgent.startConnectivityChecks(true, controlledUfrag, controlledPwd)) require.NoError(t, controlledAgent.startConnectivityChecks(false, controllingUfrag, controllingPwd)) testMessage := []byte("Test Message") go func() { for { if _, writeErr := (&Conn{agent: controllingAgent}).Write(testMessage); writeErr != nil { return } time.Sleep(20 * time.Millisecond) } }() readBuf := make([]byte, len(testMessage)) _, err = (&Conn{agent: controlledAgent}).Read(readBuf) require.NoError(t, err) require.Equal(t, readBuf, testMessage) require.NoError(t, wan.Stop()) } ice-4.1.0/errors.go000066400000000000000000000241741511704470200141300ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "errors" "github.com/pion/ice/v4/internal/taskloop" ) var ( // ErrUnknownType indicates an error with Unknown info. ErrUnknownType = errors.New("Unknown") // ErrSchemeType indicates the scheme type could not be parsed. ErrSchemeType = errors.New("unknown scheme type") // ErrSTUNQuery indicates query arguments are provided in a STUN URL. ErrSTUNQuery = errors.New("queries not supported in STUN address") // ErrInvalidQuery indicates an malformed query is provided. ErrInvalidQuery = errors.New("invalid query") // ErrHost indicates malformed hostname is provided. ErrHost = errors.New("invalid hostname") // ErrPort indicates malformed port is provided. ErrPort = errors.New("invalid port") // ErrLocalUfragInsufficientBits indicates local username fragment insufficient bits are provided. // Have to be at least 24 bits long. ErrLocalUfragInsufficientBits = errors.New("local username fragment is less than 24 bits long") // ErrLocalPwdInsufficientBits indicates local password insufficient bits are provided. // Have to be at least 128 bits long. ErrLocalPwdInsufficientBits = errors.New("local password is less than 128 bits long") // ErrProtoType indicates an unsupported transport type was provided. ErrProtoType = errors.New("invalid transport protocol type") // ErrClosed indicates the agent is closed. ErrClosed = taskloop.ErrClosed // ErrNoCandidatePairs indicates agent does not have a valid candidate pair. ErrNoCandidatePairs = errors.New("no candidate pairs available") // ErrCanceledByCaller indicates agent connection was canceled by the caller. ErrCanceledByCaller = errors.New("connecting canceled by caller") // ErrMultipleStart indicates agent was started twice. ErrMultipleStart = errors.New("attempted to start agent twice") // ErrRemoteUfragEmpty indicates agent was started with an empty remote ufrag. ErrRemoteUfragEmpty = errors.New("remote ufrag is empty") // ErrRemotePwdEmpty indicates agent was started with an empty remote pwd. ErrRemotePwdEmpty = errors.New("remote pwd is empty") // ErrNoOnCandidateHandler indicates agent was started without OnCandidate. ErrNoOnCandidateHandler = errors.New("no OnCandidate provided") // ErrMultipleGatherAttempted indicates GatherCandidates has been called multiple times. ErrMultipleGatherAttempted = errors.New("attempting to gather candidates during gathering state") // ErrUsernameEmpty indicates agent was give TURN URL with an empty Username. ErrUsernameEmpty = errors.New("username is empty") // ErrPasswordEmpty indicates agent was give TURN URL with an empty Password. ErrPasswordEmpty = errors.New("password is empty") // ErrAddressParseFailed indicates we were unable to parse a candidate address. ErrAddressParseFailed = errors.New("failed to parse address") // ErrLiteUsingNonHostCandidates indicates non host candidates were selected for a lite agent. ErrLiteUsingNonHostCandidates = errors.New("lite agents must only use host candidates") // ErrUselessUrlsProvided indicates that one or more URL was provided to the agent but no host // candidate required them. ErrUselessUrlsProvided = errors.New("agent does not need URL with selected candidate types") // ErrUnsupportedNAT1To1IPCandidateType indicates that the specified NAT1To1IPCandidateType is // unsupported. // // Deprecated: use ErrUnsupportedAddressRewriteCandidateType instead. May still be returned // when configuring address rewrite rules while NAT1:1 compatibility remains. ErrUnsupportedNAT1To1IPCandidateType = errors.New("unsupported address rewrite candidate type") // ErrUnsupportedAddressRewriteCandidateType is an alias for ErrUnsupportedNAT1To1IPCandidateType. ErrUnsupportedAddressRewriteCandidateType = ErrUnsupportedNAT1To1IPCandidateType // ErrInvalidNAT1To1IPMapping indicates that the given 1:1 NAT IP mapping is invalid. // // Deprecated: use ErrInvalidAddressRewriteMapping instead. May still be returned by // WithAddressRewriteRules while NAT1:1 compatibility remains. ErrInvalidNAT1To1IPMapping = errors.New("invalid address rewrite mapping") // ErrInvalidAddressRewriteMapping is an alias for ErrInvalidNAT1To1IPMapping. ErrInvalidAddressRewriteMapping = ErrInvalidNAT1To1IPMapping // ErrExternalMappedIPNotFound in address rewrite mapping. // // Kept for compatibility; current code paths treat "no externals" via match state and // no longer return this error. ErrExternalMappedIPNotFound = errors.New("external mapped IP not found") // ErrMulticastDNSWithNAT1To1IPMapping indicates that the mDNS gathering cannot be used along // with 1:1 NAT IP mapping for host candidate. // // Deprecated: use ErrMulticastDNSWithAddressRewrite instead. May still be returned by // WithAddressRewriteRules while NAT1:1 compatibility remains. ErrMulticastDNSWithNAT1To1IPMapping = errors.New( "mDNS gathering cannot be used with address rewrite for host candidate", ) // ErrMulticastDNSWithAddressRewrite is an alias for ErrMulticastDNSWithNAT1To1IPMapping. ErrMulticastDNSWithAddressRewrite = ErrMulticastDNSWithNAT1To1IPMapping // ErrIneffectiveNAT1To1IPMappingHost indicates that 1:1 NAT IP mapping for host candidate is // requested, but the host candidate type is disabled. // // Deprecated: use ErrIneffectiveAddressRewriteHost instead. May still be returned by // WithAddressRewriteRules while NAT1:1 compatibility remains. ErrIneffectiveNAT1To1IPMappingHost = errors.New("address rewrite for host candidate ineffective") // ErrIneffectiveAddressRewriteHost is an alias for ErrIneffectiveNAT1To1IPMappingHost. ErrIneffectiveAddressRewriteHost = ErrIneffectiveNAT1To1IPMappingHost // ErrIneffectiveNAT1To1IPMappingSrflx indicates that 1:1 NAT IP mapping for srflx candidate is // requested, but the srflx candidate type is disabled. // // Deprecated: use ErrIneffectiveAddressRewriteSrflx instead. May still be returned by // WithAddressRewriteRules while NAT1:1 compatibility remains. ErrIneffectiveNAT1To1IPMappingSrflx = errors.New("address rewrite for srflx candidate ineffective") // ErrIneffectiveAddressRewriteSrflx is an alias for ErrIneffectiveNAT1To1IPMappingSrflx. ErrIneffectiveAddressRewriteSrflx = ErrIneffectiveNAT1To1IPMappingSrflx // ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName. ErrInvalidMulticastDNSHostName = errors.New( "invalid mDNS HostName, must end with .local and can only contain a single '.'", ) // ErrRunCanceled indicates a run operation was canceled by its individual done. ErrRunCanceled = errors.New("run was canceled by done") // ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr. ErrTCPRemoteAddrAlreadyExists = errors.New("conn with same remote addr already exists") // ErrUnknownCandidateTyp indicates that a candidate had a unknown type value. ErrUnknownCandidateTyp = errors.New("unknown candidate typ") // ErrDetermineNetworkType indicates that the NetworkType was not able to be parsed. ErrDetermineNetworkType = errors.New("unable to determine networkType") // ErrOnlyControllingAgentCanRenominate indicates that only controlling agent can renominate. ErrOnlyControllingAgentCanRenominate = errors.New("only controlling agent can renominate") // ErrRenominationNotEnabled indicates that renomination is not enabled. ErrRenominationNotEnabled = errors.New("renomination is not enabled") // ErrCandidatePairNotFound indicates that candidate pair was not found. ErrCandidatePairNotFound = errors.New("candidate pair not found") // ErrInvalidNominationAttribute indicates an invalid nomination attribute type was provided. ErrInvalidNominationAttribute = errors.New("invalid nomination attribute type") // ErrInvalidNominationValueGenerator indicates a nil nomination value generator was provided. ErrInvalidNominationValueGenerator = errors.New("nomination value generator cannot be nil") // ErrInvalidNetworkMonitorInterval indicates an invalid network monitor interval was provided. ErrInvalidNetworkMonitorInterval = errors.New("network monitor interval must be greater than 0") errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate") errClosingConnection = errors.New("failed to close connection") errConnectionAddrAlreadyExist = errors.New("connection with same remote address already exists") errGetXorMappedAddrResponse = errors.New("failed to get XOR-MAPPED-ADDRESS response") errInvalidAddress = errors.New("invalid address") errNoTCPMuxAvailable = errors.New("no TCP mux is available") errNotImplemented = errors.New("not implemented yet") errNoUDPMuxAvailable = errors.New("no UDP mux is available") errNoXorAddrMapping = errors.New("no address mapping") errParseFoundation = errors.New("failed to parse foundation") errParseComponent = errors.New("failed to parse component") errParsePort = errors.New("failed to parse port") errParsePriority = errors.New("failed to parse priority") errParseRelatedAddr = errors.New("failed to parse related addresses") errParseExtension = errors.New("failed to parse extension") errParseTCPType = errors.New("failed to parse TCP type") errUDPMuxDisabled = errors.New("UDPMux is not enabled") errUnknownRole = errors.New("unknown role") errWrite = errors.New("failed to write") errWriteSTUNMessage = errors.New("failed to send STUN message") errWriteSTUNMessageToIceConn = errors.New("failed to write STUN message to ICE connection") errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr") errFailedToCastUDPAddr = errors.New("failed to cast net.Addr to net.UDPAddr") errInvalidIPAddress = errors.New("invalid ip address") // UDPMuxDefault should not listen on unspecified address, but to keep backward compatibility, don't return error now. // will be used in the future. // errListenUnspecified = errors.New("can't listen on unspecified address"). ) ice-4.1.0/examples/000077500000000000000000000000001511704470200140735ustar00rootroot00000000000000ice-4.1.0/examples/automatic-renomination/000077500000000000000000000000001511704470200205615ustar00rootroot00000000000000ice-4.1.0/examples/automatic-renomination/README.md000066400000000000000000000442201511704470200220420ustar00rootroot00000000000000# Automatic Renomination Example This example demonstrates the ICE automatic renomination feature using real network interfaces. Automatic renomination allows the controlling ICE agent to automatically switch between candidate pairs when a better connection path becomes available. ## What is Automatic Renomination? Automatic renomination is a feature where the controlling ICE agent continuously monitors candidate pairs and automatically switches to better pairs when they become available. This is particularly useful for: - **Adapting to network changes**: When network conditions improve or degrade - **Optimizing for latency**: Automatically switching to lower-latency paths - **Quality of service**: Maintaining the best possible connection quality - **Interface failover**: Switching to alternate interfaces when primary path fails ## How It Works The automatic renomination feature evaluates candidate pairs based on: 1. **Candidate types**: Direct connections (host-to-host) are preferred over relay connections 2. **Round-trip time (RTT)**: Lower latency paths are preferred 3. **Connection stability**: Pairs that have recently received responses are favored When a candidate pair is found that is significantly better than the current selection (>10ms RTT improvement or direct vs relay), the agent automatically renominates to use the better pair. ## Quick Start Tutorial This step-by-step tutorial walks you through setting up virtual network interfaces and testing automatic renomination. **Important:** This example uses **network namespaces with two veth pairs** to properly isolate network traffic so that `tc` (traffic control) rules can affect latency. The controlled agent runs in the default namespace, and the controlling agent runs in a separate namespace (ns1). They communicate via two veth pairs, giving us multiple candidate pairs for automatic renomination. ### Step 1: Create Network Namespace with Two veth Pairs Create a network namespace and two veth pairs to connect them: ```bash # Create namespace sudo ip netns add ns1 # Create FIRST veth pair (veth0 <-> veth1) sudo ip link add veth0 type veth peer name veth1 sudo ip link set veth1 netns ns1 sudo ip addr add 192.168.100.1/24 dev veth0 sudo ip link set veth0 up sudo ip netns exec ns1 ip addr add 192.168.100.2/24 dev veth1 sudo ip netns exec ns1 ip link set veth1 up # Create SECOND veth pair (veth2 <-> veth3) sudo ip link add veth2 type veth peer name veth3 sudo ip link set veth3 netns ns1 sudo ip addr add 192.168.101.1/24 dev veth2 sudo ip link set veth2 up sudo ip netns exec ns1 ip addr add 192.168.101.2/24 dev veth3 sudo ip netns exec ns1 ip link set veth3 up # Bring up loopback in ns1 sudo ip netns exec ns1 ip link set lo up ``` **Verify connectivity on both pairs:** ```bash # Ping via first veth pair ping -c 2 192.168.100.2 # Ping via second veth pair ping -c 2 192.168.101.2 ``` You should see successful pings on both with low latency (~0.05ms). **Verify that tc rules work:** ```bash # Add 100ms latency to veth0 (first pair) sudo tc qdisc add dev veth0 root netem delay 100ms # Test ping via first pair - should now show ~100ms latency ping -c 2 192.168.100.2 # Test ping via second pair - should still be fast ping -c 2 192.168.101.2 # Remove tc rule for now sudo tc qdisc del dev veth0 root ``` After adding the tc rule to veth0, pings to 192.168.100.2 should show ~100ms latency, while pings to 192.168.101.2 remain fast. This proves tc is working and we have independent paths! ### Step 2: Start the Controlled Agent (Default Namespace) Open a terminal and start the controlled (non-controlling) agent in the **default namespace**: ```bash cd examples/automatic-renomination go run main.go ``` **Expected output:** ``` === Automatic Renomination Example === Local Agent is CONTROLLED Press 'Enter' when both processes have started ``` Don't press Enter yet - wait for Step 3. ### Step 3: Start the Controlling Agent (ns1 Namespace) Open a second terminal and start the controlling agent in the **ns1 namespace**: ```bash cd examples/automatic-renomination sudo ip netns exec ns1 go run main.go -controlling ``` **Expected output:** ``` === Automatic Renomination Example === Local Agent is CONTROLLING (with automatic renomination enabled) Press 'Enter' when both processes have started ``` ### Step 4: Connect the Agents Press Enter in **both terminals**. You should see candidate gathering and connection establishment: **Expected output (both terminals):** ``` Gathering candidates... Local candidate: candidate:... 192.168.100.x ... typ host Local candidate: candidate:... 192.168.101.x ... typ host Added remote candidate: candidate:... 192.168.100.x ... typ host Added remote candidate: candidate:... 192.168.101.x ... typ host Starting ICE connection... >>> ICE Connection State: Checking >>> SELECTED CANDIDATE PAIR CHANGED <<< Local: candidate:... 192.168.10x.x ... typ host (type: host) Remote: candidate:... 192.168.10x.x ... typ host (type: host) >>> ICE Connection State: Connected === CONNECTED === ``` You should see **2 local candidates** and **2 remote candidates**, giving you **4 candidate pairs** total. **Controlling agent will also show:** ``` Automatic renomination is enabled on the controlling agent. To test it, you can use traffic control (tc) to change network conditions: # Add 100ms latency to eth0: sudo tc qdisc add dev eth0 root netem delay 100ms # Remove the latency: sudo tc qdisc del dev eth0 root Watch for "SELECTED CANDIDATE PAIR CHANGED" messages above. The agent will automatically renominate to better paths when detected. ``` You should also see periodic messages being exchanged with RTT information: ``` Sent: Message #1 from controlling agent [RTT: 0.35ms] Received: Message #1 from controlled agent [RTT: 0.35ms] ``` ### Step 5: Add Latency to Trigger Renomination In a third terminal, add latency to veth0 (the first veth pair): ```bash # Add 100ms latency to veth0 sudo tc qdisc add dev veth0 root netem delay 100ms # Check that the rule was applied sudo tc qdisc show dev veth0 ``` **Expected output:** ``` qdisc netem 8001: root refcnt 2 limit 1000 delay 100ms ``` ### Step 6: Observe Automatic Renomination Watch the **controlling agent's terminal**. Look at the debug output showing candidate pair states: ``` === DEBUG: Candidate Pair States === candidate:... <-> candidate:... State: succeeded, Nominated: true RTT: 100.xx ms <-- This pair now has high latency! candidate:... <-> candidate:... State: succeeded, Nominated: false RTT: 0.3x ms <-- This pair is still fast! =================================== ``` Within a few seconds (based on the renomination interval of 3 seconds), once the RTT difference exceeds 10ms, you should see: **Expected output:** ``` >>> SELECTED CANDIDATE PAIR CHANGED <<< Local: candidate:... 192.168.101.x ... typ host (type: host) Remote: candidate:... 192.168.101.x ... typ host (type: host) ``` This shows the agent automatically switched from the slow path (192.168.100.x) to the fast path (192.168.101.x)! **What to look for:** - RTT on the nominated pair increases from ~0.3ms to ~100ms after adding tc rule - The RTT displayed in sent/received messages will also increase to ~100ms - After 3-6 seconds, renomination triggers - New selected pair uses the 192.168.101.x addresses (veth2/veth3) - RTT in both the debug output and sent/received messages drops back to ~0.3ms ### Step 7: Remove Latency and Observe Switch Back Remove the latency from veth0: ```bash sudo tc qdisc del dev veth0 root ``` Wait a few seconds and watch for another renomination event as the agent switches back to the now-improved path. ### Step 8: Clean Up When you're done testing, clean up the namespace and both veth pairs: ```bash # Stop both agents with Ctrl+C first # Remove any traffic control rules sudo tc qdisc del dev veth0 root 2>/dev/null || true sudo tc qdisc del dev veth2 root 2>/dev/null || true # Remove namespace (this automatically removes veth1 and veth3) sudo ip netns del ns1 2>/dev/null || true # Remove veth0 and veth2 from default namespace sudo ip link delete veth0 2>/dev/null || true sudo ip link delete veth2 2>/dev/null || true # Verify cleanup ip link show | grep veth # Should show no output ip netns list | grep ns1 # Should show no output ``` ## Running the Example This example requires running two processes - one controlling and one controlled. ### Terminal 1 (Controlling Agent) ```bash go run main.go -controlling ``` ### Terminal 2 (Controlled Agent) ```bash go run main.go ``` Press Enter in both terminals once both processes are running to start the ICE connection. ## Testing Automatic Renomination Once connected, you'll see messages like: ``` === CONNECTED === Automatic renomination is enabled on the controlling agent. To test it, you can use traffic control (tc) to change network conditions: # Add 100ms latency to eth0: sudo tc qdisc add dev eth0 root netem delay 100ms # Remove the latency: sudo tc qdisc del dev eth0 root Watch for "SELECTED CANDIDATE PAIR CHANGED" messages above. ``` ### Using Traffic Control (tc) to Trigger Renomination Traffic control (`tc`) is a Linux tool for simulating network conditions. Here are some useful commands: #### Add latency to an interface ```bash # Add 100ms delay sudo tc qdisc add dev eth0 root netem delay 100ms # Add variable latency (50ms ± 10ms) sudo tc qdisc add dev eth0 root netem delay 50ms 10ms ``` #### Simulate packet loss ```bash # Add 5% packet loss sudo tc qdisc add dev eth0 root netem loss 5% ``` #### Limit bandwidth ```bash # Limit to 1mbit sudo tc qdisc add dev eth0 root tbf rate 1mbit burst 32kbit latency 400ms ``` #### Remove all rules ```bash # Remove all tc rules from interface sudo tc qdisc del dev eth0 root ``` #### Check current rules ```bash # View current tc configuration sudo tc qdisc show dev eth0 ``` ### Expected Behavior When you change network conditions with `tc`, watch the controlling agent's output: 1. Initial connection will select the best available path 2. When you add latency/loss, the RTT increases 3. If the RTT difference exceeds 10ms, automatic renomination may trigger 4. You'll see a "SELECTED CANDIDATE PAIR CHANGED" message 5. The connection continues using the new pair **Note**: Renomination only occurs if a significantly better pair is available. Simply degrading one path won't trigger renomination unless there's an alternate path that's measurably better. ## Understanding the Output ### Connection State Changes ``` >>> ICE Connection State: Checking >>> ICE Connection State: Connected ``` These show the overall ICE connection state progression. ### Candidate Discovery ``` Local candidate: candidate:1 1 udp 2130706431 192.168.1.100 54321 typ host Added remote candidate: candidate:2 1 udp 2130706431 192.168.1.101 54322 typ host ``` These show discovered local and remote ICE candidates. ### Pair Selection ``` >>> SELECTED CANDIDATE PAIR CHANGED <<< Local: candidate:1 1 udp 2130706431 192.168.1.100 54321 typ host (type: host) Remote: candidate:2 1 udp 2130706431 192.168.1.101 54322 typ host (type: host) ``` This indicates automatic renomination occurred and shows the new selected pair. ### Message RTT Display ``` Sent: Message #1 from controlling agent [RTT: 0.35ms] Received: Message #1 from controlled agent [RTT: 0.35ms] ``` Each sent and received message displays the current Round-Trip Time (RTT) of the selected candidate pair. This RTT value: - Shows the current latency of the connection path - Updates in real-time as network conditions change - Helps verify that automatic renomination is working (RTT should improve after switching to a better path) - May show "N/A" briefly during initial connection or if RTT hasn't been measured yet ## Testing Scenarios ### Scenario 1: Interface Latency Change 1. Start both agents 2. Wait for connection 3. Add latency to one interface: `sudo tc qdisc add dev eth0 root netem delay 100ms` 4. If multiple interfaces exist with different RTTs, automatic renomination should occur 5. Remove latency: `sudo tc qdisc del dev eth0 root` ### Scenario 2: Multiple Network Interfaces If your machine has multiple network interfaces (e.g., WiFi and Ethernet): 1. Start agents connected via one interface 2. Degrade that interface with `tc` 3. The agent should automatically switch to the other interface if it provides better quality ### Scenario 3: Connection Recovery 1. Start with one interface having high latency 2. Once connected, remove the latency 3. The agent should detect the improved path and switch back ## Configuration Options You can modify the example to customize automatic renomination: ### Renomination Interval ```go renominationInterval := 5 * time.Second // How often to check (default: 3s) iceAgent, err = ice.NewAgentWithOptions( ice.WithNetworkTypes([]ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}), ice.WithInterfaceFilter(interfaceFilter), ice.WithRenomination(ice.DefaultNominationValueGenerator()), ice.WithAutomaticRenomination(renominationInterval), ) ``` ### Interface Filter The example uses an `InterfaceFilter` to constrain the ICE agent to only use veth interfaces: ```go interfaceFilter := func(interfaceName string) bool { // Allow all veth interfaces (veth0, veth1, veth2, veth3) // This gives us multiple candidate pairs for automatic renomination return len(interfaceName) >= 4 && interfaceName[:4] == "veth" } ``` To use your real network interfaces instead: ```go // Option 1: Use all interfaces (no InterfaceFilter) iceAgent, err = ice.NewAgentWithOptions( ice.WithNetworkTypes([]ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}), ice.WithRenomination(ice.DefaultNominationValueGenerator()), ice.WithAutomaticRenomination(renominationInterval), ) // Option 2: Filter to specific real interfaces (e.g., eth0 and wlan0) interfaceFilter := func(interfaceName string) bool { return interfaceName == "eth0" || interfaceName == "wlan0" } // Option 3: Use only IPv4 interfaces starting with "eth" interfaceFilter := func(interfaceName string) bool { return strings.HasPrefix(interfaceName, "eth") } ``` **Note:** When using real network interfaces without network namespaces, you'll need to run the two processes on different machines to properly test tc rules, as local traffic on the same machine may bypass tc. ## Troubleshooting ### ICE agent not using veth interfaces If you see candidates on your real interfaces (like eth0, eth2, wlan0) instead of veth0/veth1: - **Check the InterfaceFilter**: Make sure the code has the `InterfaceFilter` configured to only allow veth0 and veth1 - **Verify veth interfaces exist**: Run `ip link show veth0` (and `sudo ip netns exec ns1 ip link show veth1`) to confirm they're created - **Verify interfaces are UP**: Run `ip link show veth0` and check for `UP` in the output - **Check IP addresses**: Run `ip addr show veth0` and `sudo ip netns exec ns1 ip addr show veth1` to confirm the 192.168.100.x addresses are assigned ### No candidates found / Connection fails If the agents fail to connect after adding the InterfaceFilter: - **Create dummy interfaces first**: The dummy interfaces must be created before starting the agents - **Both agents need the filter**: Both controlling and controlled agents must have the same InterfaceFilter - **Check for errors**: Look for errors during candidate gathering that might indicate interface issues ### No renomination occurring - **Only one candidate pair available**: Automatic renomination needs alternate paths to switch between. If only one candidate pair exists, renomination won't occur. - **Insufficient quality difference**: The new path must be significantly better (>10ms RTT improvement or better candidate type) to trigger renomination. - **Not enough time elapsed**: The renomination interval (default 3s) must pass before evaluation occurs. - **Wrong interface**: Make sure you're adding latency to the interface that's actually being used. Check the "SELECTED CANDIDATE PAIR CHANGED" message to see which interface/IP is in use. ### Permission denied with tc commands All `tc` commands require root privileges. Use `sudo` before each command. ### Interface name not found Use `ip link` to list available network interfaces on your system. Replace `eth0` with your actual interface name (e.g., `enp0s3`, `wlan0`, `wlp3s0`). ## Cleanup After testing, it's important to clean up any virtual interfaces and traffic control rules you created. ### Remove Traffic Control Rules If you added any tc rules to interfaces, remove them: ```bash # Remove tc rules from veth interfaces sudo tc qdisc del dev veth0 root 2>/dev/null || true sudo tc qdisc del dev veth2 root 2>/dev/null || true # List all interfaces with tc rules tc qdisc show # Remove tc rules from any interface shown above # sudo tc qdisc del dev root ``` ### Remove Network Namespace and veth Pairs If you created the namespace and veth pairs for testing, remove them: ```bash # Remove namespace (automatically removes veth1 and veth3) sudo ip netns del ns1 2>/dev/null || true # Remove veth0 and veth2 from default namespace sudo ip link delete veth0 2>/dev/null || true sudo ip link delete veth2 2>/dev/null || true # Verify removal ip link show | grep veth # Should show no output ip netns list | grep ns1 # Should show no output ``` ### Verify Clean State Check that everything is cleaned up: ```bash # Check for any remaining tc rules tc qdisc show # Check for veth interfaces ip link show | grep veth # Check for namespaces ip netns list # All commands should show no veth interfaces or ns1 namespace if cleanup was successful ``` ## Additional Resources - [ICE RFC 8445](https://datatracker.ietf.org/doc/html/rfc8445) - Interactive Connectivity Establishment (ICE) Protocol - [draft-thatcher-ice-renomination](https://datatracker.ietf.org/doc/html/draft-thatcher-ice-renomination-01) - ICE Renomination Specification - [Linux tc man page](https://man7.org/linux/man-pages/man8/tc.8.html) - Traffic Control documentation - [netem documentation](https://wiki.linuxfoundation.org/networking/netem) - Network Emulation guide - [Linux network namespaces](https://man7.org/linux/man-pages/man8/ip-netns.8.html) - Network namespace documentation - [veth pairs](https://man7.org/linux/man-pages/man4/veth.4.html) - Virtual ethernet pair documentation ice-4.1.0/examples/automatic-renomination/main.go000066400000000000000000000211261511704470200220360ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT // Package main demonstrates automatic renomination with real network interfaces. // Run two instances of this program - one controlling and one controlled - and use // traffic control (tc) commands to simulate network changes and trigger automatic renomination. package main import ( "bufio" "context" "flag" "fmt" "net/http" "net/url" "os" "time" "github.com/pion/ice/v4" "github.com/pion/logging" ) const ( rttNotAvailable = "N/A" ) //nolint:gochecknoglobals var ( isControlling bool iceAgent *ice.Agent remoteAuthChannel chan string localHTTPPort, remoteHTTPPort int localHTTPAddr, remoteHTTPAddr string selectedLocalCandidateID string selectedRemoteCandidateID string ) // getRTT returns the current RTT for the selected candidate pair. func getRTT() string { if selectedLocalCandidateID == "" || selectedRemoteCandidateID == "" { return rttNotAvailable } stats := iceAgent.GetCandidatePairsStats() for _, stat := range stats { if stat.LocalCandidateID == selectedLocalCandidateID && stat.RemoteCandidateID == selectedRemoteCandidateID { if stat.CurrentRoundTripTime > 0 { return fmt.Sprintf("%.2fms", stat.CurrentRoundTripTime*1000) } return rttNotAvailable } } return rttNotAvailable } // HTTP Listener to get ICE Credentials from remote Peer. func remoteAuth(_ http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { panic(err) } remoteAuthChannel <- r.PostForm["ufrag"][0] remoteAuthChannel <- r.PostForm["pwd"][0] } // HTTP Listener to get ICE Candidate from remote Peer. func remoteCandidate(_ http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { panic(err) } c, err := ice.UnmarshalCandidate(r.PostForm["candidate"][0]) if err != nil { panic(err) } if err := iceAgent.AddRemoteCandidate(c); err != nil { //nolint:contextcheck panic(err) } fmt.Printf("Added remote candidate: %s\n", c) } func main() { //nolint:cyclop,maintidx var ( err error conn *ice.Conn ) remoteAuthChannel = make(chan string, 3) flag.BoolVar(&isControlling, "controlling", false, "is ICE Agent controlling") flag.Parse() if isControlling { // Controlling agent runs in ns1 namespace localHTTPPort = 9000 remoteHTTPPort = 9001 localHTTPAddr = "192.168.100.2" // veth1 in ns1 remoteHTTPAddr = "192.168.100.1" // veth0 in default namespace } else { // Controlled agent runs in default namespace localHTTPPort = 9001 remoteHTTPPort = 9000 localHTTPAddr = "192.168.100.1" // veth0 in default namespace remoteHTTPAddr = "192.168.100.2" // veth1 in ns1 } http.HandleFunc("/remoteAuth", remoteAuth) http.HandleFunc("/remoteCandidate", remoteCandidate) go func() { if err = http.ListenAndServe(fmt.Sprintf("%s:%d", localHTTPAddr, localHTTPPort), nil); err != nil { //nolint:gosec panic(err) } }() fmt.Println("=== Automatic Renomination Example ===") if isControlling { fmt.Println("Local Agent is CONTROLLING (with automatic renomination enabled)") } else { fmt.Println("Local Agent is CONTROLLED") } fmt.Println() fmt.Print("Press 'Enter' when both processes have started") if _, err = bufio.NewReader(os.Stdin).ReadBytes('\n'); err != nil { panic(err) } // Create the ICE agent with automatic renomination enabled on the controlling side // Use InterfaceFilter to constrain to veth interfaces for testing interfaceFilter := func(interfaceName string) bool { // Allow all veth interfaces (veth0, veth1, veth2, veth3) // This gives us multiple candidate pairs for automatic renomination return len(interfaceName) >= 4 && interfaceName[:4] == "veth" } // Create a logger factory with Debug level enabled loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory.DefaultLogLevel = logging.LogLevelDebug if isControlling { renominationInterval := 3 * time.Second iceAgent, err = ice.NewAgentWithOptions( ice.WithNetworkTypes([]ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}), ice.WithInterfaceFilter(interfaceFilter), ice.WithLoggerFactory(loggerFactory), ice.WithRenomination(ice.DefaultNominationValueGenerator()), ice.WithAutomaticRenomination(renominationInterval), ) } else { iceAgent, err = ice.NewAgentWithOptions( ice.WithNetworkTypes([]ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}), ice.WithInterfaceFilter(interfaceFilter), ice.WithLoggerFactory(loggerFactory), ) } if err != nil { panic(err) } // When we have gathered a new ICE Candidate send it to the remote peer if err = iceAgent.OnCandidate(func(c ice.Candidate) { if c == nil { return } fmt.Printf("Local candidate: %s\n", c) _, err = http.PostForm(fmt.Sprintf("http://%s:%d/remoteCandidate", remoteHTTPAddr, remoteHTTPPort), //nolint url.Values{ "candidate": {c.Marshal()}, }) if err != nil { panic(err) } }); err != nil { panic(err) } // When ICE Connection state has changed print to stdout if err = iceAgent.OnConnectionStateChange(func(c ice.ConnectionState) { fmt.Printf(">>> ICE Connection State: %s\n", c.String()) }); err != nil { panic(err) } // When selected candidate pair changes, print it if err = iceAgent.OnSelectedCandidatePairChange(func(local, remote ice.Candidate) { // Track the selected candidate IDs for RTT lookup selectedLocalCandidateID = local.ID() selectedRemoteCandidateID = remote.ID() fmt.Println() fmt.Println(">>> SELECTED CANDIDATE PAIR CHANGED <<<") fmt.Printf(" Local: %s (type: %s)\n", local, local.Type()) fmt.Printf(" Remote: %s (type: %s)\n", remote, remote.Type()) fmt.Println() }); err != nil { panic(err) } // Get the local auth details and send to remote peer localUfrag, localPwd, err := iceAgent.GetLocalUserCredentials() if err != nil { panic(err) } _, err = http.PostForm(fmt.Sprintf("http://%s:%d/remoteAuth", remoteHTTPAddr, remoteHTTPPort), //nolint url.Values{ "ufrag": {localUfrag}, "pwd": {localPwd}, }) if err != nil { panic(err) } remoteUfrag := <-remoteAuthChannel remotePwd := <-remoteAuthChannel if err = iceAgent.GatherCandidates(); err != nil { panic(err) } fmt.Println("Gathering candidates...") time.Sleep(2 * time.Second) // Give time for candidate gathering // Start the ICE Agent. One side must be controlled, and the other must be controlling fmt.Println("Starting ICE connection...") if isControlling { conn, err = iceAgent.Dial(context.Background(), remoteUfrag, remotePwd) } else { conn, err = iceAgent.Accept(context.Background(), remoteUfrag, remotePwd) } if err != nil { panic(err) } fmt.Println() fmt.Println("=== CONNECTED ===") fmt.Println() if isControlling { fmt.Println("Automatic renomination is enabled on the controlling agent.") fmt.Println("To test it, you can use traffic control (tc) to change network conditions:") fmt.Println() fmt.Println(" # Add 100ms latency to veth0:") fmt.Println(" sudo tc qdisc add dev veth0 root netem delay 100ms") fmt.Println() fmt.Println(" # Remove the latency:") fmt.Println(" sudo tc qdisc del dev veth0 root") fmt.Println() fmt.Println("Watch for \"SELECTED CANDIDATE PAIR CHANGED\" messages above.") fmt.Println("The agent will automatically renominate to better paths when detected.") fmt.Println() // Print debug info about candidate pairs every 5 seconds go func() { for { time.Sleep(5 * time.Second) fmt.Println() fmt.Println("=== DEBUG: Candidate Pair States ===") stats := iceAgent.GetCandidatePairsStats() for _, stat := range stats { fmt.Printf(" %s <-> %s\n", stat.LocalCandidateID, stat.RemoteCandidateID) fmt.Printf(" State: %s, Nominated: %v\n", stat.State, stat.Nominated) if stat.CurrentRoundTripTime > 0 { fmt.Printf(" RTT: %.2fms\n", stat.CurrentRoundTripTime*1000) } } fmt.Println("===================================") fmt.Println() } }() } // Send a message every 5 seconds go func() { counter := 0 for { time.Sleep(5 * time.Second) counter++ role := "controlling" if !isControlling { role = "controlled" } msg := fmt.Sprintf("Message #%d from %s agent", counter, role) if _, err = conn.Write([]byte(msg)); err != nil { fmt.Printf("Write error: %v\n", err) return } fmt.Printf("Sent: %s [RTT: %s]\n", msg, getRTT()) } }() // Receive messages in a loop from the remote peer buf := make([]byte, 1500) for { n, err := conn.Read(buf) if err != nil { fmt.Printf("Read error: %v\n", err) return } fmt.Printf("Received: %s [RTT: %s]\n", string(buf[:n]), getRTT()) } } ice-4.1.0/examples/continual-gathering/000077500000000000000000000000001511704470200200355ustar00rootroot00000000000000ice-4.1.0/examples/continual-gathering/README.md000066400000000000000000000040061511704470200213140ustar00rootroot00000000000000# Continual Gathering Example This example demonstrates the `ContinualGatheringPolicy` feature in Pion ICE, which allows agents to continuously discover network candidates throughout a connection's lifetime. ## Overview Traditional ICE gathering (`GatherOnce`) collects candidates once at startup and stops. This can be problematic when: - Users switch between WiFi and cellular networks - Network interfaces are added/removed - Moving between access points ("walk-out-the-door" problem) With `GatherContinually`, the agent monitors for network changes and automatically discovers new candidates, enabling seamless connectivity transitions. ## Usage ```bash # Traditional gathering (stops after initial collection) go run main.go -mode once # Continual gathering (monitors for network changes) go run main.go -mode continually -interval 2s ``` ## Testing While running in continual mode, try: - Connecting/disconnecting WiFi - Enabling/disabling network adapters - Switching between networks New candidates will be discovered and reported automatically! ### Testing with Virtual Network Adapters (Linux) You can easily test the continual gathering by creating/removing virtual network adapters: ```bash # Create a virtual network adapter with an IP address sudo ip link add veth0 type veth peer name veth1 sudo ip addr add 10.0.0.1/24 dev veth0 sudo ip link set veth0 up # Wait a few seconds, then check the example output # You should see new candidates for 10.0.0.1 # Remove the virtual adapter sudo ip link delete veth0 # The removed interface will be detected and logged ``` Alternative using dummy interface: ```bash # Create a dummy interface (simpler, no peer needed) sudo ip link add dummy0 type dummy sudo ip addr add 192.168.100.1/24 dev dummy0 sudo ip link set dummy0 up # Remove it sudo ip link delete dummy0 ``` You can also change IP addresses on existing interfaces: ```bash # Add a secondary IP to an existing interface sudo ip addr add 172.16.0.1/24 dev eth0 # Remove it sudo ip addr del 172.16.0.1/24 dev eth0 ``` ice-4.1.0/examples/continual-gathering/main.go000066400000000000000000000105511511704470200213120ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT // Package main demonstrates the ContinualGatheringPolicy feature package main import ( "context" "flag" "fmt" "log" "os" "os/signal" "syscall" "time" "github.com/pion/ice/v4" "github.com/pion/logging" ) func main() { //nolint:cyclop var gatheringMode string var monitorInterval time.Duration flag.StringVar(&gatheringMode, "mode", "continually", "Gathering mode: 'once' or 'continually'") flag.DurationVar(&monitorInterval, "interval", 2*time.Second, "Network monitoring interval (for continual mode)") flag.Parse() // Determine gathering policy var policy ice.ContinualGatheringPolicy switch gatheringMode { case "once": policy = ice.GatherOnce fmt.Println("Using GatherOnce policy - gathering will complete after initial collection") case "continually": policy = ice.GatherContinually fmt.Printf("Using GatherContinually policy - monitoring for network changes every %v\n", monitorInterval) default: log.Fatalf("Invalid mode: %s. Use 'once' or 'continually'", gatheringMode) } // Create logger loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory.DefaultLogLevel = logging.LogLevelDebug // Create ICE agent with the specified gathering policy using AgentOptions agent, err := ice.NewAgentWithOptions( ice.WithNetworkTypes([]ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}), ice.WithCandidateTypes([]ice.CandidateType{ice.CandidateTypeHost}), ice.WithContinualGatheringPolicy(policy), ice.WithNetworkMonitorInterval(monitorInterval), ) if err != nil { log.Fatalf("Failed to create agent: %v", err) } defer func() { if closeErr := agent.Close(); closeErr != nil { log.Printf("Failed to close agent: %v", closeErr) } }() // Track candidates candidateCount := 0 candidateMap := make(map[string]ice.Candidate) // Set up candidate handler err = agent.OnCandidate(func(candidate ice.Candidate) { if candidate == nil { if policy == ice.GatherOnce { fmt.Println("\n=== Gathering completed (no more candidates) ===") } return } candidateCount++ candidateID := candidate.String() if _, exists := candidateMap[candidateID]; !exists { candidateMap[candidateID] = candidate fmt.Printf("[%s] Candidate #%d: %s\n", time.Now().Format("15:04:05"), candidateCount, candidate) } }) if err != nil { log.Fatalf("Failed to set candidate handler: %v", err) //nolint:gocritic } // Start gathering fmt.Println("\n=== Starting candidate gathering ===") err = agent.GatherCandidates() if err != nil { log.Fatalf("Failed to start gathering: %v", err) } // Set up signal handling for graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) // Create a context for periodic status checks ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Periodically check and display gathering state go func() { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: state, err := agent.GetGatheringState() //nolint:contextcheck if err != nil { log.Printf("Failed to get gathering state: %v", err) continue } localCandidates, err := agent.GetLocalCandidates() //nolint:contextcheck if err != nil { log.Printf("Failed to get local candidates: %v", err) continue } fmt.Printf("\n[%s] Status: GatheringState=%s, Candidates=%d\n", time.Now().Format("15:04:05"), state, len(localCandidates)) if policy == ice.GatherContinually { fmt.Println("Tip: Try changing network interfaces (connect/disconnect WiFi, enable/disable network adapters)") fmt.Println(" New candidates will be discovered automatically!") } } } }() // Wait for interrupt signal fmt.Println("\nPress Ctrl+C to exit...") <-sigChan fmt.Println("\n=== Shutting down ===") cancel() // Display final statistics state, _ := agent.GetGatheringState() localCandidates, _ := agent.GetLocalCandidates() fmt.Printf("\nFinal Statistics:\n") fmt.Printf(" Gathering Policy: %s\n", policy) fmt.Printf(" Gathering State: %s\n", state) fmt.Printf(" Total Candidates Discovered: %d\n", candidateCount) fmt.Printf(" Unique Candidates: %d\n", len(candidateMap)) fmt.Printf(" Current Active Candidates: %d\n", len(localCandidates)) } ice-4.1.0/examples/nat-rules/000077500000000000000000000000001511704470200160055ustar00rootroot00000000000000ice-4.1.0/examples/nat-rules/Dockerfile000066400000000000000000000011201511704470200177710ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2025 The Pion community # SPDX-License-Identifier: MIT FROM golang:1.25-bookworm AS builder WORKDIR /src # Copy the entire repo so local changes to pion/ice are included. COPY . . WORKDIR /src/examples/nat-rules RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -trimpath -o /tmp/nat-rules-demo . FROM debian:bookworm-slim RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates && \ rm -rf /var/lib/apt/lists/* COPY --from=builder /tmp/nat-rules-demo /usr/local/bin/nat-rules-demo ENTRYPOINT ["nat-rules-demo"] ice-4.1.0/examples/nat-rules/README.md000066400000000000000000000051631511704470200172710ustar00rootroot00000000000000# Address Rewrite Rules Demo This demo shows how the extended `AddressRewriteRule` feature in `pion/ice` rewrites addresses for a multi-homed host. It runs a small client that gathers host, srflx, UDP, and TCP candidates so you can see exactly what each rule produces. The included `docker-compose.yml` places one container on multiple host networks, similar to setups like the Glimesh broadcast box. The demo covers: 1. Multiple host networks with fixed public IPs for each interface. 2. Host and server-reflexive addresses from deterministic srflx pools. 3. Scoped rules that only rewrite specific CIDRs, with a global fallback for others. 4. TCP candidates generated by an automatic TCP mux. 5. Zero-length External handling: replace+empty drops a candidate, append+empty keeps it (useful for deny/allow layering). ## Scenarios | Key | Description | | ----------- | --------------------------------------------------------------------------- | | `multi-net` | Two host networks with distinct public IPs plus a global fallback. | | `srflx` | A srflx pool (two addresses) plus a host mapping for the service interface. | | `scoped` | A CIDR-scoped rule that overrides the global mapping only for matching IPs. | | `iface` | Interface-scoped host rewrite; only matching NICs are rewritten. | The client prints local interfaces and each gathered candidate. A `nil` candidate marks the end of the scenario. ## Docker Compose Topology The service `nat-demo` attaches to three bridge networks: | Network | Subnet | Purpose | | ------------- | ------------ | --------------------- | | `lan_blue` | 10.10.0.0/24 | First host interface | | `lan_green` | 10.20.0.0/24 | Second host interface | | `lan_service` | 10.30.0.0/24 | Shared service leg | Run everything with: ```sh docker compose up --build nat-demo ``` ## Configuration Variables All mappings come from environment variables so you can match your own network. The important ones are defined in `docker-compose.yml`: * `NAT_DEMO_BLUE_LOCAL` / `NAT_DEMO_BLUE_PUBLIC` * `NAT_DEMO_BLUE_IFACE` (default `eth0`) * `NAT_DEMO_GREEN_LOCAL` / `NAT_DEMO_GREEN_PUBLIC` * `NAT_DEMO_GREEN_IFACE` (default `eth1`) * `NAT_DEMO_GLOBAL_HOST_FALLBACK` * `NAT_DEMO_SERVICE_LOCAL` / `NAT_DEMO_SERVICE_HOST_PUBLIC` * `NAT_DEMO_SCOPED_PUBLIC` / `NAT_DEMO_SCOPED_CIDR` * `NAT_DEMO_SRFLX_PRIMARY` / `NAT_DEMO_SRFLX_SECONDARY` * `NAT_DEMO_DROP_LAN` (optional) — set to `1` to drop LAN host candidates via a replace+empty rule. Override any value with `docker compose run -e VAR=... nat-demo`. ice-4.1.0/examples/nat-rules/docker-compose.yml000066400000000000000000000023031511704470200214400ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2025 The Pion community # SPDX-License-Identifier: MIT services: nat-demo: build: context: ../.. dockerfile: examples/nat-rules/Dockerfile command: ["-scenario", "all"] tty: true environment: NAT_DEMO_BLUE_LOCAL: 10.10.0.20 NAT_DEMO_BLUE_PUBLIC: 203.0.113.10 NAT_DEMO_GREEN_LOCAL: 10.20.0.20 NAT_DEMO_GREEN_PUBLIC: 203.0.113.20 NAT_DEMO_GLOBAL_HOST_FALLBACK: 198.51.100.200 NAT_DEMO_SERVICE_LOCAL: 10.30.0.20 NAT_DEMO_SERVICE_HOST_PUBLIC: 203.0.113.30 NAT_DEMO_SCOPED_PUBLIC: 203.0.113.40 NAT_DEMO_SCOPED_CIDR: 10.30.0.0/24 NAT_DEMO_SRFLX_LOCAL: 0.0.0.0 NAT_DEMO_SRFLX_PRIMARY: 198.51.100.50 NAT_DEMO_SRFLX_SECONDARY: 198.51.100.60 networks: lan_blue: ipv4_address: 10.10.0.20 lan_green: ipv4_address: 10.20.0.20 lan_service: ipv4_address: 10.30.0.20 networks: lan_blue: driver: bridge ipam: config: - subnet: 10.10.0.0/24 lan_green: driver: bridge ipam: config: - subnet: 10.20.0.0/24 lan_service: driver: bridge ipam: config: - subnet: 10.30.0.0/24 ice-4.1.0/examples/nat-rules/main.go000066400000000000000000000322761511704470200172720ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT // Package main provides a simple CLI that prints the host and srflx addresses // produced by different address rewrite (1:1) rule configurations. It is designed to be run // inside the accompanying docker-compose topology so that each scenario can // demonstrate how multi-homed hosts, srflx pools, CIDR scoping, and TCP muxing // interact with the new rules. package main import ( "context" "flag" "fmt" "log" "net" "os" "sort" "strings" "sync" "time" "github.com/pion/ice/v4" "github.com/pion/logging" ) const ( defaultTimeout = 8 * time.Second ) type scenario struct { Key string Title string Description string RewriteRules []ice.AddressRewriteRule NetworkTypes []ice.NetworkType CandidateTypes []ice.CandidateType TimeoutOverride time.Duration } func (s scenario) timeout() time.Duration { if s.TimeoutOverride > 0 { return s.TimeoutOverride } return defaultTimeout } func (s scenario) requiresTCPMux() bool { for _, nt := range s.NetworkTypes { if nt.IsTCP() { return true } } return false } func main() { log.SetFlags(0) var ( scenarioKey string listOnly bool timeout time.Duration ) flag.StringVar(&scenarioKey, "scenario", "all", "Scenario key to run (use -list to see options)") flag.BoolVar(&listOnly, "list", false, "List available scenarios") flag.DurationVar(&timeout, "timeout", 0, "Override gather timeout for each scenario") flag.Parse() scenarios := buildScenarios(timeout) if listOnly { fmt.Println("Available scenarios:") for _, sc := range scenarios { fmt.Printf(" %s\t%s\n", sc.Key, sc.Title) } return } fmt.Println("Address rewrite rule demonstration client") printInterfaceSnapshot() ctx := context.Background() if scenarioKey == "all" { for _, sc := range scenarios { if err := runScenario(ctx, sc); err != nil { log.Fatalf("scenario %s failed: %v", sc.Key, err) } } return } sc, ok := findScenario(scenarios, scenarioKey) if !ok { log.Fatalf("unknown scenario %q. Use -list to see valid keys.", scenarioKey) } if err := runScenario(ctx, sc); err != nil { log.Fatalf("scenario %s failed: %v", sc.Key, err) } } func findScenario(scenarios []scenario, key string) (scenario, bool) { for _, sc := range scenarios { if sc.Key == key { return sc, true } } return scenario{}, false } func buildScenarios(timeout time.Duration) []scenario { scenarios := []scenario{ buildMultiNetworkScenario(), buildSrflxScenario(), buildScopedCatchAllScenario(), buildIfaceScopedScenario(), } if timeout > 0 { for i := range scenarios { scenarios[i].TimeoutOverride = timeout } } return scenarios } func buildMultiNetworkScenario() scenario { localBlue := envOrDefault("NAT_DEMO_BLUE_LOCAL", "10.10.0.20") publicBlue := envOrDefault("NAT_DEMO_BLUE_PUBLIC", "203.0.113.10") localGreen := envOrDefault("NAT_DEMO_GREEN_LOCAL", "10.20.0.20") publicGreen := envOrDefault("NAT_DEMO_GREEN_PUBLIC", "203.0.113.20") globalFallback := envOrDefault("NAT_DEMO_GLOBAL_HOST_FALLBACK", "198.51.100.200") dropLAN := envOrDefault("NAT_DEMO_DROP_LAN", "0") == "1" sc := scenario{ Key: "multi-net", Title: "Multiple host networks with UDP+TCP replacement", Description: "Maps each host interface to a deterministic public IP and enables TCP" + "muxing to prove the rules work for both UDP and ICE-TCP candidates.", NetworkTypes: []ice.NetworkType{ ice.NetworkTypeUDP4, ice.NetworkTypeTCP4, }, CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost}, RewriteRules: []ice.AddressRewriteRule{ { External: []string{publicBlue}, Local: localBlue, AsCandidateType: ice.CandidateTypeHost, Networks: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeTCP4}, }, { External: []string{publicGreen}, Local: localGreen, AsCandidateType: ice.CandidateTypeHost, Networks: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeTCP4}, }, { External: []string{globalFallback}, AsCandidateType: ice.CandidateTypeHost, }, }, } if dropLAN { sc.RewriteRules = append(sc.RewriteRules, ice.AddressRewriteRule{ External: nil, Local: localBlue, AsCandidateType: ice.CandidateTypeHost, Mode: ice.AddressRewriteReplace, }) } return sc } func buildIfaceScopedScenario() scenario { localBlue := envOrDefault("NAT_DEMO_BLUE_LOCAL", "10.10.0.20") publicBlue := envOrDefault("NAT_DEMO_BLUE_PUBLIC", "203.0.113.10") localGreen := envOrDefault("NAT_DEMO_GREEN_LOCAL", "10.20.0.20") publicGreen := envOrDefault("NAT_DEMO_GREEN_PUBLIC", "203.0.113.20") blueIface := envOrDefault("NAT_DEMO_BLUE_IFACE", "eth0") greenIface := envOrDefault("NAT_DEMO_GREEN_IFACE", "eth1") globalFallback := envOrDefault("NAT_DEMO_GLOBAL_HOST_FALLBACK", "198.51.100.200") return scenario{ Key: "iface", Title: "Interface-scoped host rewrite", Description: "Uses iface-scoped rules so only the intended NICs are rewritten." + "Others fall back to the global mapping.", NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4}, CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost}, RewriteRules: []ice.AddressRewriteRule{ { External: []string{publicBlue}, Local: localBlue, Iface: blueIface, AsCandidateType: ice.CandidateTypeHost, }, { External: []string{publicGreen}, Local: localGreen, Iface: greenIface, AsCandidateType: ice.CandidateTypeHost, }, { External: []string{globalFallback}, AsCandidateType: ice.CandidateTypeHost, }, }, } } func buildSrflxScenario() scenario { hostLocal := envOrDefault("NAT_DEMO_SERVICE_LOCAL", "10.30.0.20") hostPublic := envOrDefault("NAT_DEMO_SERVICE_HOST_PUBLIC", "203.0.113.30") srflxLocal := envOrDefault("NAT_DEMO_SRFLX_LOCAL", "0.0.0.0") srflxPrimary := envOrDefault("NAT_DEMO_SRFLX_PRIMARY", "198.51.100.50") srflxSecondary := envOrDefault("NAT_DEMO_SRFLX_SECONDARY", "198.51.100.60") return scenario{ Key: "srflx", Title: "Server-reflexive pool plus host override", Description: "Publishes a pair of pre-defined srflx addresses (one bound to 0.0.0.0 and one catch-all) " + "plus a host mapping for an \"edge\" interface, mirroring NAT64/CLAT style deployments", NetworkTypes: []ice.NetworkType{ ice.NetworkTypeUDP4, }, CandidateTypes: []ice.CandidateType{ ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, }, RewriteRules: []ice.AddressRewriteRule{ { External: []string{srflxPrimary}, Local: srflxLocal, AsCandidateType: ice.CandidateTypeServerReflexive, Networks: []ice.NetworkType{ice.NetworkTypeUDP4}, }, { External: []string{srflxSecondary}, AsCandidateType: ice.CandidateTypeServerReflexive, Networks: []ice.NetworkType{ice.NetworkTypeUDP4}, }, { External: []string{hostPublic}, Local: hostLocal, AsCandidateType: ice.CandidateTypeHost, Networks: []ice.NetworkType{ice.NetworkTypeUDP4}, }, }, } } func buildScopedCatchAllScenario() scenario { scopedLocal := envOrDefault("NAT_DEMO_SERVICE_LOCAL", "10.30.0.20") scopedPublic := envOrDefault("NAT_DEMO_SCOPED_PUBLIC", "203.0.113.40") scopedCIDR := envOrDefault("NAT_DEMO_SCOPED_CIDR", "10.30.0.0/24") catchAll := envOrDefault("NAT_DEMO_GLOBAL_HOST_FALLBACK", "198.51.100.200") return scenario{ Key: "scoped", Title: "CIDR scoped rule with catch-all fallback", Description: "Limits mapping to a specific CIDR while" + " keeping a catch-all public address for any traffic that lands on other interfaces, " + "Demonstrates the rule precedence and scope matching order.", NetworkTypes: []ice.NetworkType{ ice.NetworkTypeUDP4, }, CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost}, RewriteRules: []ice.AddressRewriteRule{ { External: []string{scopedPublic}, Local: scopedLocal, AsCandidateType: ice.CandidateTypeHost, CIDR: scopedCIDR, Networks: []ice.NetworkType{ice.NetworkTypeUDP4}, }, { External: []string{catchAll}, AsCandidateType: ice.CandidateTypeHost, CIDR: scopedCIDR, Networks: []ice.NetworkType{ice.NetworkTypeUDP4}, }, }, } } func runScenario(ctx context.Context, sc scenario) error { //nolint:cyclop fmt.Printf("\n=~~++= %s (%s) =~~++=\n%s\n", sc.Title, sc.Key, sc.Description) printRules(sc.RewriteRules) var opts []ice.AgentOption if len(sc.RewriteRules) > 0 { opts = append(opts, ice.WithAddressRewriteRules(sc.RewriteRules...)) } if len(sc.NetworkTypes) > 0 { opts = append(opts, ice.WithNetworkTypes(sc.NetworkTypes)) } if len(sc.CandidateTypes) > 0 { opts = append(opts, ice.WithCandidateTypes(sc.CandidateTypes)) } var tcpMux *ice.TCPMuxDefault if sc.requiresTCPMux() { listener, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return fmt.Errorf("scenario %s: tcp listen: %w", sc.Key, err) } tcpMux = ice.NewTCPMuxDefault(ice.TCPMuxParams{ Listener: listener, Logger: logging.NewDefaultLoggerFactory().NewLogger("nat-demo/tcp"), }) opts = append(opts, ice.WithTCPMux(tcpMux)) fmt.Printf("TCP mux listening on %s\n", listener.Addr()) defer func() { if err := tcpMux.Close(); err != nil { log.Printf("failed to close TCP mux: %v", err) } }() } agent, err := ice.NewAgentWithOptions(opts...) //nolint:contextcheck if err != nil { return fmt.Errorf("scenario %s: create agent: %w", sc.Key, err) } defer func() { if closeErr := agent.Close(); closeErr != nil { log.Printf("failed to close agent: %v", closeErr) } }() done := make(chan struct{}) var once sync.Once var mu sync.Mutex var gathered []string err = agent.OnCandidate(func(c ice.Candidate) { if c == nil { once.Do(func() { close(done) }) return } line := formatCandidate(c) mu.Lock() gathered = append(gathered, line) mu.Unlock() fmt.Println(" " + line) }) if err != nil { return fmt.Errorf("scenario %s: set candidate handler: %w", sc.Key, err) } if err := agent.GatherCandidates(); err != nil { //nolint:contextcheck return fmt.Errorf("scenario %s: gather: %w", sc.Key, err) } timeout := sc.timeout() select { case <-ctx.Done(): return fmt.Errorf("scenario %s: context canceled: %w", sc.Key, ctx.Err()) case <-done: case <-time.After(timeout): return fmt.Errorf("scenario %s: gather timed out after %s", sc.Key, timeout) //nolint:err113 } mu.Lock() total := len(gathered) mu.Unlock() if total == 0 { fmt.Println(" (no candidates gathered)") } fmt.Printf("Scenario %s complete: %d candidates reported.\n", sc.Key, total) return nil } func formatCandidate(c ice.Candidate) string { network := c.NetworkType().String() if c.NetworkType().IsTCP() && c.TCPType() != ice.TCPTypeUnspecified { network = fmt.Sprintf("%s/%s", network, c.TCPType()) } rel := "none" if relAddr := c.RelatedAddress(); relAddr != nil { rel = fmt.Sprintf("%s:%d", relAddr.Address, relAddr.Port) } return fmt.Sprintf("%s via %s -> %s:%d (rel=%s priority=%d)", c.Type(), network, c.Address(), c.Port(), rel, c.Priority()) } func printRules(rules []ice.AddressRewriteRule) { if len(rules) == 0 { fmt.Println("No NAT rules configured.") return } fmt.Println("Active address rewrite rules:") for idx, rule := range rules { candidateType := rule.AsCandidateType if candidateType == ice.CandidateTypeUnspecified { candidateType = ice.CandidateTypeHost } scope := describeRuleScope(rule) fmt.Printf(" %d) %s => [%s]%s\n", idx+1, candidateType, strings.Join(rule.External, ", "), scope) } } func describeRuleScope(rule ice.AddressRewriteRule) string { parts := make([]string, 0, 4) if rule.Iface != "" { parts = append(parts, fmt.Sprintf("iface=%s", rule.Iface)) } if rule.Local != "" { parts = append(parts, fmt.Sprintf("local=%s", rule.Local)) } if rule.CIDR != "" { parts = append(parts, fmt.Sprintf("cidr=%s", rule.CIDR)) } if len(rule.Networks) > 0 { parts = append(parts, fmt.Sprintf("networks=%s", formatNetworkList(rule.Networks))) } if len(parts) == 0 { return "" } return " | " + strings.Join(parts, " ") } func formatNetworkList(networks []ice.NetworkType) string { if len(networks) == 0 { return "all" } names := make([]string, len(networks)) for i, nt := range networks { names[i] = nt.String() } sort.Strings(names) return strings.Join(names, ",") } func envOrDefault(key, fallback string) string { if val := strings.TrimSpace(os.Getenv(key)); val != "" { return val } return fallback } func printInterfaceSnapshot() { fmt.Println("\nLocal interface snapshot:") ifaces, err := net.Interfaces() if err != nil { fmt.Printf(" unable to list interfaces: %v\n\n", err) return } for _, iface := range ifaces { if iface.Flags&net.FlagUp == 0 { continue } addrs, err := iface.Addrs() if err != nil || len(addrs) == 0 { continue } addrTexts := make([]string, 0, len(addrs)) for _, addr := range addrs { addrTexts = append(addrTexts, addr.String()) } fmt.Printf(" %s: %s\n", iface.Name, strings.Join(addrTexts, ", ")) } fmt.Println() } ice-4.1.0/examples/ping-pong/000077500000000000000000000000001511704470200157715ustar00rootroot00000000000000ice-4.1.0/examples/ping-pong/README.md000066400000000000000000000015111511704470200172460ustar00rootroot00000000000000# ping-pong This example demonstrates how to connect two peers via ICE. Once started they send the current time between each other. Currently this example exchanges candidates over a HTTP server running on localhost. In a real world setup `pion/ice` will typically exchange auth and candidates via a signaling server. ## Instruction ### Run controlling ```sh go run main.go -controlling ``` ### Run controlled ```sh go run main.go ``` ### Press enter in both to start the connection! You will see terminal output showing the messages being sent back and forth ``` Local Agent is controlled Press 'Enter' when both processes have started ICE Connection State has changed: Checking ICE Connection State has changed: Connected Sent: 'fCFXXlnGmXdYjOy' Received: 'EpqTQYLQMUCjBDX' Sent: 'yhgOtrufSfVmvrR' Received: 'xYSTPxBPZKfgnFr' ``` ice-4.1.0/examples/ping-pong/main.go000066400000000000000000000100231511704470200172400ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements a simple example demonstrating a Pion-to-Pion ICE connection package main import ( "bufio" "context" "flag" "fmt" "net/http" "net/url" "os" "time" "github.com/pion/ice/v4" "github.com/pion/randutil" ) //nolint:gochecknoglobals var ( isControlling bool iceAgent *ice.Agent remoteAuthChannel chan string localHTTPPort, remoteHTTPPort int ) // HTTP Listener to get ICE Credentials from remote Peer. func remoteAuth(_ http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { panic(err) } remoteAuthChannel <- r.PostForm["ufrag"][0] remoteAuthChannel <- r.PostForm["pwd"][0] } // HTTP Listener to get ICE Candidate from remote Peer. func remoteCandidate(_ http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { panic(err) } c, err := ice.UnmarshalCandidate(r.PostForm["candidate"][0]) if err != nil { panic(err) } if err := iceAgent.AddRemoteCandidate(c); err != nil { //nolint:contextcheck panic(err) } } func main() { //nolint var ( err error conn *ice.Conn ) remoteAuthChannel = make(chan string, 3) flag.BoolVar(&isControlling, "controlling", false, "is ICE Agent controlling") flag.Parse() if isControlling { localHTTPPort = 9000 remoteHTTPPort = 9001 } else { localHTTPPort = 9001 remoteHTTPPort = 9000 } http.HandleFunc("/remoteAuth", remoteAuth) http.HandleFunc("/remoteCandidate", remoteCandidate) go func() { if err = http.ListenAndServe(fmt.Sprintf(":%d", localHTTPPort), nil); err != nil { //nolint:gosec panic(err) } }() if isControlling { fmt.Println("Local Agent is controlling") } else { fmt.Println("Local Agent is controlled") } fmt.Print("Press 'Enter' when both processes have started") if _, err = bufio.NewReader(os.Stdin).ReadBytes('\n'); err != nil { panic(err) } iceAgent, err = ice.NewAgent(&ice.AgentConfig{ NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4}, }) if err != nil { panic(err) } // When we have gathered a new ICE Candidate send it to the remote peer if err = iceAgent.OnCandidate(func(c ice.Candidate) { if c == nil { return } _, err = http.PostForm(fmt.Sprintf("http://localhost:%d/remoteCandidate", remoteHTTPPort), //nolint url.Values{ "candidate": {c.Marshal()}, }) if err != nil { panic(err) } }); err != nil { panic(err) } // When ICE Connection state has change print to stdout if err = iceAgent.OnConnectionStateChange(func(c ice.ConnectionState) { fmt.Printf("ICE Connection State has changed: %s\n", c.String()) }); err != nil { panic(err) } // Get the local auth details and send to remote peer localUfrag, localPwd, err := iceAgent.GetLocalUserCredentials() if err != nil { panic(err) } _, err = http.PostForm(fmt.Sprintf("http://localhost:%d/remoteAuth", remoteHTTPPort), //nolint url.Values{ "ufrag": {localUfrag}, "pwd": {localPwd}, }) if err != nil { panic(err) } remoteUfrag := <-remoteAuthChannel remotePwd := <-remoteAuthChannel if err = iceAgent.GatherCandidates(); err != nil { panic(err) } // Start the ICE Agent. One side must be controlled, and the other must be controlling if isControlling { conn, err = iceAgent.Dial(context.TODO(), remoteUfrag, remotePwd) } else { conn, err = iceAgent.Accept(context.TODO(), remoteUfrag, remotePwd) } if err != nil { panic(err) } // Send messages in a loop to the remote peer go func() { for { time.Sleep(time.Second * 3) val, err := randutil.GenerateCryptoRandomString(15, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") if err != nil { panic(err) } if _, err = conn.Write([]byte(val)); err != nil { panic(err) } fmt.Printf("Sent: '%s'\n", val) } }() // Receive messages in a loop from the remote peer buf := make([]byte, 1500) for { n, err := conn.Read(buf) if err != nil { panic(err) } fmt.Printf("Received: '%s'\n", string(buf[:n])) } } ice-4.1.0/external_ip_mapper.go000066400000000000000000000255721511704470200164750ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" "net" "strings" ) // AddressRewriteMode controls whether a rule replaces or appends candidates. type AddressRewriteMode int const ( addressRewriteModeUnspecified AddressRewriteMode = iota AddressRewriteReplace AddressRewriteAppend ) // AddressRewriteRule represents a rule for remapping candidate addresses. type AddressRewriteRule struct { // External are the 1:1 external addresses to advertise for this rule. // For replace mode, an empty list is treated as "drop the matched local // address" (no candidate emitted). For append mode, an empty list is a // no-op: the original candidate is kept. // Empty External rules are intentional: // - Mode AddressRewriteReplace drops the matched candidate (deny-list style). // - Mode AddressRewriteAppend keeps the original candidate and adds nothing, // which is useful when you combine a catch-all replace with per-interface // allow rules. External []string // Local optionally pins this rule to a specific local address. When set, // external IPs map to that address regardless of IP family. When empty, // External acts as a catch-all for the family implied by the local scope // (CIDR when set, otherwise the external IP family). Local string // Iface is the optional interface name to limit the rule to, empty = any. Iface string // CIDR is the optional CIDR to limit the rule to, empty = any. CIDR string // AsCandidateType is the candidate type to publish as for this rule. Defaults to host // when unspecified. Supported values: host, server reflexive, relay. AsCandidateType CandidateType // Mode controls whether we replace the original candidate or append extra // candidates. // // If Mode is zero, the default is: // - CandidateTypeHost -> AddressRewriteReplace // - CandidateTypeServerReflexive, CandidateTypeRelay -> AddressRewriteAppend // For replace mode, a match with zero external IPs removes the candidate. // For append mode, a match with zero external IPs leaves the original // candidate untouched. Mode AddressRewriteMode // Networks is the optional networks to limit the rule to, nil/empty = all. Networks []NetworkType } func validateIPString(ipStr string) (net.IP, bool, error) { ip := net.ParseIP(ipStr) if ip == nil { return nil, false, ErrInvalidNAT1To1IPMapping } return ip, (ip.To4() != nil), nil } // ipMapping holds the mapping of local and external IP address // // for a particular IP family. type ipMapping struct { ipSole []net.IP // When non-empty, these are the catch-all external IPs for one local IP family ipMap map[string][]net.IP // Local-to-external IP mapping (k: local, v: external IPs) valid bool // If not set any external IP, valid is false catchAllSet bool } func newIPMapping() ipMapping { return ipMapping{ ipMap: make(map[string][]net.IP), } } func (m *ipMapping) addSoleIP(ip net.IP) { m.ipSole = append(m.ipSole, ip) m.valid = true m.catchAllSet = true } func addExternalMappings( external []string, ruleMapping *addressRewriteRuleMapping, hasLocalAddr bool, localAddr net.IP, localIsIPv4 bool, ) (bool, error) { added := false for _, raw := range external { extIPStr := strings.TrimSpace(raw) ipPair := strings.Split(extIPStr, "/") if len(ipPair) != 1 { return false, ErrInvalidNAT1To1IPMapping } extIP, isExtIPv4, err := validateIPString(ipPair[0]) if err != nil { return false, err } targetLocalIPv4 := isExtIPv4 if hasLocalAddr { targetLocalIPv4 = localIsIPv4 } else if ruleMapping.cidr != nil { targetLocalIPv4 = ruleMapping.cidr.IP.To4() != nil } if !ruleMapping.isFamilyAllowed(targetLocalIPv4) { continue } ruleMapping.addImplicitMapping(extIP, targetLocalIPv4, hasLocalAddr, localAddr) added = true } return added, nil } func maybeMarkEmptyMapping( ruleMapping *addressRewriteRuleMapping, added bool, hasLocalAddr bool, localIsIPv4 bool, localAddr net.IP, ) { if added { return } if hasLocalAddr { if ruleMapping.isFamilyAllowed(localIsIPv4) { family := ruleMapping.mappingForFamily(localIsIPv4) family.ipMap[localAddr.String()] = nil family.valid = true } return } if ruleMapping.allowIPv4 { ruleMapping.ipv4Mapping.valid = true ruleMapping.ipv4Mapping.catchAllSet = true } if ruleMapping.allowIPv6 { ruleMapping.ipv6Mapping.valid = true ruleMapping.ipv6Mapping.catchAllSet = true } } func (m *ipMapping) addIPMapping(locIP, extIP net.IP) { locIPStr := locIP.String() m.ipMap[locIPStr] = append(m.ipMap[locIPStr], extIP) m.valid = true } func cloneIPs(src []net.IP) []net.IP { if len(src) == 0 { return nil } cloned := make([]net.IP, 0, len(src)) for _, ip := range src { if ip == nil { continue } copied := make(net.IP, len(ip)) copy(copied, ip) cloned = append(cloned, copied) } return cloned } func (m *ipMapping) findExternalIPs(locIP net.IP) []net.IP { if !m.valid { return nil } if m.ipMap != nil { if extIPs, ok := m.ipMap[locIP.String()]; ok && len(extIPs) > 0 { return cloneIPs(extIPs) } } if len(m.ipSole) > 0 { return cloneIPs(m.ipSole) } return nil } type addressRewriteRuleMapping struct { rule AddressRewriteRule mode AddressRewriteMode ipv4Mapping ipMapping ipv6Mapping ipMapping cidr *net.IPNet allowIPv4 bool allowIPv6 bool } func (m *addressRewriteRuleMapping) hasMappings() bool { return m.ipv4Mapping.valid || m.ipv6Mapping.valid } func (m *addressRewriteRuleMapping) mappingForFamily(isIPv4 bool) *ipMapping { if isIPv4 { return &m.ipv4Mapping } return &m.ipv6Mapping } func (m *addressRewriteRuleMapping) isFamilyAllowed(isLocalIPv4 bool) bool { if isLocalIPv4 { return m.allowIPv4 } return m.allowIPv6 } func (m *addressRewriteRuleMapping) addImplicitMapping( extIP net.IP, isLocalIPv4 bool, hasLocalAddr bool, localAddr net.IP, ) { mapping := m.mappingForFamily(isLocalIPv4) if hasLocalAddr { mapping.addIPMapping(localAddr, extIP) } else { mapping.addSoleIP(extIP) } } type addressRewriteMapper struct { rulesByCandidateType map[CandidateType][]*addressRewriteRuleMapping } //nolint:gocognit,gocyclo,cyclop func newAddressRewriteMapper(rules []AddressRewriteRule) (*addressRewriteMapper, error) { if len(rules) == 0 { return nil, nil //nolint:nilnil } mapper := &addressRewriteMapper{ rulesByCandidateType: make(map[CandidateType][]*addressRewriteRuleMapping), } for _, rule := range rules { candidateType := rule.AsCandidateType if candidateType == CandidateTypeUnspecified { candidateType = CandidateTypeHost } if candidateType == CandidateTypePeerReflexive { return nil, ErrUnsupportedNAT1To1IPCandidateType } mode := rule.Mode if mode == addressRewriteModeUnspecified { mode = defaultAddressRewriteMode(candidateType) } ruleMapping := &addressRewriteRuleMapping{ rule: rule, mode: mode, ipv4Mapping: newIPMapping(), ipv6Mapping: newIPMapping(), allowIPv4: true, allowIPv6: true, } if len(rule.Networks) > 0 { ruleMapping.allowIPv4 = false ruleMapping.allowIPv6 = false for _, network := range rule.Networks { if network.IsIPv4() { ruleMapping.allowIPv4 = true } if network.IsIPv6() { ruleMapping.allowIPv6 = true } } if !ruleMapping.allowIPv4 && !ruleMapping.allowIPv6 { continue } } if rule.CIDR != "" { _, ipNet, err := net.ParseCIDR(rule.CIDR) if err != nil { return nil, ErrInvalidNAT1To1IPMapping } ruleMapping.cidr = ipNet } var ( localAddr net.IP localIsIPv4 bool hasLocalAddr bool err error ) if trimmedLocal := strings.TrimSpace(rule.Local); trimmedLocal != "" { localAddr, localIsIPv4, err = validateIPString(trimmedLocal) if err != nil { return nil, err } hasLocalAddr = true if ruleMapping.cidr != nil && !ruleMapping.cidr.Contains(localAddr) { return nil, fmt.Errorf("%w: Invalid local IP is outside CIDR", ErrInvalidNAT1To1IPMapping) } } added, mapErr := addExternalMappings(rule.External, ruleMapping, hasLocalAddr, localAddr, localIsIPv4) if mapErr != nil { return nil, mapErr } maybeMarkEmptyMapping(ruleMapping, added, hasLocalAddr, localIsIPv4, localAddr) if ruleMapping.hasMappings() { mapper.rulesByCandidateType[candidateType] = append(mapper.rulesByCandidateType[candidateType], ruleMapping) } } if len(mapper.rulesByCandidateType) == 0 { return nil, nil //nolint:nilnil } return mapper, nil } func (m *addressRewriteMapper) hasCandidateType(candidateType CandidateType) bool { rules := m.rulesByCandidateType[candidateType] for _, rule := range rules { if rule.hasMappings() { return true } } return false } func (m *addressRewriteMapper) shouldReplace(candidateType CandidateType) bool { for _, rule := range m.rulesByCandidateType[candidateType] { if rule.mode == AddressRewriteReplace { return true } } return false } func (m *addressRewriteMapper) findExternalIPs( candidateType CandidateType, localIPStr string, iface string, ) ([]net.IP, bool, AddressRewriteMode, error) { locIP, isLocIPv4, err := validateIPString(localIPStr) if err != nil { return nil, false, addressRewriteModeUnspecified, err } rules := m.rulesByCandidateType[candidateType] ips, matched, mode := evaluateRewriteRules(rules, locIP, isLocIPv4, iface) return ips, matched, mode, nil } func ruleMappingForLookup( rule *addressRewriteRuleMapping, locIP net.IP, isLocIPv4 bool, iface string, ) (*ipMapping, bool) { if rule.rule.Iface != "" && rule.rule.Iface != iface { return nil, false } if rule.cidr != nil && !rule.cidr.Contains(locIP) { return nil, false } ipMapping := rule.mappingForFamily(isLocIPv4) if !ipMapping.valid { return nil, false } return ipMapping, true } func catchAllSpecificity(rule *addressRewriteRuleMapping, iface string) int { spec := 0 if rule.rule.Iface != "" { spec += 2 if rule.cidr != nil { spec++ } } else if iface == "" && rule.cidr != nil { spec++ } return spec } func evaluateRewriteRules( rules []*addressRewriteRuleMapping, locIP net.IP, isLocIPv4 bool, iface string, ) (ips []net.IP, matched bool, mode AddressRewriteMode) { var ( catchAll []net.IP catchAllMode AddressRewriteMode hasCatchAll bool bestSpec = -1 ) for _, rule := range rules { ipMapping, ok := ruleMappingForLookup(rule, locIP, isLocIPv4, iface) if !ok { continue } if explicit, ok := ipMapping.ipMap[locIP.String()]; ok { cloned := cloneIPs(explicit) return cloned, true, rule.mode } if ipMapping.catchAllSet { spec := catchAllSpecificity(rule, iface) if !hasCatchAll || spec > bestSpec { catchAll = cloneIPs(ipMapping.ipSole) catchAllMode = rule.mode hasCatchAll = true bestSpec = spec } } } if hasCatchAll { return catchAll, true, catchAllMode } return nil, false, addressRewriteModeUnspecified } ice-4.1.0/external_ip_mapper_test.go000066400000000000000000000576531511704470200175410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "testing" "github.com/stretchr/testify/assert" ) func makeRule(candidateType CandidateType, ips ...string) AddressRewriteRule { return AddressRewriteRule{ External: ips, AsCandidateType: candidateType, } } func makeLocalRule(local string, ips ...string) AddressRewriteRule { return AddressRewriteRule{ External: ips, Local: local, AsCandidateType: CandidateTypeHost, } } func assertExternalIPStrings( t *testing.T, mapper *addressRewriteMapper, candidateType CandidateType, localIP string, iface string, expected ...string, ) { t.Helper() ips, matched, _, err := mapper.findExternalIPs(candidateType, localIP, iface) assert.NoError(t, err) assert.True(t, matched) assert.Len(t, ips, len(expected)) for i, ip := range ips { assert.Equal(t, expected[i], ip.String()) } } func assertNoExternalMapping(t *testing.T, mapper *addressRewriteMapper, candidateType CandidateType, localIP string) { t.Helper() ips, matched, _, err := mapper.findExternalIPs(candidateType, localIP, "") assert.NoError(t, err) assert.False(t, matched) assert.Nil(t, ips) } func TestValidateIPString(t *testing.T) { var ip net.IP var isIPv4 bool var err error ip, isIPv4, err = validateIPString("1.2.3.4") assert.NoError(t, err) assert.True(t, isIPv4) assert.Equal(t, "1.2.3.4", ip.String()) ip, isIPv4, err = validateIPString("2601:4567::5678") assert.NoError(t, err) assert.False(t, isIPv4) assert.Equal(t, "2601:4567::5678", ip.String()) _, _, err = validateIPString("bad.6.6.6") assert.Error(t, err) } //nolint:nlreturn // test fixtures intentionally inline for clarity. func TestNewExternalIPMapper(t *testing.T) { t.Run("nil rules", func(t *testing.T) { mapper, err := newAddressRewriteMapper(nil) assert.NoError(t, err) assert.Nil(t, mapper) }) t.Run("empty rules", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{}) assert.NoError(t, err) assert.Nil(t, mapper) }) t.Run("default candidate type", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{makeRule(CandidateTypeUnspecified, "1.2.3.4")}) assert.NoError(t, err) assert.NotNil(t, mapper) assert.True(t, mapper.hasCandidateType(CandidateTypeHost)) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.1", "eth0", "1.2.3.4") }) t.Run("server reflexive candidate type", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{makeRule(CandidateTypeServerReflexive, "1.2.3.4")}) assert.NoError(t, err) assert.NotNil(t, mapper) assert.True(t, mapper.hasCandidateType(CandidateTypeServerReflexive)) }) t.Run("unsupported candidate type", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{makeRule(CandidateTypePeerReflexive, "1.2.3.4")}) assert.ErrorIs(t, err, ErrUnsupportedNAT1To1IPCandidateType) assert.Nil(t, mapper) }) cases := []struct { name string rules []AddressRewriteRule expectMapper func(t *testing.T, mapper *addressRewriteMapper) }{ { name: "mixed external families", rules: []AddressRewriteRule{ makeRule(CandidateTypeHost, "1.2.3.4", "2001:db8::1"), }, expectMapper: func(t *testing.T, mapper *addressRewriteMapper) { t.Helper() assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.1", "eth1", "1.2.3.4") assertExternalIPStrings(t, mapper, CandidateTypeHost, "2001:db8::10", "eth0", "2001:db8::1") }, }, { name: "invalid external ip", rules: []AddressRewriteRule{ makeRule(CandidateTypeHost, "bad.2.3.4"), }, }, { name: "explicit mapping via slash rejected", rules: []AddressRewriteRule{ makeRule(CandidateTypeHost, "1.2.3.4/10.0.0.1"), }, }, { name: "invalid local ip", rules: []AddressRewriteRule{ makeLocalRule("10.0.0.bad", "1.2.3.4"), }, }, { name: "mixed family pair ipv6 ext ipv4 local", rules: []AddressRewriteRule{ makeLocalRule("10.0.0.1", "2200::1"), }, expectMapper: func(t *testing.T, mapper *addressRewriteMapper) { t.Helper() assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.1", "hosttest0", "2200::1") }, }, { name: "mixed family pair ipv4 ext ipv6 local", rules: []AddressRewriteRule{ makeLocalRule("fe80::1", "1.2.3.4"), }, expectMapper: func(t *testing.T, mapper *addressRewriteMapper) { t.Helper() assertExternalIPStrings(t, mapper, CandidateTypeHost, "fe80::1", "xdsl0", "1.2.3.4") }, }, { name: "implicit and explicit mix", rules: []AddressRewriteRule{ makeLocalRule("10.0.0.1", "1.2.3.5"), makeRule(CandidateTypeHost, "1.2.3.4"), }, expectMapper: func(t *testing.T, mapper *addressRewriteMapper) { t.Helper() assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.1", "", "1.2.3.5") assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.2", "", "1.2.3.4") }, }, { name: "invalid pair format", rules: []AddressRewriteRule{ makeRule(CandidateTypeHost, "1.2.3.4/192.168.0.2/10.0.0.1"), }, }, { name: "cidr family mismatch with external", rules: []AddressRewriteRule{ { External: []string{"2001:db8::1"}, AsCandidateType: CandidateTypeHost, CIDR: "10.0.0.0/24", }, }, expectMapper: func(t *testing.T, mapper *addressRewriteMapper) { t.Helper() assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.5", "", "2001:db8::1") assertNoExternalMapping(t, mapper, CandidateTypeHost, "192.168.0.1") assertNoExternalMapping(t, mapper, CandidateTypeHost, "2001:db8::5") }, }, { name: "invalid cidr explicit mapping", rules: []AddressRewriteRule{ { External: []string{"1.2.3.4"}, Local: "10.0.0.1", AsCandidateType: CandidateTypeHost, CIDR: "192.168.0.0/24", }, }, }, { name: "invalid cidr syntax", rules: []AddressRewriteRule{ { External: []string{"1.2.3.4"}, AsCandidateType: CandidateTypeHost, CIDR: "not-a-cidr", }, }, }, } for _, tc := range cases { tc := tc t.Run(tc.name, func(t *testing.T) { mapper, err := newAddressRewriteMapper(tc.rules) if tc.expectMapper != nil { assert.NoError(t, err) tc.expectMapper(t, mapper) } else { assert.ErrorIs(t, err, ErrInvalidNAT1To1IPMapping) assert.Nil(t, mapper) } }) } } func TestFindExternalIPHost(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ makeRule(CandidateTypeHost, "1.2.3.4"), makeRule(CandidateTypeHost, "2200::1"), }) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.1", "", "1.2.3.4") assertExternalIPStrings(t, mapper, CandidateTypeHost, "fe80::1", "", "2200::1") } func TestAddressRewriteIfaceScope(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.10"}, Local: "10.0.0.10", AsCandidateType: CandidateTypeHost, Iface: "eth0", }, }) assert.NoError(t, err) assert.NotNil(t, mapper) ips, matched, _, findErr := mapper.findExternalIPs(CandidateTypeHost, "10.0.0.10", "eth0") assert.NoError(t, findErr) assert.True(t, matched) assert.NotEmpty(t, ips) assert.Equal(t, "203.0.113.10", ips[0].String()) ips, matched, _, findErr = mapper.findExternalIPs(CandidateTypeHost, "10.0.0.10", "wlan0") assert.NoError(t, findErr) assert.False(t, matched) assert.Nil(t, ips) } func TestAddressRewriteRuleOrdering(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.200"}, AsCandidateType: CandidateTypeHost, // catch-all }, { External: []string{"198.51.100.5"}, AsCandidateType: CandidateTypeHost, CIDR: "10.0.0.0/24", }, { External: []string{"198.51.100.6"}, AsCandidateType: CandidateTypeHost, Iface: "eth0", }, }) assert.NoError(t, err) ips, matched, _, findErr := mapper.findExternalIPs(CandidateTypeHost, "10.0.0.5", "") assert.NoError(t, findErr) assert.True(t, matched) assert.Equal(t, "198.51.100.5", ips[0].String()) ips, matched, _, findErr = mapper.findExternalIPs(CandidateTypeHost, "10.0.0.6", "eth0") assert.NoError(t, findErr) assert.True(t, matched) assert.Equal(t, "198.51.100.6", ips[0].String()) ips, matched, _, findErr = mapper.findExternalIPs(CandidateTypeHost, "10.0.0.6", "wlan0") assert.NoError(t, findErr) assert.True(t, matched) assert.Equal(t, "203.0.113.200", ips[0].String()) } func TestAddressRewriteModeDefaultsAndExplicit(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ makeRule(CandidateTypeHost, "1.2.3.4"), makeRule(CandidateTypeServerReflexive, "5.6.7.8"), makeRule(CandidateTypeRelay, "203.0.113.44"), { External: []string{"9.9.9.9"}, Local: "10.0.0.5", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, }) assert.NoError(t, err) assert.NotNil(t, mapper) hostRules := mapper.rulesByCandidateType[CandidateTypeHost] assert.Equal(t, AddressRewriteReplace, hostRules[0].mode) assert.Equal(t, AddressRewriteAppend, hostRules[1].mode) srflxRules := mapper.rulesByCandidateType[CandidateTypeServerReflexive] assert.Equal(t, AddressRewriteAppend, srflxRules[0].mode) relayRules := mapper.rulesByCandidateType[CandidateTypeRelay] assert.Equal(t, AddressRewriteAppend, relayRules[0].mode) ips, matched, mode, modeErr := mapper.findExternalIPs(CandidateTypeHost, "10.0.0.5", "") assert.NoError(t, modeErr) assert.True(t, matched) assert.Equal(t, AddressRewriteAppend, mode) assert.Len(t, ips, 1) } func TestAddressRewriteModeDefaultsTable(t *testing.T) { tests := []struct { name string rule AddressRewriteRule localIP string expectMode AddressRewriteMode expectAddress string }{ { name: "host default replace", rule: makeRule(CandidateTypeHost, "203.0.113.10"), localIP: "10.0.0.1", expectMode: AddressRewriteReplace, expectAddress: "203.0.113.10", }, { name: "srflx default append", rule: makeRule(CandidateTypeServerReflexive, "203.0.113.20"), localIP: "0.0.0.0", expectMode: AddressRewriteAppend, expectAddress: "203.0.113.20", }, { name: "relay default append", rule: makeRule(CandidateTypeRelay, "203.0.113.30"), localIP: "192.0.2.1", expectMode: AddressRewriteAppend, expectAddress: "203.0.113.30", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{tt.rule}) assert.NoError(t, err) ips, matched, mode, findErr := mapper.findExternalIPs(tt.rule.AsCandidateType, tt.localIP, "") assert.NoError(t, findErr) assert.True(t, matched) assert.NotEmpty(t, ips) assert.Equal(t, tt.expectMode, mode) assert.Equal(t, tt.expectAddress, ips[0].String()) }) } } func TestCloneIPsSkipsNil(t *testing.T) { ipv4 := net.ParseIP("203.0.113.10") ipv6 := net.ParseIP("2001:db8::1") cloned := cloneIPs([]net.IP{ipv4, nil, ipv6}) assert.Len(t, cloned, 2) assert.Equal(t, "203.0.113.10", cloned[0].String()) assert.Equal(t, "2001:db8::1", cloned[1].String()) } func TestIPMappingFindExternalIPs(t *testing.T) { t.Run("invalid mapping returns nil", func(t *testing.T) { m := ipMapping{} ips := m.findExternalIPs(net.ParseIP("10.0.0.1")) assert.Nil(t, ips) }) t.Run("catch-all returns sole", func(t *testing.T) { m := ipMapping{ ipSole: []net.IP{net.ParseIP("203.0.113.10")}, valid: true, } ips := m.findExternalIPs(net.ParseIP("10.0.0.1")) assert.Len(t, ips, 1) assert.Equal(t, "203.0.113.10", ips[0].String()) }) t.Run("no mapping found returns error", func(t *testing.T) { m := ipMapping{ valid: true, } ips := m.findExternalIPs(net.ParseIP("10.0.0.1")) assert.Nil(t, ips) }) } func TestAddressRewriteModeHostReplaceAndAppend(t *testing.T) { t.Run("replace host mapping removes original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.1"}, Local: "10.0.0.1", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, }) assert.NoError(t, err) ips, matched, mode, findErr := mapper.findExternalIPs(CandidateTypeHost, "10.0.0.1", "") assert.NoError(t, findErr) assert.True(t, matched) assert.Equal(t, AddressRewriteReplace, mode) assert.Equal(t, []string{"203.0.113.1"}, []string{ips[0].String()}) }) t.Run("append host mapping keeps original and adds new", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.1"}, Local: "10.0.0.1", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, }) assert.NoError(t, err) ips, matched, mode, findErr := mapper.findExternalIPs(CandidateTypeHost, "10.0.0.1", "") assert.NoError(t, findErr) assert.True(t, matched) assert.Equal(t, AddressRewriteAppend, mode) assert.Equal(t, []string{"203.0.113.1"}, []string{ips[0].String()}) }) t.Run("replace host mapping allows cross family", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.20"}, Local: "2001:db8::5", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, }) assert.NoError(t, err) ips, matched, mode, findErr := mapper.findExternalIPs(CandidateTypeHost, "2001:db8::5", "") assert.NoError(t, findErr) assert.True(t, matched) assert.Equal(t, AddressRewriteReplace, mode) assert.Equal(t, []string{"203.0.113.20"}, []string{ips[0].String()}) }) } func TestAddressRewriteModeSrflxReplaceAndAppend(t *testing.T) { t.Run("srflx default append", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"198.51.100.10"}, Local: "0.0.0.0", AsCandidateType: CandidateTypeServerReflexive, }, }) assert.NoError(t, err) ips, matched, mode, findErr := mapper.findExternalIPs(CandidateTypeServerReflexive, "0.0.0.0", "") assert.NoError(t, findErr) assert.True(t, matched) assert.Equal(t, AddressRewriteAppend, mode) assert.Equal(t, []string{"198.51.100.10"}, []string{ips[0].String()}) }) t.Run("srflx explicit replace", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"198.51.100.20"}, Local: "0.0.0.0", AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteReplace, }, }) assert.NoError(t, err) ips, matched, mode, findErr := mapper.findExternalIPs(CandidateTypeServerReflexive, "0.0.0.0", "") assert.NoError(t, findErr) assert.True(t, matched) assert.Equal(t, AddressRewriteReplace, mode) assert.Equal(t, []string{"198.51.100.20"}, []string{ips[0].String()}) }) } func TestFindExternalIPMultipleCatchAll(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{makeRule(CandidateTypeHost, "1.2.3.4", "5.6.7.8")}) assert.NoError(t, err) assert.NotNil(t, mapper) ips, matched, _, err := mapper.findExternalIPs(CandidateTypeHost, "10.0.0.1", "") assert.NoError(t, err) assert.True(t, matched) assert.Len(t, ips, 2) assert.Equal(t, "1.2.3.4", ips[0].String()) assert.Equal(t, "5.6.7.8", ips[1].String()) } func TestFindExternalIPCIDRFilter(t *testing.T) { rule := makeRule(CandidateTypeHost, "1.2.3.4") rule.CIDR = "10.0.0.0/24" mapper, err := newAddressRewriteMapper([]AddressRewriteRule{rule}) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.10", "", "1.2.3.4") assertNoExternalMapping(t, mapper, CandidateTypeHost, "192.168.0.1") } func TestFindExternalIPExplicitMapping(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ makeLocalRule("10.0.0.1", "1.2.3.4"), makeLocalRule("10.0.0.2", "1.2.3.5"), makeLocalRule("fe80::1", "2200::1"), makeLocalRule("fe80::2", "2200::2"), }) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.1", "", "1.2.3.4") assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.2", "", "1.2.3.5") assertNoExternalMapping(t, mapper, CandidateTypeHost, "10.0.0.3") assertExternalIPStrings(t, mapper, CandidateTypeHost, "fe80::1", "", "2200::1") assertExternalIPStrings(t, mapper, CandidateTypeHost, "fe80::2", "", "2200::2") assertNoExternalMapping(t, mapper, CandidateTypeHost, "fe80::3") } func TestFindExternalIPServerReflexive(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{makeRule(CandidateTypeServerReflexive, "1.2.3.4")}) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeServerReflexive, "0.0.0.0", "", "1.2.3.4") } func TestFindExternalIPFallbackAndErrors(t *testing.T) { t.Run("fallback to local address when candidate type missing", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{makeRule(CandidateTypeServerReflexive, "1.2.3.4")}) assert.NoError(t, err) assert.NotNil(t, mapper) assertNoExternalMapping(t, mapper, CandidateTypeHost, "10.0.0.1") }) t.Run("invalid local ip", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{makeRule(CandidateTypeHost, "1.2.3.4")}) assert.NoError(t, err) assert.NotNil(t, mapper) _, _, _, err = mapper.findExternalIPs(CandidateTypeHost, "really.bad", "") assert.Error(t, err) }) t.Run("append with zero externals returns error but keeps original upstream", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "10.0.0.11", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, }) assert.NoError(t, err) assert.NotNil(t, mapper) _, matched, mode, err := mapper.findExternalIPs(CandidateTypeHost, "10.0.0.11", "") assert.True(t, matched) assert.Equal(t, AddressRewriteAppend, mode) assert.NoError(t, err) }) } func TestExternalIPMapperNetworksFilter(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.2"}, Local: "10.0.0.2", AsCandidateType: CandidateTypeHost, Networks: []NetworkType{NetworkTypeUDP4}, }, }) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.2", "", "203.0.113.2") assertNoExternalMapping(t, mapper, CandidateTypeHost, "2001:db8:1::1") mapper, err = newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"2001:db8::6"}, Local: "2001:db8:2::6", AsCandidateType: CandidateTypeServerReflexive, Networks: []NetworkType{NetworkTypeUDP6}, }, }) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeServerReflexive, "2001:db8:2::6", "", "2001:db8::6") assertNoExternalMapping(t, mapper, CandidateTypeServerReflexive, "192.0.2.10") t.Run("nil and empty networks are equivalent", func(t *testing.T) { nilNetworks, nilErr := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.5"}, Local: "10.0.0.5", AsCandidateType: CandidateTypeHost, }, }) assert.NoError(t, nilErr) assert.NotNil(t, nilNetworks) assertExternalIPStrings(t, nilNetworks, CandidateTypeHost, "10.0.0.5", "", "203.0.113.5") emptyNetworks, emptyErr := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.5"}, Local: "10.0.0.5", AsCandidateType: CandidateTypeHost, Networks: []NetworkType{}, }, }) assert.NoError(t, emptyErr) assert.NotNil(t, emptyNetworks) assertExternalIPStrings(t, emptyNetworks, CandidateTypeHost, "10.0.0.5", "", "203.0.113.5") }) t.Run("nil/empty allow ipv6 while udp4 excludes it", func(t *testing.T) { nilMapper, nilErr := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"2001:db8::50"}, Local: "2001:db8::50", AsCandidateType: CandidateTypeHost, }, }) assert.NoError(t, nilErr) assert.NotNil(t, nilMapper) assertExternalIPStrings(t, nilMapper, CandidateTypeHost, "2001:db8::50", "", "2001:db8::50") emptyMapper, emptyErr := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"2001:db8::50"}, Local: "2001:db8::50", AsCandidateType: CandidateTypeHost, Networks: []NetworkType{}, }, }) assert.NoError(t, emptyErr) assert.NotNil(t, emptyMapper) assertExternalIPStrings(t, emptyMapper, CandidateTypeHost, "2001:db8::50", "", "2001:db8::50") udp4Mapper, udp4Err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"2001:db8::50"}, Local: "2001:db8::50", AsCandidateType: CandidateTypeHost, Networks: []NetworkType{NetworkTypeUDP4}, }, }) assert.NoError(t, udp4Err) assert.Nil(t, udp4Mapper) }) t.Run("mixed family rule respects each address family", func(t *testing.T) { mixedMapper, mixedErr := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.99"}, AsCandidateType: CandidateTypeHost, }, { External: []string{"2001:db8::99"}, Local: "2001:db8:1::99", AsCandidateType: CandidateTypeHost, }, }) assert.NoError(t, mixedErr) assert.NotNil(t, mixedMapper) assertExternalIPStrings(t, mixedMapper, CandidateTypeHost, "10.10.10.10", "", "203.0.113.99") assertExternalIPStrings(t, mixedMapper, CandidateTypeHost, "2001:db8:1::99", "", "2001:db8::99") }) } func TestAddressRewritePrecedenceMatrix(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.200"}, AsCandidateType: CandidateTypeHost, CIDR: "10.0.0.0/24", }, { External: []string{"198.51.100.200"}, AsCandidateType: CandidateTypeHost, }, { External: []string{"192.0.2.50"}, Local: "10.0.0.50", AsCandidateType: CandidateTypeHost, }, }) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.50", "", "192.0.2.50") assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.25", "", "203.0.113.200") assertExternalIPStrings(t, mapper, CandidateTypeHost, "172.16.0.1", "", "198.51.100.200") } func TestExternalIPMapperRuleOrderAndSpecificity(t *testing.T) { t.Run("earliest matching rule wins", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.10"}, Local: "10.0.0.5", AsCandidateType: CandidateTypeHost, }, { External: []string{"198.51.100.10"}, Local: "10.0.0.5", AsCandidateType: CandidateTypeHost, }, }) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.5", "", "203.0.113.10") }) t.Run("specific mapping outranks cidr and catch-all", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.30"}, Local: "10.0.0.5", AsCandidateType: CandidateTypeHost, CIDR: "10.0.0.0/24", }, { External: []string{"203.0.113.40"}, AsCandidateType: CandidateTypeHost, CIDR: "10.0.0.0/24", }, }) assert.NoError(t, err) assert.NotNil(t, mapper) assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.5", "", "203.0.113.30") assertExternalIPStrings(t, mapper, CandidateTypeHost, "10.0.0.10", "", "203.0.113.40") assertNoExternalMapping(t, mapper, CandidateTypeHost, "192.0.2.20") }) } ice-4.1.0/gather.go000066400000000000000000001052711511704470200140640ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "context" "crypto/tls" "fmt" "io" "net" "net/netip" "reflect" "sync" "time" "github.com/pion/dtls/v3" "github.com/pion/ice/v4/internal/fakenet" stunx "github.com/pion/ice/v4/internal/stun" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/stdnet" "github.com/pion/turn/v4" ) type turnClient interface { Listen() error Allocate() (net.PacketConn, error) Close() } func defaultTurnClient(cfg *turn.ClientConfig) (turnClient, error) { return turn.NewClient(cfg) } // Close a net.Conn and log if we have a failure. func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ...any) { if c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) { log.Warnf("Connection is not allocated: "+msg, args...) return } log.Warnf(msg, args...) if err := c.Close(); err != nil { log.Warnf("Failed to close connection: %v", err) } } // GatherCandidates initiates the trickle based gathering process. func (a *Agent) GatherCandidates() error { var gatherErr error if runErr := a.loop.Run(a.loop, func(ctx context.Context) { if a.gatheringState != GatheringStateNew { gatherErr = ErrMultipleGatherAttempted return } else if a.onCandidateHdlr.Load() == nil { gatherErr = ErrNoOnCandidateHandler return } a.gatherCandidateCancel() // Cancel previous gathering routine ctx, cancel := context.WithCancel(ctx) a.gatherCandidateCancel = cancel done := make(chan struct{}) a.gatherCandidateDone = done go a.gatherCandidates(ctx, done) }); runErr != nil { return runErr } return gatherErr } func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) { //nolint:cyclop defer close(done) if err := a.setGatheringState(GatheringStateGathering); err != nil { //nolint:contextcheck a.log.Warnf("Failed to set gatheringState to GatheringStateGathering: %v", err) return } a.gatherCandidatesInternal(ctx) switch a.continualGatheringPolicy { case GatherOnce: if err := a.setGatheringState(GatheringStateComplete); err != nil { //nolint:contextcheck a.log.Warnf("Failed to set gatheringState to GatheringStateComplete: %v", err) } case GatherContinually: // Initialize known interfaces before starting monitoring _, addrs, err := localInterfaces( a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback, ) if err != nil { a.log.Warnf("Failed to get initial interfaces for monitoring: %v", err) } else { for _, info := range addrs { a.lastKnownInterfaces[info.addr.String()] = info.addr } a.log.Infof("Initialized network monitoring with %d IP addresses", len(addrs)) } go a.startNetworkMonitoring(ctx) } } func (a *Agent) shouldRewriteCandidateType(candidateType CandidateType) bool { return a.addressRewriteMapper != nil && a.addressRewriteMapper.hasCandidateType(candidateType) } func (a *Agent) shouldRewriteHostCandidates() bool { return a.mDNSMode != MulticastDNSModeQueryAndGather && a.shouldRewriteCandidateType(CandidateTypeHost) } func (a *Agent) applyHostAddressRewrite(addr netip.Addr, mappedAddrs []netip.Addr, iface string) ([]netip.Addr, bool) { mappedIPs, matched, mode, innerErr := a.addressRewriteMapper.findExternalIPs( CandidateTypeHost, addr.String(), iface, ) if innerErr != nil { a.log.Warnf("Address rewrite mapping is enabled but no external IP is found for %s", addr.String()) return mappedAddrs, true } if !matched { return mappedAddrs, true } if mode == AddressRewriteReplace { mappedAddrs = mappedAddrs[:0] } mappedAddrs = appendHostMappedAddrs(mappedAddrs, mappedIPs, addr, a.log) if len(mappedAddrs) == 0 && mode == AddressRewriteReplace { a.log.Warnf("Address rewrite mapping is enabled but produced no usable external IP for %s", addr.String()) return mappedAddrs, false } return mappedAddrs, true } func appendHostMappedAddrs( mappedAddrs []netip.Addr, mappedIPs []net.IP, addr netip.Addr, log logging.LeveledLogger, ) []netip.Addr { for _, mappedIP := range mappedIPs { conv, ok := netip.AddrFromSlice(mappedIP) if !ok { log.Warnf("failed to convert mapped external IP to netip.Addr'%s'", addr.String()) continue } // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable mappedAddrs = append(mappedAddrs, conv.Unmap()) } return mappedAddrs } func (a *Agent) applyHostRewriteForUDPMux(candidateIPs []net.IP, udpAddr *net.UDPAddr) ([]net.IP, bool) { mappedIPs, matched, mode, err := a.addressRewriteMapper.findExternalIPs(CandidateTypeHost, udpAddr.IP.String(), "") if err != nil { a.log.Warnf("Address rewrite mapping is enabled but failed for %s: %v", udpAddr.IP.String(), err) return candidateIPs, false } if !matched { return candidateIPs, true } if len(mappedIPs) == 0 { if mode == AddressRewriteReplace { return candidateIPs, false } return candidateIPs, true } if mode == AddressRewriteReplace { candidateIPs = candidateIPs[:0] } return append(candidateIPs, mappedIPs...), true } // gatherCandidatesInternal performs the actual candidate gathering for all configured types. func (a *Agent) gatherCandidatesInternal(ctx context.Context) { var wg sync.WaitGroup for _, t := range a.candidateTypes { switch t { case CandidateTypeHost: wg.Add(1) go func() { a.gatherCandidatesLocal(ctx, a.networkTypes) wg.Done() }() case CandidateTypeServerReflexive: a.gatherServerReflexiveCandidates(ctx, &wg) case CandidateTypeRelay: wg.Add(1) go func() { a.gatherCandidatesRelay(ctx, a.urls) wg.Done() }() case CandidateTypePeerReflexive, CandidateTypeUnspecified: } } // Block until all STUN and TURN URLs have been gathered (or timed out) wg.Wait() } func (a *Agent) gatherServerReflexiveCandidates(ctx context.Context, wg *sync.WaitGroup) { replaceSrflx := a.addressRewriteMapper != nil && a.addressRewriteMapper.shouldReplace(CandidateTypeServerReflexive) if !replaceSrflx { wg.Add(1) go func() { if a.udpMuxSrflx != nil { a.gatherCandidatesSrflxUDPMux(ctx, a.urls, a.networkTypes) } else { a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) } wg.Done() }() } if a.addressRewriteMapper != nil && a.addressRewriteMapper.hasCandidateType(CandidateTypeServerReflexive) { wg.Add(1) go func() { a.gatherCandidatesSrflxMapped(ctx, a.networkTypes) wg.Done() }() } } //nolint:gocognit,gocyclo,cyclop,maintidx func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) { networks := map[string]struct{}{} for _, networkType := range networkTypes { if networkType.IsTCP() { networks[tcp] = struct{}{} } else { networks[udp] = struct{}{} } } // When UDPMux is enabled, skip other UDP candidates if a.udpMux != nil { if err := a.gatherCandidatesLocalUDPMux(ctx); err != nil { a.log.Warnf("Failed to create host candidate for UDPMux: %s", err) } delete(networks, udp) } _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) if err != nil { a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) return } for _, info := range localAddrs { addr := info.addr ifaceName := info.iface mappedAddrs := []netip.Addr{addr} if a.shouldRewriteHostCandidates() { var ok bool mappedAddrs, ok = a.applyHostAddressRewrite(addr, mappedAddrs, ifaceName) if !ok { continue } } for mappedIdx, mappedIP := range mappedAddrs { address := mappedIP.String() var isLocationTracked bool if a.mDNSMode == MulticastDNSModeQueryAndGather { address = a.mDNSName } else { // Here, we are not doing multicast gathering, so we will need to skip this address so // that we don't accidentally reveal location tracking information. Otherwise, the // case above hides the IP behind an mDNS address. isLocationTracked = shouldFilterLocationTrackedIP(mappedIP) } for network := range networks { // TCPMux maintains a single listener per interface. Avoid duplicating passive TCP candidates // for additional mapped IPs until connection sharing is supported. if network == tcp && mappedIdx > 0 { continue } type connAndPort struct { conn net.PacketConn port int } var ( conns []connAndPort tcpType TCPType ) switch network { case tcp: if a.tcpMux == nil { continue } // Only advertise TCP candidates for addresses that the mux listener is actually // bound to. When the listener is bound to a specific IP, exposing other interface // addresses would generate unreachable passive candidates and can stall active // TCP connect attempts. if addrProvider, ok := a.tcpMux.(interface{ LocalAddr() net.Addr }); ok { if muxAddr, ok := addrProvider.LocalAddr().(*net.TCPAddr); ok { if ip := muxAddr.IP; ip != nil && !ip.IsUnspecified() && !ip.Equal(addr.AsSlice()) { continue } } } // Handle ICE TCP passive mode var muxConns []net.PacketConn if multi, ok := a.tcpMux.(AllConnsGetter); ok { a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag) // Note: this is missing zone for IPv6 by just grabbing the IP slice muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.Is6(), addr.AsSlice()) if err != nil { a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) continue } } else { a.log.Debugf("GetConn by ufrag: %s", a.localUfrag) // Note: this is missing zone for IPv6 by just grabbing the IP slice conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.Is6(), addr.AsSlice()) if err != nil { a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) continue } muxConns = []net.PacketConn{conn} } // Extract the port for each PacketConn we got. for _, conn := range muxConns { if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok { conns = append(conns, connAndPort{conn, tcpConn.Port}) } else { a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, addr, a.localUfrag) } } if len(conns) == 0 { // Didn't succeed with any, try the next network. continue } tcpType = TCPTypePassive // Is there a way to verify that the listen address is even // accessible from the current interface. case udp: conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{ IP: addr.AsSlice(), Port: 0, Zone: addr.Zone(), }) if err != nil { a.log.Warnf("Failed to listen %s %s", network, addr) continue } if udpConn, ok := conn.LocalAddr().(*net.UDPAddr); ok { conns = append(conns, connAndPort{conn, udpConn.Port}) } else { a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, addr, a.localUfrag) continue } } for _, connAndPort := range conns { hostConfig := CandidateHostConfig{ Network: network, Address: address, Port: connAndPort.port, Component: ComponentRTP, TCPType: tcpType, // we will still process this candidate so that we start up the right // listeners. IsLocationTracked: isLocationTracked, } candidateHost, err := NewCandidateHost(&hostConfig) if err == nil && a.mDNSMode == MulticastDNSModeQueryAndGather { err = candidateHost.setIPAddr(addr) } if err != nil { closeConnAndLog( connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err, ) continue } if err := a.addCandidate(ctx, candidateHost, connAndPort.conn); err != nil { if closeErr := candidateHost.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) } } } } } } // shouldFilterLocationTrackedIP returns if this candidate IP should be filtered out from // any candidate publishing/notification for location tracking reasons. func shouldFilterLocationTrackedIP(candidateIP netip.Addr) bool { // https://tools.ietf.org/html/rfc8445#section-5.1.1.1 // Similarly, when host candidates corresponding to // an IPv6 address generated using a mechanism that prevents location // tracking are gathered, then host candidates corresponding to IPv6 // link-local addresses [RFC4291] MUST NOT be gathered. return candidateIP.Is6() && (candidateIP.IsLinkLocalUnicast() || candidateIP.IsLinkLocalMulticast()) } // shouldFilterLocationTracked returns if this candidate IP should be filtered out from // any candidate publishing/notification for location tracking reasons. func shouldFilterLocationTracked(candidateIP net.IP) bool { addr, ok := netip.AddrFromSlice(candidateIP) if !ok { return false } return shouldFilterLocationTrackedIP(addr) } func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit,cyclop if a.udpMux == nil { return errUDPMuxDisabled } localAddresses := a.udpMux.GetListenAddresses() existingConfigs := make(map[CandidateHostConfig]struct{}) for _, addr := range localAddresses { udpAddr, ok := addr.(*net.UDPAddr) if !ok { return errInvalidAddress } candidateIPs := []net.IP{udpAddr.IP} if _, ok := a.udpMux.(*UDPMuxDefault); ok && !a.includeLoopback && udpAddr.IP.IsLoopback() { // Unlike MultiUDPMux Default, UDPMuxDefault doesn't have // a separate param to include loopback, so we respect agent config continue } if a.shouldRewriteHostCandidates() { var ok bool candidateIPs, ok = a.applyHostRewriteForUDPMux(candidateIPs, udpAddr) if !ok { continue } } for _, candidateIP := range candidateIPs { var address string var isLocationTracked bool if a.mDNSMode == MulticastDNSModeQueryAndGather { address = a.mDNSName } else { address = candidateIP.String() // Here, we are not doing multicast gathering, so we will need to skip this address so // that we don't accidentally reveal location tracking information. Otherwise, the // case above hides the IP behind an mDNS address. isLocationTracked = shouldFilterLocationTracked(candidateIP) } hostConfig := CandidateHostConfig{ Network: udp, Address: address, Port: udpAddr.Port, Component: ComponentRTP, IsLocationTracked: isLocationTracked, } // Detect a duplicate candidate before calling addCandidate(). // otherwise, addCandidate() detects the duplicate candidate // and close its connection, invalidating all candidates // that share the same connection. if _, ok := existingConfigs[hostConfig]; ok { continue } conn, err := a.udpMux.GetConn(a.localUfrag, udpAddr) if err != nil { return err } c, err := NewCandidateHost(&hostConfig) if err != nil { closeConnAndLog(conn, a.log, "failed to create host mux candidate: %s %d: %v", candidateIP, udpAddr.Port, err) continue } if err := a.addCandidate(ctx, c, conn); err != nil { if closeErr := c.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } closeConnAndLog(conn, a.log, "failed to add candidate: %s %d: %v", candidateIP, udpAddr.Port, err) continue } existingConfigs[hostConfig] = struct{}{} } } return nil } func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []NetworkType) { //nolint:gocognit,cyclop var wg sync.WaitGroup defer wg.Wait() _, ifaces, _ := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) for _, networkType := range networkTypes { if networkType.IsTCP() { continue } network := networkType.String() wg.Add(1) go func() { defer wg.Done() conn, err := listenUDPInPortRange( a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0}, ) if err != nil { a.log.Warnf("Failed to listen %s: %v", network, err) return } lAddr, ok := conn.LocalAddr().(*net.UDPAddr) if !ok { closeConnAndLog(conn, a.log, "Address rewrite mapping is enabled but LocalAddr is not a UDPAddr") return } addresses, ok := a.resolveSrflxAddresses(lAddr.IP, findIfaceForIP(ifaces, lAddr.IP)) if !ok { closeConnAndLog( conn, a.log, "Address rewrite mapping did not provide usable external IPs for %s", lAddr.IP.String(), ) return } for idx, mappedIP := range addresses { currentConn := conn currentAddr := lAddr if idx > 0 { newConn, listenErr := listenUDPInPortRange( a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: lAddr.IP, Port: 0}, ) if listenErr != nil { closeConnAndLog(newConn, a.log, "Failed to listen %s for additional srflx mapping: %v", network, listenErr) return } currentConn = newConn var ok bool currentAddr, ok = currentConn.LocalAddr().(*net.UDPAddr) if !ok { closeConnAndLog(currentConn, a.log, "Address rewrite mapping is enabled but LocalAddr is not a UDPAddr") return } } if shouldFilterLocationTracked(mappedIP) { closeConnAndLog(currentConn, a.log, "external IP is somehow filtered for location tracking reasons %s", mappedIP) continue } srflxConfig := CandidateServerReflexiveConfig{ Network: network, Address: mappedIP.String(), Port: currentAddr.Port, Component: ComponentRTP, RelAddr: currentAddr.IP.String(), RelPort: currentAddr.Port, } c, err := NewCandidateServerReflexive(&srflxConfig) if err != nil { closeConnAndLog(currentConn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, mappedIP.String(), currentAddr.Port, err) continue } if err := a.addCandidate(ctx, c, currentConn); err != nil { if closeErr := c.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) closeConnAndLog( currentConn, a.log, "closing srflx conn after addCandidate failure: %v", err, ) } } }() } } //nolint:gocognit,cyclop func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { var wg sync.WaitGroup defer wg.Wait() for _, networkType := range networkTypes { if networkType.IsTCP() { continue } for i := range urls { for _, listenAddr := range a.udpMuxSrflx.GetListenAddresses() { udpAddr, ok := listenAddr.(*net.UDPAddr) if !ok { a.log.Warn("Failed to cast udpMuxSrflx listen address to UDPAddr") continue } wg.Add(1) go func(url stun.URI, network string, localAddr *net.UDPAddr) { defer wg.Done() hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) if err != nil { a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err) return } if shouldFilterLocationTracked(serverAddr.IP) { a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) return } xorAddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, a.stunGatherTimeout) if err != nil { a.log.Warnf("Failed get server reflexive address %s %s: %v", network, url, err) return } conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), localAddr) if err != nil { a.log.Warnf("Failed to find connection in UDPMuxSrflx %s %s: %v", network, url, err) return } ip := xorAddr.IP port := xorAddr.Port srflxConfig := CandidateServerReflexiveConfig{ Network: network, Address: ip.String(), Port: port, Component: ComponentRTP, RelAddr: localAddr.IP.String(), RelPort: localAddr.Port, } c, err := NewCandidateServerReflexive(&srflxConfig) if err != nil { closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) return } if err := a.addCandidate(ctx, c, conn); err != nil { if closeErr := c.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) } }(*urls[i], networkType.String(), udpAddr) } } } } //nolint:cyclop,gocognit func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { var wg sync.WaitGroup defer wg.Wait() for _, networkType := range networkTypes { if networkType.IsTCP() { continue } for i := range urls { wg.Add(1) go func(url stun.URI, network string) { defer wg.Done() hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) if err != nil { a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err) return } if shouldFilterLocationTracked(serverAddr.IP) { a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) return } conn, err := listenUDPInPortRange( a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0}, ) if err != nil { closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err) return } // If the agent closes midway through the connection // we end it early to prevent close delay. cancelCtx, cancelFunc := context.WithCancel(ctx) defer cancelFunc() go func() { select { case <-cancelCtx.Done(): return case <-a.loop.Done(): _ = conn.Close() } }() xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout) if err != nil { closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err) return } ip := xorAddr.IP port := xorAddr.Port lAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert srflxConfig := CandidateServerReflexiveConfig{ Network: network, Address: ip.String(), Port: port, Component: ComponentRTP, RelAddr: lAddr.IP.String(), RelPort: lAddr.Port, } c, err := NewCandidateServerReflexive(&srflxConfig) if err != nil { closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) return } if err := a.addCandidate(ctx, c, conn); err != nil { if closeErr := c.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) } }(*urls[i], networkType.String()) } } } //nolint:maintidx,gocognit,gocyclo,cyclop func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { var wg sync.WaitGroup defer wg.Wait() network := NetworkTypeUDP4.String() _, ifaces, _ := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback) for i := range urls { switch { case urls[i].Scheme != stun.SchemeTypeTURN && urls[i].Scheme != stun.SchemeTypeTURNS: continue case urls[i].Username == "": a.log.Errorf("Failed to gather relay candidates: %v", ErrUsernameEmpty) return case urls[i].Password == "": a.log.Errorf("Failed to gather relay candidates: %v", ErrPasswordEmpty) return } wg.Add(1) go func(url stun.URI) { defer wg.Done() turnServerAddr := fmt.Sprintf("%s:%d", url.Host, url.Port) var ( locConn net.PacketConn err error relAddr string relPort int relayProtocol string ) switch { case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN: if locConn, err = a.net.ListenPacket(network, "0.0.0.0:0"); err != nil { a.log.Warnf("Failed to listen %s: %v", network, err) return } relAddr = locConn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert relPort = locConn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert relayProtocol = udp case a.proxyDialer != nil && url.Proto == stun.ProtoTypeTCP && (url.Scheme == stun.SchemeTypeTURN || url.Scheme == stun.SchemeTypeTURNS): conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr) if connectErr != nil { a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr) return } relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert if url.Scheme == stun.SchemeTypeTURN { relayProtocol = tcp } else if url.Scheme == stun.SchemeTypeTURNS { relayProtocol = "tls" } locConn = turn.NewSTUNConn(conn) case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURN: tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr) if connectErr != nil { a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr) return } conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) if connectErr != nil { a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr) return } relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert relayProtocol = tcp locConn = turn.NewSTUNConn(conn) case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURNS: udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr) if connectErr != nil { a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr) return } udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr) if dialErr != nil { a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr) return } conn, connectErr := dtls.Client(&fakenet.PacketConn{Conn: udpConn}, udpConn.RemoteAddr(), &dtls.Config{ ServerName: url.Host, InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec LoggerFactory: a.loggerFactory, }) if connectErr != nil { a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) return } if connectErr = conn.HandshakeContext(ctx); connectErr != nil { a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) return } relAddr = conn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert relPort = conn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert relayProtocol = relayProtocolDTLS locConn = &fakenet.PacketConn{Conn: conn} case url.Proto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURNS: tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr) if resolvErr != nil { a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr) return } tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) if dialErr != nil { a.log.Warnf("Failed to connect to relay: %v", dialErr) return } conn := tls.Client(tcpConn, &tls.Config{ ServerName: url.Host, InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec }) if hsErr := conn.HandshakeContext(ctx); hsErr != nil { if closeErr := tcpConn.Close(); closeErr != nil { a.log.Errorf("Failed to close relay connection: %v", closeErr) } a.log.Warnf("Failed to connect to relay: %v", hsErr) return } relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert relayProtocol = relayProtocolTLS locConn = turn.NewSTUNConn(conn) default: a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url) return } factory := a.turnClientFactory if factory == nil { factory = defaultTurnClient } client, err := factory(&turn.ClientConfig{ TURNServerAddr: turnServerAddr, Conn: locConn, Username: url.Username, Password: url.Password, LoggerFactory: a.loggerFactory, Net: a.net, }) if err != nil { closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err) return } if err = client.Listen(); err != nil { client.Close() closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err) return } relayConn, err := client.Allocate() if err != nil { client.Close() closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err) return } rAddr := relayConn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert if shouldFilterLocationTracked(rAddr.IP) { a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP) return } a.addRelayCandidates(ctx, relayEndpoint{ network: network, address: rAddr.IP, port: rAddr.Port, relAddr: relAddr, relPort: relPort, iface: findIfaceForIP(ifaces, net.ParseIP(relAddr)), protocol: relayProtocol, conn: relayConn, onClose: func() error { client.Close() return locConn.Close() }, closeConn: func() { if relayConErr := relayConn.Close(); relayConErr != nil { a.log.Warnf("Failed to close relay %v", relayConErr) } }, }) }(*urls[i]) } } type relayEndpoint struct { network string address net.IP port int relAddr string relPort int protocol string iface string conn net.PacketConn onClose func() error closeConn func() } func (a *Agent) resolveRelayAddresses(ep relayEndpoint) ([]net.IP, bool) { addresses := []net.IP{ep.address} if !a.shouldRewriteCandidateType(CandidateTypeRelay) { return addresses, true } mappedIPs, matched, mode, err := a.addressRewriteMapper.findExternalIPs( CandidateTypeRelay, ep.relAddr, ep.iface, ) if err != nil { return nil, false } if !matched { return addresses, true } if len(mappedIPs) == 0 { if mode == AddressRewriteReplace { a.log.Warnf("Address rewrite mapping returned no external relay addresses for %s", ep.relAddr) return nil, false } return addresses, true } if mode == AddressRewriteReplace { return mappedIPs, true } return append(addresses, mappedIPs...), true } func (a *Agent) resolveSrflxAddresses(localIP net.IP, iface string) ([]net.IP, bool) { addresses := []net.IP{localIP} if !a.shouldRewriteCandidateType(CandidateTypeServerReflexive) { return addresses, true } mappedIPs, matched, mode, err := a.addressRewriteMapper.findExternalIPs( CandidateTypeServerReflexive, localIP.String(), iface, ) if err != nil { a.log.Warnf("Address rewrite mapping is enabled but no external IP is found for %s: %v", localIP.String(), err) return nil, false } if !matched { return addresses, true } if len(mappedIPs) == 0 { if mode == AddressRewriteReplace { return nil, false } return addresses, true } if mode == AddressRewriteReplace { return mappedIPs, true } return mappedIPs, true } func findIfaceForIP(ifaces []ifaceAddr, ip net.IP) string { if ip == nil { return "" } for _, info := range ifaces { if info.addr.String() == ip.String() { return info.iface } } return "" } func (a *Agent) createRelayCandidate(ctx context.Context, ep relayEndpoint, ip net.IP, onClose func() error) error { relayConfig := CandidateRelayConfig{ Network: ep.network, Component: ComponentRTP, Address: ip.String(), Port: ep.port, RelAddr: ep.relAddr, RelPort: ep.relPort, RelayProtocol: ep.protocol, OnClose: onClose, } candidate, err := NewCandidateRelay(&relayConfig) if err != nil { a.log.Warnf("failed to create relay candidate: %s %d: %v", ip, ep.port, err) return err } if err := a.addCandidate(ctx, candidate, ep.conn); err != nil { if closeErr := candidate.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) return err } return nil } func (a *Agent) addRelayCandidates(ctx context.Context, ep relayEndpoint) { if ep.conn == nil || ep.address == nil { return } addresses, ok := a.resolveRelayAddresses(ep) if !ok { return } for idx, ip := range addresses { onClose := ep.onClose if idx > 0 { onClose = nil } if err := a.createRelayCandidate(ctx, ep, ip, onClose); err != nil { if idx == 0 { if ep.closeConn != nil { ep.closeConn() } return } a.log.Warnf("failed to create additional relay candidate for %s: %v", ip, err) continue } } } // startNetworkMonitoring starts a goroutine that periodically checks for network changes // and re-gathers candidates when changes are detected. This is only used with GatherContinually policy. func (a *Agent) startNetworkMonitoring(ctx context.Context) { ticker := time.NewTicker(a.networkMonitorInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if a.detectNetworkChanges() { a.gatherCandidatesInternal(ctx) } } } } // detectNetworkChanges checks if the network interfaces have changed since the last check. func (a *Agent) detectNetworkChanges() bool { // Try to refresh interfaces if using stdnet if stdNet, ok := a.net.(*stdnet.Net); ok { if err := stdNet.UpdateInterfaces(); err != nil { a.log.Warnf("Failed to update interfaces: %v", err) } } _, currentAddrs, err := localInterfaces( a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback, ) if err != nil { a.log.Warnf("Failed to get local interfaces during network monitoring: %v", err) return false } currentInterfaces := make(map[string]netip.Addr) for _, info := range currentAddrs { key := info.addr.String() currentInterfaces[key] = info.addr } hasAdditions := false for key, addr := range currentInterfaces { if _, exists := a.lastKnownInterfaces[key]; !exists { a.log.Infof("New IP address detected: %s", addr) hasAdditions = true } } a.lastKnownInterfaces = currentInterfaces return hasAdditions } ice-4.1.0/gather_test.go000066400000000000000000003000011511704470200151070ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "context" "crypto/tls" "errors" "fmt" "io" "net" "net/netip" "net/url" "runtime" "sort" "strconv" "strings" "sync" "sync/atomic" "syscall" "testing" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/ice/v4/internal/taskloop" "github.com/pion/logging" "github.com/pion/stun/v3" transport "github.com/pion/transport/v3" "github.com/pion/transport/v3/test" "github.com/pion/transport/v3/vnet" "github.com/pion/turn/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/proxy" ) func skipOnPermission(t *testing.T, err error, action string) { t.Helper() if err == nil { return } if errors.Is(err, syscall.EPERM) || errors.Is(err, syscall.EACCES) || strings.Contains(err.Error(), "permission denied") || strings.Contains(err.Error(), "operation not permitted") { t.Skipf("skipping %s: %v", action, err) } } func TestListenUDP(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() _, localAddrs, err := localInterfaces( agent.net, agent.interfaceFilter, agent.ipFilter, []NetworkType{NetworkTypeUDP4}, false, ) require.NotEqual(t, len(localAddrs), 0, "localInterfaces found no interfaces, unable to test") require.NoError(t, err) ip := localAddrs[0].addr.AsSlice() conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0}) require.NoError(t, err, "listenUDP error with no port restriction") require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn") _, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0}) require.Equal(t, err, ErrPort, "listenUDP with invalid port range did not return ErrPort") conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0}) require.NoError(t, err, "listenUDP error with no port restriction") require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn") _, port, err := net.SplitHostPort(conn.LocalAddr().String()) require.NoError(t, err) require.Equal(t, port, "5000", "listenUDP with port restriction of 5000 listened on incorrect port") portMin := 5100 portMax := 5109 total := portMax - portMin + 1 result := make([]int, 0, total) portRange := make([]int, 0, total) for i := 0; i < total; i++ { conn, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0}) require.NoError(t, err, "listenUDP error with no port restriction") require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn") _, port, err = net.SplitHostPort(conn.LocalAddr().String()) require.NoError(t, err) p, _ := strconv.Atoi(port) require.False(t, p < portMin || p > portMax) result = append(result, p) portRange = append(portRange, portMin+i) } require.False(t, sort.IntsAreSorted(result)) sort.Ints(result) require.Equal(t, result, portRange) _, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0}) require.Equal(t, err, ErrPort, "listenUDP with port restriction [%d, %d], did not return ErrPort", portMin, portMax) } func TestGatherConcurrency(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() agent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, IncludeLoopback: true, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(Candidate) { candidateGatheredFunc() })) // Testing for panic for i := 0; i < 10; i++ { _ = agent.GatherCandidates() } <-candidateGathered.Done() } func TestLoopbackCandidate(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() type testCase struct { name string agentConfig *AgentConfig loExpected bool } mux, err := NewMultiUDPMuxFromPort(12500) require.NoError(t, err) muxWithLo, errlo := NewMultiUDPMuxFromPort(12501, UDPMuxFromPortWithLoopback()) require.NoError(t, errlo) unspecConn, errconn := net.ListenPacket("udp", ":0") // nolint: noctx require.NoError(t, errconn) defer func() { _ = unspecConn.Close() }() muxUnspecDefault := NewUDPMuxDefault(UDPMuxParams{ UDPConn: unspecConn, }) testCases := []testCase{ { name: "mux should not have loopback candidate", agentConfig: &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, UDPMux: mux, }, loExpected: false, }, { name: "mux with loopback should not have loopback candidate", agentConfig: &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, UDPMux: muxWithLo, }, loExpected: true, }, { name: "UDPMuxDefault with unspecified IP should not have loopback candidate", agentConfig: &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, UDPMux: muxUnspecDefault, }, loExpected: false, }, { name: "UDPMuxDefault with unspecified IP should respect agent includeloopback", agentConfig: &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, UDPMux: muxUnspecDefault, IncludeLoopback: true, }, loExpected: true, }, { name: "includeloopback enabled", agentConfig: &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, IncludeLoopback: true, }, loExpected: true, }, { name: "includeloopback disabled", agentConfig: &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, IncludeLoopback: false, }, loExpected: false, }, } for _, tc := range testCases { tcase := tc t.Run(tcase.name, func(t *testing.T) { agent, err := NewAgent(tc.agentConfig) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) var loopback int32 require.NoError(t, agent.OnCandidate(func(c Candidate) { if c != nil { if net.ParseIP(c.Address()).IsLoopback() { atomic.StoreInt32(&loopback, 1) } } else { candidateGatheredFunc() return } t.Log(c.NetworkType(), c.Priority(), c) })) require.NoError(t, agent.GatherCandidates()) <-candidateGathered.Done() require.Equal(t, tcase.loExpected, atomic.LoadInt32(&loopback) == 1) }) } require.NoError(t, mux.Close()) require.NoError(t, muxWithLo.Close()) require.NoError(t, muxUnspecDefault.Close()) } // Assert that STUN gathering is done concurrently. func TestSTUNConcurrency(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() serverPort := randomPort(t) serverListener, err := net.ListenPacket("udp4", localhostIPStr+":"+strconv.Itoa(serverPort)) // nolint: noctx require.NoError(t, err) server, err := turn.NewServer(turn.ServerConfig{ Realm: "pion.ly", AuthHandler: optimisticAuthHandler, PacketConnConfigs: []turn.PacketConnConfig{ { PacketConn: serverListener, RelayAddressGenerator: &turn.RelayAddressGeneratorNone{Address: localhostIPStr}, }, }, }) require.NoError(t, err) defer func() { require.NoError(t, server.Close()) }() urls := []*stun.URI{} for i := 0; i <= 10; i++ { urls = append(urls, &stun.URI{ Scheme: stun.SchemeTypeSTUN, Host: localhostIPStr, Port: serverPort + 1, }) } urls = append(urls, &stun.URI{ Scheme: stun.SchemeTypeSTUN, Host: localhostIPStr, Port: serverPort, }) listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, }) require.NoError(t, err) defer func() { _ = listener.Close() }() tcpMux := NewTCPMuxDefault( TCPMuxParams{ Listener: listener, Logger: logging.NewDefaultLoggerFactory().NewLogger("ice"), ReadBufferSize: 8, }, ) defer func() { _ = tcpMux.Close() }() agent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), Urls: urls, CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive}, TCPMux: tcpMux, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { candidateGatheredFunc() return } t.Log(c.NetworkType(), c.Priority(), c) })) require.NoError(t, agent.GatherCandidates()) <-candidateGathered.Done() } // Assert that TURN gathering is done concurrently. func TestTURNConcurrency(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() runTest := func( protocol stun.ProtoType, scheme stun.SchemeType, packetConn net.PacketConn, listener net.Listener, serverPort int, ) { packetConnConfigs := []turn.PacketConnConfig{} if packetConn != nil { packetConnConfigs = append(packetConnConfigs, turn.PacketConnConfig{ PacketConn: packetConn, RelayAddressGenerator: &turn.RelayAddressGeneratorNone{Address: localhostIPStr}, }) } listenerConfigs := []turn.ListenerConfig{} if listener != nil { listenerConfigs = append(listenerConfigs, turn.ListenerConfig{ Listener: listener, RelayAddressGenerator: &turn.RelayAddressGeneratorNone{Address: localhostIPStr}, }) } server, err := turn.NewServer(turn.ServerConfig{ Realm: "pion.ly", AuthHandler: optimisticAuthHandler, PacketConnConfigs: packetConnConfigs, ListenerConfigs: listenerConfigs, }) require.NoError(t, err) defer func() { require.NoError(t, server.Close()) }() urls := []*stun.URI{} // avoid long delay on unreachable ports on Windows if runtime.GOOS != "windows" { for i := 0; i <= 10; i++ { urls = append(urls, &stun.URI{ Scheme: scheme, Host: localhostIPStr, Username: "username", Password: "password", Proto: protocol, Port: serverPort + 1 + i, }) } } urls = append(urls, &stun.URI{ Scheme: scheme, Host: localhostIPStr, Username: "username", Password: "password", Proto: protocol, Port: serverPort, }) agent, err := NewAgent(&AgentConfig{ CandidateTypes: []CandidateType{CandidateTypeRelay}, InsecureSkipVerify: true, NetworkTypes: supportedNetworkTypes(), Urls: urls, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c != nil { candidateGatheredFunc() } })) require.NoError(t, agent.GatherCandidates()) <-candidateGathered.Done() } t.Run("UDP Relay", func(t *testing.T) { serverPort := randomPort(t) serverListener, err := net.ListenPacket("udp", localhostIPStr+":"+strconv.Itoa(serverPort)) // nolint: noctx require.NoError(t, err) runTest(stun.ProtoTypeUDP, stun.SchemeTypeTURN, serverListener, nil, serverPort) }) t.Run("TCP Relay", func(t *testing.T) { serverPort := randomPort(t) serverListener, err := net.Listen("tcp", localhostIPStr+":"+strconv.Itoa(serverPort)) // nolint: noctx require.NoError(t, err) runTest(stun.ProtoTypeTCP, stun.SchemeTypeTURN, nil, serverListener, serverPort) }) t.Run("TLS Relay", func(t *testing.T) { certificate, genErr := selfsign.GenerateSelfSigned() require.NoError(t, genErr) serverPort := randomPort(t) serverListener, err := tls.Listen("tcp", localhostIPStr+":"+strconv.Itoa(serverPort), &tls.Config{ //nolint:gosec Certificates: []tls.Certificate{certificate}, }) require.NoError(t, err) runTest(stun.ProtoTypeTCP, stun.SchemeTypeTURNS, nil, serverListener, serverPort) }) t.Run("DTLS Relay", func(t *testing.T) { certificate, genErr := selfsign.GenerateSelfSigned() require.NoError(t, genErr) serverPort := randomPort(t) serverListener, err := dtls.Listen( "udp", &net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort}, &dtls.Config{ Certificates: []tls.Certificate{certificate}, }, ) require.NoError(t, err) runTest(stun.ProtoTypeUDP, stun.SchemeTypeTURNS, nil, serverListener, serverPort) }) } // Assert that STUN and TURN gathering are done concurrently. func TestSTUNTURNConcurrency(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 8).Stop() serverPort := randomPort(t) serverListener, err := net.ListenPacket("udp4", localhostIPStr+":"+strconv.Itoa(serverPort)) // nolint: noctx require.NoError(t, err) server, err := turn.NewServer(turn.ServerConfig{ Realm: "pion.ly", AuthHandler: optimisticAuthHandler, PacketConnConfigs: []turn.PacketConnConfig{ { PacketConn: serverListener, RelayAddressGenerator: &turn.RelayAddressGeneratorNone{Address: localhostIPStr}, }, }, }) require.NoError(t, err) defer func() { require.NoError(t, server.Close()) }() urls := []*stun.URI{} for i := 0; i <= 10; i++ { urls = append(urls, &stun.URI{ Scheme: stun.SchemeTypeSTUN, Host: localhostIPStr, Port: serverPort + 1, }) } urls = append(urls, &stun.URI{ Scheme: stun.SchemeTypeTURN, Proto: stun.ProtoTypeUDP, Host: localhostIPStr, Port: serverPort, Username: "username", Password: "password", }) agent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), Urls: urls, CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay}, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() { // As TURN and STUN should be checked in parallel, this should complete before the default STUN timeout (5s) gatherLim := test.TimeOut(time.Second * 3) candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c != nil { candidateGatheredFunc() } })) require.NoError(t, agent.GatherCandidates()) <-candidateGathered.Done() gatherLim.Stop() } } // Assert that srflx candidates can be gathered from TURN servers // // When TURN servers are utilized, both types of candidates // (i.e. srflx and relay) are obtained from the TURN server. // // https://tools.ietf.org/html/rfc5245#section-2.1 func TestTURNSrflx(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() serverPort := randomPort(t) serverListener, err := net.ListenPacket("udp4", localhostIPStr+":"+strconv.Itoa(serverPort)) // nolint: noctx require.NoError(t, err) server, err := turn.NewServer(turn.ServerConfig{ Realm: "pion.ly", AuthHandler: optimisticAuthHandler, PacketConnConfigs: []turn.PacketConnConfig{ { PacketConn: serverListener, RelayAddressGenerator: &turn.RelayAddressGeneratorNone{Address: localhostIPStr}, }, }, }) require.NoError(t, err) defer func() { require.NoError(t, server.Close()) }() urls := []*stun.URI{{ Scheme: stun.SchemeTypeTURN, Proto: stun.ProtoTypeUDP, Host: localhostIPStr, Port: serverPort, Username: "username", Password: "password", }} agent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), Urls: urls, CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay}, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c != nil && c.Type() == CandidateTypeServerReflexive { candidateGatheredFunc() } })) require.NoError(t, agent.GatherCandidates()) <-candidateGathered.Done() } func TestGatherCandidatesRelayProducesRelay(t *testing.T) { defer test.CheckRoutines(t)() listener, err := net.ListenPacket("udp4", "127.0.0.1:0") // nolint: noctx skipOnPermission(t, err, "listening for TURN server") require.NoError(t, err) defer func() { _ = listener.Close() }() server, err := turn.NewServer(turn.ServerConfig{ Realm: "pion.ly", AuthHandler: optimisticAuthHandler, PacketConnConfigs: []turn.PacketConnConfig{ { PacketConn: listener, RelayAddressGenerator: &turn.RelayAddressGeneratorNone{Address: "127.0.0.1"}, }, }, }) require.NoError(t, err) defer func() { require.NoError(t, server.Close()) }() serverPort := listener.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert turnURL := &stun.URI{ Scheme: stun.SchemeTypeTURN, Host: "127.0.0.1", Port: serverPort, Username: "username", Password: "password", Proto: stun.ProtoTypeUDP, } agent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, CandidateTypes: []CandidateType{CandidateTypeRelay}, Urls: []*stun.URI{turnURL}, }) skipOnPermission(t, err, "creating relay agent") require.NoError(t, err) defer func() { _ = agent.Close() }() var ( mu sync.Mutex relays []Candidate gathered = make(chan struct{}) ) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(gathered) return } if c.Type() == CandidateTypeRelay { mu.Lock() relays = append(relays, c) mu.Unlock() } })) require.NoError(t, agent.GatherCandidates()) select { case <-gathered: case <-time.After(5 * time.Second): require.FailNow(t, "gatherCandidatesRelay did not finish before timeout") } mu.Lock() defer mu.Unlock() if len(relays) == 0 { t.Skip("no relay candidates gathered in this environment") } for _, r := range relays { require.Equal(t, CandidateTypeRelay, r.Type()) require.True(t, r.NetworkType().IsUDP()) } } type relayGatherNet struct { addr *net.UDPAddr } func newRelayGatherNet(addr *net.UDPAddr) *relayGatherNet { if addr == nil { addr = &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1)} } return &relayGatherNet{addr: addr} } func (n *relayGatherNet) ListenPacket(string, string) (net.PacketConn, error) { return newStubPacketConn(n.addr), nil } func (n *relayGatherNet) ListenUDP(string, *net.UDPAddr) (transport.UDPConn, error) { return nil, transport.ErrNotSupported } func (n *relayGatherNet) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) { return nil, transport.ErrNotSupported } func (n *relayGatherNet) Dial(string, string) (net.Conn, error) { return nil, transport.ErrNotSupported } func (n *relayGatherNet) DialUDP(string, *net.UDPAddr, *net.UDPAddr) (transport.UDPConn, error) { return nil, transport.ErrNotSupported } func (n *relayGatherNet) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) { return nil, transport.ErrNotSupported } func (n *relayGatherNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) { return net.ResolveIPAddr(network, address) } func (n *relayGatherNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { return net.ResolveUDPAddr(network, address) } func (n *relayGatherNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { return net.ResolveTCPAddr(network, address) } func (n *relayGatherNet) Interfaces() ([]*transport.Interface, error) { iface := transport.NewInterface(net.Interface{ Index: 1, MTU: 1500, Name: "relaytest0", Flags: net.FlagUp, }) iface.AddAddress(&net.IPNet{IP: n.addr.IP, Mask: net.CIDRMask(24, 32)}) return []*transport.Interface{iface}, nil } func (n *relayGatherNet) InterfaceByIndex(index int) (*transport.Interface, error) { ifaces, err := n.Interfaces() if err != nil { return nil, err } for _, iface := range ifaces { if iface.Index == index { return iface, nil } } return nil, transport.ErrInterfaceNotFound } func (n *relayGatherNet) InterfaceByName(name string) (*transport.Interface, error) { ifaces, err := n.Interfaces() if err != nil { return nil, err } for _, iface := range ifaces { if iface.Name == name { return iface, nil } } return nil, transport.ErrInterfaceNotFound } func (n *relayGatherNet) CreateDialer(*net.Dialer) transport.Dialer { return nil } type hostGatherNet struct { addr *net.UDPAddr } func newHostGatherNet(addr *net.UDPAddr) *hostGatherNet { if addr == nil { addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)} } return &hostGatherNet{addr: addr} } func (n *hostGatherNet) ListenPacket(string, string) (net.PacketConn, error) { return newStubPacketConn(n.addr), nil } func (n *hostGatherNet) ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { if laddr == nil { laddr = n.addr } return net.ListenUDP(network, laddr) //nolint:wrapcheck } func (n *hostGatherNet) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) { return nil, transport.ErrNotSupported } func (n *hostGatherNet) Dial(string, string) (net.Conn, error) { return nil, transport.ErrNotSupported } func (n *hostGatherNet) DialUDP(string, *net.UDPAddr, *net.UDPAddr) (transport.UDPConn, error) { return nil, transport.ErrNotSupported } func (n *hostGatherNet) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) { return nil, transport.ErrNotSupported } func (n *hostGatherNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) { return net.ResolveIPAddr(network, address) } func (n *hostGatherNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { return net.ResolveUDPAddr(network, address) } func (n *hostGatherNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { return net.ResolveTCPAddr(network, address) } func (n *hostGatherNet) Interfaces() ([]*transport.Interface, error) { iface := transport.NewInterface(net.Interface{ Index: 1, MTU: 1500, Name: "hosttest0", Flags: net.FlagUp, }) iface.AddAddress(&net.IPNet{IP: n.addr.IP, Mask: net.CIDRMask(24, 32)}) return []*transport.Interface{iface}, nil } func (n *hostGatherNet) InterfaceByIndex(index int) (*transport.Interface, error) { ifaces, err := n.Interfaces() if err != nil { return nil, err } for _, iface := range ifaces { if iface.Index == index { return iface, nil } } return nil, transport.ErrInterfaceNotFound } func (n *hostGatherNet) InterfaceByName(name string) (*transport.Interface, error) { ifaces, err := n.Interfaces() if err != nil { return nil, err } for _, iface := range ifaces { if iface.Name == name { return iface, nil } } return nil, transport.ErrInterfaceNotFound } func (n *hostGatherNet) CreateDialer(*net.Dialer) transport.Dialer { return nil } type errorPacketConn struct { addr net.Addr closed bool } type testTCPPacketConn struct { addr *net.TCPAddr } func (c *testTCPPacketConn) ReadFrom([]byte) (int, net.Addr, error) { return 0, c.addr, io.EOF } func (c *testTCPPacketConn) WriteTo(p []byte, _ net.Addr) (int, error) { return len(p), nil } func (c *testTCPPacketConn) Close() error { return nil } func (c *testTCPPacketConn) LocalAddr() net.Addr { return c.addr } func (c *testTCPPacketConn) SetDeadline(time.Time) error { return nil } func (c *testTCPPacketConn) SetReadDeadline(time.Time) error { return nil } func (c *testTCPPacketConn) SetWriteDeadline(time.Time) error { return nil } type boundTCPMux struct { localAddr net.Addr } func (m *boundTCPMux) Close() error { return nil } func (m *boundTCPMux) GetConnByUfrag(_ string, _ bool, local net.IP) (net.PacketConn, error) { return &testTCPPacketConn{addr: &net.TCPAddr{IP: local, Port: 12345}}, nil } func (m *boundTCPMux) RemoveConnByUfrag(string) {} func (m *boundTCPMux) LocalAddr() net.Addr { if m.localAddr != nil { return m.localAddr } return &net.TCPAddr{} } func (c *errorPacketConn) ReadFrom(_ []byte) (int, net.Addr, error) { return 0, c.addr, io.EOF } func (c *errorPacketConn) WriteTo(_ []byte, _ net.Addr) (int, error) { return 0, errors.New("write failure") //nolint:err113 // test } func (c *errorPacketConn) Close() error { c.closed = true return nil } func (c *errorPacketConn) LocalAddr() net.Addr { return c.addr } func (c *errorPacketConn) SetDeadline(time.Time) error { return nil } func (c *errorPacketConn) SetReadDeadline(time.Time) error { return nil } func (c *errorPacketConn) SetWriteDeadline(time.Time) error { return nil } type errorTurnNet struct { pc net.PacketConn } func (n *errorTurnNet) ListenPacket(string, string) (net.PacketConn, error) { return n.pc, nil } func (n *errorTurnNet) ListenUDP(string, *net.UDPAddr) (transport.UDPConn, error) { return nil, transport.ErrNotSupported } func (n *errorTurnNet) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) { return nil, transport.ErrNotSupported } func (n *errorTurnNet) Dial(string, string) (net.Conn, error) { return nil, transport.ErrNotSupported } func (n *errorTurnNet) DialUDP(string, *net.UDPAddr, *net.UDPAddr) (transport.UDPConn, error) { return nil, transport.ErrNotSupported } func (n *errorTurnNet) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) { return nil, transport.ErrNotSupported } func (n *errorTurnNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) { return net.ResolveIPAddr(network, address) } func (n *errorTurnNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { return net.ResolveUDPAddr(network, address) } func (n *errorTurnNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { return net.ResolveTCPAddr(network, address) } func (n *errorTurnNet) Interfaces() ([]*transport.Interface, error) { iface := transport.NewInterface(net.Interface{ Index: 1, MTU: 1500, Name: "errturn0", Flags: net.FlagUp, }) iface.AddAddress(&net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)}) return []*transport.Interface{iface}, nil } func (n *errorTurnNet) InterfaceByIndex(index int) (*transport.Interface, error) { ifaces, err := n.Interfaces() if err != nil { return nil, err } for _, iface := range ifaces { if iface.Index == index { return iface, nil } } return nil, transport.ErrInterfaceNotFound } func (n *errorTurnNet) InterfaceByName(name string) (*transport.Interface, error) { ifaces, err := n.Interfaces() if err != nil { return nil, err } for _, iface := range ifaces { if iface.Name == name { return iface, nil } } return nil, transport.ErrInterfaceNotFound } func (n *errorTurnNet) CreateDialer(*net.Dialer) transport.Dialer { return nil } type stubTurnClient struct { listenCalled bool allocateCalled bool closeCalled bool cfgConn net.PacketConn relayConn net.PacketConn } func (s *stubTurnClient) Listen() error { s.listenCalled = true return nil } func (s *stubTurnClient) Allocate() (net.PacketConn, error) { s.allocateCalled = true if s.relayConn == nil { s.relayConn = newStubPacketConn(&net.UDPAddr{IP: net.IP{203, 0, 113, 5}, Port: 5000}) } return s.relayConn, nil } func (s *stubTurnClient) Close() { s.closeCalled = true } func TestGatherCandidatesRelayCallsAddRelayCandidates(t *testing.T) { defer test.CheckRoutines(t)() stubClient := &stubTurnClient{} locConn := newStubPacketConn(&net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 50000}) stubClient.relayConn = locConn agent, err := NewAgentWithOptions( WithNet(newRelayGatherNet(&net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 50000})), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeRelay}), WithAddressRewriteRules( AddressRewriteRule{ External: []string{"198.51.100.77"}, Local: "10.0.0.1", AsCandidateType: CandidateTypeRelay, Mode: AddressRewriteReplace, }, ), WithUrls([]*stun.URI{ { Scheme: stun.SchemeTypeTURN, Host: "example.com", Port: 3478, Username: "username", Password: "password", Proto: stun.ProtoTypeUDP, }, }), WithMulticastDNSMode(MulticastDNSModeDisabled), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() agent.turnClientFactory = func(cfg *turn.ClientConfig) (turnClient, error) { stubClient.cfgConn = cfg.Conn return stubClient, nil } candCh := make(chan Candidate, 1) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c != nil && c.Type() == CandidateTypeRelay { candCh <- c } })) agent.gatherCandidatesRelay(context.Background(), agent.urls) var cand Candidate select { case cand = <-candCh: case <-time.After(2 * time.Second): assert.Fail(t, "expected relay candidate") } require.Equal(t, CandidateTypeRelay, cand.Type()) assert.Equal(t, "198.51.100.77", cand.Address()) assert.True(t, stubClient.listenCalled) assert.True(t, stubClient.allocateCalled) relay, ok := cand.(*CandidateRelay) require.True(t, ok) require.NoError(t, relay.close()) assert.True(t, stubClient.closeCalled) assert.True(t, locConn.closed) } func TestGatherCandidatesRelayDefaultClientError(t *testing.T) { defer test.CheckRoutines(t)() errConn := &errorPacketConn{addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}} agent, err := NewAgentWithOptions( WithNet(&errorTurnNet{pc: errConn}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeRelay}), WithUrls([]*stun.URI{ { Scheme: stun.SchemeTypeTURN, Proto: stun.ProtoTypeUDP, Host: "127.0.0.1", Port: 3478, Username: "user", Password: "pass", }, }), WithMulticastDNSMode(MulticastDNSModeDisabled), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateCh := make(chan struct{}, 1) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c != nil { candidateCh <- struct{}{} } })) agent.gatherCandidatesRelay(context.Background(), agent.urls) select { case <-candidateCh: assert.Fail(t, "unexpected candidate when TURN client fails") case <-time.After(200 * time.Millisecond): } assert.True(t, errConn.closed, "expected packet conn to be closed on TURN client failure") } func TestCloseConnLog(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, a.Close()) }() closeConnAndLog(nil, a.log, "normal nil") var nc *net.UDPConn closeConnAndLog(nc, a.log, "nil ptr") } type mockProxy struct { proxyWasDialed func() } type mockConn struct{} func (m *mockConn) Read([]byte) (n int, err error) { return 0, io.EOF } func (m *mockConn) Write([]byte) (int, error) { return 0, io.EOF } func (m *mockConn) Close() error { return io.EOF } func (m *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} } func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } func (m *mockConn) SetDeadline(time.Time) error { return io.EOF } func (m *mockConn) SetReadDeadline(time.Time) error { return io.EOF } func (m *mockConn) SetWriteDeadline(time.Time) error { return io.EOF } func (m *mockProxy) Dial(string, string) (net.Conn, error) { m.proxyWasDialed() return &mockConn{}, nil } func TestTURNProxyDialer(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() proxyWasDialed, proxyWasDialedFunc := context.WithCancel(context.Background()) proxy.RegisterDialerType("tcp", func(*url.URL, proxy.Dialer) (proxy.Dialer, error) { return &mockProxy{proxyWasDialedFunc}, nil }) tcpProxyURI, err := url.Parse("tcp://fakeproxy:3128") require.NoError(t, err) proxyDialer, err := proxy.FromURL(tcpProxyURI, proxy.Direct) require.NoError(t, err) agent, err := NewAgent(&AgentConfig{ CandidateTypes: []CandidateType{CandidateTypeRelay}, NetworkTypes: supportedNetworkTypes(), Urls: []*stun.URI{ { Scheme: stun.SchemeTypeTURN, Host: localhostIPStr, Username: "username", Password: "password", Proto: stun.ProtoTypeTCP, Port: 5000, }, }, ProxyDialer: proxyDialer, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateGatherFinish, candidateGatherFinishFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { candidateGatherFinishFunc() } })) require.NoError(t, agent.GatherCandidates()) <-candidateGatherFinish.Done() <-proxyWasDialed.Done() } func buildSimpleVNet(t *testing.T) (*vnet.Router, *vnet.Net) { t.Helper() router, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: logging.NewDefaultLoggerFactory(), }) require.NoError(t, err) nw, err := vnet.NewNet(&vnet.NetConfig{}) require.NoError(t, err) require.NoError(t, router.AddNet(nw)) require.NoError(t, router.Start()) return router, nw } func TestGatherCandidatesSrflxMappedPortRangeError(t *testing.T) { defer test.CheckRoutines(t)() router, nw := buildSimpleVNet(t) defer func() { require.NoError(t, router.Stop()) }() agent, err := NewAgentWithOptions( WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeServerReflexive}), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"203.0.113.10"}, Local: "0.0.0.0", AsCandidateType: CandidateTypeServerReflexive, }), WithNet(nw), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.OnCandidate(func(Candidate) {})) agent.portMin = 9000 agent.portMax = 8000 agent.gatherCandidatesSrflxMapped(context.Background(), []NetworkType{NetworkTypeUDP4}) localCandidates, err := agent.GetLocalCandidates() require.NoError(t, err) require.Len(t, localCandidates, 0) } func TestGatherCandidatesLocalUDPMux(t *testing.T) { t.Run("requires mux", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() err = agent.gatherCandidatesLocalUDPMux(context.Background()) require.ErrorIs(t, err, errUDPMuxDisabled) }) t.Run("creates host candidates from mux addresses", func(t *testing.T) { listenAddr := &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 4789} udpMux := newMockUDPMux([]net.Addr{listenAddr}) agent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, CandidateTypes: []CandidateType{CandidateTypeHost}, UDPMux: udpMux, IncludeLoopback: true, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.OnCandidate(func(Candidate) {})) err = agent.gatherCandidatesLocalUDPMux(context.Background()) require.NoError(t, err) candidates, err := agent.GetLocalCandidates() require.NoError(t, err) require.NotEmpty(t, candidates) host, ok := candidates[0].(*CandidateHost) require.True(t, ok, "expected host candidate") require.Equal(t, listenAddr.IP.String(), host.Address()) require.Equal(t, listenAddr.Port, host.Port()) require.Equal(t, 1, udpMux.connCount(), "expected mux to provide a single connection") }) } func TestGatherCandidatesSrflxUDPMux(t *testing.T) { stunURI := &stun.URI{ Scheme: stun.SchemeTypeSTUN, Host: "127.0.0.1", Port: 3478, } relatedAddr := &net.UDPAddr{IP: net.IP{10, 0, 0, 1}, Port: 49000} srflxAddr := &stun.XORMappedAddress{ IP: net.IP{203, 0, 113, 5}, Port: 50000, } udpMuxSrflx := newMockUniversalUDPMux([]net.Addr{relatedAddr}, srflxAddr) agent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, CandidateTypes: []CandidateType{CandidateTypeServerReflexive}, UDPMuxSrflx: udpMuxSrflx, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.OnCandidate(func(Candidate) {})) agent.gatherCandidatesSrflxUDPMux(context.Background(), []*stun.URI{stunURI}, []NetworkType{NetworkTypeUDP4}) candidates, err := agent.GetLocalCandidates() require.NoError(t, err) require.Len(t, candidates, 1) srflx, ok := candidates[0].(*CandidateServerReflexive) require.True(t, ok, "expected server reflexive candidate") require.Equal(t, srflxAddr.IP.String(), srflx.Address()) require.Equal(t, srflxAddr.Port, srflx.Port()) require.NotNil(t, srflx.RelatedAddress()) require.Equal(t, relatedAddr.IP.String(), srflx.RelatedAddress().Address) require.Equal(t, relatedAddr.Port, srflx.RelatedAddress().Port) require.Equal(t, 1, udpMuxSrflx.connCount(), "expected mux to be asked for one connection") } // TestUDPMuxDefaultWithNAT1To1IPsUsage requires that candidates // are given and connections are valid when using UDPMuxDefault and NAT1To1IPs. func TestUDPMuxDefaultWithNAT1To1IPsUsage(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() conn, err := net.ListenPacket("udp4", ":0") // nolint: noctx require.NoError(t, err) defer func() { _ = conn.Close() }() mux := NewUDPMuxDefault(UDPMuxParams{ UDPConn: conn, }) defer func() { _ = mux.Close() }() agent, err := NewAgent(&AgentConfig{ NAT1To1IPs: []string{"1.2.3.4"}, NAT1To1IPCandidateType: CandidateTypeHost, UDPMux: mux, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() gatherCandidateDone := make(chan struct{}) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(gatherCandidateDone) } else { require.Equal(t, "1.2.3.4", c.Address()) } })) require.NoError(t, agent.GatherCandidates()) <-gatherCandidateDone require.NotEqual(t, 0, len(mux.connsIPv4)) } // Assert that candidates are given for each mux in a MultiUDPMux. func TestMultiUDPMuxUsage(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() var expectedPorts []int var udpMuxInstances []UDPMux for i := 0; i < 3; i++ { port := randomPort(t) conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: port}) require.NoError(t, err) defer func() { _ = conn.Close() }() expectedPorts = append(expectedPorts, port) muxDefault := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn}) udpMuxInstances = append(udpMuxInstances, muxDefault) idx := i defer func() { _ = udpMuxInstances[idx].Close() }() } agent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, CandidateTypes: []CandidateType{CandidateTypeHost}, UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...), }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateCh := make(chan Candidate) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(candidateCh) return } candidateCh <- c })) require.NoError(t, agent.GatherCandidates()) portFound := make(map[int]bool) for c := range candidateCh { portFound[c.Port()] = true require.True(t, c.NetworkType().IsUDP(), "All candidates should be UDP") } require.Len(t, portFound, len(expectedPorts)) for _, port := range expectedPorts { require.True(t, portFound[port], "There should be a candidate for each UDP mux port") } } func closedStartedCh() <-chan struct{} { ch := make(chan struct{}) close(ch) return ch } func TestResolveRelayAddresses(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("test") t.Run("no mapping", func(t *testing.T) { agent := &Agent{log: logger} ep := relayEndpoint{address: net.IPv4(10, 0, 0, 10), relAddr: "198.51.100.1"} addrs, ok := agent.resolveRelayAddresses(ep) assert.True(t, ok) assert.Equal(t, []net.IP{ep.address}, addrs) }) t.Run("append mode adds mapped address", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.10"}, Local: "198.51.100.1", AsCandidateType: CandidateTypeRelay, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 10), relAddr: "198.51.100.1"} addrs, ok := agent.resolveRelayAddresses(ep) assert.True(t, ok) require.Len(t, addrs, 2) assert.Equal(t, "10.0.0.10", addrs[0].String()) assert.Equal(t, "203.0.113.10", addrs[1].String()) }) t.Run("replace mode swaps to mapped", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.20"}, Local: "198.51.100.2", AsCandidateType: CandidateTypeRelay, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 11), relAddr: "198.51.100.2"} addrs, ok := agent.resolveRelayAddresses(ep) assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, "203.0.113.20", addrs[0].String()) }) t.Run("replace match with zero external drops", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "198.51.100.4", AsCandidateType: CandidateTypeRelay, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 13), relAddr: "198.51.100.4"} addrs, ok := agent.resolveRelayAddresses(ep) assert.False(t, ok) assert.Empty(t, addrs) }) t.Run("append match with zero external keeps original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "198.51.100.5", AsCandidateType: CandidateTypeRelay, Mode: AddressRewriteAppend, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 14), relAddr: "198.51.100.5"} addrs, ok := agent.resolveRelayAddresses(ep) assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, "10.0.0.14", addrs[0].String()) }) t.Run("invalid relAddr returns error", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.30"}, AsCandidateType: CandidateTypeRelay, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 13), relAddr: "not-an-ip"} addrs, ok := agent.resolveRelayAddresses(ep) assert.False(t, ok) assert.Nil(t, addrs) }) t.Run("mapper present but unmatched keeps original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.40"}, Local: "198.51.100.4", AsCandidateType: CandidateTypeRelay, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 14), relAddr: "198.51.100.5"} addrs, ok := agent.resolveRelayAddresses(ep) assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, "10.0.0.14", addrs[0].String()) }) t.Run("relay rewrite respects iface filter", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.41"}, Local: "198.51.100.6", AsCandidateType: CandidateTypeRelay, Iface: "hosttest0", Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, net: newHostGatherNet(&net.UDPAddr{IP: net.IPv4(198, 51, 100, 6)}), } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 41), relAddr: "198.51.100.6", iface: "hosttest0"} addrs, ok := agent.resolveRelayAddresses(ep) assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, "203.0.113.41", addrs[0].String()) agent.addressRewriteMapper.rulesByCandidateType[CandidateTypeRelay][0].rule.Iface = "other0" addrs, ok = agent.resolveRelayAddresses(ep) assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, "10.0.0.41", addrs[0].String()) }) } func TestResolveHostAndSrflxFallbacks(t *testing.T) { //nolint:maintidx logger := logging.NewDefaultLoggerFactory().NewLogger("test") t.Run("host no rule keeps original", func(t *testing.T) { agent := &Agent{ addressRewriteMapper: &addressRewriteMapper{ rulesByCandidateType: make(map[CandidateType][]*addressRewriteRuleMapping), }, log: logger, } addr := netip.MustParseAddr("10.0.0.45") mapped, ok := agent.applyHostAddressRewrite(addr, []netip.Addr{addr}, "") assert.True(t, ok) require.Len(t, mapped, 1) assert.Equal(t, addr, mapped[0]) }) t.Run("host replace unmatched keeps original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.50"}, Local: "198.51.100.50", AsCandidateType: CandidateTypeHost, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } addr := netip.MustParseAddr("10.0.0.50") mapped, ok := agent.applyHostAddressRewrite(addr, []netip.Addr{addr}, "") assert.True(t, ok) require.Len(t, mapped, 1) assert.Equal(t, addr, mapped[0]) }) t.Run("host replace match with zero external drops", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "10.0.0.51", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } addr := netip.MustParseAddr("10.0.0.51") mapped, ok := agent.applyHostAddressRewrite(addr, []netip.Addr{addr}, "") assert.False(t, ok) assert.Empty(t, mapped) }) t.Run("host append match with zero external keeps original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "10.0.0.52", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } addr := netip.MustParseAddr("10.0.0.52") mapped, ok := agent.applyHostAddressRewrite(addr, []netip.Addr{addr}, "") assert.True(t, ok) require.Len(t, mapped, 1) assert.Equal(t, addr, mapped[0]) }) t.Run("host rewrite respects iface filter", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.53"}, Local: "10.0.0.53", AsCandidateType: CandidateTypeHost, Iface: "hosttest0", }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } addr := netip.MustParseAddr("10.0.0.53") mapped, ok := agent.applyHostAddressRewrite(addr, []netip.Addr{addr}, "hosttest0") assert.True(t, ok) require.Len(t, mapped, 1) assert.Equal(t, "203.0.113.53", mapped[0].String()) mapped, ok = agent.applyHostAddressRewrite(addr, []netip.Addr{addr}, "other0") assert.True(t, ok) require.Len(t, mapped, 1) assert.Equal(t, addr, mapped[0]) }) t.Run("srflx replace unmatched keeps original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"198.51.100.60"}, Local: "203.0.113.60", AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } localIP := net.IPv4(192, 0, 2, 60) addrs, ok := agent.resolveSrflxAddresses(localIP, "hosttest0") assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, localIP.String(), addrs[0].String()) }) t.Run("srflx replace match with zero external drops", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "192.0.2.70", AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } localIP := net.IPv4(192, 0, 2, 70) addrs, ok := agent.resolveSrflxAddresses(localIP, "hosttest0") assert.False(t, ok) assert.Empty(t, addrs) }) t.Run("srflx append match with zero external keeps original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "192.0.2.71", AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteAppend, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } localIP := net.IPv4(192, 0, 2, 71) addrs, ok := agent.resolveSrflxAddresses(localIP, "hosttest0") assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, localIP.String(), addrs[0].String()) }) t.Run("srflx rewrite applies only on matching iface", func(t *testing.T) { localIP := net.IPv4(192, 0, 2, 90) mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"198.51.100.90"}, Local: localIP.String(), AsCandidateType: CandidateTypeServerReflexive, Iface: "hosttest0", Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, net: newHostGatherNet(&net.UDPAddr{IP: localIP}), } addrs, ok := agent.resolveSrflxAddresses(localIP, "hosttest0") assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, "198.51.100.90", addrs[0].String()) mapper.rulesByCandidateType[CandidateTypeServerReflexive][0].rule.Iface = "other0" addrs, ok = agent.resolveSrflxAddresses(localIP, "hosttest0") assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, localIP.String(), addrs[0].String()) }) t.Run("srflx append catch-all with zero external keeps original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteAppend, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } localIP := net.IPv4(192, 0, 2, 72) addrs, ok := agent.resolveSrflxAddresses(localIP, "") assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, localIP.String(), addrs[0].String()) }) t.Run("srflx no mapper returns original", func(t *testing.T) { agent := &Agent{ addressRewriteMapper: nil, log: logger, } localIP := net.IPv4(192, 0, 2, 90) addrs, ok := agent.resolveSrflxAddresses(localIP, "") assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, localIP.String(), addrs[0].String()) }) t.Run("srflx replace with zero externals drops", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "192.0.2.91", AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } localIP := net.IPv4(192, 0, 2, 91) addrs, ok := agent.resolveSrflxAddresses(localIP, "") assert.False(t, ok) assert.Nil(t, addrs) }) t.Run("srflx invalid local ip returns false", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"198.51.100.99"}, AsCandidateType: CandidateTypeServerReflexive, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } addrs, ok := agent.resolveSrflxAddresses(nil, "") assert.False(t, ok) assert.Nil(t, addrs) }) t.Run("relay unmatched keeps original", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.70"}, Local: "198.51.100.70", AsCandidateType: CandidateTypeRelay, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 70), relAddr: "198.51.100.71"} addrs, ok := agent.resolveRelayAddresses(ep) assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, "10.0.0.70", addrs[0].String()) }) } func TestCatchAllRewriteApplied(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("test") t.Run("host catch-all replaces", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.80"}, AsCandidateType: CandidateTypeHost, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } addr := netip.MustParseAddr("10.0.0.80") mapped, ok := agent.applyHostAddressRewrite(addr, []netip.Addr{addr}, "") assert.True(t, ok) require.Len(t, mapped, 1) assert.Equal(t, "203.0.113.80", mapped[0].String()) }) t.Run("srflx catch-all appends mapped only", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.81"}, AsCandidateType: CandidateTypeServerReflexive, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } local := net.IPv4(10, 0, 0, 81) addrs, ok := agent.resolveSrflxAddresses(local, "") assert.True(t, ok) require.Len(t, addrs, 1) assert.Equal(t, "203.0.113.81", addrs[0].String()) }) t.Run("relay catch-all appends", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.82"}, AsCandidateType: CandidateTypeRelay, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logger, } ep := relayEndpoint{address: net.IPv4(10, 0, 0, 82), relAddr: "0.0.0.0"} addrs, ok := agent.resolveRelayAddresses(ep) assert.True(t, ok) require.Len(t, addrs, 2) assert.Equal(t, "10.0.0.82", addrs[0].String()) assert.Equal(t, "203.0.113.82", addrs[1].String()) }) } func TestAddRelayCandidatesWithRewrite(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.77"}, Local: "198.51.100.77", AsCandidateType: CandidateTypeRelay, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logging.NewDefaultLoggerFactory().NewLogger("test"), loop: taskloop.New(func() {}), localCandidates: make(map[NetworkType][]Candidate), remoteCandidates: make(map[NetworkType][]Candidate), startedCh: closedStartedCh(), candidateNotifier: &handlerNotifier{ candidateFunc: func(Candidate) {}, done: make(chan struct{}), }, } ep := relayEndpoint{ network: "udp", address: net.IPv4(10, 0, 0, 50), port: 3478, relAddr: "198.51.100.77", relPort: 50000, conn: newStubPacketConn(nil), } ctx, cancel := context.WithCancel(context.Background()) defer cancel() t.Cleanup(func() { agent.loop.Close() }) agent.addRelayCandidates(ctx, ep) cands := agent.localCandidates[NetworkTypeUDP4] require.Len(t, cands, 2) assert.Equal(t, "10.0.0.50", cands[0].Address()) assert.Equal(t, "203.0.113.77", cands[1].Address()) } func TestAddRelayCandidatesSkipsNilConnOrAddress(t *testing.T) { agent := &Agent{ log: logging.NewDefaultLoggerFactory().NewLogger("test"), localCandidates: make(map[NetworkType][]Candidate), remoteCandidates: make(map[NetworkType][]Candidate), startedCh: closedStartedCh(), candidateNotifier: &handlerNotifier{ candidateFunc: func(Candidate) {}, done: make(chan struct{}), }, loop: taskloop.New(func() {}), } t.Cleanup(func() { agent.loop.Close() }) ctx := context.Background() agent.addRelayCandidates(ctx, relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.IPv4(10, 0, 0, 1), port: 3478, relAddr: "198.51.100.1", relPort: 5000, conn: nil, }) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Len(t, cands, 0) agent.addRelayCandidates(ctx, relayEndpoint{ network: NetworkTypeUDP4.String(), address: nil, port: 3478, relAddr: "198.51.100.1", relPort: 5000, conn: newStubPacketConn(nil), }) cands, err = agent.GetLocalCandidates() require.NoError(t, err) assert.Len(t, cands, 0) } func TestAddRelayCandidatesSkipsWhenResolveFails(t *testing.T) { t.Run("replace with zero externals drops", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "198.51.100.2", AsCandidateType: CandidateTypeRelay, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logging.NewDefaultLoggerFactory().NewLogger("test"), localCandidates: make(map[NetworkType][]Candidate), remoteCandidates: make(map[NetworkType][]Candidate), startedCh: closedStartedCh(), candidateNotifier: &handlerNotifier{ candidateFunc: func(Candidate) {}, done: make(chan struct{}), }, loop: taskloop.New(func() {}), } t.Cleanup(func() { agent.loop.Close() }) agent.addRelayCandidates(context.Background(), relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.IPv4(10, 0, 0, 2), port: 3478, relAddr: "198.51.100.2", relPort: 5000, conn: newStubPacketConn(nil), }) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Len(t, cands, 0) }) t.Run("invalid relAddr causes skip", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: []string{"203.0.113.10"}, AsCandidateType: CandidateTypeRelay, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: logging.NewDefaultLoggerFactory().NewLogger("test"), localCandidates: make(map[NetworkType][]Candidate), remoteCandidates: make(map[NetworkType][]Candidate), startedCh: closedStartedCh(), candidateNotifier: &handlerNotifier{ candidateFunc: func(Candidate) {}, done: make(chan struct{}), }, loop: taskloop.New(func() {}), } t.Cleanup(func() { agent.loop.Close() }) agent.addRelayCandidates(context.Background(), relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.IPv4(10, 0, 0, 3), port: 3478, relAddr: "not-an-ip", relPort: 5000, conn: newStubPacketConn(nil), }) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Len(t, cands, 0) }) } func TestCreateRelayCandidateErrorPaths(t *testing.T) { t.Run("NewCandidateRelay failure skips and closes conn", func(t *testing.T) { var closed bool agent := &Agent{ log: logging.NewDefaultLoggerFactory().NewLogger("test"), localCandidates: make(map[NetworkType][]Candidate), remoteCandidates: make(map[NetworkType][]Candidate), startedCh: closedStartedCh(), candidateNotifier: &handlerNotifier{ candidateFunc: func(Candidate) {}, done: make(chan struct{}), }, loop: taskloop.New(func() {}), } t.Cleanup(func() { agent.loop.Close() }) ep := relayEndpoint{ network: "bogus-network", address: net.IPv4(10, 0, 0, 4), port: 3478, relAddr: "198.51.100.4", relPort: 5000, conn: newStubPacketConn(nil), closeConn: func() { closed = true }, } agent.addRelayCandidates(context.Background(), ep) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Empty(t, cands) assert.True(t, closed) }) t.Run("addCandidate failure triggers candidate close", func(t *testing.T) { var onCloseCalled int agent := &Agent{ log: logging.NewDefaultLoggerFactory().NewLogger("test"), localCandidates: make(map[NetworkType][]Candidate), remoteCandidates: make(map[NetworkType][]Candidate), startedCh: closedStartedCh(), candidateNotifier: &handlerNotifier{ candidateFunc: func(Candidate) {}, done: make(chan struct{}), }, loop: taskloop.New(func() {}), } t.Cleanup(func() { agent.loop.Close() }) ctx, cancel := context.WithCancel(context.Background()) cancel() // force addCandidate to fail agent.addRelayCandidates(ctx, relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.IPv4(10, 0, 0, 5), port: 3478, relAddr: "198.51.100.5", relPort: 5000, conn: newStubPacketConn(nil), onClose: func() error { onCloseCalled++ return fmt.Errorf("close err") //nolint:err113 }, }) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Empty(t, cands) assert.Equal(t, 1, onCloseCalled) }) } func TestGatherCandidatesLocalTCPMuxSkipsUnboundInterfaces(t *testing.T) { tcpMux := &boundTCPMux{ localAddr: &net.TCPAddr{IP: net.ParseIP("203.0.113.10"), Port: 5555}, } agent, err := NewAgentWithOptions( WithNet(newHostGatherNet(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeTCP4}), WithTCPMux(tcpMux), WithIncludeLoopback(), WithMulticastDNSMode(MulticastDNSModeDisabled), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NoError(t, agent.OnCandidate(func(Candidate) {})) agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeTCP4}) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Empty(t, cands) } func TestGatherCandidatesLocalHostErrorPaths(t *testing.T) { t.Run("UDPMux invalid address closes conn", func(t *testing.T) { mux := newInvalidAddrUDPMux() rec := &recordingLogger{} agent, err := NewAgentWithOptions( WithNet(newHostGatherNet(nil)), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithUDPMux(mux), WithMulticastDNSMode(MulticastDNSModeDisabled), WithLoggerFactory(&recordingLoggerFactory{logger: rec}), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NoError(t, agent.OnCandidate(func(Candidate) {})) assert.NoError(t, agent.gatherCandidatesLocalUDPMux(context.Background())) assert.True(t, mux.conn.closed) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Empty(t, cands) assert.Greater(t, len(rec.warnings), 0) }) t.Run("NewCandidateHost failure logs and closes conn", func(t *testing.T) { rec := &recordingLogger{} agent, err := NewAgentWithOptions( WithNet(newHostGatherNet(nil)), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithMulticastDNSMode(MulticastDNSModeQueryAndGather), WithLoggerFactory(&recordingLoggerFactory{logger: rec}), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NoError(t, agent.OnCandidate(func(Candidate) {})) agent.includeLoopback = true agent.mDNSName = "invalid-mdns" // no .local suffix -> NewCandidateHost parse fails agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Empty(t, cands) assert.Greater(t, len(rec.warnings), 0) }) t.Run("addCandidate error logs and keeps no candidates", func(t *testing.T) { rec := &recordingLogger{} agent, err := NewAgentWithOptions( WithNet(newHostGatherNet(nil)), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithMulticastDNSMode(MulticastDNSModeDisabled), WithLoggerFactory(&recordingLoggerFactory{logger: rec}), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NoError(t, agent.OnCandidate(func(Candidate) {})) agent.includeLoopback = true agent.loop.Close() agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}) agent.loop.Run(agent.loop, func(context.Context) { //nolint:errcheck,gosec assert.Empty(t, agent.localCandidates[NetworkTypeUDP4]) }) assert.Greater(t, len(rec.warnings), 0) }) t.Run("host rewrite replace with zero externals skips candidate", func(t *testing.T) { mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "192.0.2.10", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ net: newHostGatherNet(&net.UDPAddr{IP: net.IPv4(192, 0, 2, 10)}), networkTypes: []NetworkType{NetworkTypeUDP4}, includeLoopback: true, mDNSMode: MulticastDNSModeDisabled, addressRewriteMapper: mapper, localCandidates: make(map[NetworkType][]Candidate), remoteCandidates: make(map[NetworkType][]Candidate), log: logging.NewDefaultLoggerFactory().NewLogger("test"), } agent.loop = taskloop.New(func() {}) t.Cleanup(func() { agent.loop.Close() }) agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}) cands, err := agent.GetLocalCandidates() require.NoError(t, err) assert.Empty(t, cands) }) runUDPMuxRewrite := func( name string, rule AddressRewriteRule, ip net.IP, mux UDPMux, expectLen int, expectAddrs []string, ) { t.Run(name, func(t *testing.T) { rec := &recordingLogger{} mapper, err := newAddressRewriteMapper([]AddressRewriteRule{rule}) require.NoError(t, err) agent, err := NewAgentWithOptions( WithNet(newHostGatherNet(&net.UDPAddr{IP: ip})), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithUDPMux(mux), WithMulticastDNSMode(MulticastDNSModeDisabled), WithLoggerFactory(&recordingLoggerFactory{logger: rec}), ) require.NoError(t, err) agent.addressRewriteMapper = mapper t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NoError(t, agent.OnCandidate(func(Candidate) {})) require.NoError(t, agent.gatherCandidatesLocalUDPMux(context.Background())) cands, err := agent.GetLocalCandidates() require.NoError(t, err) if expectLen == 0 { assert.Empty(t, cands) } else { require.Len(t, cands, expectLen) got := []string{} for _, c := range cands { got = append(got, c.Address()) } assert.ElementsMatch(t, expectAddrs, got) } }) } runUDPMuxRewrite( "UDPMux append with zero externals logs and keeps original", AddressRewriteRule{ External: nil, Local: "10.0.0.11", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, net.IPv4(10, 0, 0, 11), newMockUDPMux([]net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 11}, Port: 1234}}), 1, []string{"10.0.0.11"}, ) runUDPMuxRewrite( "UDPMux replace with zero externals logs and drops", AddressRewriteRule{ External: nil, Local: "10.0.0.12", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, net.IPv4(10, 0, 0, 12), newMockUDPMux([]net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 12}, Port: 1234}}), 0, nil, ) runUDPMuxRewrite( "UDPMux findExternalIPs error logs and drops", AddressRewriteRule{ External: []string{"203.0.113.9"}, AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, net.IPv4zero, newInvalidAddrUDPMux(), 0, nil, ) } func TestApplyHostRewriteForUDPMuxErrors(t *testing.T) { rec := &recordingLogger{} mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "10.0.0.50", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent := &Agent{ addressRewriteMapper: mapper, log: rec, } in := []net.IP{net.IPv4(10, 0, 0, 50)} out, ok := agent.applyHostRewriteForUDPMux(in, &net.UDPAddr{IP: net.IPv4(10, 0, 0, 50), Port: 1234}) assert.False(t, ok) assert.Equal(t, in, out) } // Assert that candidates are given for each mux in a MultiTCPMux. func TestMultiTCPMuxUsage(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() var expectedPorts []int var tcpMuxInstances []TCPMux for i := 0; i < 3; i++ { port := randomPort(t) listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, Port: port, }) require.NoError(t, err) defer func() { _ = listener.Close() }() expectedPorts = append(expectedPorts, port) tcpMux := NewTCPMuxDefault(TCPMuxParams{ Listener: listener, ReadBufferSize: 8, }) defer func() { _ = tcpMux.Close() }() tcpMuxInstances = append(tcpMuxInstances, tcpMux) } agent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), CandidateTypes: []CandidateType{CandidateTypeHost}, TCPMux: NewMultiTCPMuxDefault(tcpMuxInstances...), }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() candidateCh := make(chan Candidate) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(candidateCh) return } candidateCh <- c })) require.NoError(t, agent.GatherCandidates()) portFound := make(map[int]bool) for c := range candidateCh { activeCandidate := c.Port() == 0 if c.NetworkType().IsTCP() && !activeCandidate { portFound[c.Port()] = true } } require.Len(t, portFound, len(expectedPorts)) for _, port := range expectedPorts { require.True(t, portFound[port], "There should be a candidate for each TCP mux port") } } func TestGatherAddressRewriteHostModes(t *testing.T) { //nolint:cyclop t.Run("replace host via UDPMux", func(t *testing.T) { mux := newMockUDPMux([]net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 1}, Port: 1234}}) agent, err := NewAgentWithOptions( WithNet(newStubNet(t)), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithUDPMux(mux), WithMulticastDNSMode(MulticastDNSModeDisabled), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"203.0.113.1"}, Local: "10.0.0.1", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) var ( mu sync.Mutex addresses []Candidate done = make(chan struct{}) ) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(done) return } mu.Lock() addresses = append(addresses, c) mu.Unlock() })) require.NoError(t, agent.GatherCandidates()) select { case <-done: case <-time.After(2 * time.Second): require.FailNow(t, "gather did not complete") } mu.Lock() defer mu.Unlock() require.Len(t, addresses, 1) assert.Equal(t, "203.0.113.1", addresses[0].Address()) assert.Equal(t, CandidateTypeHost, addresses[0].Type()) }) t.Run("append host via UDPMux", func(t *testing.T) { mux := newMockUDPMux([]net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 1}, Port: 1234}}) agent, err := NewAgentWithOptions( WithNet(newStubNet(t)), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithUDPMux(mux), WithMulticastDNSMode(MulticastDNSModeDisabled), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"203.0.113.2"}, Local: "10.0.0.1", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) var ( mu sync.Mutex addresses []Candidate done = make(chan struct{}) ) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(done) return } mu.Lock() addresses = append(addresses, c) mu.Unlock() })) require.NoError(t, agent.GatherCandidates()) select { case <-done: case <-time.After(2 * time.Second): require.FailNow(t, "gather did not complete") } mu.Lock() defer mu.Unlock() require.Len(t, addresses, 2) seenAddrs := []string{addresses[0].Address(), addresses[1].Address()} assert.ElementsMatch(t, []string{"10.0.0.1", "203.0.113.2"}, seenAddrs) for _, cand := range addresses { assert.Equal(t, CandidateTypeHost, cand.Type()) } }) t.Run("replace host via UDPMux with empty mapping drops candidate", func(t *testing.T) { mux := newMockUDPMux([]net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 2}, Port: 1234}}) agent, err := NewAgentWithOptions( WithNet(newStubNet(t)), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithUDPMux(mux), WithMulticastDNSMode(MulticastDNSModeDisabled), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "10.0.0.2", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteReplace, }, }) require.NoError(t, err) agent.addressRewriteMapper = mapper var ( mu sync.Mutex addresses []Candidate done = make(chan struct{}) ) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(done) return } mu.Lock() addresses = append(addresses, c) mu.Unlock() })) require.NoError(t, agent.GatherCandidates()) select { case <-done: case <-time.After(2 * time.Second): require.FailNow(t, "gather did not complete") } mu.Lock() defer mu.Unlock() assert.Empty(t, addresses) assert.Equal(t, 0, mux.connCount()) }) t.Run("append host via UDPMux with missing externals keeps original", func(t *testing.T) { mux := newMockUDPMux([]net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 3}, Port: 1234}}) agent, err := NewAgentWithOptions( WithNet(newStubNet(t)), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithUDPMux(mux), WithMulticastDNSMode(MulticastDNSModeDisabled), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) mapper, err := newAddressRewriteMapper([]AddressRewriteRule{ { External: nil, Local: "10.0.0.3", AsCandidateType: CandidateTypeHost, Mode: AddressRewriteAppend, }, }) require.NoError(t, err) agent.addressRewriteMapper = mapper var ( mu sync.Mutex addresses []Candidate done = make(chan struct{}) ) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(done) return } mu.Lock() addresses = append(addresses, c) mu.Unlock() })) require.NoError(t, agent.GatherCandidates()) select { case <-done: case <-time.After(2 * time.Second): require.FailNow(t, "gather did not complete") } mu.Lock() defer mu.Unlock() require.Len(t, addresses, 1) assert.Equal(t, "10.0.0.3", addresses[0].Address()) assert.Equal(t, 1, mux.connCount()) }) } func TestGatherAddressRewriteSrflxModes(t *testing.T) { urls := []*stun.URI{{ Scheme: SchemeTypeSTUN, Host: "127.0.0.1", Port: 3478, }} t.Run("append srflx still gathers", func(t *testing.T) { mux := newCountingUniversalUDPMux( []net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 2}, Port: 2345}}, &stun.XORMappedAddress{ IP: net.IP{198, 51, 100, 10}, Port: 5000, }, ) var ( mu sync.Mutex addresses []string done = make(chan struct{}) ) agent, err := NewAgentWithOptions( WithNet(newStubNet(t)), WithCandidateTypes([]CandidateType{CandidateTypeServerReflexive}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithUDPMuxSrflx(mux), WithMulticastDNSMode(MulticastDNSModeDisabled), WithUrls(urls), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(done) } mu.Lock() if c != nil { addresses = append(addresses, c.Address()) } mu.Unlock() })) require.NoError(t, agent.GatherCandidates()) select { case <-done: case <-time.After(2 * time.Second): require.FailNow(t, "gather did not complete") } assert.Greater(t, mux.getConnForURLCount, 0) mu.Lock() require.Len(t, addresses, 1) assert.Equal(t, "198.51.100.10", addresses[0]) mu.Unlock() }) t.Run("replace srflx skips gather", func(t *testing.T) { router, nw := buildSimpleVNet(t) defer func() { require.NoError(t, router.Stop()) }() var ( mu sync.Mutex addresses []string done = make(chan struct{}) ) agent, err := NewAgentWithOptions( WithNet(nw), WithCandidateTypes([]CandidateType{CandidateTypeServerReflexive}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithMulticastDNSMode(MulticastDNSModeDisabled), WithUrls(urls), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"203.0.113.50"}, Local: "0.0.0.0", AsCandidateType: CandidateTypeServerReflexive, Mode: AddressRewriteReplace, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(done) } mu.Lock() if c != nil { addresses = append(addresses, c.Address()) } mu.Unlock() })) require.NoError(t, agent.GatherCandidates()) select { case <-done: case <-time.After(2 * time.Second): require.FailNow(t, "gather did not complete") } mu.Lock() require.Len(t, addresses, 1) assert.Equal(t, "203.0.113.50", addresses[0]) mu.Unlock() }) } func TestGatherAddressRewriteRelayModes(t *testing.T) { t.Run("replace relay", func(t *testing.T) { agent, err := NewAgentWithOptions( WithNet(newStubNet(t)), WithCandidateTypes([]CandidateType{CandidateTypeRelay}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithMulticastDNSMode(MulticastDNSModeDisabled), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"203.0.113.60"}, Local: "10.0.0.10", AsCandidateType: CandidateTypeRelay, Mode: AddressRewriteReplace, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) agent.addRelayCandidates(context.Background(), relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.ParseIP("192.0.2.10"), port: 5000, relAddr: "10.0.0.10", relPort: 4000, protocol: udp, conn: newStubPacketConn(&net.UDPAddr{IP: net.IP{10, 0, 0, 10}, Port: 4000}), }) local, err := agent.GetLocalCandidates() require.NoError(t, err) require.Len(t, local, 1) assert.Equal(t, "203.0.113.60", local[0].Address()) }) t.Run("append relay", func(t *testing.T) { agent, err := NewAgentWithOptions( WithNet(newStubNet(t)), WithCandidateTypes([]CandidateType{CandidateTypeRelay}), WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithMulticastDNSMode(MulticastDNSModeDisabled), WithAddressRewriteRules(AddressRewriteRule{ External: []string{"203.0.113.70"}, Local: "10.0.0.20", AsCandidateType: CandidateTypeRelay, Mode: AddressRewriteAppend, }), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, agent.Close()) }) agent.addRelayCandidates(context.Background(), relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.ParseIP("192.0.2.20"), port: 6000, relAddr: "10.0.0.20", relPort: 5000, protocol: udp, conn: newStubPacketConn(&net.UDPAddr{IP: net.IP{10, 0, 0, 20}, Port: 5000}), }) local, err := agent.GetLocalCandidates() require.NoError(t, err) require.Len(t, local, 2) addresses := []string{local[0].Address(), local[1].Address()} assert.ElementsMatch(t, []string{"192.0.2.20", "203.0.113.70"}, addresses) }) } // Assert that UniversalUDPMux is used while gathering when configured in the Agent. func TestUniversalUDPMuxUsage(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: randomPort(t)}) require.NoError(t, err) defer func() { _ = conn.Close() }() udpMuxSrflx := &universalUDPMuxMock{ conn: conn, } numSTUNS := 3 urls := []*stun.URI{} for i := 0; i < numSTUNS; i++ { urls = append(urls, &stun.URI{ Scheme: SchemeTypeSTUN, Host: localhostIPStr, Port: 3478 + i, }) } agent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), Urls: urls, CandidateTypes: []CandidateType{CandidateTypeServerReflexive}, UDPMuxSrflx: udpMuxSrflx, }) require.NoError(t, err) var aClosed bool defer func() { if aClosed { return } require.NoError(t, agent.Close()) }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { candidateGatheredFunc() return } t.Log(c.NetworkType(), c.Priority(), c) })) require.NoError(t, agent.GatherCandidates()) <-candidateGathered.Done() require.NoError(t, agent.Close()) aClosed = true // Twice because of 2 STUN servers configured require.Equal( t, numSTUNS, udpMuxSrflx.getXORMappedAddrUsedTimes, "expected times that GetXORMappedAddr should be called", ) // One for Restart() when agent has been initialized and one time when Close() the agent require.Equal(t, 2, udpMuxSrflx.removeConnByUfragTimes, "expected times that RemoveConnByUfrag should be called") // Twice because of 2 STUN servers configured require.Equal(t, numSTUNS, udpMuxSrflx.getConnForURLTimes, "expected times that GetConnForURL should be called") } type universalUDPMuxMock struct { UDPMux getXORMappedAddrUsedTimes int removeConnByUfragTimes int getConnForURLTimes int mu sync.Mutex conn *net.UDPConn } func (m *universalUDPMuxMock) GetRelayedAddr(net.Addr, time.Duration) (*net.Addr, error) { return nil, errNotImplemented } func (m *universalUDPMuxMock) GetConnForURL(string, string, net.Addr) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() m.getConnForURLTimes++ return m.conn, nil } func (m *universalUDPMuxMock) GetXORMappedAddr(net.Addr, time.Duration) (*stun.XORMappedAddress, error) { m.mu.Lock() defer m.mu.Unlock() m.getXORMappedAddrUsedTimes++ return &stun.XORMappedAddress{IP: net.IP{100, 64, 0, 1}, Port: 77878}, nil } func (m *universalUDPMuxMock) RemoveConnByUfrag(string) { m.mu.Lock() defer m.mu.Unlock() m.removeConnByUfragTimes++ } func (m *universalUDPMuxMock) GetListenAddresses() []net.Addr { return []net.Addr{m.conn.LocalAddr()} } type countingUniversalUDPMux struct { *mockUniversalUDPMux getConnForURLCount int removeConnByUfragCount int } func newCountingUniversalUDPMux(addrs []net.Addr, xorAddr *stun.XORMappedAddress) *countingUniversalUDPMux { return &countingUniversalUDPMux{ mockUniversalUDPMux: newMockUniversalUDPMux(addrs, xorAddr), } } func (m *countingUniversalUDPMux) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { m.getConnForURLCount++ return m.mockUniversalUDPMux.GetConnForURL(ufrag, url, addr) } func (m *countingUniversalUDPMux) RemoveConnByUfrag(s string) { m.removeConnByUfragCount++ m.mockUniversalUDPMux.RemoveConnByUfrag(s) } func TestGatherCandidatesSrflxMappedEmitsCandidates(t *testing.T) { defer test.CheckRoutines(t)() router, nw := buildSimpleVNet(t) defer func() { require.NoError(t, router.Stop()) }() agent, err := NewAgentWithOptions( WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeServerReflexive}), WithAddressRewriteRules(AddressRewriteRule{ External: []string{ "203.0.113.10", "203.0.113.20", }, AsCandidateType: CandidateTypeServerReflexive, }), WithNet(nw), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() var ( mu sync.Mutex seen []Candidate gathered = make(chan struct{}) ) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c == nil { close(gathered) return } mu.Lock() seen = append(seen, c) mu.Unlock() })) require.NoError(t, agent.GatherCandidates()) select { case <-gathered: case <-time.After(5 * time.Second): require.FailNow(t, "gatherCandidatesSrflxMapped did not finish before timeout") } mu.Lock() addresses := make([]string, 0, len(seen)) for _, cand := range seen { addresses = append(addresses, cand.Address()) } mu.Unlock() require.Len(t, addresses, 2) require.ElementsMatch(t, []string{"203.0.113.10", "203.0.113.20"}, addresses) localCandidates, err := agent.GetLocalCandidates() require.NoError(t, err) require.Len(t, localCandidates, 2) for _, cand := range localCandidates { require.Equal(t, CandidateTypeServerReflexive, cand.Type()) relAddr := cand.RelatedAddress() require.NotNil(t, relAddr) require.NotEmpty(t, relAddr.Address) require.Equal(t, relAddr.Port, cand.Port()) } } func TestGatherCandidatesSrflxMappedMissingExternalIPs(t *testing.T) { defer test.CheckRoutines(t)() router, nw := buildSimpleVNet(t) defer func() { require.NoError(t, router.Stop()) }() agent, err := NewAgentWithOptions( WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeServerReflexive}), WithNet(nw), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() require.NoError(t, agent.OnCandidate(func(Candidate) {})) agent.addressRewriteMapper = &addressRewriteMapper{ rulesByCandidateType: map[CandidateType][]*addressRewriteRuleMapping{ CandidateTypeServerReflexive: { { ipv4Mapping: ipMapping{ ipMap: map[string][]net.IP{ "192.0.2.10": {net.ParseIP("203.0.113.10")}, }, valid: true, }, allowIPv4: true, }, }, }, } agent.gatherCandidatesSrflxMapped(context.Background(), []NetworkType{NetworkTypeUDP4}) localCandidates, err := agent.GetLocalCandidates() require.NoError(t, err) require.Len(t, localCandidates, 1) require.Equal(t, CandidateTypeServerReflexive, localCandidates[0].Type()) } func TestShouldFilterLocationTrackedIP(t *testing.T) { linkLocal := netip.MustParseAddr("fe80::1") globalV6 := netip.MustParseAddr("2001:db8::1") ipv4 := netip.MustParseAddr("192.0.2.1") require.True(t, shouldFilterLocationTrackedIP(linkLocal)) require.False(t, shouldFilterLocationTrackedIP(globalV6)) require.False(t, shouldFilterLocationTrackedIP(ipv4)) } func TestShouldFilterLocationTracked(t *testing.T) { require.True(t, shouldFilterLocationTracked(net.ParseIP("fe80::abcd"))) require.False(t, shouldFilterLocationTracked(net.ParseIP("2001:db8::abcd"))) require.False(t, shouldFilterLocationTracked(net.ParseIP("192.0.2.10"))) require.False(t, shouldFilterLocationTracked(net.IP{})) } func TestContinualGatheringPolicy(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory.DefaultLogLevel = logging.LogLevelDebug t.Run("GatherOnce completes gathering", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{ //nolint:contextcheck NetworkTypes: []NetworkType{NetworkTypeUDP4}, CandidateTypes: []CandidateType{CandidateTypeHost}, LoggerFactory: loggerFactory, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() // Set handler to collect candidates candidateCh := make(chan Candidate, 10) err = agent.OnCandidate(func(c Candidate) { if c != nil { candidateCh <- c } }) require.NoError(t, err) // Start gathering err = agent.GatherCandidates() //nolint:contextcheck require.NoError(t, err) // Wait for gathering to complete gatheringComplete := false timeout := time.After(5 * time.Second) for !gatheringComplete { select { case <-candidateCh: // Got a candidate, continue case <-timeout: assert.Fail(t, "Timeout waiting for gathering to complete") case <-time.After(100 * time.Millisecond): // Check if gathering is complete state, gatherErr := agent.GetGatheringState() //nolint:contextcheck require.NoError(t, gatherErr) if state == GatheringStateComplete { gatheringComplete = true } case <-ctx.Done(): assert.Fail(t, "Context timeout") } } // Verify gathering state is complete state, err := agent.GetGatheringState() //nolint:contextcheck require.NoError(t, err) assert.Equal(t, GatheringStateComplete, state, "GatherOnce should set state to Complete") }) t.Run("GatherContinually never completes", func(t *testing.T) { monitorInterval := 500 * time.Millisecond agent, err := NewAgentWithOptions( //nolint:contextcheck WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithContinualGatheringPolicy(GatherContinually), WithNetworkMonitorInterval(monitorInterval), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() // Set handler to collect candidates candidateCh := make(chan Candidate, 10) err = agent.OnCandidate(func(c Candidate) { if c != nil { candidateCh <- c } }) require.NoError(t, err) // Start gathering err = agent.GatherCandidates() //nolint:contextcheck require.NoError(t, err) // Wait for initial candidates select { case <-candidateCh: // Got at least one candidate case <-time.After(5 * time.Second): assert.Fail(t, "Timeout waiting for initial candidates") case <-ctx.Done(): assert.Fail(t, "Context timeout") } // Wait to ensure gathering doesn't complete time.Sleep(1 * time.Second) // Verify gathering state is still gathering state, err := agent.GetGatheringState() //nolint:contextcheck require.NoError(t, err) assert.Equal(t, GatheringStateGathering, state, "GatherContinually should keep state as Gathering") }) t.Run("Network monitoring interval is configurable", func(t *testing.T) { customInterval := 100 * time.Millisecond agent, err := NewAgentWithOptions( WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithContinualGatheringPolicy(GatherContinually), WithNetworkMonitorInterval(customInterval), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() // Verify the interval was set assert.Equal(t, customInterval, agent.networkMonitorInterval) }) t.Run("Default network monitoring interval", func(t *testing.T) { agent, err := NewAgentWithOptions( WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithContinualGatheringPolicy(GatherContinually), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() // Verify default interval is 2 seconds assert.Equal(t, 2*time.Second, agent.networkMonitorInterval) }) } func TestNetworkChangeDetection(t *testing.T) { // Limit runtime in case of deadlocks report := test.CheckRoutines(t) defer report() loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory.DefaultLogLevel = logging.LogLevelDebug t.Run("detectNetworkChanges identifies new interfaces", func(t *testing.T) { customInterval := 100 * time.Millisecond agent, err := NewAgentWithOptions( WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), WithCandidateTypes([]CandidateType{CandidateTypeHost}), WithContinualGatheringPolicy(GatherContinually), WithNetworkMonitorInterval(customInterval), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() // Initialize the last known interfaces _, addrs, err := localInterfaces( agent.net, agent.interfaceFilter, agent.ipFilter, agent.networkTypes, agent.includeLoopback, ) require.NoError(t, err) for _, info := range addrs { agent.lastKnownInterfaces[info.addr.String()] = info.addr } // First check should return false (no changes) hasChanges := agent.detectNetworkChanges() assert.False(t, hasChanges, "Should not detect changes when interfaces haven't changed") // Simulate a removed interface by clearing the last known interfaces // and then checking again if len(agent.lastKnownInterfaces) > 0 { // Remove one interface from the map to simulate change for key := range agent.lastKnownInterfaces { delete(agent.lastKnownInterfaces, key) break } // This should detect a change hasChanges = agent.detectNetworkChanges() assert.True(t, hasChanges, "Should detect changes when interfaces are different") } }) } func TestContinualGatheringPolicyString(t *testing.T) { tests := []struct { policy ContinualGatheringPolicy expected string }{ {GatherOnce, "gather_once"}, {GatherContinually, "gather_continually"}, {ContinualGatheringPolicy(99), "unknown"}, } for _, tt := range tests { t.Run(tt.expected, func(t *testing.T) { assert.Equal(t, tt.expected, tt.policy.String()) }) } } type stubPacketConn struct { addr net.Addr closed bool mu sync.Mutex } func newStubPacketConn(addr net.Addr) *stubPacketConn { if addr == nil { addr = &net.UDPAddr{IP: net.IPv4zero, Port: 0} } return &stubPacketConn{addr: addr} } func (s *stubPacketConn) ReadFrom(_ []byte) (int, net.Addr, error) { return 0, s.addr, io.EOF } func (s *stubPacketConn) WriteTo(p []byte, _ net.Addr) (int, error) { return len(p), nil } func (s *stubPacketConn) Close() error { s.mu.Lock() defer s.mu.Unlock() s.closed = true return nil } func (s *stubPacketConn) LocalAddr() net.Addr { return s.addr } func (s *stubPacketConn) SetDeadline(time.Time) error { return nil } func (s *stubPacketConn) SetReadDeadline(time.Time) error { return nil } func (s *stubPacketConn) SetWriteDeadline(time.Time) error { return nil } type mockUDPMux struct { listenAddrs []net.Addr mu sync.Mutex conns []*stubPacketConn } func newMockUDPMux(addrs []net.Addr) *mockUDPMux { return &mockUDPMux{listenAddrs: addrs} } func (m *mockUDPMux) GetConn(string, net.Addr) (net.PacketConn, error) { conn := newStubPacketConn(m.listenAddrs[0]) m.mu.Lock() m.conns = append(m.conns, conn) m.mu.Unlock() return conn, nil } func (m *mockUDPMux) RemoveConnByUfrag(string) {} func (m *mockUDPMux) GetListenAddresses() []net.Addr { return m.listenAddrs } func (m *mockUDPMux) Close() error { return nil } func (m *mockUDPMux) connCount() int { m.mu.Lock() defer m.mu.Unlock() return len(m.conns) } type invalidAddrUDPMux struct { conn *stubPacketConn } func newInvalidAddrUDPMux() *invalidAddrUDPMux { return &invalidAddrUDPMux{conn: newStubPacketConn(&net.UDPAddr{IP: net.IPv4(10, 0, 0, 10), Port: 1234})} } func (m *invalidAddrUDPMux) GetConn(string, net.Addr) (net.PacketConn, error) { return m.conn, nil } func (m *invalidAddrUDPMux) RemoveConnByUfrag(string) {} func (m *invalidAddrUDPMux) GetListenAddresses() []net.Addr { return []net.Addr{&net.UDPAddr{IP: nil, Port: 1234}} } func (m *invalidAddrUDPMux) Close() error { return nil } type mockUniversalUDPMux struct { *mockUDPMux xorAddr *stun.XORMappedAddress } func newMockUniversalUDPMux(addrs []net.Addr, xorAddr *stun.XORMappedAddress) *mockUniversalUDPMux { return &mockUniversalUDPMux{ mockUDPMux: newMockUDPMux(addrs), xorAddr: xorAddr, } } func (m *mockUniversalUDPMux) GetXORMappedAddr(net.Addr, time.Duration) (*stun.XORMappedAddress, error) { return m.xorAddr, nil } func (m *mockUniversalUDPMux) GetRelayedAddr(net.Addr, time.Duration) (*net.Addr, error) { return nil, errNotImplemented } func (m *mockUniversalUDPMux) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { return m.mockUDPMux.GetConn(ufrag+url, addr) } ice-4.1.0/gather_vnet_test.go000066400000000000000000000253261511704470200161610ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "context" "fmt" "net" "testing" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/pion/transport/v3/vnet" "github.com/stretchr/testify/require" ) func TestVNetGather(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() t.Run("No local IP address", func(t *testing.T) { n, err := vnet.NewNet(&vnet.NetConfig{}) require.NoError(t, err) a, err := NewAgent(&AgentConfig{ Net: n, }) require.NoError(t, err) defer func() { require.NoError(t, a.Close()) }() _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.Len(t, localIPs, 0) require.NoError(t, err) }) t.Run("Gather a dynamic IP address", func(t *testing.T) { cider := "1.2.3.0/24" _, ipNet, err := net.ParseCIDR(cider) require.NoError(t, err) router, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: cider, LoggerFactory: loggerFactory, }) require.NoError(t, err) nw, err := vnet.NewNet(&vnet.NetConfig{}) require.NoError(t, err) require.NoError(t, router.AddNet(nw)) a, err := NewAgent(&AgentConfig{ Net: nw, }) require.NoError(t, err) defer func() { require.NoError(t, a.Close()) }() _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.Len(t, localAddrs, 1) require.NoError(t, err) for _, addr := range localAddrs { require.False(t, addr.addr.IsLoopback()) require.True(t, ipNet.Contains(addr.addr.AsSlice())) } }) t.Run("listenUDP", func(t *testing.T) { router, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) require.NoError(t, err) nw, err := vnet.NewNet(&vnet.NetConfig{}) require.NoError(t, err) require.NoError(t, router.AddNet(nw)) agent, err := NewAgent(&AgentConfig{Net: nw}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() _, localAddrs, err := localInterfaces( agent.net, agent.interfaceFilter, agent.ipFilter, []NetworkType{NetworkTypeUDP4}, false, ) require.NotEqual(t, 0, len(localAddrs)) require.NoError(t, err) ip := localAddrs[0].addr.AsSlice() conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0}) require.NoError(t, err) require.NotNil(t, conn) require.NoError(t, conn.Close()) _, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0}) require.ErrorIs(t, ErrPort, err) conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0}) require.NoError(t, err) require.NotNil(t, conn) defer func() { require.NoError(t, conn.Close()) }() _, port, err := net.SplitHostPort(conn.LocalAddr().String()) require.NoError(t, err) require.Equal(t, "5000", port) }) } func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("gather 1:1 NAT external IPs as host candidates", func(t *testing.T) { externalIP0 := "1.2.3.4" externalIP1 := "1.2.3.5" localIP0 := "10.0.0.1" localIP1 := "10.0.0.2" map0 := fmt.Sprintf("%s/%s", externalIP0, localIP0) map1 := fmt.Sprintf("%s/%s", externalIP1, localIP1) wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) require.NoError(t, err, "should succeed") lan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "10.0.0.0/24", StaticIPs: []string{map0, map1}, NATType: &vnet.NATType{ Mode: vnet.NATModeNAT1To1, }, LoggerFactory: loggerFactory, }) require.NoError(t, err, "should succeed") err = wan.AddRouter(lan) require.NoError(t, err, "should succeed") nw, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{localIP0, localIP1}, }) require.NoError(t, err) err = lan.AddNet(nw) require.NoError(t, err, "should succeed") agent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{ NetworkTypeUDP4, }, NAT1To1IPs: []string{map0, map1}, Net: nw, }) require.NoError(t, err, "should succeed") defer func() { require.NoError(t, agent.Close()) }() done := make(chan struct{}) err = agent.OnCandidate(func(c Candidate) { if c == nil { close(done) } }) require.NoError(t, err, "should succeed") err = agent.GatherCandidates() require.NoError(t, err, "should succeed") log.Debug("Wait until gathering is complete...") <-done log.Debug("Gathering is done") candidates, err := agent.GetLocalCandidates() require.NoError(t, err, "should succeed") require.Len(t, candidates, 2) lAddr := [2]*net.UDPAddr{nil, nil} for i, candi := range candidates { lAddr[i] = candi.(*CandidateHost).conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert require.Equal(t, candi.Port(), lAddr[i].Port) } if candidates[0].Address() == externalIP0 { //nolint:nestif require.Equal(t, candidates[1].Address(), externalIP1) require.Equal(t, lAddr[0].IP.String(), localIP0) require.Equal(t, lAddr[1].IP.String(), localIP1) } else if candidates[0].Address() == externalIP1 { require.Equal(t, candidates[1].Address(), externalIP0) require.Equal(t, lAddr[0].IP.String(), localIP1) require.Equal(t, lAddr[1].IP.String(), localIP0) } }) t.Run("gather 1:1 NAT external IPs as srflx candidates", func(t *testing.T) { wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) require.NoError(t, err, "should succeed") lan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "10.0.0.0/24", StaticIPs: []string{ "1.2.3.4/10.0.0.1", }, NATType: &vnet.NATType{ Mode: vnet.NATModeNAT1To1, }, LoggerFactory: loggerFactory, }) require.NoError(t, err, "should succeed") err = wan.AddRouter(lan) require.NoError(t, err, "should succeed") nw, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{ "10.0.0.1", }, }) require.NoError(t, err) err = lan.AddNet(nw) require.NoError(t, err, "should succeed") agent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{ NetworkTypeUDP4, }, NAT1To1IPs: []string{ "1.2.3.4", }, NAT1To1IPCandidateType: CandidateTypeServerReflexive, Net: nw, }) require.NoError(t, err, "should succeed") defer func() { require.NoError(t, agent.Close()) }() done := make(chan struct{}) err = agent.OnCandidate(func(c Candidate) { if c == nil { close(done) } }) require.NoError(t, err, "should succeed") err = agent.GatherCandidates() require.NoError(t, err, "should succeed") log.Debug("Wait until gathering is complete...") <-done log.Debug("Gathering is done") candidates, err := agent.GetLocalCandidates() require.NoError(t, err, "should succeed") require.Len(t, candidates, 2) var candiHost *CandidateHost var candiSrflx *CandidateServerReflexive for _, candidate := range candidates { switch candi := candidate.(type) { case *CandidateHost: candiHost = candi case *CandidateServerReflexive: candiSrflx = candi default: t.Fatal("Unexpected candidate type") // nolint } } require.NotNil(t, candiHost, "should not be nil") require.Equal(t, "10.0.0.1", candiHost.Address(), "should match") require.NotNil(t, candiSrflx, "should not be nil") require.Equal(t, "1.2.3.4", candiSrflx.Address(), "should match") }) } func TestVNetGatherWithInterfaceFilter(t *testing.T) { defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() router, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) require.NoError(t, err) nw, err := vnet.NewNet(&vnet.NetConfig{}) require.NoError(t, err) require.NoError(t, router.AddNet(nw)) t.Run("InterfaceFilter should exclude the interface", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{ Net: nw, InterfaceFilter: func(interfaceName string) (keep bool) { require.Equal(t, "eth0", interfaceName) return false }, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() _, localIPs, err := localInterfaces( agent.net, agent.interfaceFilter, agent.ipFilter, []NetworkType{NetworkTypeUDP4}, false, ) require.NoError(t, err) require.Len(t, localIPs, 0) }) t.Run("IPFilter should exclude the IP", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{ Net: nw, IPFilter: func(ip net.IP) (keep bool) { require.Equal(t, net.IP{1, 2, 3, 1}, ip) return false }, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() _, localIPs, err := localInterfaces( agent.net, agent.interfaceFilter, agent.ipFilter, []NetworkType{NetworkTypeUDP4}, false, ) require.NoError(t, err) require.Len(t, localIPs, 0) }) t.Run("InterfaceFilter should not exclude the interface", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{ Net: nw, InterfaceFilter: func(interfaceName string) (keep bool) { require.Equal(t, "eth0", interfaceName) return true }, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() _, localIPs, err := localInterfaces( agent.net, agent.interfaceFilter, agent.ipFilter, []NetworkType{NetworkTypeUDP4}, false, ) require.NoError(t, err) require.Len(t, localIPs, 1) }) } func TestVNetGather_TURNConnectionLeak(t *testing.T) { defer test.CheckRoutines(t)() turnServerURL := &stun.URI{ Scheme: stun.SchemeTypeTURN, Host: vnetSTUNServerIP, Port: vnetSTUNServerPort, Username: "user", Password: "pass", Proto: stun.ProtoTypeUDP, } // buildVNet with a Symmetric NATs for both LANs natType := &vnet.NATType{ MappingBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent, } v, err := buildVNet(natType, natType) require.NoError(t, err, "should succeed") defer v.close() cfg0 := &AgentConfig{ Urls: []*stun.URI{ turnServerURL, }, NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, NAT1To1IPs: []string{vnetGlobalIPA}, Net: v.net0, } aAgent, err := NewAgent(cfg0) require.NoError(t, err, "should succeed") defer func() { // Assert relay conn leak on close. require.NoError(t, aAgent.Close()) }() aAgent.gatherCandidatesRelay(context.Background(), []*stun.URI{turnServerURL}) } ice-4.1.0/go.mod000066400000000000000000000013601511704470200133630ustar00rootroot00000000000000module github.com/pion/ice/v4 go 1.21 require ( github.com/google/uuid v1.6.0 github.com/pion/dtls/v3 v3.0.9 github.com/pion/logging v0.2.4 github.com/pion/mdns/v2 v2.1.0 github.com/pion/randutil v0.1.0 github.com/pion/stun/v3 v3.0.2 github.com/pion/transport/v3 v3.1.1 github.com/pion/turn/v4 v4.1.3 github.com/stretchr/testify v1.11.1 golang.org/x/net v0.35.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect golang.org/x/crypto v0.33.0 // indirect golang.org/x/sys v0.30.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) ice-4.1.0/go.sum000066400000000000000000000064171511704470200134200ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM= github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY= github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/stun/v3 v3.0.2 h1:BJuGEN2oLrJisiNEJtUTJC4BGbzbfp37LizfqswblFU= github.com/pion/stun/v3 v3.0.2/go.mod h1:JFJKfIWvt178MCF5H/YIgZ4VX3LYE77vca4b9HP60SA= github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= github.com/pion/turn/v4 v4.1.3 h1:jVNW0iR05AS94ysEtvzsrk3gKs9Zqxf6HmnsLfRvlzA= github.com/pion/turn/v4 v4.1.3/go.mod h1:TD/eiBUf5f5LwXbCJa35T7dPtTpCHRJ9oJWmyPLVT3A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ice-4.1.0/ice.go000066400000000000000000000053161511704470200133510ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice // ConnectionState is an enum showing the state of a ICE Connection. type ConnectionState int // List of supported States. const ( // ConnectionStateUnknown represents an unknown state. ConnectionStateUnknown ConnectionState = iota // ConnectionStateNew ICE agent is gathering addresses. ConnectionStateNew // ConnectionStateChecking ICE agent has been given local and remote candidates, and is attempting to find a match. ConnectionStateChecking // ConnectionStateConnected ICE agent has a pairing, but is still checking other pairs. ConnectionStateConnected // ConnectionStateCompleted ICE agent has finished. ConnectionStateCompleted // ConnectionStateFailed ICE agent never could successfully connect. ConnectionStateFailed // ConnectionStateDisconnected ICE agent connected successfully, but has entered a failed state. ConnectionStateDisconnected // ConnectionStateClosed ICE agent has finished and is no longer handling requests. ConnectionStateClosed ) func (c ConnectionState) String() string { switch c { case ConnectionStateNew: return "New" case ConnectionStateChecking: return "Checking" case ConnectionStateConnected: return "Connected" case ConnectionStateCompleted: return "Completed" case ConnectionStateFailed: return "Failed" case ConnectionStateDisconnected: return "Disconnected" case ConnectionStateClosed: return "Closed" default: return "Invalid" } } // GatheringState describes the state of the candidate gathering process. type GatheringState int const ( // GatheringStateUnknown represents an unknown state. GatheringStateUnknown GatheringState = iota // GatheringStateNew indicates candidate gathering is not yet started. GatheringStateNew // GatheringStateGathering indicates candidate gathering is ongoing. GatheringStateGathering // GatheringStateComplete indicates candidate gathering has been completed. GatheringStateComplete ) func (t GatheringState) String() string { switch t { case GatheringStateNew: return "new" case GatheringStateGathering: return "gathering" case GatheringStateComplete: return "complete" default: return ErrUnknownType.Error() } } // ContinualGatheringPolicy defines the behavior for gathering ICE candidates. type ContinualGatheringPolicy int const ( GatherOnce ContinualGatheringPolicy = iota GatherContinually ) func (c ContinualGatheringPolicy) String() string { switch c { case GatherOnce: return "gather_once" case GatherContinually: return "gather_continually" default: return unknownStr } } const ( unknownStr = "unknown" relayProtocolDTLS = "dtls" relayProtocolTLS = "tls" ) ice-4.1.0/ice_test.go000066400000000000000000000023661511704470200144120ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "github.com/stretchr/testify/require" ) func TestConnectedState_String(t *testing.T) { testCases := []struct { connectionState ConnectionState expectedString string }{ {ConnectionStateUnknown, "Invalid"}, {ConnectionStateNew, "New"}, {ConnectionStateChecking, "Checking"}, {ConnectionStateConnected, "Connected"}, {ConnectionStateCompleted, "Completed"}, {ConnectionStateFailed, "Failed"}, {ConnectionStateDisconnected, "Disconnected"}, {ConnectionStateClosed, "Closed"}, } for i, testCase := range testCases { require.Equal(t, testCase.expectedString, testCase.connectionState.String(), "testCase: %d %v", i, testCase, ) } } func TestGatheringState_String(t *testing.T) { testCases := []struct { gatheringState GatheringState expectedString string }{ {GatheringStateUnknown, ErrUnknownType.Error()}, {GatheringStateNew, "new"}, {GatheringStateGathering, "gathering"}, {GatheringStateComplete, "complete"}, } for i, testCase := range testCases { require.Equal(t, testCase.expectedString, testCase.gatheringState.String(), "testCase: %d %v", i, testCase, ) } } ice-4.1.0/icecontrol.go000066400000000000000000000050771511704470200147560ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "encoding/binary" "github.com/pion/stun/v3" ) // tiebreaker is common helper for ICE-{CONTROLLED,CONTROLLING} // and represents the so-called tiebreaker number. type tiebreaker uint64 const tiebreakerSize = 8 // 64 bit // AddToAs adds tiebreaker value to m as t attribute. func (a tiebreaker) AddToAs(m *stun.Message, t stun.AttrType) error { v := make([]byte, tiebreakerSize) binary.BigEndian.PutUint64(v, uint64(a)) m.Add(t, v) return nil } // GetFromAs decodes tiebreaker value in message getting it as for t type. func (a *tiebreaker) GetFromAs(m *stun.Message, t stun.AttrType) error { v, err := m.Get(t) if err != nil { return err } if err = stun.CheckSize(t, len(v), tiebreakerSize); err != nil { return err } *a = tiebreaker(binary.BigEndian.Uint64(v)) return nil } // AttrControlled represents ICE-CONTROLLED attribute. type AttrControlled uint64 // AddTo adds ICE-CONTROLLED to message. func (c AttrControlled) AddTo(m *stun.Message) error { return tiebreaker(c).AddToAs(m, stun.AttrICEControlled) } // GetFrom decodes ICE-CONTROLLED from message. func (c *AttrControlled) GetFrom(m *stun.Message) error { return (*tiebreaker)(c).GetFromAs(m, stun.AttrICEControlled) } // AttrControlling represents ICE-CONTROLLING attribute. type AttrControlling uint64 // AddTo adds ICE-CONTROLLING to message. func (c AttrControlling) AddTo(m *stun.Message) error { return tiebreaker(c).AddToAs(m, stun.AttrICEControlling) } // GetFrom decodes ICE-CONTROLLING from message. func (c *AttrControlling) GetFrom(m *stun.Message) error { return (*tiebreaker)(c).GetFromAs(m, stun.AttrICEControlling) } // AttrControl is helper that wraps ICE-{CONTROLLED,CONTROLLING}. type AttrControl struct { Role Role Tiebreaker uint64 } // AddTo adds ICE-CONTROLLED or ICE-CONTROLLING attribute depending on Role. func (c AttrControl) AddTo(m *stun.Message) error { if c.Role == Controlling { return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlling) } return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlled) } // GetFrom decodes Role and Tiebreaker value from message. func (c *AttrControl) GetFrom(m *stun.Message) error { if m.Contains(stun.AttrICEControlling) { c.Role = Controlling return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlling) } if m.Contains(stun.AttrICEControlled) { c.Role = Controlled return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlled) } return stun.ErrAttributeNotFound } ice-4.1.0/icecontrol_test.go000066400000000000000000000057431511704470200160150ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "github.com/pion/stun/v3" "github.com/stretchr/testify/require" ) func TestControlled_GetFrom(t *testing.T) { //nolint:dupl m := new(stun.Message) var attrCtr AttrControlled require.ErrorIs(t, stun.ErrAttributeNotFound, attrCtr.GetFrom(m)) require.NoError(t, m.Build(stun.BindingRequest, &attrCtr)) m1 := new(stun.Message) _, err := m1.Write(m.Raw) require.NoError(t, err) var c1 AttrControlled require.NoError(t, c1.GetFrom(m1)) require.Equal(t, c1, attrCtr) t.Run("IncorrectSize", func(t *testing.T) { m3 := new(stun.Message) m3.Add(stun.AttrICEControlled, make([]byte, 100)) var c2 AttrControlled require.True(t, stun.IsAttrSizeInvalid(c2.GetFrom(m3))) }) } func TestControlling_GetFrom(t *testing.T) { //nolint:dupl m := new(stun.Message) var attrCtr AttrControlling require.ErrorIs(t, stun.ErrAttributeNotFound, attrCtr.GetFrom(m)) require.NoError(t, m.Build(stun.BindingRequest, &attrCtr)) m1 := new(stun.Message) _, err := m1.Write(m.Raw) require.NoError(t, err) var c1 AttrControlling require.NoError(t, c1.GetFrom(m1)) require.Equal(t, c1, attrCtr) t.Run("IncorrectSize", func(t *testing.T) { m3 := new(stun.Message) m3.Add(stun.AttrICEControlling, make([]byte, 100)) var c2 AttrControlling require.True(t, stun.IsAttrSizeInvalid(c2.GetFrom(m3))) }) } func TestControl_GetFrom(t *testing.T) { //nolint:cyclop t.Run("Blank", func(t *testing.T) { m := new(stun.Message) var c AttrControl require.ErrorIs(t, stun.ErrAttributeNotFound, c.GetFrom(m)) }) t.Run("Controlling", func(t *testing.T) { //nolint:dupl m := new(stun.Message) var attCtr AttrControl require.ErrorIs(t, stun.ErrAttributeNotFound, attCtr.GetFrom(m)) attCtr.Role = Controlling attCtr.Tiebreaker = 4321 require.NoError(t, m.Build(stun.BindingRequest, &attCtr)) m1 := new(stun.Message) _, err := m1.Write(m.Raw) require.NoError(t, err) var c1 AttrControl require.NoError(t, c1.GetFrom(m1)) require.Equal(t, c1, attCtr) t.Run("IncorrectSize", func(t *testing.T) { m3 := new(stun.Message) m3.Add(stun.AttrICEControlling, make([]byte, 100)) var c2 AttrControl err := c2.GetFrom(m3) require.True(t, stun.IsAttrSizeInvalid(err)) }) }) t.Run("Controlled", func(t *testing.T) { //nolint:dupl m := new(stun.Message) var attrCtrl AttrControl require.ErrorIs(t, stun.ErrAttributeNotFound, attrCtrl.GetFrom(m)) attrCtrl.Role = Controlled attrCtrl.Tiebreaker = 1234 require.NoError(t, m.Build(stun.BindingRequest, &attrCtrl)) m1 := new(stun.Message) _, err := m1.Write(m.Raw) require.NoError(t, err) var c1 AttrControl require.NoError(t, c1.GetFrom(m1)) require.Equal(t, c1, attrCtrl) t.Run("IncorrectSize", func(t *testing.T) { m3 := new(stun.Message) m3.Add(stun.AttrICEControlling, make([]byte, 100)) var c2 AttrControl err := c2.GetFrom(m3) require.True(t, stun.IsAttrSizeInvalid(err)) }) }) } ice-4.1.0/internal/000077500000000000000000000000001511704470200140715ustar00rootroot00000000000000ice-4.1.0/internal/atomic/000077500000000000000000000000001511704470200153455ustar00rootroot00000000000000ice-4.1.0/internal/atomic/atomic.go000066400000000000000000000010211511704470200171420ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package atomic contains custom atomic types package atomic import "sync/atomic" // Error is an atomic error. type Error struct { v atomic.Value } // Store updates the value of the atomic variable. func (a *Error) Store(err error) { a.v.Store(struct{ error }{err}) } // Load retrieves the current value of the atomic variable. func (a *Error) Load() error { err, _ := a.v.Load().(struct{ error }) return err.error } ice-4.1.0/internal/fakenet/000077500000000000000000000000001511704470200155065ustar00rootroot00000000000000ice-4.1.0/internal/fakenet/mock_conn.go000066400000000000000000000020201511704470200177750ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package fakenet import ( "net" "time" ) // MockPacketConn for tests. type MockPacketConn struct{} func (m *MockPacketConn) ReadFrom([]byte) (n int, addr net.Addr, err error) { return 0, nil, nil } //nolint:revive func (m *MockPacketConn) WriteTo([]byte, net.Addr) (n int, err error) { return 0, nil } //nolint:revive func (m *MockPacketConn) Close() error { return nil } //nolint:revive func (m *MockPacketConn) LocalAddr() net.Addr { return nil } //nolint:revive func (m *MockPacketConn) SetDeadline(time.Time) error { return nil } //nolint:revive func (m *MockPacketConn) SetReadDeadline(time.Time) error { return nil } //nolint:revive func (m *MockPacketConn) SetWriteDeadline(time.Time) error { return nil } //nolint:revive ice-4.1.0/internal/fakenet/packet_conn.go000066400000000000000000000012641511704470200203240ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package fakenet contains fake network abstractions package fakenet import ( "net" ) // Compile-time assertion. var _ net.PacketConn = (*PacketConn)(nil) // PacketConn wraps a net.Conn and emulates net.PacketConn. type PacketConn struct { net.Conn } // ReadFrom reads a packet from the connection. func (f *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, err = f.Conn.Read(p) addr = f.Conn.RemoteAddr() return } // WriteTo writes a packet with payload p to addr. func (f *PacketConn) WriteTo(p []byte, _ net.Addr) (int, error) { return f.Conn.Write(p) } ice-4.1.0/internal/internal.go000066400000000000000000000003021511704470200162270ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package internal implements internal functionality for Pions ICE module package internal ice-4.1.0/internal/stun/000077500000000000000000000000001511704470200150625ustar00rootroot00000000000000ice-4.1.0/internal/stun/stun.go000066400000000000000000000036051511704470200164060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package stun contains ICE specific STUN code package stun import ( "errors" "fmt" "net" "time" "github.com/pion/stun/v3" ) var ( errGetXorMappedAddrResponse = errors.New("failed to get XOR-MAPPED-ADDRESS response") errMismatchUsername = errors.New("username mismatch") ) // GetXORMappedAddr initiates a STUN requests to serverAddr using conn, reads the response and returns // the XORMappedAddress returned by the STUN server. func GetXORMappedAddr(conn net.PacketConn, serverAddr net.Addr, timeout time.Duration) (*stun.XORMappedAddress, error) { if timeout > 0 { if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { return nil, err } // Reset timeout after completion defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck } req, err := stun.Build(stun.BindingRequest, stun.TransactionID) if err != nil { return nil, err } if _, err = conn.WriteTo(req.Raw, serverAddr); err != nil { return nil, err } const maxMessageSize = 1280 buf := make([]byte, maxMessageSize) n, _, err := conn.ReadFrom(buf) if err != nil { return nil, err } res := &stun.Message{Raw: buf[:n]} if err = res.Decode(); err != nil { return nil, err } var addr stun.XORMappedAddress if err = addr.GetFrom(res); err != nil { return nil, fmt.Errorf("%w: %v", errGetXorMappedAddrResponse, err) //nolint:errorlint } return &addr, nil } // AssertUsername checks that the given STUN message m has a USERNAME attribute with a given value. func AssertUsername(m *stun.Message, expectedUsername string) error { var username stun.Username if err := username.GetFrom(m); err != nil { return err } else if string(username) != expectedUsername { return fmt.Errorf("%w expected(%x) actual(%x)", errMismatchUsername, expectedUsername, string(username)) } return nil } ice-4.1.0/internal/taskloop/000077500000000000000000000000001511704470200157255ustar00rootroot00000000000000ice-4.1.0/internal/taskloop/taskloop.go000066400000000000000000000047141511704470200201160ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package taskloop implements a task loop to run // tasks sequentially in a separate Goroutine. package taskloop import ( "context" "errors" "time" atomicx "github.com/pion/ice/v4/internal/atomic" ) // ErrClosed indicates that the loop has been stopped. var ErrClosed = errors.New("the agent is closed") type task struct { fn func(context.Context) done chan struct{} } // Loop runs submitted task serially in a dedicated Goroutine. type Loop struct { tasks chan task // State for closing done chan struct{} taskLoopDone chan struct{} err atomicx.Error } // New creates and starts a new task loop. func New(onClose func()) *Loop { l := &Loop{ tasks: make(chan task), done: make(chan struct{}), taskLoopDone: make(chan struct{}), } go l.runLoop(onClose) return l } // runLoop handles registered tasks and agent close. func (l *Loop) runLoop(onClose func()) { defer func() { onClose() close(l.taskLoopDone) }() for { select { case <-l.done: return case t := <-l.tasks: t.fn(l) close(t.done) } } } // Close stops the loop after finishing the execution of the current task. // Other pending tasks will not be executed. func (l *Loop) Close() { if err := l.Err(); err != nil { return } l.err.Store(ErrClosed) close(l.done) <-l.taskLoopDone } // Run serially executes the submitted callback. // Blocking tasks must be cancelable by context. func (l *Loop) Run(ctx context.Context, t func(context.Context)) error { if err := l.Err(); err != nil { return err } done := make(chan struct{}) select { case <-ctx.Done(): return ctx.Err() case <-l.done: return ErrClosed case l.tasks <- task{t, done}: <-done return nil } } // The following methods implement context.Context for TaskLoop // Done returns a channel that's closed when the task loop has been stopped. func (l *Loop) Done() <-chan struct{} { return l.done } // Err returns nil if the task loop is still running. // Otherwise it return errClosed if the loop has been closed/stopped. func (l *Loop) Err() error { select { case <-l.done: return ErrClosed default: return nil } } // Deadline returns the no valid time as task loops have no deadline. func (l *Loop) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false } // Value is not supported for task loops. func (l *Loop) Value(any) any { return nil } ice-4.1.0/internal/taskloop/taskloop_test.go000066400000000000000000000022511511704470200211470ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package taskloop import ( "context" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) func TestRunReturnsErrClosedWhenLoopClosing(t *testing.T) { loop := New(func() {}) blockStarted := make(chan struct{}) releaseBlock := make(chan struct{}) go func() { _ = loop.Run(context.Background(), func(context.Context) { close(blockStarted) <-releaseBlock }) }() <-blockStarted var secondRan atomic.Bool errCh := make(chan error, 1) go func() { errCh <- loop.Run(context.Background(), func(context.Context) { secondRan.Store(true) }) }() time.Sleep(10 * time.Millisecond) closeDone := make(chan struct{}) go func() { loop.Close() close(closeDone) }() select { case err := <-errCh: assert.ErrorIs(t, err, ErrClosed) case <-time.After(time.Second): assert.Fail(t, "Run did not return after loop close") } close(releaseBlock) select { case <-closeDone: case <-time.After(time.Second): assert.Fail(t, "Close did not return") } assert.False(t, secondRan.Load(), "second task should not excute after loop is closed") } ice-4.1.0/mdns.go000066400000000000000000000073761511704470200135620ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "github.com/google/uuid" "github.com/pion/logging" "github.com/pion/mdns/v2" "github.com/pion/transport/v3" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) // MulticastDNSMode represents the different Multicast modes ICE can run in. type MulticastDNSMode byte // MulticastDNSMode enum. const ( // MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs. MulticastDNSModeDisabled MulticastDNSMode = iota + 1 // MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs. MulticastDNSModeQueryOnly // MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted, // and local host candidates will use mDNS. MulticastDNSModeQueryAndGather ) func generateMulticastDNSName() (string, error) { // https://tools.ietf.org/id/draft-ietf-rtcweb-mdns-ice-candidates-02.html#gathering // The unique name MUST consist of a version 4 UUID as defined in [RFC4122], followed by “.local”. u, err := uuid.NewRandom() return u.String() + ".local", err } //nolint:cyclop func createMulticastDNS( netTransport transport.Net, networkTypes []NetworkType, interfaces []*transport.Interface, includeLoopback bool, mDNSMode MulticastDNSMode, mDNSName string, log logging.LeveledLogger, loggerFactory logging.LoggerFactory, ) (*mdns.Conn, MulticastDNSMode, error) { if mDNSMode == MulticastDNSModeDisabled { return nil, mDNSMode, nil } var useV4, useV6 bool if len(networkTypes) == 0 { useV4 = true useV6 = true } else { for _, nt := range networkTypes { if nt.IsIPv4() { useV4 = true continue } if nt.IsIPv6() { useV6 = true } } } addr4, mdnsErr := netTransport.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) if mdnsErr != nil { return nil, mDNSMode, mdnsErr } addr6, mdnsErr := netTransport.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) if mdnsErr != nil { return nil, mDNSMode, mdnsErr } var pktConnV4 *ipv4.PacketConn var mdns4Err error if useV4 { var l transport.UDPConn l, mdns4Err = netTransport.ListenUDP("udp4", addr4) if mdns4Err != nil { // If ICE fails to start MulticastDNS server just warn the user and continue log.Errorf("Failed to enable mDNS over IPv4: (%s)", mdns4Err) return nil, MulticastDNSModeDisabled, nil } pktConnV4 = ipv4.NewPacketConn(l) } var pktConnV6 *ipv6.PacketConn var mdns6Err error if useV6 { var l transport.UDPConn l, mdns6Err = netTransport.ListenUDP("udp6", addr6) if mdns6Err != nil { log.Errorf("Failed to enable mDNS over IPv6: (%s)", mdns6Err) return nil, MulticastDNSModeDisabled, nil } pktConnV6 = ipv6.NewPacketConn(l) } if mdns4Err != nil && mdns6Err != nil { // If ICE fails to start MulticastDNS server just warn the user and continue log.Errorf("Failed to enable mDNS, continuing in mDNS disabled mode") //nolint:nilerr return nil, MulticastDNSModeDisabled, nil } var ifcs []net.Interface if interfaces != nil { ifcs = make([]net.Interface, 0, len(ifcs)) for _, ifc := range interfaces { ifcs = append(ifcs, ifc.Interface) } } switch mDNSMode { case MulticastDNSModeQueryOnly: conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ Interfaces: ifcs, IncludeLoopback: includeLoopback, LoggerFactory: loggerFactory, }) return conn, mDNSMode, err case MulticastDNSModeQueryAndGather: conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ Interfaces: ifcs, IncludeLoopback: includeLoopback, LocalNames: []string{mDNSName}, LoggerFactory: loggerFactory, }) return conn, mDNSMode, err default: return nil, mDNSMode, nil } } ice-4.1.0/mdns_test.go000066400000000000000000000122251511704470200146060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "context" "regexp" "testing" "time" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) func TestMulticastDNSOnlyConnection(t *testing.T) { defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 30).Stop() type testCase struct { Name string NetworkTypes []NetworkType } testCases := []testCase{ {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, } if ipv6Available(t) { testCases = append(testCases, testCase{Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, testCase{Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, ) } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { cfg := &AgentConfig{ NetworkTypes: tc.NetworkTypes, CandidateTypes: []CandidateType{CandidateTypeHost}, MulticastDNSMode: MulticastDNSModeQueryAndGather, InterfaceFilter: problematicNetworkInterfaces, } aAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() aNotifier, aConnected := onConnected() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) bAgent, err := NewAgent(cfg) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() bNotifier, bConnected := onConnected() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) connect(t, aAgent, bAgent) <-aConnected <-bConnected }) } } func TestMulticastDNSMixedConnection(t *testing.T) { defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 30).Stop() type testCase struct { Name string NetworkTypes []NetworkType } testCases := []testCase{ {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, } if ipv6Available(t) { testCases = append(testCases, testCase{Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, testCase{Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, ) } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { aAgent, err := NewAgent(&AgentConfig{ NetworkTypes: tc.NetworkTypes, CandidateTypes: []CandidateType{CandidateTypeHost}, MulticastDNSMode: MulticastDNSModeQueryAndGather, InterfaceFilter: problematicNetworkInterfaces, }) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() aNotifier, aConnected := onConnected() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) bAgent, err := NewAgent(&AgentConfig{ NetworkTypes: tc.NetworkTypes, CandidateTypes: []CandidateType{CandidateTypeHost}, MulticastDNSMode: MulticastDNSModeQueryOnly, InterfaceFilter: problematicNetworkInterfaces, }) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() bNotifier, bConnected := onConnected() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) connect(t, aAgent, bAgent) <-aConnected <-bConnected }) } } func TestMulticastDNSStaticHostName(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() type testCase struct { Name string NetworkTypes []NetworkType } testCases := []testCase{ {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, } if ipv6Available(t) { testCases = append(testCases, testCase{Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, testCase{Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, ) } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { _, err := NewAgent(&AgentConfig{ NetworkTypes: tc.NetworkTypes, CandidateTypes: []CandidateType{CandidateTypeHost}, MulticastDNSMode: MulticastDNSModeQueryAndGather, MulticastDNSHostName: "invalidHostName", InterfaceFilter: problematicNetworkInterfaces, }) require.Equal(t, err, ErrInvalidMulticastDNSHostName) agent, err := NewAgent(&AgentConfig{ NetworkTypes: tc.NetworkTypes, CandidateTypes: []CandidateType{CandidateTypeHost}, MulticastDNSMode: MulticastDNSModeQueryAndGather, MulticastDNSHostName: "validName.local", InterfaceFilter: problematicNetworkInterfaces, }) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() correctHostName, resolveFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(c Candidate) { if c != nil && c.Address() == "validName.local" { resolveFunc() } })) require.NoError(t, agent.GatherCandidates()) <-correctHostName.Done() }) } } func TestGenerateMulticastDNSName(t *testing.T) { name, err := generateMulticastDNSName() require.NoError(t, err) isMDNSName := regexp.MustCompile( `^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-4[0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}.local+$`, ).MatchString require.True(t, isMDNSName(name)) } ice-4.1.0/net.go000066400000000000000000000073361511704470200134030ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "net/netip" "github.com/pion/logging" "github.com/pion/transport/v3" ) type ifaceAddr struct { addr netip.Addr iface string } // The conditions of invalidation written below are defined in // https://tools.ietf.org/html/rfc8445#section-5.1.1.1 // It is partial because the link-local check is done later in various gather local // candidate methods which conditionally accept IPv6 based on usage of mDNS or not. func isSupportedIPv6Partial(ip net.IP) bool { if len(ip) != net.IPv6len || // Deprecated IPv4-compatible IPv6 addresses [RFC4291] and IPv6 site- // local unicast addresses [RFC3879] MUST NOT be included in the // address candidates. isZeros(ip[0:12]) || // !(IPv4-compatible IPv6) ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 { // !(IPv6 site-local unicast) return false } return true } func isZeros(ip net.IP) bool { for i := 0; i < len(ip); i++ { if ip[i] != 0 { return false } } return true } //nolint:gocognit,cyclop func localInterfaces( n transport.Net, interfaceFilter func(string) (keep bool), ipFilter func(net.IP) (keep bool), networkTypes []NetworkType, includeLoopback bool, ) ([]*transport.Interface, []ifaceAddr, error) { ipAddrs := []ifaceAddr{} ifaces, err := n.Interfaces() if err != nil { return nil, ipAddrs, err } filteredIfaces := make([]*transport.Interface, 0, len(ifaces)) var ipV4Requested, ipv6Requested bool if len(networkTypes) == 0 { ipV4Requested = true ipv6Requested = true } else { for _, typ := range networkTypes { if typ.IsIPv4() { ipV4Requested = true } if typ.IsIPv6() { ipv6Requested = true } } } for _, iface := range ifaces { if iface.Flags&net.FlagUp == 0 { continue // Interface down } if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback { continue // Loopback interface } if interfaceFilter != nil && !interfaceFilter(iface.Name) { continue } ifaceAddrs, err := iface.Addrs() if err != nil { continue } atLeastOneAddr := false for _, addr := range ifaceAddrs { ipAddr, _, _, err := parseAddrFromIface(addr, iface.Name) if err != nil || (ipAddr.IsLoopback() && !includeLoopback) { continue } if ipAddr.Is6() { if !ipv6Requested { continue } else if !isSupportedIPv6Partial(ipAddr.AsSlice()) { continue } } else if !ipV4Requested { continue } if ipFilter != nil && !ipFilter(ipAddr.AsSlice()) { continue } atLeastOneAddr = true ipAddrs = append(ipAddrs, ifaceAddr{addr: ipAddr, iface: iface.Name}) } if atLeastOneAddr { ifaceCopy := iface filteredIfaces = append(filteredIfaces, ifaceCopy) } } return filteredIfaces, ipAddrs, nil } //nolint:cyclop func listenUDPInPortRange( netTransport transport.Net, log logging.LeveledLogger, portMax, portMin int, network string, lAddr *net.UDPAddr, ) (transport.UDPConn, error) { if (lAddr.Port != 0) || ((portMin == 0) && (portMax == 0)) { return netTransport.ListenUDP(network, lAddr) } if portMin == 0 { portMin = 1024 // Start at 1024 which is non-privileged } if portMax == 0 { portMax = 0xFFFF } if portMin > portMax { return nil, ErrPort } portStart := globalMathRandomGenerator.Intn(portMax-portMin+1) + portMin portCurrent := portStart for { addr := &net.UDPAddr{ IP: lAddr.IP, Zone: lAddr.Zone, Port: portCurrent, } c, e := netTransport.ListenUDP(network, addr) if e == nil { return c, e //nolint:nilerr } log.Debugf("Failed to listen %s: %v", lAddr.String(), e) portCurrent++ if portCurrent > portMax { portCurrent = portMin } if portCurrent == portStart { break } } return nil, ErrPort } ice-4.1.0/net_test.go000066400000000000000000000145751511704470200144450ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "errors" "net" "net/netip" "sort" "strings" "testing" "github.com/pion/logging" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" "github.com/stretchr/testify/require" ) func TestIsSupportedIPv6Partial(t *testing.T) { require.False(t, isSupportedIPv6Partial(net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1})) require.False(t, isSupportedIPv6Partial(net.ParseIP("fec0::2333"))) require.True(t, isSupportedIPv6Partial(net.ParseIP("fe80::2333"))) require.True(t, isSupportedIPv6Partial(net.ParseIP("ff02::2333"))) require.True(t, isSupportedIPv6Partial(net.ParseIP("2001::1"))) } func TestCreateAddr(t *testing.T) { ipv4 := mustAddr(t, net.IP{127, 0, 0, 1}) ipv6 := mustAddr(t, net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) port := 9000 require.Equal(t, &net.UDPAddr{IP: ipv4.AsSlice(), Port: port}, createAddr(NetworkTypeUDP4, ipv4, port)) require.Equal(t, &net.UDPAddr{IP: ipv6.AsSlice(), Port: port}, createAddr(NetworkTypeUDP6, ipv6, port)) require.Equal(t, &net.TCPAddr{IP: ipv4.AsSlice(), Port: port}, createAddr(NetworkTypeTCP4, ipv4, port)) require.Equal(t, &net.TCPAddr{IP: ipv6.AsSlice(), Port: port}, createAddr(NetworkTypeTCP6, ipv6, port)) } func problematicNetworkInterfaces(s string) (keep bool) { defaultDockerBridgeNetwork := strings.Contains(s, "docker") customDockerBridgeNetwork := strings.Contains(s, "br-") // Apple filters accessPoint := strings.Contains(s, "ap") appleWirelessDirectLink := strings.Contains(s, "awdl") appleLowLatencyWLANInterface := strings.Contains(s, "llw") appleTunnelingInterface := strings.Contains(s, "utun") return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork && !accessPoint && !appleWirelessDirectLink && !appleLowLatencyWLANInterface && !appleTunnelingInterface } func mustAddr(t *testing.T, ip net.IP) netip.Addr { t.Helper() addr, ok := netip.AddrFromSlice(ip) if !ok { t.Fatal(ipConvertError{ip}) // nolint } return addr } type errInterfacesNet struct { transport.Net retErr error } func (e *errInterfacesNet) Interfaces() ([]*transport.Interface, error) { return nil, e.retErr } var errBoom = errors.New("boom") func TestLocalInterfaces_ErrorFromInterfaces(t *testing.T) { base, err := stdnet.NewNet() require.NoError(t, err) wrapped := &errInterfacesNet{ Net: base, retErr: errBoom, } ifaces, addrs, gotErr := localInterfaces( wrapped, nil, nil, nil, false, ) require.ErrorIs(t, gotErr, wrapped.retErr) require.Nil(t, ifaces, "expected nil iface slice on error") require.NotNil(t, addrs, "ipAddrs should be a non-nil empty slice") require.Len(t, addrs, 0) } type fixedInterfacesNet struct { transport.Net list []*transport.Interface } func (f *fixedInterfacesNet) Interfaces() ([]*transport.Interface, error) { return f.list, nil } func TestLocalInterfaces_SkipInterfaceDown(t *testing.T) { base, err := stdnet.NewNet() require.NoError(t, err) sysIfaces, err := base.Interfaces() require.NoError(t, err) if len(sysIfaces) == 0 { t.Skip("no system network interfaces available") } clone := *sysIfaces[0] clone.Flags &^= net.FlagUp wrapped := &fixedInterfacesNet{ Net: base, list: []*transport.Interface{&clone}, } ifcs, addrs, ierr := localInterfaces( wrapped, nil, nil, nil, false, ) require.NoError(t, ierr) require.Len(t, ifcs, 0, "down interfaces must be skipped") require.Len(t, addrs, 0, "no addresses should be collected from a down interface") } func TestLocalInterfaces_SkipLoopbackAddrs_WhenIncludeLoopbackFalse(t *testing.T) { base, err := stdnet.NewNet() require.NoError(t, err) sysIfaces, err := base.Interfaces() require.NoError(t, err) var loop *transport.Interface for _, ifc := range sysIfaces { if ifc.Flags&net.FlagLoopback != 0 { loop = ifc break } } if loop == nil { t.Skip("no loopback interface found on this system") } // clone the loopback iface and clear the Loopback flag so the outer check // doesn't drop it to force the inner `(ipAddr.IsLoopback() && !includeLoopback)`. cloned := *loop cloned.Flags |= net.FlagUp cloned.Flags &^= net.FlagLoopback wrapped := &fixedInterfacesNet{ Net: base, list: []*transport.Interface{&cloned}, } ifaces, addrs, ierr := localInterfaces( wrapped, nil, // interfaceFilter nil, // ipFilter nil, // networkTypes false, // includeLoopback ) require.NoError(t, ierr) // don't assert on the number of interfaces because some systems may // report the iface as having addresses in a way that causes it to be included. // assert that all loopback addresses were skipped. for _, a := range addrs { require.False(t, a.addr.IsLoopback(), "loopback addresses must be skipped when includeLoopback=false") } _ = ifaces // intentionally don't assert on this, see above comment } // Captures ListenUDP attempts and always fails so the loop exhausts. type listenUDPCaptor struct { transport.Net attempts []int } func (c *listenUDPCaptor) ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { c.attempts = append(c.attempts, laddr.Port) return nil, errBoom } func TestListenUDPInPortRange_DefaultsPortMinTo1024(t *testing.T) { base, err := stdnet.NewNet() require.NoError(t, err) captor := &listenUDPCaptor{Net: base} logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test") // portMin == 0 (should become 1024), portMax small to keep the loop short. _, err = listenUDPInPortRange( captor, logger, 1030, // portMax 0, // portMin -> becomes 1024 udp4, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}, ) require.ErrorIs(t, err, ErrPort) // should have attempted exactly [1024..1030] in some order. sort.Ints(captor.attempts) require.Equal(t, []int{1024, 1025, 1026, 1027, 1028, 1029, 1030}, captor.attempts) } func TestListenUDPInPortRange_DefaultsPortMaxToFFFF(t *testing.T) { base, err := stdnet.NewNet() require.NoError(t, err) captor := &listenUDPCaptor{Net: base} logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test") // portMax == 0 (should become 0xFFFF). Use portMin=65535 so the range is 1 port. _, err = listenUDPInPortRange( captor, logger, 0, // portMax -> becomes 65535 65535, // portMin udp4, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}, ) require.ErrorIs(t, err, ErrPort) require.Equal(t, []int{65535}, captor.attempts) } ice-4.1.0/networktype.go000066400000000000000000000055761511704470200152140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" "net/netip" "strings" ) const ( udp = "udp" tcp = "tcp" udp4 = "udp4" udp6 = "udp6" tcp4 = "tcp4" tcp6 = "tcp6" ) func supportedNetworkTypes() []NetworkType { return []NetworkType{ NetworkTypeUDP4, NetworkTypeUDP6, NetworkTypeTCP4, NetworkTypeTCP6, } } // NetworkType represents the type of network. type NetworkType int const ( // NetworkTypeUDP4 indicates UDP over IPv4. NetworkTypeUDP4 NetworkType = iota + 1 // NetworkTypeUDP6 indicates UDP over IPv6. NetworkTypeUDP6 // NetworkTypeTCP4 indicates TCP over IPv4. NetworkTypeTCP4 // NetworkTypeTCP6 indicates TCP over IPv6. NetworkTypeTCP6 ) func (t NetworkType) String() string { switch t { case NetworkTypeUDP4: return udp4 case NetworkTypeUDP6: return udp6 case NetworkTypeTCP4: return tcp4 case NetworkTypeTCP6: return tcp6 default: return ErrUnknownType.Error() } } // IsUDP returns true when network is UDP4 or UDP6. func (t NetworkType) IsUDP() bool { return t == NetworkTypeUDP4 || t == NetworkTypeUDP6 } // IsTCP returns true when network is TCP4 or TCP6. func (t NetworkType) IsTCP() bool { return t == NetworkTypeTCP4 || t == NetworkTypeTCP6 } // NetworkShort returns the short network description. func (t NetworkType) NetworkShort() string { switch t { case NetworkTypeUDP4, NetworkTypeUDP6: return udp case NetworkTypeTCP4, NetworkTypeTCP6: return tcp default: return ErrUnknownType.Error() } } // IsReliable returns true if the network is reliable. func (t NetworkType) IsReliable() bool { switch t { case NetworkTypeUDP4, NetworkTypeUDP6: return false case NetworkTypeTCP4, NetworkTypeTCP6: return true } return false } // IsIPv4 returns whether the network type is IPv4 or not. func (t NetworkType) IsIPv4() bool { switch t { case NetworkTypeUDP4, NetworkTypeTCP4: return true case NetworkTypeUDP6, NetworkTypeTCP6: return false } return false } // IsIPv6 returns whether the network type is IPv6 or not. func (t NetworkType) IsIPv6() bool { switch t { case NetworkTypeUDP4, NetworkTypeTCP4: return false case NetworkTypeUDP6, NetworkTypeTCP6: return true } return false } // determineNetworkType determines the type of network based on // the short network string and an IP address. func determineNetworkType(network string, ip netip.Addr) (NetworkType, error) { // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. ip = ip.Unmap() switch { case strings.HasPrefix(strings.ToLower(network), udp): if ip.Is4() { return NetworkTypeUDP4, nil } return NetworkTypeUDP6, nil case strings.HasPrefix(strings.ToLower(network), tcp): if ip.Is4() { return NetworkTypeTCP4, nil } return NetworkTypeTCP6, nil } return NetworkType(0), fmt.Errorf("%w from %s %s", ErrDetermineNetworkType, network, ip) } ice-4.1.0/networktype_test.go000066400000000000000000000063551511704470200162470ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "testing" "github.com/stretchr/testify/require" ) func TestNetworkTypeParsing_Success(t *testing.T) { ipv4 := net.ParseIP("192.168.0.1") ipv6 := net.ParseIP("fe80::a3:6ff:fec4:5454") for _, test := range []struct { name string inNetwork string inIP net.IP expected NetworkType }{ { "lowercase UDP4", "udp", ipv4, NetworkTypeUDP4, }, { "uppercase UDP4", "UDP", ipv4, NetworkTypeUDP4, }, { "lowercase UDP6", "udp", ipv6, NetworkTypeUDP6, }, { "uppercase UDP6", "UDP", ipv6, NetworkTypeUDP6, }, } { actual, err := determineNetworkType(test.inNetwork, mustAddr(t, test.inIP)) require.NoError(t, err) require.Equal(t, test.expected, actual) } } func TestNetworkTypeParsing_Failure(t *testing.T) { ipv6 := net.ParseIP("fe80::a3:6ff:fec4:5454") for _, test := range []struct { name string inNetwork string inIP net.IP }{ { "invalid network", "junkNetwork", ipv6, }, } { _, err := determineNetworkType(test.inNetwork, mustAddr(t, test.inIP)) require.Error(t, err) } } func TestNetworkTypeIsUDP(t *testing.T) { require.True(t, NetworkTypeUDP4.IsUDP()) require.True(t, NetworkTypeUDP6.IsUDP()) require.False(t, NetworkTypeUDP4.IsTCP()) require.False(t, NetworkTypeUDP6.IsTCP()) } func TestNetworkTypeIsTCP(t *testing.T) { require.True(t, NetworkTypeTCP4.IsTCP()) require.True(t, NetworkTypeTCP6.IsTCP()) require.False(t, NetworkTypeTCP4.IsUDP()) require.False(t, NetworkTypeTCP6.IsUDP()) } func TestNetworkType_String_Default(t *testing.T) { var invalid NetworkType // 0 triggers default branch require.Equal(t, ErrUnknownType.Error(), invalid.String()) require.Equal(t, "udp4", NetworkTypeUDP4.String()) require.Equal(t, "udp6", NetworkTypeUDP6.String()) require.Equal(t, "tcp4", NetworkTypeTCP4.String()) require.Equal(t, "tcp6", NetworkTypeTCP6.String()) } func TestNetworkType_NetworkShort_Default(t *testing.T) { var invalid NetworkType require.Equal(t, ErrUnknownType.Error(), invalid.NetworkShort()) require.Equal(t, udp, NetworkTypeUDP4.NetworkShort()) require.Equal(t, udp, NetworkTypeUDP6.NetworkShort()) require.Equal(t, tcp, NetworkTypeTCP4.NetworkShort()) require.Equal(t, tcp, NetworkTypeTCP6.NetworkShort()) } func TestNetworkType_IPvFlags_Default(t *testing.T) { var invalid NetworkType require.False(t, invalid.IsIPv4()) require.False(t, invalid.IsIPv6()) require.True(t, NetworkTypeUDP4.IsIPv4()) require.True(t, NetworkTypeTCP4.IsIPv4()) require.False(t, NetworkTypeUDP6.IsIPv4()) require.False(t, NetworkTypeTCP6.IsIPv4()) require.True(t, NetworkTypeUDP6.IsIPv6()) require.True(t, NetworkTypeTCP6.IsIPv6()) require.False(t, NetworkTypeUDP4.IsIPv6()) require.False(t, NetworkTypeTCP4.IsIPv6()) } func TestNetworkType_IsReliable(t *testing.T) { // UDP is unreliable require.False(t, NetworkTypeUDP4.IsReliable()) require.False(t, NetworkTypeUDP6.IsReliable()) // TCP is reliable require.True(t, NetworkTypeTCP4.IsReliable()) require.True(t, NetworkTypeTCP6.IsReliable()) // default/unknown falls through to false var invalid NetworkType require.False(t, invalid.IsReliable()) } ice-4.1.0/priority.go000066400000000000000000000015021511704470200144630ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "encoding/binary" "github.com/pion/stun/v3" ) // PriorityAttr represents PRIORITY attribute. type PriorityAttr uint32 const prioritySize = 4 // 32 bit // AddTo adds PRIORITY attribute to message. func (p PriorityAttr) AddTo(m *stun.Message) error { v := make([]byte, prioritySize) binary.BigEndian.PutUint32(v, uint32(p)) m.Add(stun.AttrPriority, v) return nil } // GetFrom decodes PRIORITY attribute from message. func (p *PriorityAttr) GetFrom(m *stun.Message) error { v, err := m.Get(stun.AttrPriority) if err != nil { return err } if err = stun.CheckSize(stun.AttrPriority, len(v), prioritySize); err != nil { return err } *p = PriorityAttr(binary.BigEndian.Uint32(v)) return nil } ice-4.1.0/priority_test.go000066400000000000000000000014551511704470200155310ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "github.com/pion/stun/v3" "github.com/stretchr/testify/require" ) func TestPriority_GetFrom(t *testing.T) { //nolint:dupl m := new(stun.Message) var priority PriorityAttr require.ErrorIs(t, stun.ErrAttributeNotFound, priority.GetFrom(m)) require.NoError(t, m.Build(stun.BindingRequest, &priority)) m1 := new(stun.Message) _, err := m1.Write(m.Raw) require.NoError(t, err) var p1 PriorityAttr require.NoError(t, p1.GetFrom(m1)) require.Equal(t, p1, priority) t.Run("IncorrectSize", func(t *testing.T) { m3 := new(stun.Message) m3.Add(stun.AttrPriority, make([]byte, 100)) var p2 PriorityAttr require.True(t, stun.IsAttrSizeInvalid(p2.GetFrom(m3))) }) } ice-4.1.0/rand.go000066400000000000000000000035431511704470200135350ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import "github.com/pion/randutil" const ( runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" runesDigit = "0123456789" runesCandidateIDFoundation = runesAlpha + runesDigit + "+/" lenUFrag = 16 lenPwd = 32 ) // Seeding random generator each time limits number of generated sequence to 31-bits, // and causes collision on low time accuracy environments. // Use global random generator seeded by crypto grade random. var ( globalMathRandomGenerator = randutil.NewMathRandomGenerator() //nolint:gochecknoglobals globalCandidateIDGenerator = candidateIDGenerator{globalMathRandomGenerator} //nolint:gochecknoglobals ) // candidateIDGenerator is a random candidate ID generator. // Candidate ID is used in SDP and always shared to the other peer. // It doesn't require cryptographic random. type candidateIDGenerator struct { randutil.MathRandomGenerator } func newCandidateIDGenerator() *candidateIDGenerator { return &candidateIDGenerator{ randutil.NewMathRandomGenerator(), } } func (g *candidateIDGenerator) Generate() string { // https://tools.ietf.org/html/rfc5245#section-15.1 // candidate-id = "candidate" ":" foundation // foundation = 1*32ice-char // ice-char = ALPHA / DIGIT / "+" / "/" return "candidate:" + g.MathRandomGenerator.GenerateString(32, runesCandidateIDFoundation) } // generatePwd generates ICE pwd. // This internally uses generateCryptoRandomString. func generatePwd() (string, error) { return randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) } // generateUFrag generates ICE user fragment. // This internally uses generateCryptoRandomString. func generateUFrag() (string, error) { return randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) } ice-4.1.0/rand_test.go000066400000000000000000000025611511704470200145730ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "sync" "testing" "github.com/stretchr/testify/require" ) func TestRandomGeneratorCollision(t *testing.T) { candidateIDGen := newCandidateIDGenerator() testCases := map[string]struct { gen func(t *testing.T) string }{ "CandidateID": { gen: func(*testing.T) string { return candidateIDGen.Generate() }, }, "PWD": { gen: func(t *testing.T) string { t.Helper() s, err := generatePwd() require.NoError(t, err) return s }, }, "Ufrag": { gen: func(t *testing.T) string { t.Helper() s, err := generateUFrag() require.NoError(t, err) return s }, }, } const num = 100 const iteration = 100 for name, testCase := range testCases { testCase := testCase t.Run(name, func(t *testing.T) { for iter := 0; iter < iteration; iter++ { var wg sync.WaitGroup var mu sync.Mutex rands := make([]string, 0, num) for i := 0; i < num; i++ { wg.Add(1) go func() { r := testCase.gen(t) mu.Lock() rands = append(rands, r) mu.Unlock() wg.Done() }() } wg.Wait() require.Len(t, rands, num) for i := 0; i < num; i++ { for j := i + 1; j < num; j++ { require.NotEqual(t, rands[i], rands[j]) } } } }) } } ice-4.1.0/renomination.go000066400000000000000000000050551511704470200153130ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" "github.com/pion/stun/v3" ) // Default STUN Nomination attribute type for ICE renomination. // Following the specification draft-thatcher-ice-renomination-01. const ( // DefaultNominationAttribute represents the default STUN Nomination attribute. // This is a custom attribute for ICE renomination support. // This value can be overridden via AgentConfig.NominationAttribute. DefaultNominationAttribute stun.AttrType = 0x0030 // Using a value in the reserved range ) // NominationAttribute represents a STUN Nomination attribute. type NominationAttribute struct { Value uint32 } // GetFrom decodes a Nomination attribute from a STUN message. func (a *NominationAttribute) GetFrom(m *stun.Message) error { return a.GetFromWithType(m, DefaultNominationAttribute) } // GetFromWithType decodes a Nomination attribute from a STUN message using a specific attribute type. func (a *NominationAttribute) GetFromWithType(m *stun.Message, attrType stun.AttrType) error { v, err := m.Get(attrType) if err != nil { return err } if len(v) < 4 { return stun.ErrAttributeSizeInvalid } // Extract 24-bit value from the last 3 bytes a.Value = uint32(v[1])<<16 | uint32(v[2])<<8 | uint32(v[3]) return nil } // AddTo adds a Nomination attribute to a STUN message. func (a NominationAttribute) AddTo(m *stun.Message) error { return a.AddToWithType(m, DefaultNominationAttribute) } // AddToWithType adds a Nomination attribute to a STUN message using a specific attribute type. func (a NominationAttribute) AddToWithType(m *stun.Message, attrType stun.AttrType) error { // Store as 4 bytes with first byte as 0 v := make([]byte, 4) v[1] = byte(a.Value >> 16) v[2] = byte(a.Value >> 8) v[3] = byte(a.Value) m.Add(attrType, v) return nil } // String returns string representation of the nomination attribute. func (a NominationAttribute) String() string { return fmt.Sprintf("NOMINATION: %d", a.Value) } // Nomination creates a new STUN nomination attribute. func Nomination(value uint32) NominationAttribute { return NominationAttribute{Value: value} } // NominationSetter is a STUN setter for nomination attribute with configurable type. type NominationSetter struct { Value uint32 AttrType stun.AttrType } // AddTo adds a Nomination attribute to a STUN message using the configured attribute type. func (n NominationSetter) AddTo(m *stun.Message) error { attr := NominationAttribute{Value: n.Value} return attr.AddToWithType(m, n.AttrType) } ice-4.1.0/renomination_test.go000066400000000000000000000612241511704470200163520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2025 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "net" "testing" "time" "github.com/pion/ice/v4/internal/fakenet" "github.com/pion/stun/v3" "github.com/stretchr/testify/assert" ) const ( testLocalUfrag = "localufrag" testLocalPwd = "localpwd" testRemoteUfrag = "remoteufrag" testRemotePwd = "remotepwd" ) // Mock packet conn that captures sent packets. type mockPacketConnWithCapture struct { sentPackets [][]byte sentAddrs []net.Addr } func (m *mockPacketConnWithCapture) ReadFrom([]byte) (n int, addr net.Addr, err error) { return 0, nil, nil } func (m *mockPacketConnWithCapture) WriteTo(b []byte, addr net.Addr) (n int, err error) { // Capture the packet packet := make([]byte, len(b)) copy(packet, b) m.sentPackets = append(m.sentPackets, packet) m.sentAddrs = append(m.sentAddrs, addr) return len(b), nil } func (m *mockPacketConnWithCapture) Close() error { return nil } func (m *mockPacketConnWithCapture) LocalAddr() net.Addr { return nil } func (m *mockPacketConnWithCapture) SetDeadline(time.Time) error { return nil } func (m *mockPacketConnWithCapture) SetReadDeadline(time.Time) error { return nil } func (m *mockPacketConnWithCapture) SetWriteDeadline(time.Time) error { return nil } // createRenominationTestAgent creates a test agent with renomination enabled and returns local/remote candidates. func createRenominationTestAgent(t *testing.T, controlling bool) (*Agent, Candidate, Candidate) { t.Helper() agent, err := NewAgentWithOptions(WithRenomination(func() uint32 { return 1 })) assert.NoError(t, err) agent.isControlling.Store(controlling) local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 12345, Component: 1, }) assert.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 54321, Component: 1, }) assert.NoError(t, err) return agent, local, remote } func TestNominationAttribute(t *testing.T) { t.Run("AddTo and GetFrom", func(t *testing.T) { m := &stun.Message{} attr := NominationAttribute{Value: 0x123456} err := attr.AddTo(m) assert.NoError(t, err) var parsed NominationAttribute err = parsed.GetFrom(m) assert.NoError(t, err) assert.Equal(t, uint32(0x123456), parsed.Value) }) t.Run("24-bit value boundary", func(t *testing.T) { m := &stun.Message{} maxValue := uint32((1 << 24) - 1) // 24-bit max value attr := NominationAttribute{Value: maxValue} err := attr.AddTo(m) assert.NoError(t, err) var parsed NominationAttribute err = parsed.GetFrom(m) assert.NoError(t, err) assert.Equal(t, maxValue, parsed.Value) }) t.Run("String representation", func(t *testing.T) { attr := NominationAttribute{Value: 12345} str := attr.String() assert.Contains(t, str, "NOMINATION") assert.Contains(t, str, "12345") }) t.Run("Nomination helper function", func(t *testing.T) { attr := Nomination(42) assert.Equal(t, uint32(42), attr.Value) }) } func TestRenominationConfiguration(t *testing.T) { nominationCounter := uint32(0) agent, err := NewAgentWithOptions(WithRenomination(func() uint32 { nominationCounter++ return nominationCounter })) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() assert.True(t, agent.enableRenomination) assert.NotNil(t, agent.nominationValueGenerator) // Test nomination value generation value1 := agent.nominationValueGenerator() value2 := agent.nominationValueGenerator() assert.Equal(t, uint32(1), value1) assert.Equal(t, uint32(2), value2) } func TestControlledSelectorNominationAcceptance(t *testing.T) { agent, err := NewAgentWithOptions(WithRenomination(DefaultNominationValueGenerator())) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() selector := &controlledSelector{ agent: agent, log: agent.log, } selector.Start() // First nomination should be accepted nomination1 := uint32(5) assert.True(t, selector.shouldAcceptNomination(&nomination1)) // Higher nomination should be accepted nomination2 := uint32(10) assert.True(t, selector.shouldAcceptNomination(&nomination2)) // Lower nomination should be rejected nomination3 := uint32(7) assert.False(t, selector.shouldAcceptNomination(&nomination3)) // Equal nomination should be rejected nomination4 := uint32(10) assert.False(t, selector.shouldAcceptNomination(&nomination4)) // Nil nomination should be accepted (standard ICE) assert.True(t, selector.shouldAcceptNomination(nil)) } func TestControlledSelectorNominationDisabled(t *testing.T) { config := &AgentConfig{ // Renomination disabled by default } agent, err := NewAgent(config) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() selector := &controlledSelector{ agent: agent, log: agent.log, } selector.Start() // Standard ICE nomination (no value) should be accepted assert.True(t, selector.shouldAcceptNomination(nil)) // When controlling side uses renomination (sends nomination values), // controlled side should apply "last nomination wins" regardless of local config nomination1 := uint32(5) assert.True(t, selector.shouldAcceptNomination(&nomination1)) nomination2 := uint32(3) // Lower value should be rejected assert.False(t, selector.shouldAcceptNomination(&nomination2)) nomination3 := uint32(8) // Higher value should be accepted assert.True(t, selector.shouldAcceptNomination(&nomination3)) } func TestAgentRenominateCandidate(t *testing.T) { t.Run("controlling agent can renominate", func(t *testing.T) { nominationCounter := uint32(0) agent, err := NewAgentWithOptions(WithRenomination(func() uint32 { nominationCounter++ return nominationCounter })) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() // Set up credentials for STUN authentication agent.localUfrag = testLocalUfrag agent.localPwd = testLocalPwd agent.remoteUfrag = testRemoteUfrag agent.remotePwd = testRemotePwd // Set agent as controlling agent.isControlling.Store(true) // Create test candidates with mock connection local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 12345, Component: 1, }) assert.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 54321, Component: 1, }) assert.NoError(t, err) // Mock the connection for the local candidate to avoid nil pointer mockConn := &fakenet.MockPacketConn{} local.conn = mockConn // Add pair to agent pair := agent.addPair(local, remote) pair.state = CandidatePairStateSucceeded // Test renomination err = agent.RenominateCandidate(local, remote) assert.NoError(t, err) }) t.Run("non-controlling agent cannot renominate", func(t *testing.T) { agent, local, remote := createRenominationTestAgent(t, false) defer func() { assert.NoError(t, agent.Close()) }() err := agent.RenominateCandidate(local, remote) assert.Error(t, err) assert.Contains(t, err.Error(), "only controlling agent can renominate") }) t.Run("renomination when disabled", func(t *testing.T) { config := &AgentConfig{ // Renomination disabled by default } agent, err := NewAgent(config) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() agent.isControlling.Store(true) local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 12345, Component: 1, }) assert.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 54321, Component: 1, }) assert.NoError(t, err) err = agent.RenominateCandidate(local, remote) assert.Error(t, err) assert.Contains(t, err.Error(), "renomination is not enabled") }) t.Run("renomination with non-existent candidate pair", func(t *testing.T) { agent, local, remote := createRenominationTestAgent(t, true) defer func() { assert.NoError(t, agent.Close()) }() // Don't add pair to agent - should fail err := agent.RenominateCandidate(local, remote) assert.Error(t, err) assert.Contains(t, err.Error(), "candidate pair not found") }) } func TestSendNominationRequest(t *testing.T) { t.Run("STUN message contains nomination attribute", func(t *testing.T) { nominationCounter := uint32(0) agent, err := NewAgentWithOptions(WithRenomination(func() uint32 { nominationCounter++ return nominationCounter })) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() // Set up credentials for STUN authentication agent.localUfrag = testLocalUfrag agent.localPwd = testLocalPwd agent.remoteUfrag = testRemoteUfrag agent.remotePwd = testRemotePwd agent.isControlling.Store(true) // Create test candidates local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 12345, Component: 1, }) assert.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 54321, Component: 1, }) assert.NoError(t, err) // Mock connection to capture sent messages mockConn := &mockPacketConnWithCapture{} local.conn = mockConn pair := agent.addPair(local, remote) pair.state = CandidatePairStateSucceeded // Test sendNominationRequest directly nominationValue := uint32(123) err = agent.sendNominationRequest(pair, nominationValue) assert.NoError(t, err) // Verify message was sent assert.True(t, len(mockConn.sentPackets) > 0) // Parse the sent STUN message msg := &stun.Message{} err = msg.UnmarshalBinary(mockConn.sentPackets[0]) assert.NoError(t, err) // Verify it's a binding request assert.True(t, msg.Type.Method == stun.MethodBinding) assert.True(t, msg.Type.Class == stun.ClassRequest) // Verify USE-CANDIDATE is present assert.True(t, msg.Contains(stun.AttrUseCandidate)) // Verify nomination attribute is present var nomination NominationAttribute err = nomination.GetFrom(msg) assert.NoError(t, err) assert.Equal(t, nominationValue, nomination.Value) }) t.Run("STUN message without nomination when disabled", func(t *testing.T) { config := &AgentConfig{ // Renomination disabled by default } agent, err := NewAgent(config) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() // Set up credentials agent.localUfrag = testLocalUfrag agent.localPwd = testLocalPwd agent.remoteUfrag = testRemoteUfrag agent.remotePwd = testRemotePwd agent.isControlling.Store(true) local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 12345, Component: 1, }) assert.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 54321, Component: 1, }) assert.NoError(t, err) mockConn := &mockPacketConnWithCapture{} local.conn = mockConn pair := agent.addPair(local, remote) // Send nomination with value 0 (should not include nomination attribute) err = agent.sendNominationRequest(pair, 0) assert.NoError(t, err) // Parse the sent message msg := &stun.Message{} err = msg.UnmarshalBinary(mockConn.sentPackets[0]) assert.NoError(t, err) // Verify USE-CANDIDATE is present assert.True(t, msg.Contains(stun.AttrUseCandidate)) // Verify nomination attribute is NOT present var nomination NominationAttribute err = nomination.GetFrom(msg) assert.Error(t, err) // Should fail since attribute is not present }) } func TestRenominationErrorCases(t *testing.T) { t.Run("getNominationValue with nil generator", func(t *testing.T) { // Try to create agent with nil generator - should fail _, err := NewAgentWithOptions(WithRenomination(nil)) assert.ErrorIs(t, err, ErrInvalidNominationValueGenerator) // Create agent without renomination for testing agent, err := NewAgentWithOptions() assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() // Should return 0 when no generator is set value := agent.getNominationValue() assert.Equal(t, uint32(0), value) }) t.Run("STUN message build with invalid attributes", func(t *testing.T) { agent, err := NewAgentWithOptions(WithRenomination(func() uint32 { return 1 })) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() // Set up minimal credentials but missing remote password agent.localUfrag = "localufrag" agent.localPwd = "localpwd" agent.remoteUfrag = "remoteufrag" // agent.remotePwd = "" // Missing remote password agent.isControlling.Store(true) local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 12345, Component: 1, }) assert.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 54321, Component: 1, }) assert.NoError(t, err) mockConn := &mockPacketConnWithCapture{} local.conn = mockConn pair := agent.addPair(local, remote) // This should succeed even with missing remote password // as the STUN library will still build the message err = agent.sendNominationRequest(pair, 1) assert.NoError(t, err) }) } func TestNominationValueBoundaries(t *testing.T) { t.Run("24-bit maximum value", func(t *testing.T) { maxValue := uint32((1 << 24) - 1) // 0xFFFFFF attr := NominationAttribute{Value: maxValue} m := &stun.Message{} err := attr.AddTo(m) assert.NoError(t, err) var parsed NominationAttribute err = parsed.GetFrom(m) assert.NoError(t, err) assert.Equal(t, maxValue, parsed.Value) }) t.Run("zero nomination value", func(t *testing.T) { attr := NominationAttribute{Value: 0} m := &stun.Message{} err := attr.AddTo(m) assert.NoError(t, err) var parsed NominationAttribute err = parsed.GetFrom(m) assert.NoError(t, err) assert.Equal(t, uint32(0), parsed.Value) }) t.Run("nomination value overflow", func(t *testing.T) { // Test value larger than 24-bit overflowValue := uint32(1 << 25) // Larger than 24-bit max attr := NominationAttribute{Value: overflowValue} m := &stun.Message{} err := attr.AddTo(m) assert.NoError(t, err) var parsed NominationAttribute err = parsed.GetFrom(m) assert.NoError(t, err) // Should be truncated to 24-bit value expectedValue := overflowValue & 0xFFFFFF assert.Equal(t, expectedValue, parsed.Value) }) t.Run("invalid attribute size", func(t *testing.T) { m := &stun.Message{} // Add a nomination attribute with invalid size (too short) m.Add(DefaultNominationAttribute, []byte{0x01, 0x02}) // Only 2 bytes instead of 4 var nomination NominationAttribute err := nomination.GetFrom(m) assert.Error(t, err) assert.Equal(t, stun.ErrAttributeSizeInvalid, err) }) t.Run("configurable nomination attribute type", func(t *testing.T) { // Test with custom attribute type customAttrType := stun.AttrType(0x0040) attr := NominationAttribute{Value: 12345} m := &stun.Message{} err := attr.AddToWithType(m, customAttrType) assert.NoError(t, err) // Try to read with default type - should fail var parsed1 NominationAttribute err = parsed1.GetFrom(m) assert.Error(t, err) // Read with custom type - should succeed var parsed2 NominationAttribute err = parsed2.GetFromWithType(m, customAttrType) assert.NoError(t, err) assert.Equal(t, uint32(12345), parsed2.Value) }) t.Run("NominationSetter with custom attribute type", func(t *testing.T) { customAttrType := stun.AttrType(0x0050) setter := NominationSetter{ Value: 98765, AttrType: customAttrType, } m := &stun.Message{} err := setter.AddTo(m) assert.NoError(t, err) // Verify the attribute was added with custom type var parsed NominationAttribute err = parsed.GetFromWithType(m, customAttrType) assert.NoError(t, err) assert.Equal(t, uint32(98765), parsed.Value) // Verify it wasn't added with default type var parsedDefault NominationAttribute err = parsedDefault.GetFrom(m) assert.Error(t, err) }) } func TestControlledSelectorWithActualSTUNMessages(t *testing.T) { t.Run("HandleBindingRequest with nomination attribute", func(t *testing.T) { agent, err := NewAgentWithOptions(WithRenomination(DefaultNominationValueGenerator())) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() // Set up credentials for STUN agent.localUfrag = testLocalUfrag agent.localPwd = testLocalPwd agent.remoteUfrag = testRemoteUfrag agent.remotePwd = testRemotePwd selector := &controlledSelector{ agent: agent, log: agent.log, } selector.Start() // Create test candidates local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 12345, Component: 1, }) assert.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 54321, Component: 1, }) assert.NoError(t, err) // Mock connection for response mockConn := &mockPacketConnWithCapture{} local.conn = mockConn // Create STUN binding request with nomination and USE-CANDIDATE msg, err := stun.Build( stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), Nomination(5), // First nomination value stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) assert.NoError(t, err) // Handle the binding request selector.HandleBindingRequest(msg, local, remote) // Verify selector accepted the nomination assert.NotNil(t, selector.lastNomination) assert.Equal(t, uint32(5), *selector.lastNomination) // Create another STUN request with higher nomination value msg2, err := stun.Build( stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), Nomination(10), // Higher nomination value stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) assert.NoError(t, err) // Handle the second binding request selector.HandleBindingRequest(msg2, local, remote) // Should accept higher nomination assert.Equal(t, uint32(10), *selector.lastNomination) // Create another STUN request with lower nomination value msg3, err := stun.Build( stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), Nomination(7), // Lower nomination value stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) assert.NoError(t, err) // Handle the third binding request selector.HandleBindingRequest(msg3, local, remote) // Should reject lower nomination (lastNomination should remain 10) assert.Equal(t, uint32(10), *selector.lastNomination) }) t.Run("HandleBindingRequest without nomination attribute", func(t *testing.T) { agent, err := NewAgentWithOptions(WithRenomination(DefaultNominationValueGenerator())) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() agent.localUfrag = testLocalUfrag agent.localPwd = testLocalPwd agent.remoteUfrag = testRemoteUfrag agent.remotePwd = testRemotePwd selector := &controlledSelector{ agent: agent, log: agent.log, } selector.Start() local, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 12345, Component: 1, }) assert.NoError(t, err) remote, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "127.0.0.1", Port: 54321, Component: 1, }) assert.NoError(t, err) mockConn := &mockPacketConnWithCapture{} local.conn = mockConn // Create STUN binding request without nomination (standard ICE) msg, err := stun.Build( stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), // No nomination attribute stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) assert.NoError(t, err) // Handle the binding request selector.HandleBindingRequest(msg, local, remote) // Without nomination attribute, lastNomination should remain nil assert.Nil(t, selector.lastNomination) }) } func TestInvalidRenominationConfig(t *testing.T) { t.Run("nil nomination generator with renomination enabled", func(t *testing.T) { config := &AgentConfig{} // Without renomination, agent should work fine agent, err := NewAgent(config) assert.NoError(t, err) defer func() { assert.NoError(t, agent.Close()) }() // Agent should be created successfully without renomination assert.False(t, agent.enableRenomination) assert.Nil(t, agent.nominationValueGenerator) // getNominationValue should return 0 value := agent.getNominationValue() assert.Equal(t, uint32(0), value) }) t.Run("different generator behaviors", func(t *testing.T) { // Test constant generator agent1, err := NewAgentWithOptions(WithRenomination(func() uint32 { return 42 })) assert.NoError(t, err) defer func() { assert.NoError(t, agent1.Close()) }() value1 := agent1.getNominationValue() value2 := agent1.getNominationValue() assert.Equal(t, uint32(42), value1) assert.Equal(t, uint32(42), value2) // Test incrementing generator counter := uint32(0) agent2, err := NewAgentWithOptions(WithRenomination(func() uint32 { counter++ return counter })) assert.NoError(t, err) defer func() { assert.NoError(t, agent2.Close()) }() value3 := agent2.getNominationValue() value4 := agent2.getNominationValue() assert.Equal(t, uint32(1), value3) assert.Equal(t, uint32(2), value4) }) t.Run("controlled agent handles renomination regardless of local config", func(t *testing.T) { // Create controlled agent with renomination DISABLED controlledAgent, err := NewAgent(&AgentConfig{ // Renomination disabled by default // Disabled locally }) assert.NoError(t, err) defer func() { assert.NoError(t, controlledAgent.Close()) }() // Set up as controlled (non-controlling) controlledAgent.isControlling.Store(false) // Create controlled selector to test nomination handling selector := &controlledSelector{ agent: controlledAgent, log: controlledAgent.log, } // Test 1: Should accept nomination without value (standard ICE) assert.True(t, selector.shouldAcceptNomination(nil)) // Test 2: Should accept first nomination with value (renomination from controlling side) value1 := uint32(1) assert.True(t, selector.shouldAcceptNomination(&value1)) assert.Equal(t, &value1, selector.lastNomination) // Test 3: Should accept higher nomination value value2 := uint32(2) assert.True(t, selector.shouldAcceptNomination(&value2)) assert.Equal(t, &value2, selector.lastNomination) // Test 4: Should reject lower nomination value value0 := uint32(0) assert.False(t, selector.shouldAcceptNomination(&value0)) assert.Equal(t, &value2, selector.lastNomination) // Should remain unchanged }) } func TestAgentWithCustomNominationAttribute(t *testing.T) { t.Run("agent uses custom nomination attribute with option", func(t *testing.T) { customAttr := uint16(0x0042) // Create agent with custom nomination attribute using option agent, err := NewAgentWithOptions( WithRenomination(func() uint32 { return 100 }), WithNominationAttribute(customAttr), ) assert.NoError(t, err) defer agent.Close() //nolint:errcheck // Verify the agent has the custom attribute configured assert.Equal(t, stun.AttrType(customAttr), agent.nominationAttribute) }) t.Run("agent uses default nomination attribute when not configured", func(t *testing.T) { // Create agent without custom nomination attribute agentConfig := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, } agent, err := NewAgent(agentConfig) assert.NoError(t, err) defer agent.Close() //nolint:errcheck // Verify the agent has the default attribute assert.Equal(t, stun.AttrType(0x0030), agent.nominationAttribute) }) t.Run("multiple options can be applied", func(t *testing.T) { customAttr := uint16(0x0055) // Test that multiple options can be applied agent, err := NewAgentWithOptions( WithRenomination(func() uint32 { return 200 }), WithNominationAttribute(customAttr), ) assert.NoError(t, err) defer agent.Close() //nolint:errcheck assert.Equal(t, stun.AttrType(customAttr), agent.nominationAttribute) }) t.Run("WithNominationAttribute returns error for invalid value", func(t *testing.T) { // Test that 0x0000 is rejected as invalid _, err := NewAgentWithOptions(WithNominationAttribute(0x0000)) assert.ErrorIs(t, err, ErrInvalidNominationAttribute) }) } ice-4.1.0/renovate.json000066400000000000000000000001731511704470200147740ustar00rootroot00000000000000{ "$schema": "https://docs.renovatebot.com/renovate-schema.json", "extends": [ "github>pion/renovate-config" ] } ice-4.1.0/role.go000066400000000000000000000015451511704470200135520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" ) // Role represents ICE agent role, which can be controlling or controlled. type Role byte // Possible ICE agent roles. const ( Controlling Role = iota Controlled ) // UnmarshalText implements TextUnmarshaler. func (r *Role) UnmarshalText(text []byte) error { switch string(text) { case "controlling": *r = Controlling case "controlled": *r = Controlled default: return fmt.Errorf("%w %q", errUnknownRole, text) } return nil } // MarshalText implements TextMarshaler. func (r Role) MarshalText() (text []byte, err error) { return []byte(r.String()), nil } func (r Role) String() string { switch r { case Controlling: return "controlling" case Controlled: return "controlled" default: return "unknown" } } ice-4.1.0/role_test.go000066400000000000000000000026511511704470200146100ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "testing" "github.com/stretchr/testify/require" ) func TestUnmarshalText_Success(t *testing.T) { tests := []struct { name string in string want Role }{ {"controlling", "controlling", Controlling}, {"controlled", "controlled", Controlled}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var r Role err := r.UnmarshalText([]byte(tt.in)) require.NoError(t, err) require.Equal(t, tt.want, r) }) } } func TestUnmarshalText_UnknownKeepsValueAndErrors(t *testing.T) { r := Controlled err := r.UnmarshalText([]byte("neither")) require.ErrorIs(t, err, errUnknownRole) require.Equal(t, Controlled, r, "role should remain unchanged on error") } func TestMarshalText(t *testing.T) { tests := []struct { name string in Role want string }{ {"controlling", Controlling, "controlling"}, {"controlled", Controlled, "controlled"}, {"unknown", Role(99), "unknown"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b, err := tt.in.MarshalText() require.NoError(t, err) require.Equal(t, tt.want, string(b)) }) } } func TestString(t *testing.T) { require.Equal(t, "controlling", Controlling.String()) require.Equal(t, "controlled", Controlled.String()) require.Equal(t, "unknown", Role(255).String()) } ice-4.1.0/selection.go000066400000000000000000000371361511704470200146030ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" "time" "github.com/pion/logging" "github.com/pion/stun/v3" ) type pairCandidateSelector interface { Start() ContactCandidates() PingCandidate(local, remote Candidate) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) HandleBindingRequest(m *stun.Message, local, remote Candidate) } type controllingSelector struct { startTime time.Time agent *Agent nominatedPair *CandidatePair log logging.LeveledLogger } func (s *controllingSelector) Start() { s.startTime = time.Now() s.nominatedPair = nil } func (s *controllingSelector) isNominatable(c Candidate) bool { switch { case c.Type() == CandidateTypeHost: return time.Since(s.startTime).Nanoseconds() > s.agent.hostAcceptanceMinWait.Nanoseconds() case c.Type() == CandidateTypeServerReflexive: return time.Since(s.startTime).Nanoseconds() > s.agent.srflxAcceptanceMinWait.Nanoseconds() case c.Type() == CandidateTypePeerReflexive: return time.Since(s.startTime).Nanoseconds() > s.agent.prflxAcceptanceMinWait.Nanoseconds() case c.Type() == CandidateTypeRelay: return time.Since(s.startTime).Nanoseconds() > s.agent.relayAcceptanceMinWait.Nanoseconds() } s.log.Errorf("Invalid candidate type: %s", c.Type()) return false } func (s *controllingSelector) ContactCandidates() { switch { case s.agent.getSelectedPair() != nil: if s.agent.validateSelectedPair() { s.log.Trace("Checking keepalive") s.agent.checkKeepalive() // If automatic renomination is enabled, continuously ping all candidate pairs // to keep them tested with fresh RTT measurements for switching decisions if s.agent.automaticRenomination && s.agent.enableRenomination { s.agent.keepAliveCandidatesForRenomination() } s.checkForAutomaticRenomination() } case s.nominatedPair != nil: s.nominatePair(s.nominatedPair) default: p := s.agent.getBestValidCandidatePair() if p != nil && s.isNominatable(p.Local) && s.isNominatable(p.Remote) { s.log.Tracef("Nominatable pair found, nominating (%s, %s)", p.Local, p.Remote) p.nominated = true s.nominatedPair = p s.nominatePair(p) return } s.agent.pingAllCandidates() } } func (s *controllingSelector) nominatePair(pair *CandidatePair) { // The controlling agent MUST include the USE-CANDIDATE attribute in // order to nominate a candidate pair (Section 8.1.1). The controlled // agent MUST NOT include the USE-CANDIDATE attribute in a Binding // request. msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), UseCandidate(), AttrControlling(s.agent.tieBreaker), PriorityAttr(pair.Local.Priority()), stun.NewShortTermIntegrity(s.agent.remotePwd), stun.Fingerprint, ) if err != nil { s.log.Error(err.Error()) return } s.log.Tracef("Ping STUN (nominate candidate pair) from %s to %s", pair.Local, pair.Remote) s.agent.sendBindingRequest(msg, pair.Local, pair.Remote) } func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop s.agent.sendBindingSuccess(message, local, remote) pair := s.agent.findPair(local, remote) if pair == nil { pair = s.agent.addPair(local, remote) pair.UpdateRequestReceived() return } pair.UpdateRequestReceived() if pair.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil { bestPair := s.agent.getBestAvailableCandidatePair() if bestPair == nil { s.log.Tracef("No best pair available") } else if bestPair.equal(pair) && s.isNominatable(pair.Local) && s.isNominatable(pair.Remote) { s.log.Tracef( "The candidate (%s, %s) is the best candidate available, marking it as nominated", pair.Local, pair.Remote, ) s.nominatedPair = pair s.nominatePair(pair) } } if s.agent.userBindingRequestHandler != nil { if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch { s.agent.setSelectedPair(pair) } } } func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) if !ok { s.log.Warnf("Discard success response from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) return } transactionAddr := pendingRequest.destination // Assert that NAT is not symmetric // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 if !addrEqual(transactionAddr, remoteAddr) { s.log.Debugf( "Discard message: transaction source and destination does not match expected(%s), actual(%s)", transactionAddr, remote, ) return } s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) pair := s.agent.findPair(local, remote) if pair == nil { // This shouldn't happen s.log.Error("Success response from invalid candidate pair") return } pair.state = CandidatePairStateSucceeded s.log.Tracef("Found valid candidate pair: %s", pair) // Handle nomination/renomination if pendingRequest.isUseCandidate { selectedPair := s.agent.getSelectedPair() // If this is a renomination request (has nomination value), always update the selected pair // If it's a standard nomination (no value), only set if no pair is selected yet if pendingRequest.nominationValue != nil { s.log.Infof("Renomination success response received for pair %s (nomination value: %d), switching to this pair", pair, *pendingRequest.nominationValue) s.agent.setSelectedPair(pair) } else if selectedPair == nil { s.agent.setSelectedPair(pair) } } pair.UpdateRoundTripTime(rtt) } func (s *controllingSelector) PingCandidate(local, remote Candidate) { msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), AttrControlling(s.agent.tieBreaker), PriorityAttr(local.Priority()), stun.NewShortTermIntegrity(s.agent.remotePwd), stun.Fingerprint, ) if err != nil { s.log.Error(err.Error()) return } s.agent.sendBindingRequest(msg, local, remote) } // checkForAutomaticRenomination evaluates if automatic renomination should occur. // This is called periodically when the agent is in connected state and automatic // renomination is enabled. func (s *controllingSelector) checkForAutomaticRenomination() { if !s.agent.automaticRenomination || !s.agent.enableRenomination { s.log.Tracef("Automatic renomination check skipped: automaticRenomination=%v, enableRenomination=%v", s.agent.automaticRenomination, s.agent.enableRenomination) return } timeSinceStart := time.Since(s.startTime) if timeSinceStart < s.agent.renominationInterval { s.log.Tracef("Automatic renomination check skipped: not enough time since start (%v < %v)", timeSinceStart, s.agent.renominationInterval) return } if !s.agent.lastRenominationTime.IsZero() { timeSinceLastRenomination := time.Since(s.agent.lastRenominationTime) if timeSinceLastRenomination < s.agent.renominationInterval { s.log.Tracef("Automatic renomination check skipped: too soon since last renomination (%v < %v)", timeSinceLastRenomination, s.agent.renominationInterval) return } } currentPair := s.agent.getSelectedPair() if currentPair == nil { s.log.Tracef("Automatic renomination check skipped: no current selected pair") return } bestPair := s.agent.findBestCandidatePair() if bestPair == nil { s.log.Tracef("Automatic renomination check skipped: no best pair found") return } s.log.Debugf("Evaluating automatic renomination: current=%s (RTT=%.2fms), best=%s (RTT=%.2fms)", currentPair, currentPair.CurrentRoundTripTime()*1000, bestPair, bestPair.CurrentRoundTripTime()*1000) if s.agent.shouldRenominate(currentPair, bestPair) { s.log.Infof("Automatic renomination triggered: switching from %s to %s", currentPair, bestPair) // Update last renomination time to prevent rapid renominations s.agent.lastRenominationTime = time.Now() if err := s.agent.RenominateCandidate(bestPair.Local, bestPair.Remote); err != nil { s.log.Errorf("Failed to trigger automatic renomination: %v", err) } } else { s.log.Debugf("Automatic renomination not warranted") } } type controlledSelector struct { agent *Agent log logging.LeveledLogger lastNomination *uint32 // For renomination: tracks highest nomination value seen } func (s *controlledSelector) Start() { s.lastNomination = nil } // shouldAcceptNomination checks if a nomination should be accepted based on renomination rules. func (s *controlledSelector) shouldAcceptNomination(nominationValue *uint32) bool { // If no nomination value, accept normally (standard ICE nomination) if nominationValue == nil { return true } // If nomination value is present, controlling side is using renomination // Apply "last nomination wins" rule if s.lastNomination == nil || *nominationValue > *s.lastNomination { s.lastNomination = nominationValue s.log.Tracef("Accepting nomination with value %d", *nominationValue) return true } s.log.Tracef("Rejecting nomination value %d (current is %d)", *nominationValue, *s.lastNomination) return false } // shouldSwitchSelectedPair determines if we should switch to a new nominated pair. // Returns true if the switch should occur, false otherwise. func (s *controlledSelector) shouldSwitchSelectedPair(pair, selectedPair *CandidatePair, nominationValue *uint32) bool { switch { case selectedPair == nil: // No current selection, accept the nomination return true case selectedPair == pair: // Same pair, no change needed return false case nominationValue != nil: // Renomination is in use (nomination value present) // Accept the switch based on nomination value alone, not priority // The shouldAcceptNomination check already validated this is a valid renomination s.log.Debugf("Accepting renomination to pair %s (nomination value: %d)", pair, *nominationValue) return true } // Standard ICE nomination without renomination - apply priority rules // Only switch if we don't check priority, OR new pair has strictly higher priority return !s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() < pair.priority() } func (s *controlledSelector) ContactCandidates() { if s.agent.getSelectedPair() != nil { if s.agent.validateSelectedPair() { s.log.Trace("Checking keepalive") s.agent.checkKeepalive() } } else { s.agent.pingAllCandidates() } } func (s *controlledSelector) PingCandidate(local, remote Candidate) { msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), AttrControlled(s.agent.tieBreaker), PriorityAttr(local.Priority()), stun.NewShortTermIntegrity(s.agent.remotePwd), stun.Fingerprint, ) if err != nil { s.log.Error(err.Error()) return } s.agent.sendBindingRequest(msg, local, remote) } func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { //nolint:godox // TODO according to the standard we should specifically answer a failed nomination: // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 // If the controlled agent does not accept the request from the // controlling agent, the controlled agent MUST reject the nomination // request with an appropriate error code response (e.g., 400) // [RFC5389]. ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) if !ok { s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) return } transactionAddr := pendingRequest.destination // Assert that NAT is not symmetric // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 if !addrEqual(transactionAddr, remoteAddr) { s.log.Debugf( "Discard message: transaction source and destination does not match expected(%s), actual(%s)", transactionAddr, remote, ) return } s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) pair := s.agent.findPair(local, remote) if pair == nil { // This shouldn't happen s.log.Error("Success response from invalid candidate pair") return } pair.state = CandidatePairStateSucceeded s.log.Tracef("Found valid candidate pair: %s", pair) if pair.nominateOnBindingSuccess { if selectedPair := s.agent.getSelectedPair(); selectedPair == nil || (selectedPair != pair && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= pair.priority())) { s.agent.setSelectedPair(pair) } else if selectedPair != pair { s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair) } } pair.UpdateRoundTripTime(rtt) } func (s *controlledSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop pair := s.agent.findPair(local, remote) if pair == nil { pair = s.agent.addPair(local, remote) } pair.UpdateRequestReceived() if message.Contains(stun.AttrUseCandidate) { //nolint:nestif // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 // Check for renomination attribute var nominationValue *uint32 var nomination NominationAttribute if err := nomination.GetFromWithType(message, s.agent.nominationAttribute); err == nil { nominationValue = &nomination.Value s.log.Tracef("Received nomination with value %d", nomination.Value) } // Check if we should accept this nomination based on renomination rules if !s.shouldAcceptNomination(nominationValue) { s.log.Tracef("Rejecting nomination request due to renomination rules") s.agent.sendBindingSuccess(message, local, remote) return } if pair.state == CandidatePairStateSucceeded { // If the state of this pair is Succeeded, it means that the check // previously sent by this pair produced a successful response and // generated a valid pair (Section 7.2.5.3.2). The agent sets the // nominated flag value of the valid pair to true. selectedPair := s.agent.getSelectedPair() if s.shouldSwitchSelectedPair(pair, selectedPair, nominationValue) { s.log.Tracef("Accepting nomination for pair %s", pair) s.agent.setSelectedPair(pair) } else { s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair) } } else { // If the received Binding request triggered a new check to be // enqueued in the triggered-check queue (Section 7.3.1.4), once the // check is sent and if it generates a successful response, and // generates a valid pair, the agent sets the nominated flag of the // pair to true. If the request fails (Section 7.2.5.2), the agent // MUST remove the candidate pair from the valid list, set the // candidate pair state to Failed, and set the checklist state to // Failed. pair.nominateOnBindingSuccess = true } } s.agent.sendBindingSuccess(message, local, remote) s.PingCandidate(local, remote) if s.agent.userBindingRequestHandler != nil { if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch { s.agent.setSelectedPair(pair) } } } type liteSelector struct { pairCandidateSelector } // A lite selector should not contact candidates. func (s *liteSelector) ContactCandidates() { if _, ok := s.pairCandidateSelector.(*controllingSelector); ok { //nolint:godox // https://github.com/pion/ice/issues/96 // TODO: implement lite controlling agent. For now falling back to full agent. // This only happens if both peers are lite. See RFC 8445 S6.1.1 and S6.2 s.pairCandidateSelector.ContactCandidates() } else if v, ok := s.pairCandidateSelector.(*controlledSelector); ok { v.agent.validateSelectedPair() } } ice-4.1.0/selection_test.go000066400000000000000000001270561511704470200156430ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "bytes" "context" "errors" "fmt" "io" "net" "strings" "sync/atomic" "testing" "time" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const ( selectionTestPassword = "pwd" selectionTestRemoteUfrag = "remote" selectionTestLocalUfrag = "local" ) func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool { t.Helper() testMessage := []byte("Hello World") testBuffer := make([]byte, len(testMessage)) readDone, readDoneCancel := context.WithCancel(context.Background()) go func() { _, err := readingConn.Read(testBuffer) if errors.Is(err, io.EOF) { return } require.NoError(t, err) require.True(t, bytes.Equal(testMessage, testBuffer)) readDoneCancel() }() attempts := 0 for { select { case <-time.After(5 * time.Millisecond): if attempts > maxAttempts { return false } _, err := writingConn.Write(testMessage) require.NoError(t, err) attempts++ case <-readDone.Done(): return true } } } func TestBindingRequestHandler(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() var switchToNewCandidatePair, controlledLoggingFired atomic.Value oneHour := time.Hour keepaliveInterval := time.Millisecond * 20 aNotifier, aConnected := onConnected() bNotifier, bConnected := onConnected() controllingAgent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, MulticastDNSMode: MulticastDNSModeDisabled, KeepaliveInterval: &keepaliveInterval, CheckInterval: &oneHour, BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { controlledLoggingFired.Store(true) return false }, }) require.NoError(t, err) require.NoError(t, controllingAgent.OnConnectionStateChange(aNotifier)) controlledAgent, err := NewAgent(&AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, MulticastDNSMode: MulticastDNSModeDisabled, KeepaliveInterval: &keepaliveInterval, CheckInterval: &oneHour, BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { // Don't switch candidate pair until we are ready val, ok := switchToNewCandidatePair.Load().(bool) return ok && val }, }) require.NoError(t, err) require.NoError(t, controlledAgent.OnConnectionStateChange(bNotifier)) controlledConn, controllingConn := connect(t, controlledAgent, controllingAgent) <-aConnected <-bConnected // Assert we have connected and can send data require.True(t, sendUntilDone(t, controlledConn, controllingConn, 100)) // Take the lock on the controlling Agent and unset state assert.NoError(t, controlledAgent.loop.Run(controlledAgent.loop, func(_ context.Context) { for net, cs := range controlledAgent.remoteCandidates { for _, c := range cs { require.NoError(t, c.close()) } delete(controlledAgent.remoteCandidates, net) } for _, c := range controlledAgent.localCandidates[NetworkTypeUDP4] { cast, ok := c.(*CandidateHost) require.True(t, ok) cast.remoteCandidateCaches = map[AddrPort]Candidate{} } controlledAgent.setSelectedPair(nil) controlledAgent.checklist = make([]*CandidatePair, 0) })) // Assert that Selected Candidate pair has only been unset on Controlled side candidatePair, err := controlledAgent.GetSelectedCandidatePair() assert.Nil(t, candidatePair) assert.NoError(t, err) candidatePair, err = controllingAgent.GetSelectedCandidatePair() assert.NotNil(t, candidatePair) assert.NoError(t, err) // Sending will fail, we no longer have a selected candidate pair require.False(t, sendUntilDone(t, controlledConn, controllingConn, 20)) // Send STUN Binding requests until a new Selected Candidate Pair has been set by BindingRequestHandler switchToNewCandidatePair.Store(true) for { controllingAgent.requestConnectivityCheck() candidatePair, err = controlledAgent.GetSelectedCandidatePair() require.NoError(t, err) if candidatePair != nil { break } time.Sleep(time.Millisecond * 5) } // We have a new selected candidate pair because of BindingRequestHandler, test that it works require.True(t, sendUntilDone(t, controllingConn, controlledConn, 100)) fired, ok := controlledLoggingFired.Load().(bool) require.True(t, ok) require.True(t, fired) closePipe(t, controllingConn, controlledConn) } // copied from pion/webrtc's peerconnection_go_test.go. type testICELogger struct { lastErrorMessage string } func (t *testICELogger) Trace(string) {} func (t *testICELogger) Tracef(string, ...any) {} func (t *testICELogger) Debug(string) {} func (t *testICELogger) Debugf(string, ...any) {} func (t *testICELogger) Info(string) {} func (t *testICELogger) Infof(string, ...any) {} func (t *testICELogger) Warn(string) {} func (t *testICELogger) Warnf(string, ...any) {} func (t *testICELogger) Error(msg string) { t.lastErrorMessage = msg } func (t *testICELogger) Errorf(format string, args ...any) { t.lastErrorMessage = fmt.Sprintf(format, args...) } type testICELoggerFactory struct { logger *testICELogger } func (t *testICELoggerFactory) NewLogger(string) logging.LeveledLogger { return t.logger } func TestControllingSelector_IsNominatable_LogsInvalidType(t *testing.T) { testLogger := &testICELogger{} loggerFactory := &testICELoggerFactory{logger: testLogger} sel := &controllingSelector{ agent: &Agent{}, log: loggerFactory.NewLogger("test"), } sel.Start() c := hostCandidate() c.candidateBase.candidateType = CandidateTypeUnspecified got := sel.isNominatable(c) require.False(t, got) require.Contains(t, testLogger.lastErrorMessage, "Invalid candidate type") require.Contains(t, testLogger.lastErrorMessage, "Unknown candidate type") // from c.Type().String() } func TestControllingSelector_NominatePair_BuildError(t *testing.T) { testLogger := &testICELogger{} loggerFactory := &testICELoggerFactory{logger: testLogger} // selector with an Agent with ufrags to make an oversized username // (username = remoteUfrag + ":" + localUfrag) since oversized username causes // stun.NewUsername(...) inside stun.Build to fail. long := strings.Repeat("x", 300) // > 255 each side sel := &controllingSelector{ agent: &Agent{ remoteUfrag: long, localUfrag: long, remotePwd: "pwd", // any non-empty value is fine tieBreaker: 0, }, log: loggerFactory.NewLogger("test"), } sel.Start() p := newCandidatePair(hostCandidate(), hostCandidate(), true) sel.nominatePair(p) require.NotEmpty(t, testLogger.lastErrorMessage, "expected error log from nominatePair on Build failure") } type pingNoIOCand struct{ candidateBase } func newPingNoIOCand() *pingNoIOCand { return &pingNoIOCand{ candidateBase: candidateBase{ candidateType: CandidateTypeHost, component: ComponentRTP, }, } } func (d *pingNoIOCand) writeTo(b []byte, _ Candidate) (int, error) { return len(b), nil } func bareAgentForPing() *Agent { return &Agent{ hostAcceptanceMinWait: time.Hour, srflxAcceptanceMinWait: time.Hour, prflxAcceptanceMinWait: time.Hour, relayAcceptanceMinWait: time.Hour, checklist: []*CandidatePair{}, keepaliveInterval: time.Second, checkInterval: time.Second, connectionStateNotifier: &handlerNotifier{ done: make(chan struct{}), connectionStateFunc: func(ConnectionState) {}}, //nolint formatting candidateNotifier: &handlerNotifier{ done: make(chan struct{}), candidateFunc: func(Candidate) {}}, //nolint formatting selectedCandidatePairNotifier: &handlerNotifier{ done: make(chan struct{}), candidatePairFunc: func(*CandidatePair) {}}, //nolint formatting } } func bigStr() string { return strings.Repeat("x", 40000) } func TestControllingSelector_PingCandidate_BuildError(t *testing.T) { a := bareAgentForPing() // make Username really big so stun.Build returns an error. a.remoteUfrag = bigStr() a.localUfrag = bigStr() a.remotePwd = selectionTestPassword a.tieBreaker = 1 testLogger := &testICELogger{} sel := &controllingSelector{agent: a, log: testLogger} sel.Start() local := newPingNoIOCand() remote := newPingNoIOCand() sel.PingCandidate(local, remote) require.NotEmpty(t, testLogger.lastErrorMessage, "expected error to be logged from stun.Build") } func TestControlledSelector_PingCandidate_BuildError(t *testing.T) { a := bareAgentForPing() a.remoteUfrag = bigStr() a.localUfrag = bigStr() a.remotePwd = selectionTestPassword a.tieBreaker = 1 testLogger := &testICELogger{} sel := &controlledSelector{agent: a, log: testLogger} sel.Start() local := newPingNoIOCand() remote := newPingNoIOCand() sel.PingCandidate(local, remote) require.NotEmpty(t, testLogger.lastErrorMessage, "expected error to be logged from stun.Build") } type warnTestLogger struct { warned bool } func (l *warnTestLogger) Trace(string) {} func (l *warnTestLogger) Tracef(string, ...any) {} func (l *warnTestLogger) Debug(string) {} func (l *warnTestLogger) Debugf(string, ...any) {} func (l *warnTestLogger) Info(string) {} func (l *warnTestLogger) Infof(string, ...any) {} func (l *warnTestLogger) Warn(string) { l.warned = true } func (l *warnTestLogger) Warnf(string, ...any) { l.warned = true } func (l *warnTestLogger) Error(string) {} func (l *warnTestLogger) Errorf(string, ...any) {} type dummyNoIOCand struct{ candidateBase } func newDummyNoIOCand(t CandidateType) *dummyNoIOCand { return &dummyNoIOCand{ candidateBase: candidateBase{ candidateType: t, component: ComponentRTP, }, } } func (d *dummyNoIOCand) writeTo(p []byte, _ Candidate) (int, error) { return len(p), nil } func TestControlledSelector_HandleSuccessResponse_UnknownTxID(t *testing.T) { logger := &warnTestLogger{} ag := &Agent{log: logger} sel := &controlledSelector{agent: ag, log: logger} sel.Start() local := newDummyNoIOCand(CandidateTypeHost) remote := newDummyNoIOCand(CandidateTypeHost) var m stun.Message copy(m.TransactionID[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}) sel.HandleSuccessResponse(&m, local, remote, nil) require.True(t, logger.warned, "expected Warnf to be called for unknown TransactionID (hitting !ok branch)") } func TestAutomaticRenomination(t *testing.T) { //nolint:maintidx report := test.CheckRoutines(t) defer report() t.Run("Configuration", func(t *testing.T) { t.Run("WithAutomaticRenomination enables feature", func(t *testing.T) { agent, err := NewAgentWithOptions( WithRenomination(DefaultNominationValueGenerator()), WithAutomaticRenomination(5*time.Second), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() assert.True(t, agent.automaticRenomination) assert.Equal(t, 5*time.Second, agent.renominationInterval) assert.True(t, agent.enableRenomination) }) t.Run("Default interval when zero", func(t *testing.T) { agent, err := NewAgentWithOptions( WithRenomination(DefaultNominationValueGenerator()), WithAutomaticRenomination(0), ) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() assert.True(t, agent.automaticRenomination) assert.Equal(t, 3*time.Second, agent.renominationInterval) }) }) t.Run("Quality Assessment", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() localHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 10000, Component: 1, }) require.NoError(t, err) remoteHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.2", Port: 20000, Component: 1, }) require.NoError(t, err) localRelay, err := NewCandidateRelay(&CandidateRelayConfig{ Network: "udp", Address: "10.0.0.1", Port: 30000, Component: 1, RelAddr: "192.168.1.1", RelPort: 10000, }) require.NoError(t, err) remoteRelay, err := NewCandidateRelay(&CandidateRelayConfig{ Network: "udp", Address: "10.0.0.2", Port: 40000, Component: 1, RelAddr: "192.168.1.2", RelPort: 20000, }) require.NoError(t, err) t.Run("Host pair scores higher than relay pair", func(t *testing.T) { hostPair := newCandidatePair(localHost, remoteHost, true) hostPair.state = CandidatePairStateSucceeded hostPair.UpdateRoundTripTime(10 * time.Millisecond) relayPair := newCandidatePair(localRelay, remoteRelay, true) relayPair.state = CandidatePairStateSucceeded relayPair.UpdateRoundTripTime(10 * time.Millisecond) hostScore := agent.evaluateCandidatePairQuality(hostPair) relayScore := agent.evaluateCandidatePairQuality(relayPair) assert.Greater(t, hostScore, relayScore, "Host pair should score higher than relay pair with same RTT") }) t.Run("Lower RTT scores higher", func(t *testing.T) { pair1 := newCandidatePair(localHost, remoteHost, true) pair1.state = CandidatePairStateSucceeded pair1.UpdateRoundTripTime(5 * time.Millisecond) pair2 := newCandidatePair(localHost, remoteHost, true) pair2.state = CandidatePairStateSucceeded pair2.UpdateRoundTripTime(50 * time.Millisecond) score1 := agent.evaluateCandidatePairQuality(pair1) score2 := agent.evaluateCandidatePairQuality(pair2) assert.Greater(t, score1, score2, "Pair with lower RTT should score higher") }) }) t.Run("Should Renominate Logic", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() localHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 10000, Component: 1, }) require.NoError(t, err) remoteHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.2", Port: 20000, Component: 1, }) require.NoError(t, err) localRelay, err := NewCandidateRelay(&CandidateRelayConfig{ Network: "udp", Address: "10.0.0.1", Port: 30000, Component: 1, RelAddr: "192.168.1.1", RelPort: 10000, }) require.NoError(t, err) remoteRelay, err := NewCandidateRelay(&CandidateRelayConfig{ Network: "udp", Address: "10.0.0.2", Port: 40000, Component: 1, RelAddr: "192.168.1.2", RelPort: 20000, }) require.NoError(t, err) t.Run("Should renominate relay to host", func(t *testing.T) { relayPair := newCandidatePair(localRelay, remoteRelay, true) relayPair.state = CandidatePairStateSucceeded relayPair.UpdateRoundTripTime(50 * time.Millisecond) hostPair := newCandidatePair(localHost, remoteHost, true) hostPair.state = CandidatePairStateSucceeded hostPair.UpdateRoundTripTime(45 * time.Millisecond) // Similar RTT shouldSwitch := agent.shouldRenominate(relayPair, hostPair) assert.True(t, shouldSwitch, "Should renominate from relay to host even with similar RTT") }) t.Run("Should renominate for RTT improvement > 10ms", func(t *testing.T) { // Create different host candidates for pair2 to avoid same-pair check localHost2, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.3", Port: 10001, Component: 1, }) require.NoError(t, err) pair1 := newCandidatePair(localHost, remoteHost, true) pair1.state = CandidatePairStateSucceeded pair1.UpdateRoundTripTime(50 * time.Millisecond) pair2 := newCandidatePair(localHost2, remoteHost, true) pair2.state = CandidatePairStateSucceeded pair2.UpdateRoundTripTime(30 * time.Millisecond) // 20ms improvement shouldSwitch := agent.shouldRenominate(pair1, pair2) assert.True(t, shouldSwitch, "Should renominate for RTT improvement > 10ms") }) t.Run("Should not renominate for small RTT improvement", func(t *testing.T) { // Create different host candidates for pair2 to avoid same-pair check localHost2, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.3", Port: 10001, Component: 1, }) require.NoError(t, err) pair1 := newCandidatePair(localHost, remoteHost, true) pair1.state = CandidatePairStateSucceeded pair1.UpdateRoundTripTime(50 * time.Millisecond) pair2 := newCandidatePair(localHost2, remoteHost, true) pair2.state = CandidatePairStateSucceeded pair2.UpdateRoundTripTime(45 * time.Millisecond) // Only 5ms improvement shouldSwitch := agent.shouldRenominate(pair1, pair2) assert.False(t, shouldSwitch, "Should not renominate for RTT improvement < 10ms") }) t.Run("Should not renominate to same pair", func(t *testing.T) { pair := newCandidatePair(localHost, remoteHost, true) pair.state = CandidatePairStateSucceeded shouldSwitch := agent.shouldRenominate(pair, pair) assert.False(t, shouldSwitch, "Should not renominate to the same pair") }) t.Run("Should not renominate to non-succeeded pair", func(t *testing.T) { currentPair := newCandidatePair(localHost, remoteHost, true) currentPair.state = CandidatePairStateSucceeded candidatePair := newCandidatePair(localHost, remoteHost, true) candidatePair.state = CandidatePairStateInProgress shouldSwitch := agent.shouldRenominate(currentPair, candidatePair) assert.False(t, shouldSwitch, "Should not renominate to non-succeeded pair") }) }) t.Run("Find Best Candidate Pair", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) defer func() { require.NoError(t, agent.Close()) }() // Create candidates localHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 10000, Component: 1, }) require.NoError(t, err) remoteHost, err := NewCandidateHost(&CandidateHostConfig{ Network: "udp", Address: "192.168.1.2", Port: 20000, Component: 1, }) require.NoError(t, err) localRelay, err := NewCandidateRelay(&CandidateRelayConfig{ Network: "udp", Address: "10.0.0.1", Port: 30000, Component: 1, RelAddr: "192.168.1.1", RelPort: 10000, }) require.NoError(t, err) remoteRelay, err := NewCandidateRelay(&CandidateRelayConfig{ Network: "udp", Address: "10.0.0.2", Port: 40000, Component: 1, RelAddr: "192.168.1.2", RelPort: 20000, }) require.NoError(t, err) ctx := context.Background() err = agent.loop.Run(ctx, func(context.Context) { // Add pairs to checklist hostPair := agent.addPair(localHost, remoteHost) hostPair.state = CandidatePairStateSucceeded hostPair.UpdateRoundTripTime(10 * time.Millisecond) relayPair := agent.addPair(localRelay, remoteRelay) relayPair.state = CandidatePairStateSucceeded relayPair.UpdateRoundTripTime(50 * time.Millisecond) // Find best should return host pair best := agent.findBestCandidatePair() assert.NotNil(t, best) assert.Equal(t, hostPair, best, "Best pair should be the host pair with lower latency") }) require.NoError(t, err) }) } func TestAutomaticRenominationIntegration(t *testing.T) { //nolint:cyclop report := test.CheckRoutines(t) defer report() t.Run("Automatic renomination triggers after interval", func(t *testing.T) { // Create agents with automatic renomination enabled aAgent, err := NewAgentWithOptions( WithRenomination(DefaultNominationValueGenerator()), WithAutomaticRenomination(100*time.Millisecond), // Short interval for testing ) require.NoError(t, err) defer func() { require.NoError(t, aAgent.Close()) }() bAgent, err := NewAgentWithOptions( WithRenomination(DefaultNominationValueGenerator()), ) require.NoError(t, err) defer func() { require.NoError(t, bAgent.Close()) }() // Start gathering candidates err = aAgent.OnCandidate(func(c Candidate) { if c != nil { t.Logf("Agent A gathered candidate: %s", c) } }) require.NoError(t, err) err = bAgent.OnCandidate(func(c Candidate) { if c != nil { t.Logf("Agent B gathered candidate: %s", c) } }) require.NoError(t, err) require.NoError(t, aAgent.GatherCandidates()) require.NoError(t, bAgent.GatherCandidates()) // Wait for gathering to complete time.Sleep(100 * time.Millisecond) // Exchange credentials aUfrag, aPwd, err := aAgent.GetLocalUserCredentials() require.NoError(t, err) bUfrag, bPwd, err := bAgent.GetLocalUserCredentials() require.NoError(t, err) // Get candidates aCandidates, err := aAgent.GetLocalCandidates() require.NoError(t, err) bCandidates, err := bAgent.GetLocalCandidates() require.NoError(t, err) // Verify we have candidates if len(aCandidates) == 0 || len(bCandidates) == 0 { t.Skip("No candidates gathered, skipping integration test") } require.NoError(t, aAgent.startConnectivityChecks(true, bUfrag, bPwd)) require.NoError(t, bAgent.startConnectivityChecks(false, aUfrag, aPwd)) // Exchange candidates for _, c := range aCandidates { cpCand, copyErr := c.copy() require.NoError(t, copyErr) require.NoError(t, bAgent.AddRemoteCandidate(cpCand)) } for _, c := range bCandidates { cpCand, copyErr := c.copy() require.NoError(t, copyErr) require.NoError(t, aAgent.AddRemoteCandidate(cpCand)) } // Wait for initial connection ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Wait for connection on both agents select { case <-aAgent.onConnected: case <-ctx.Done(): require.Fail(t, "Agent A failed to connect") } select { case <-bAgent.onConnected: case <-ctx.Done(): require.Fail(t, "Agent B failed to connect") } // Record initial selected pairs initialAPair, err := aAgent.GetSelectedCandidatePair() require.NoError(t, err) require.NotNil(t, initialAPair) // Note: In a real scenario, automatic renomination would trigger // when a better path becomes available (e.g., relay -> direct). // For this test, we're just verifying the mechanism is in place. // Wait to see if automatic renomination check runs // (it should run but may not renominate if no better pair exists) time.Sleep(200 * time.Millisecond) // The automatic renomination check should have run at least once // We can't easily verify renomination occurred without simulating // network changes, but we can verify the feature is enabled assert.True(t, aAgent.automaticRenomination) assert.True(t, aAgent.enableRenomination) }) } func TestKeepAliveCandidatesForRenomination(t *testing.T) { report := test.CheckRoutines(t) defer report() // Create test candidates that don't require real network I/O createTestCandidates := func() (Candidate, Candidate, Candidate) { local1 := newPingNoIOCand() local1.candidateBase.networkType = NetworkTypeUDP4 local1.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 10000} local2 := newPingNoIOCand() local2.candidateBase.networkType = NetworkTypeUDP4 local2.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.3"), Port: 10001} remote := newPingNoIOCand() remote.candidateBase.networkType = NetworkTypeUDP4 remote.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.2"), Port: 20000} return local1, local2, remote } t.Run("Only pings all candidates when automatic renomination enabled", func(t *testing.T) { localHost1, localHost2, remoteHost := createTestCandidates() // Test with automatic renomination DISABLED agentWithoutAutoRenom := bareAgentForPing() agentWithoutAutoRenom.log = logging.NewDefaultLoggerFactory().NewLogger("test") agentWithoutAutoRenom.remoteUfrag = selectionTestRemoteUfrag agentWithoutAutoRenom.localUfrag = selectionTestLocalUfrag agentWithoutAutoRenom.remotePwd = selectionTestPassword agentWithoutAutoRenom.tieBreaker = 1 agentWithoutAutoRenom.isControlling.Store(true) agentWithoutAutoRenom.setSelector() // Add pairs - one selected (succeeded) and one alternate (succeeded) pair1 := agentWithoutAutoRenom.addPair(localHost1, remoteHost) pair1.state = CandidatePairStateSucceeded pair1.UpdateRoundTripTime(10 * time.Millisecond) // Don't set selected pair for the "without renomination" agent pair2 := agentWithoutAutoRenom.addPair(localHost2, remoteHost) pair2.state = CandidatePairStateSucceeded pair2.UpdateRoundTripTime(50 * time.Millisecond) // keepAliveCandidatesForRenomination should do nothing when automatic renomination is disabled agentWithoutAutoRenom.keepAliveCandidatesForRenomination() // Since automatic renomination is off, the function should not ping anything // We can't easily verify no pings were sent, but we verify the function completes // Test with automatic renomination ENABLED agentWithAutoRenom := bareAgentForPing() agentWithAutoRenom.log = logging.NewDefaultLoggerFactory().NewLogger("test") agentWithAutoRenom.automaticRenomination = true agentWithAutoRenom.enableRenomination = true agentWithAutoRenom.renominationInterval = 100 * time.Millisecond agentWithAutoRenom.remoteUfrag = selectionTestRemoteUfrag agentWithAutoRenom.localUfrag = selectionTestLocalUfrag agentWithAutoRenom.remotePwd = selectionTestPassword agentWithAutoRenom.tieBreaker = 1 agentWithAutoRenom.isControlling.Store(true) agentWithAutoRenom.setSelector() // Add pairs with different states pair1 = agentWithAutoRenom.addPair(localHost1, remoteHost) pair1.state = CandidatePairStateSucceeded pair1.UpdateRoundTripTime(10 * time.Millisecond) // Don't set selected pair for this test pair2 = agentWithAutoRenom.addPair(localHost2, remoteHost) pair2.state = CandidatePairStateSucceeded pair2.UpdateRoundTripTime(50 * time.Millisecond) // Call keepAliveCandidatesForRenomination - should ping all pairs agentWithAutoRenom.keepAliveCandidatesForRenomination() // Verify both pairs remain in succeeded state (not changed by the function) assert.Equal(t, CandidatePairStateSucceeded, pair1.state) assert.Equal(t, CandidatePairStateSucceeded, pair2.state) }) t.Run("Pings succeeded pairs unlike pingAllCandidates", func(t *testing.T) { localHost1, localHost2, remoteHost := createTestCandidates() agent := bareAgentForPing() agent.log = logging.NewDefaultLoggerFactory().NewLogger("test") agent.automaticRenomination = true agent.enableRenomination = true agent.renominationInterval = 100 * time.Millisecond agent.remoteUfrag = selectionTestRemoteUfrag agent.localUfrag = selectionTestLocalUfrag agent.remotePwd = selectionTestPassword agent.tieBreaker = 1 agent.isControlling.Store(true) agent.setSelector() // Create a pair in succeeded state pair := agent.addPair(localHost1, remoteHost) pair.state = CandidatePairStateSucceeded pair.UpdateRoundTripTime(10 * time.Millisecond) // Create another pair in succeeded state pair2 := agent.addPair(localHost2, remoteHost) pair2.state = CandidatePairStateSucceeded pair2.UpdateRoundTripTime(50 * time.Millisecond) // keepAliveCandidatesForRenomination should ping succeeded pairs // (pingAllCandidates would skip them) agent.keepAliveCandidatesForRenomination() // Pairs should still be in succeeded state assert.Equal(t, CandidatePairStateSucceeded, pair.state) assert.Equal(t, CandidatePairStateSucceeded, pair2.state) }) t.Run("Transitions waiting pairs to in-progress", func(t *testing.T) { localHost1, _, remoteHost := createTestCandidates() agent := bareAgentForPing() agent.log = logging.NewDefaultLoggerFactory().NewLogger("test") agent.automaticRenomination = true agent.enableRenomination = true agent.renominationInterval = 100 * time.Millisecond agent.remoteUfrag = selectionTestRemoteUfrag agent.localUfrag = selectionTestLocalUfrag agent.remotePwd = selectionTestPassword agent.tieBreaker = 1 agent.isControlling.Store(true) agent.setSelector() // Create a pair in waiting state pair := agent.addPair(localHost1, remoteHost) pair.state = CandidatePairStateWaiting // Call keepAliveCandidatesForRenomination agent.keepAliveCandidatesForRenomination() // Pair should transition to in-progress assert.Equal(t, CandidatePairStateInProgress, pair.state) }) t.Run("Skips failed pairs", func(t *testing.T) { localHost1, localHost2, remoteHost := createTestCandidates() agent := bareAgentForPing() agent.log = logging.NewDefaultLoggerFactory().NewLogger("test") agent.automaticRenomination = true agent.enableRenomination = true agent.renominationInterval = 100 * time.Millisecond agent.remoteUfrag = selectionTestRemoteUfrag agent.localUfrag = selectionTestLocalUfrag agent.remotePwd = selectionTestPassword agent.tieBreaker = 1 agent.isControlling.Store(true) agent.setSelector() // Create a succeeded pair pair1 := agent.addPair(localHost1, remoteHost) pair1.state = CandidatePairStateSucceeded // Create a failed pair pair2 := agent.addPair(localHost2, remoteHost) pair2.state = CandidatePairStateFailed // Call keepAliveCandidatesForRenomination agent.keepAliveCandidatesForRenomination() // Failed pair should remain failed (not transitioned to in-progress) assert.Equal(t, CandidatePairStateFailed, pair2.state) // Succeeded pair should remain succeeded assert.Equal(t, CandidatePairStateSucceeded, pair1.state) }) } // TestRenominationAcceptance verifies that the controlled agent correctly // accepts renomination based on nomination values, not just priority. func TestRenominationAcceptance(t *testing.T) { t.Run("Accepts renomination with nomination value regardless of priority", func(t *testing.T) { // Create a controlled agent agent := bareAgentForPing() agent.log = logging.NewDefaultLoggerFactory().NewLogger("test") agent.remoteUfrag = selectionTestRemoteUfrag agent.localUfrag = selectionTestLocalUfrag agent.remotePwd = selectionTestPassword agent.tieBreaker = 1 agent.isControlling.Store(false) // Controlled agent agent.nominationAttribute = stun.AttrType(0x0030) agent.onConnected = make(chan struct{}) // Initialize the channel agent.setSelector() selector, ok := agent.getSelector().(*controlledSelector) require.True(t, ok, "expected controlledSelector") // Create two host candidates with same priority local1 := newPingNoIOCand() local1.candidateBase.networkType = NetworkTypeUDP4 local1.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 10000} local2 := newPingNoIOCand() local2.candidateBase.networkType = NetworkTypeUDP4 local2.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.3"), Port: 10001} remote := newPingNoIOCand() remote.candidateBase.networkType = NetworkTypeUDP4 remote.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.2"), Port: 20000} // Create two pairs with the same priority (both host candidates) pair1 := agent.addPair(local1, remote) pair1.state = CandidatePairStateSucceeded pair2 := agent.addPair(local2, remote) pair2.state = CandidatePairStateSucceeded // Select the first pair initially agent.setSelectedPair(pair1) assert.Equal(t, pair1, agent.getSelectedPair()) assert.True(t, pair1.nominated) // Build a nomination request for the second pair with a nomination value nominationValue := uint32(100) msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), NominationSetter{ Value: nominationValue, AttrType: agent.nominationAttribute, }, stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) require.NoError(t, err) // Handle the binding request with nomination value for pair2 selector.HandleBindingRequest(msg, local2, remote) // The controlled agent should accept the renomination even though // pair2 has the same priority as pair1, because a nomination value is present selectedPair := agent.getSelectedPair() assert.Equal(t, pair2, selectedPair, "Should switch to pair2 when renomination with nomination value is received") assert.True(t, pair2.nominated) }) t.Run("Standard nomination still requires higher priority without nomination value", func(t *testing.T) { // Create a controlled agent agent := bareAgentForPing() agent.log = logging.NewDefaultLoggerFactory().NewLogger("test") agent.remoteUfrag = selectionTestRemoteUfrag agent.localUfrag = selectionTestLocalUfrag agent.remotePwd = selectionTestPassword agent.tieBreaker = 1 agent.isControlling.Store(false) agent.onConnected = make(chan struct{}) agent.setSelector() selector, ok := agent.getSelector().(*controlledSelector) require.True(t, ok, "expected controlledSelector") // Create candidates - we'll simulate lower priority by using different types local1 := newPingNoIOCand() local1.candidateBase.networkType = NetworkTypeUDP4 local1.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 10000} local1.candidateBase.candidateType = CandidateTypeHost // Higher priority local2 := newPingNoIOCand() local2.candidateBase.networkType = NetworkTypeUDP4 local2.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.3"), Port: 10001} local2.candidateBase.candidateType = CandidateTypeHost // Same priority remote := newPingNoIOCand() remote.candidateBase.networkType = NetworkTypeUDP4 remote.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.2"), Port: 20000} remote.candidateBase.candidateType = CandidateTypeHost pair1 := agent.addPair(local1, remote) pair1.state = CandidatePairStateSucceeded pair2 := agent.addPair(local2, remote) pair2.state = CandidatePairStateSucceeded // Select the first pair agent.setSelectedPair(pair1) // Build a standard nomination request WITHOUT nomination value for pair2 msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) require.NoError(t, err) // Handle the binding request for pair2 (same priority) selector.HandleBindingRequest(msg, local2, remote) // Without renomination, standard ICE rules apply // Since pair2 has equal priority to pair1, it should NOT be accepted // (only higher priority pairs are accepted in standard ICE) selectedPair := agent.getSelectedPair() assert.Equal(t, pair1, selectedPair, "Should NOT switch to pair2 with standard nomination when priority is equal") }) t.Run("Higher nomination values override lower ones", func(t *testing.T) { agent := bareAgentForPing() agent.log = logging.NewDefaultLoggerFactory().NewLogger("test") agent.remoteUfrag = selectionTestRemoteUfrag agent.localUfrag = selectionTestLocalUfrag agent.remotePwd = selectionTestPassword agent.tieBreaker = 1 agent.isControlling.Store(false) agent.nominationAttribute = stun.AttrType(0x0030) agent.onConnected = make(chan struct{}) agent.setSelector() selector, ok := agent.getSelector().(*controlledSelector) require.True(t, ok, "expected controlledSelector") local1 := newPingNoIOCand() local1.candidateBase.networkType = NetworkTypeUDP4 local1.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 10000} local2 := newPingNoIOCand() local2.candidateBase.networkType = NetworkTypeUDP4 local2.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.3"), Port: 10001} local3 := newPingNoIOCand() local3.candidateBase.networkType = NetworkTypeUDP4 local3.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.4"), Port: 10002} remote := newPingNoIOCand() remote.candidateBase.networkType = NetworkTypeUDP4 remote.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.2"), Port: 20000} pair1 := agent.addPair(local1, remote) pair1.state = CandidatePairStateSucceeded pair2 := agent.addPair(local2, remote) pair2.state = CandidatePairStateSucceeded pair3 := agent.addPair(local3, remote) pair3.state = CandidatePairStateSucceeded // Nominate pair1 with value 100 msg1, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), NominationSetter{Value: 100, AttrType: agent.nominationAttribute}, stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) require.NoError(t, err) selector.HandleBindingRequest(msg1, local1, remote) assert.Equal(t, pair1, agent.getSelectedPair()) // Try to nominate pair2 with a LOWER value (50) - should be rejected msg2, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), NominationSetter{Value: 50, AttrType: agent.nominationAttribute}, stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) require.NoError(t, err) selector.HandleBindingRequest(msg2, local2, remote) assert.Equal(t, pair1, agent.getSelectedPair(), "Should reject nomination with lower value") // Nominate pair3 with a HIGHER value (200) - should be accepted msg3, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), UseCandidate(), NominationSetter{Value: 200, AttrType: agent.nominationAttribute}, stun.NewShortTermIntegrity(agent.localPwd), stun.Fingerprint, ) require.NoError(t, err) selector.HandleBindingRequest(msg3, local3, remote) assert.Equal(t, pair3, agent.getSelectedPair(), "Should accept nomination with higher value") }) } // TestControllingSideRenomination verifies that the controlling agent correctly // updates its selected pair when receiving a success response for a renomination request. func TestControllingSideRenomination(t *testing.T) { t.Run("Switches selected pair on renomination success response", func(t *testing.T) { // Create a controlling agent agent := bareAgentForPing() agent.log = logging.NewDefaultLoggerFactory().NewLogger("test") agent.remoteUfrag = selectionTestRemoteUfrag agent.localUfrag = selectionTestLocalUfrag agent.remotePwd = selectionTestPassword agent.tieBreaker = 1 agent.isControlling.Store(true) // Controlling agent agent.nominationAttribute = stun.AttrType(0x0030) agent.onConnected = make(chan struct{}) // Initialize the channel agent.setSelector() selector, ok := agent.getSelector().(*controllingSelector) require.True(t, ok, "expected controllingSelector") // Create two host candidates local1 := newPingNoIOCand() local1.candidateBase.networkType = NetworkTypeUDP4 local1.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 10000} local2 := newPingNoIOCand() local2.candidateBase.networkType = NetworkTypeUDP4 local2.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.3"), Port: 10001} remote := newPingNoIOCand() remote.candidateBase.networkType = NetworkTypeUDP4 remote.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.2"), Port: 20000} // Create two pairs pair1 := agent.addPair(local1, remote) pair1.state = CandidatePairStateSucceeded pair1.UpdateRoundTripTime(10 * time.Millisecond) pair2 := agent.addPair(local2, remote) pair2.state = CandidatePairStateSucceeded pair2.UpdateRoundTripTime(5 * time.Millisecond) // Select the first pair initially agent.setSelectedPair(pair1) assert.Equal(t, pair1, agent.getSelectedPair()) assert.True(t, pair1.nominated) // Build a renomination request with nomination value for pair2 nominationValue := uint32(100) msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.remoteUfrag+":"+agent.localUfrag), UseCandidate(), AttrControlling(agent.tieBreaker), PriorityAttr(local2.Priority()), NominationSetter{ Value: nominationValue, AttrType: agent.nominationAttribute, }, stun.NewShortTermIntegrity(agent.remotePwd), stun.Fingerprint, ) require.NoError(t, err) // Simulate sending the binding request (adds to pendingBindingRequests) agent.sendBindingRequest(msg, local2, remote) // Verify the nomination value was stored in the pending request require.Len(t, agent.pendingBindingRequests, 1) require.NotNil(t, agent.pendingBindingRequests[0].nominationValue) require.Equal(t, nominationValue, *agent.pendingBindingRequests[0].nominationValue) // Build a success response successMsg, err := stun.Build(msg, stun.BindingSuccess, &stun.XORMappedAddress{ IP: net.ParseIP("192.168.1.2").To4(), Port: 20000, }, stun.NewShortTermIntegrity(agent.remotePwd), stun.Fingerprint, ) require.NoError(t, err) // Handle the success response - this should switch to pair2 selector.HandleSuccessResponse(successMsg, local2, remote, remote.addr()) // The controlling agent should have switched to pair2 selectedPair := agent.getSelectedPair() assert.Equal(t, pair2, selectedPair, "Controlling agent should switch to pair2 after renomination success response") assert.True(t, pair2.nominated) }) t.Run("Does not switch on standard nomination success if pair already selected", func(t *testing.T) { // Create a controlling agent agent := bareAgentForPing() agent.log = logging.NewDefaultLoggerFactory().NewLogger("test") agent.remoteUfrag = selectionTestRemoteUfrag agent.localUfrag = selectionTestLocalUfrag agent.remotePwd = selectionTestPassword agent.tieBreaker = 1 agent.isControlling.Store(true) agent.onConnected = make(chan struct{}) agent.setSelector() selector, ok := agent.getSelector().(*controllingSelector) require.True(t, ok, "expected controllingSelector") // Create two host candidates local1 := newPingNoIOCand() local1.candidateBase.networkType = NetworkTypeUDP4 local1.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 10000} local2 := newPingNoIOCand() local2.candidateBase.networkType = NetworkTypeUDP4 local2.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.3"), Port: 10001} remote := newPingNoIOCand() remote.candidateBase.networkType = NetworkTypeUDP4 remote.candidateBase.resolvedAddr = &net.UDPAddr{IP: net.ParseIP("192.168.1.2"), Port: 20000} // Create two pairs pair1 := agent.addPair(local1, remote) pair1.state = CandidatePairStateSucceeded pair2 := agent.addPair(local2, remote) pair2.state = CandidatePairStateSucceeded // Select the first pair initially agent.setSelectedPair(pair1) assert.Equal(t, pair1, agent.getSelectedPair()) // Build a standard nomination request WITHOUT nomination value for pair2 msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(agent.remoteUfrag+":"+agent.localUfrag), UseCandidate(), AttrControlling(agent.tieBreaker), PriorityAttr(local2.Priority()), stun.NewShortTermIntegrity(agent.remotePwd), stun.Fingerprint, ) require.NoError(t, err) // Simulate sending the binding request agent.sendBindingRequest(msg, local2, remote) // Verify no nomination value was stored require.Len(t, agent.pendingBindingRequests, 1) require.Nil(t, agent.pendingBindingRequests[0].nominationValue) // Build a success response successMsg, err := stun.Build(msg, stun.BindingSuccess, &stun.XORMappedAddress{ IP: net.ParseIP("192.168.1.2").To4(), Port: 20000, }, stun.NewShortTermIntegrity(agent.remotePwd), stun.Fingerprint, ) require.NoError(t, err) // Handle the success response - this should NOT switch since it's standard nomination // and a pair is already selected selector.HandleSuccessResponse(successMsg, local2, remote, remote.addr()) // The controlling agent should remain with pair1 selectedPair := agent.getSelectedPair() assert.Equal(t, pair1, selectedPair, "Controlling agent should NOT switch with standard nomination when pair already selected") }) } ice-4.1.0/stats.go000066400000000000000000000175321511704470200137520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "time" ) // CandidatePairStats contains ICE candidate pair statistics. type CandidatePairStats struct { // Timestamp is the timestamp associated with this object. Timestamp time.Time // LocalCandidateID is the ID of the local candidate LocalCandidateID string // RemoteCandidateID is the ID of the remote candidate RemoteCandidateID string // State represents the state of the checklist for the local and remote // candidates in a pair. State CandidatePairState // Nominated is true when this valid pair that should be used for media // if it is the highest-priority one amongst those whose nominated flag is set Nominated bool // PacketsSent represents the total number of packets sent on this candidate pair. PacketsSent uint32 // PacketsReceived represents the total number of packets received on this candidate pair. PacketsReceived uint32 // BytesSent represents the total number of payload bytes sent on this candidate pair // not including headers or padding. BytesSent uint64 // BytesReceived represents the total number of payload bytes received on this candidate pair // not including headers or padding. BytesReceived uint64 // LastPacketSentTimestamp represents the timestamp at which the last packet was // sent on this particular candidate pair, excluding STUN packets. LastPacketSentTimestamp time.Time // LastPacketReceivedTimestamp represents the timestamp at which the last packet // was received on this particular candidate pair, excluding STUN packets. LastPacketReceivedTimestamp time.Time // FirstRequestTimestamp represents the timestamp at which the first STUN request // was sent on this particular candidate pair. FirstRequestTimestamp time.Time // LastRequestTimestamp represents the timestamp at which the last STUN request // was sent on this particular candidate pair. The average interval between two // consecutive connectivity checks sent can be calculated with // (LastRequestTimestamp - FirstRequestTimestamp) / RequestsSent. LastRequestTimestamp time.Time // FirstResponseTimestamp represents the timestamp at which the first STUN response // was received on this particular candidate pair. FirstResponseTimestamp time.Time // LastResponseTimestamp represents the timestamp at which the last STUN response // was received on this particular candidate pair. LastResponseTimestamp time.Time // FirstRequestReceivedTimestamp represents the timestamp at which the first // connectivity check request was received. FirstRequestReceivedTimestamp time.Time // LastRequestReceivedTimestamp represents the timestamp at which the last // connectivity check request was received. LastRequestReceivedTimestamp time.Time // TotalRoundTripTime represents the sum of all round trip time measurements // in seconds since the beginning of the session, based on STUN connectivity // check responses (ResponsesReceived), including those that reply to requests // that are sent in order to verify consent. The average round trip time can // be computed from TotalRoundTripTime by dividing it by ResponsesReceived. TotalRoundTripTime float64 // CurrentRoundTripTime represents the latest round trip time measured in seconds, // computed from both STUN connectivity checks, including those that are sent // for consent verification. CurrentRoundTripTime float64 // AvailableOutgoingBitrate is calculated by the underlying congestion control // by combining the available bitrate for all the outgoing RTP streams using // this candidate pair. The bitrate measurement does not count the size of the // IP or other transport layers like TCP or UDP. It is similar to the TIAS defined // in RFC 3890, i.e., it is measured in bits per second and the bitrate is calculated // over a 1 second window. AvailableOutgoingBitrate float64 // AvailableIncomingBitrate is calculated by the underlying congestion control // by combining the available bitrate for all the incoming RTP streams using // this candidate pair. The bitrate measurement does not count the size of the // IP or other transport layers like TCP or UDP. It is similar to the TIAS defined // in RFC 3890, i.e., it is measured in bits per second and the bitrate is // calculated over a 1 second window. AvailableIncomingBitrate float64 // CircuitBreakerTriggerCount represents the number of times the circuit breaker // is triggered for this particular 5-tuple, ceasing transmission. CircuitBreakerTriggerCount uint32 // RequestsReceived represents the total number of connectivity check requests // received (including retransmissions). It is impossible for the receiver to // tell whether the request was sent in order to check connectivity or check // consent, so all connectivity checks requests are counted here. RequestsReceived uint64 // RequestsSent represents the total number of connectivity check requests // sent (not including retransmissions). RequestsSent uint64 // ResponsesReceived represents the total number of connectivity check responses received. ResponsesReceived uint64 // ResponsesSent represents the total number of connectivity check responses sent. // Since we cannot distinguish connectivity check requests and consent requests, // all responses are counted. ResponsesSent uint64 // RetransmissionsReceived represents the total number of connectivity check // request retransmissions received. RetransmissionsReceived uint64 // RetransmissionsSent represents the total number of connectivity check // request retransmissions sent. RetransmissionsSent uint64 // ConsentRequestsSent represents the total number of consent requests sent. ConsentRequestsSent uint64 // ConsentExpiredTimestamp represents the timestamp at which the latest valid // STUN binding response expired. ConsentExpiredTimestamp time.Time } // CandidateStats contains ICE candidate statistics related to the ICETransport objects. type CandidateStats struct { // Timestamp is the timestamp associated with this object. Timestamp time.Time // ID is the candidate ID ID string // NetworkType represents the type of network interface used by the base of a // local candidate (the address the ICE agent sends from). Only present for // local candidates; it's not possible to know what type of network interface // a remote candidate is using. // // Note: // This stat only tells you about the network interface used by the first "hop"; // it's possible that a connection will be bottlenecked by another type of network. // For example, when using Wi-Fi tethering, the networkType of the relevant candidate // would be "wifi", even when the next hop is over a cellular connection. NetworkType NetworkType // IP is the IP address of the candidate, allowing for IPv4 addresses and // IPv6 addresses, but fully qualified domain names (FQDNs) are not allowed. IP string // Port is the port number of the candidate. Port int // CandidateType is the "Type" field of the ICECandidate. CandidateType CandidateType // Priority is the "Priority" field of the ICECandidate. Priority uint32 // URL is the URL of the TURN or STUN server indicated in the that translated // this IP address. It is the URL address surfaced in an PeerConnectionICEEvent. URL string // RelayProtocol is the protocol used by the endpoint to communicate with the // TURN server. This is only present for local candidates. Valid values for // the TURN URL protocol is one of UDP, TCP, or TLS. RelayProtocol string // Deleted is true if the candidate has been deleted/freed. For host candidates, // this means that any network resources (typically a socket) associated with the // candidate have been released. For TURN candidates, this means the TURN allocation // is no longer active. // // Only defined for local candidates. For remote candidates, this property is not applicable. Deleted bool } ice-4.1.0/tcp_mux.go000066400000000000000000000277611511704470200143000ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "encoding/binary" "errors" "io" "net" "strings" "sync" "time" "github.com/pion/logging" "github.com/pion/stun/v3" ) // ErrGetTransportAddress can't convert net.Addr to underlying type (UDPAddr or TCPAddr). var ErrGetTransportAddress = errors.New("failed to get local transport address") // TCPMux is allows grouping multiple TCP net.Conns and using them like UDP // net.PacketConns. The main implementation of this is TCPMuxDefault, and this // interface exists to allow mocking in tests. type TCPMux interface { io.Closer GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) RemoveConnByUfrag(ufrag string) } type ipAddr string // TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by // Ufrag. It is a default implementation of TCPMux interface. type TCPMuxDefault struct { params *TCPMuxParams closed bool // connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag and local address connsIPv4, connsIPv6 map[string]map[ipAddr]*tcpPacketConn mu sync.Mutex wg sync.WaitGroup } // TCPMuxParams are parameters for TCPMux. type TCPMuxParams struct { Listener net.Listener Logger logging.LeveledLogger ReadBufferSize int // Maximum buffer size for write op. 0 means no write buffer, the write op will block until the whole packet is written // if the write buffer is full, the subsequent write packet will be dropped until it has enough space. // a default 4MB is recommended. WriteBufferSize int // A new established connection will be removed if the first STUN binding request is not received within this timeout, // avoiding the client with bad network or attacker to create a lot of empty connections. // Default 30s timeout will be used if not set. FirstStunBindTimeout time.Duration // TCPMux will create connection from STUN binding request with an unknown username, if // the connection is not used in the timeout, it will be removed to avoid resource leak / attack. // Default 30s timeout will be used if not set. AliveDurationForConnFromStun time.Duration } // NewTCPMuxDefault creates a new instance of TCPMuxDefault. func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault { if params.Logger == nil { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } if params.FirstStunBindTimeout == 0 { params.FirstStunBindTimeout = 30 * time.Second } if params.AliveDurationForConnFromStun == 0 { params.AliveDurationForConnFromStun = 30 * time.Second } mux := &TCPMuxDefault{ params: ¶ms, connsIPv4: map[string]map[ipAddr]*tcpPacketConn{}, connsIPv6: map[string]map[ipAddr]*tcpPacketConn{}, } mux.wg.Add(1) go func() { defer mux.wg.Done() mux.start() }() return mux } func (m *TCPMuxDefault) start() { m.params.Logger.Infof("Listening TCP on %s", m.params.Listener.Addr()) for { conn, err := m.params.Listener.Accept() if err != nil { m.params.Logger.Infof("Error accepting connection: %s", err) return } m.params.Logger.Debugf("Accepted connection from: %s to %s", conn.RemoteAddr(), conn.LocalAddr()) m.wg.Add(1) go func() { defer m.wg.Done() m.handleConn(conn) }() } } // LocalAddr returns the listening address of this TCPMuxDefault. func (m *TCPMuxDefault) LocalAddr() net.Addr { return m.params.Listener.Addr() } // GetConnByUfrag retrieves an existing or creates a new net.PacketConn. func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() if m.closed { return nil, io.ErrClosedPipe } if conn, ok := m.getConn(ufrag, isIPv6, local); ok { conn.ClearAliveTimer() return conn, nil } return m.createConn(ufrag, isIPv6, local, false) } func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, fromStun bool) (*tcpPacketConn, error) { addr, ok := m.LocalAddr().(*net.TCPAddr) if !ok { return nil, ErrGetTransportAddress } localAddr := *addr // Note: this is missing zone for IPv6 localAddr.IP = local var alive time.Duration if fromStun { alive = m.params.AliveDurationForConnFromStun } conn := newTCPPacketConn(tcpPacketParams{ ReadBuffer: m.params.ReadBufferSize, WriteBuffer: m.params.WriteBufferSize, LocalAddr: &localAddr, Logger: m.params.Logger, AliveDuration: alive, }) var conns map[ipAddr]*tcpPacketConn if isIPv6 { if conns, ok = m.connsIPv6[ufrag]; !ok { conns = make(map[ipAddr]*tcpPacketConn) m.connsIPv6[ufrag] = conns } } else { if conns, ok = m.connsIPv4[ufrag]; !ok { conns = make(map[ipAddr]*tcpPacketConn) m.connsIPv4[ufrag] = conns } } // Note: this is missing zone for IPv6 connKey := ipAddr(local.String()) conns[connKey] = conn m.wg.Add(1) go func() { defer m.wg.Done() <-conn.CloseChannel() m.removeConnByUfragAndLocalHost(ufrag, connKey) }() return conn, nil } func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) { err := closer.Close() if err != nil { m.params.Logger.Warnf("Error closing connection: %s", err) } } func (m *TCPMuxDefault) handleConn(conn net.Conn) { //nolint:cyclop buf := make([]byte, 512) if m.params.FirstStunBindTimeout > 0 { if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil { m.params.Logger.Warnf( "Failed to set read deadline for first STUN message: %s to %s , err: %s", conn.RemoteAddr(), conn.LocalAddr(), err, ) } } n, err := readStreamingPacket(conn, buf) if err != nil { if errors.Is(err, io.ErrShortBuffer) { m.params.Logger.Warnf("Buffer too small for first packet from %s: %s", conn.RemoteAddr(), err) } else { m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err) } m.closeAndLogError(conn) return } if err = conn.SetReadDeadline(time.Time{}); err != nil { m.params.Logger.Warnf("Failed to reset read deadline from %s: %s", conn.RemoteAddr(), err) } buf = buf[:n] msg := &stun.Message{ Raw: make([]byte, len(buf)), } // Explicitly copy raw buffer so Message can own the memory. copy(msg.Raw, buf) if err = msg.Decode(); err != nil { m.closeAndLogError(conn) m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err) return } if m == nil || msg.Type.Method != stun.MethodBinding { // Not a STUN m.closeAndLogError(conn) m.params.Logger.Warnf("Not a STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) return } for _, attr := range msg.Attributes { m.params.Logger.Debugf("Message attribute: %s", attr.String()) } attr, err := msg.Get(stun.AttrUsername) if err != nil { m.closeAndLogError(conn) m.params.Logger.Warnf( "No Username attribute in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr(), ) return } ufrag := strings.Split(string(attr), ":")[0] m.params.Logger.Debugf("Ufrag: %s", ufrag) host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err != nil { m.closeAndLogError(conn) m.params.Logger.Warnf( "Failed to get host in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr(), ) return } isIPv6 := net.ParseIP(host).To4() == nil localAddr, ok := conn.LocalAddr().(*net.TCPAddr) if !ok { m.closeAndLogError(conn) m.params.Logger.Warnf( "Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr(), ) return } m.mu.Lock() packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP) if !ok { packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP, true) if err != nil { m.mu.Unlock() m.closeAndLogError(conn) m.params.Logger.Warnf( "Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr(), ) return } } m.mu.Unlock() if err := packetConn.AddConn(conn, buf); err != nil { m.closeAndLogError(conn) m.params.Logger.Warnf( "Error adding conn to tcpPacketConn from %s to %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err, ) return } } // Close closes the listener and waits for all goroutines to exit. func (m *TCPMuxDefault) Close() error { m.mu.Lock() m.closed = true for _, conns := range m.connsIPv4 { for _, conn := range conns { m.closeAndLogError(conn) } } for _, conns := range m.connsIPv6 { for _, conn := range conns { m.closeAndLogError(conn) } } m.connsIPv4 = map[string]map[ipAddr]*tcpPacketConn{} m.connsIPv6 = map[string]map[ipAddr]*tcpPacketConn{} err := m.params.Listener.Close() m.mu.Unlock() m.wg.Wait() return err } // RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag. func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) { removedConns := make([]*tcpPacketConn, 0, 4) // Keep lock section small to avoid deadlock with conn lock m.mu.Lock() if conns, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) for _, conn := range conns { removedConns = append(removedConns, conn) } } if conns, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) for _, conn := range conns { removedConns = append(removedConns, conn) } } m.mu.Unlock() // Close the connections outside the critical section to avoid // deadlocking TCP mux if (*tcpPacketConn).Close() blocks. for _, conn := range removedConns { m.closeAndLogError(conn) } } func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, localIPAddr ipAddr) { removedConns := make([]*tcpPacketConn, 0, 4) // Keep lock section small to avoid deadlock with conn lock m.mu.Lock() if conns, ok := m.connsIPv4[ufrag]; ok { if conn, ok := conns[localIPAddr]; ok { delete(conns, localIPAddr) if len(conns) == 0 { delete(m.connsIPv4, ufrag) } removedConns = append(removedConns, conn) } } if conns, ok := m.connsIPv6[ufrag]; ok { if conn, ok := conns[localIPAddr]; ok { delete(conns, localIPAddr) if len(conns) == 0 { delete(m.connsIPv6, ufrag) } removedConns = append(removedConns, conn) } } m.mu.Unlock() // Close the connections outside the critical section to avoid // deadlocking TCP mux if (*tcpPacketConn).Close() blocks. for _, conn := range removedConns { m.closeAndLogError(conn) } } func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *tcpPacketConn, ok bool) { var conns map[ipAddr]*tcpPacketConn if isIPv6 { conns, ok = m.connsIPv6[ufrag] } else { conns, ok = m.connsIPv4[ufrag] } if conns != nil { // Note: this is missing zone for IPv6 connKey := ipAddr(local.String()) val, ok = conns[connKey] } return } const streamingPacketHeaderLen = 2 // readStreamingPacket reads 1 packet from stream // read packet bytes https://tools.ietf.org/html/rfc4571#section-2 // 2-byte length header prepends each packet: // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // ----------------------------------------------------------------- // | LENGTH | RTP or RTCP packet ... | // ----------------------------------------------------------------- func readStreamingPacket(conn net.Conn, buf []byte) (int, error) { header := make([]byte, streamingPacketHeaderLen) var bytesRead, n int var err error for bytesRead < streamingPacketHeaderLen { if n, err = conn.Read(header[bytesRead:streamingPacketHeaderLen]); err != nil { return 0, err } bytesRead += n } length := int(binary.BigEndian.Uint16(header)) if length > cap(buf) { return length, io.ErrShortBuffer } bytesRead = 0 for bytesRead < length { if n, err = conn.Read(buf[bytesRead:length]); err != nil { return 0, err } bytesRead += n } return bytesRead, nil } func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) { bufCopy := make([]byte, streamingPacketHeaderLen+len(buf)) binary.BigEndian.PutUint16(bufCopy, uint16(len(buf))) //nolint:gosec // G115 copy(bufCopy[2:], buf) n, err := conn.Write(bufCopy) if err != nil { return 0, err } return n - streamingPacketHeaderLen, nil } ice-4.1.0/tcp_mux_multi.go000066400000000000000000000050721511704470200155010ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "net" ) // AllConnsGetter allows multiple fixed TCP ports to be used, // each of which is multiplexed like TCPMux. AllConnsGetter also acts as // a TCPMux, in which case it will return a single connection for one // of the ports. type AllConnsGetter interface { GetAllConns(ufrag string, isIPv6 bool, localIP net.IP) ([]net.PacketConn, error) } // MultiTCPMuxDefault implements both TCPMux and AllConnsGetter, // allowing users to pass multiple TCPMux instances to the ICE agent // configuration. type MultiTCPMuxDefault struct { muxes []TCPMux } // NewMultiTCPMuxDefault creates an instance of MultiTCPMuxDefault that // uses the provided TCPMux instances. func NewMultiTCPMuxDefault(muxes ...TCPMux) *MultiTCPMuxDefault { return &MultiTCPMuxDefault{ muxes: muxes, } } // GetConnByUfrag returns a PacketConn given the connection's ufrag, network and local address // creates the connection if an existing one can't be found. This, unlike // GetAllConns, will only return a single PacketConn from the first mux that was // passed in to NewMultiTCPMuxDefault. func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) { // NOTE: We always use the first element here in order to maintain the // behavior of using an existing connection if one exists. if len(m.muxes) == 0 { return nil, errNoTCPMuxAvailable } return m.muxes[0].GetConnByUfrag(ufrag, isIPv6, local) } // RemoveConnByUfrag stops and removes the muxed packet connection // from all underlying TCPMux instances. func (m *MultiTCPMuxDefault) RemoveConnByUfrag(ufrag string) { for _, mux := range m.muxes { mux.RemoveConnByUfrag(ufrag) } } // GetAllConns returns a PacketConn for each underlying TCPMux. func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) { if len(m.muxes) == 0 { // Make sure that we either return at least one connection or an error. return nil, errNoTCPMuxAvailable } var conns []net.PacketConn for _, mux := range m.muxes { conn, err := mux.GetConnByUfrag(ufrag, isIPv6, local) if err != nil { // For now, this implementation is all or none. return nil, err } if conn != nil { conns = append(conns, conn) } } return conns, nil } // Close the multi mux, no further connections could be created. func (m *MultiTCPMuxDefault) Close() error { var err error for _, mux := range m.muxes { if e := mux.Close(); e != nil { err = e } } return err } ice-4.1.0/tcp_mux_multi_test.go000066400000000000000000000200671511704470200165410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "errors" "io" "net" "testing" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) func TestMultiTCPMux_Recv(t *testing.T) { for name, bufSize := range map[string]int{ "no buffer": 0, "buffered 4MB": 4 * 1024 * 1024, } { bufSize := bufSize t.Run(name, func(t *testing.T) { defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() var muxInstances []TCPMux for i := 0; i < 3; i++ { listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 0, }) require.NoError(t, err, "error starting listener") defer func() { _ = listener.Close() }() tcpMux := NewTCPMuxDefault(TCPMuxParams{ Listener: listener, Logger: loggerFactory.NewLogger("ice"), ReadBufferSize: 20, WriteBufferSize: bufSize, }) defer func() { _ = tcpMux.Close() }() muxInstances = append(muxInstances, tcpMux) require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") } multiMux := NewMultiTCPMuxDefault(muxInstances...) defer func() { _ = multiMux.Close() }() pktConns, err := multiMux.GetAllConns("myufrag", false, net.IP{127, 0, 0, 1}) require.NoError(t, err, "error retrieving muxed connection for ufrag") for _, pktConn := range pktConns { defer func() { _ = pktConn.Close() }() conn, err := net.DialTCP("tcp", nil, pktConn.LocalAddr().(*net.TCPAddr)) // nolint require.NoError(t, err, "error dialing test TCP connection") msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag")) msg.Encode() n, err := writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing TCP STUN packet") recv := make([]byte, n) n2, rAddr, err := pktConn.ReadFrom(recv) require.NoError(t, err, "error receiving data") require.Equal(t, conn.LocalAddr(), rAddr, "remote TCP address mismatch") require.Equal(t, n, n2, "received byte size mismatch") require.Equal(t, msg.Raw, recv, "received bytes mismatch") // Check echo response n, err = pktConn.WriteTo(recv, conn.LocalAddr()) require.NoError(t, err, "error writing echo STUN packet") recvEcho := make([]byte, n) n3, err := readStreamingPacket(conn, recvEcho) require.NoError(t, err, "error receiving echo data") require.Equal(t, n2, n3, "received byte size mismatch") require.Equal(t, msg.Raw, recvEcho, "received bytes mismatch") } }) } } func TestMultiTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() var tcpMuxInstances []TCPMux for i := 0; i < 3; i++ { listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 0, }) require.NoError(t, err, "error starting listener") defer func() { _ = listener.Close() }() tcpMux := NewTCPMuxDefault(TCPMuxParams{ Listener: listener, Logger: loggerFactory.NewLogger("ice"), ReadBufferSize: 20, }) defer func() { _ = tcpMux.Close() }() tcpMuxInstances = append(tcpMuxInstances, tcpMux) } muxMulti := NewMultiTCPMuxDefault(tcpMuxInstances...) _, err := muxMulti.GetAllConns("test", false, net.IP{127, 0, 0, 1}) require.NoError(t, err, "error getting conn by ufrag") require.NoError(t, muxMulti.Close(), "error closing tcpMux") conn, err := muxMulti.GetAllConns("test", false, net.IP{127, 0, 0, 1}) require.Nil(t, conn, "should receive nil because mux is closed") require.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") } func TestMultiTCPMux_GetConnByUfrag_NoMuxes(t *testing.T) { multi := NewMultiTCPMuxDefault() // no muxes pc, err := multi.GetConnByUfrag("ufrag", false, net.IP{127, 0, 0, 1}) require.Nil(t, pc) require.ErrorIs(t, err, errNoTCPMuxAvailable) } func TestMultiTCPMux_GetConnByUfrag_FromAnyMux(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("ice") l1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) require.NoError(t, err) defer func() { _ = l1.Close() }() mux1 := NewTCPMuxDefault(TCPMuxParams{ Listener: l1, Logger: logger, ReadBufferSize: 8, }) defer func() { _ = mux1.Close() }() l2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) require.NoError(t, err) defer func() { _ = l2.Close() }() mux2 := NewTCPMuxDefault(TCPMuxParams{ Listener: l2, Logger: logger, ReadBufferSize: 8, }) defer func() { _ = mux2.Close() }() multi := NewMultiTCPMuxDefault(mux1, mux2) defer func() { _ = multi.Close() }() pc, err := multi.GetConnByUfrag("myufrag", false, net.IP{127, 0, 0, 1}) require.NoError(t, err) require.NotNil(t, pc) pcAddr, ok := pc.LocalAddr().(*net.TCPAddr) require.True(t, ok, "packet conn addr should be *net.TCPAddr") m1Addr, ok := mux1.LocalAddr().(*net.TCPAddr) require.True(t, ok, "mux1 local addr should be *net.TCPAddr") m2Addr, ok := mux2.LocalAddr().(*net.TCPAddr) require.True(t, ok, "mux2 local addr should be *net.TCPAddr") isFromMux1 := pcAddr.Port == m1Addr.Port && pcAddr.IP.Equal(m1Addr.IP) isFromMux2 := pcAddr.Port == m2Addr.Port && pcAddr.IP.Equal(m2Addr.IP) require.True(t, isFromMux1 || isFromMux2, "conn must come from one of the underlying muxes") } func TestMultiTCPMux_GetAllConns_NoMuxes(t *testing.T) { multi := NewMultiTCPMuxDefault() // no underlying TCPMux instances conns, err := multi.GetAllConns("ufrag", false, net.IP{127, 0, 0, 1}) require.Nil(t, conns) require.ErrorIs(t, err, errNoTCPMuxAvailable) } var ( errTCPMuxCloseBoom = errors.New("tcp mux close boom") errTCPMuxCloseFirst = errors.New("first tcp mux close failed") errTCPMuxCloseSecond = errors.New("second tcp mux close failed") ) type closeErrTCPMux struct { TCPMux ret error } func (w *closeErrTCPMux) Close() error { _ = w.TCPMux.Close() return w.ret } func TestMultiTCPMux_Close_PropagatesError_FromWrappedMux(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("ice") // first mux: normal close (nil) l1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) require.NoError(t, err) mux1 := NewTCPMuxDefault(TCPMuxParams{ Listener: l1, Logger: logger, ReadBufferSize: 8, }) defer func() { _ = mux1.Close() }() // second mux: Close() returns injected error l2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) require.NoError(t, err) mux2Real := NewTCPMuxDefault(TCPMuxParams{ Listener: l2, Logger: logger, ReadBufferSize: 8, }) defer func() { _ = mux2Real.Close() }() mux2 := &closeErrTCPMux{TCPMux: mux2Real, ret: errTCPMuxCloseBoom} multi := NewMultiTCPMuxDefault(mux1, mux2) got := multi.Close() require.ErrorIs(t, got, errTCPMuxCloseBoom) } func TestMultiTCPMux_Close_LastErrorWins_FromWrappedMuxes(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("ice") // first mux: error1 la, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) require.NoError(t, err) mux1Real := NewTCPMuxDefault(TCPMuxParams{ Listener: la, Logger: logger, ReadBufferSize: 8, }) defer func() { _ = mux1Real.Close() }() mux1 := &closeErrTCPMux{TCPMux: mux1Real, ret: errTCPMuxCloseFirst} // second mux: error2 (last error should be returned) lb, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) require.NoError(t, err) mux2Real := NewTCPMuxDefault(TCPMuxParams{ Listener: lb, Logger: logger, ReadBufferSize: 8, }) defer func() { _ = mux2Real.Close() }() mux2 := &closeErrTCPMux{TCPMux: mux2Real, ret: errTCPMuxCloseSecond} multi := NewMultiTCPMuxDefault(mux1, mux2) got := multi.Close() require.ErrorIs(t, got, errTCPMuxCloseSecond) } ice-4.1.0/tcp_mux_test.go000066400000000000000000000171251511704470200153300ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "io" "net" "os" "testing" "time" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) var _ TCPMux = &TCPMuxDefault{} func TestTCPMux_Recv(t *testing.T) { for name, bufSize := range map[string]int{ "no buffer": 0, "buffered 4MB": 4 * 1024 * 1024, } { bufSize := bufSize t.Run(name, func(t *testing.T) { defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 0, }) require.NoError(t, err, "error starting listener") defer func() { _ = listener.Close() }() tcpMux := NewTCPMuxDefault(TCPMuxParams{ Listener: listener, Logger: loggerFactory.NewLogger("ice"), ReadBufferSize: 20, WriteBufferSize: bufSize, }) defer func() { _ = tcpMux.Close() }() require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint require.NoError(t, err, "error dialing test TCP connection") msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag")) msg.Encode() n, err := writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing TCP STUN packet") listenerAddr, ok := listener.Addr().(*net.TCPAddr) require.True(t, ok) pktConn, err := tcpMux.GetConnByUfrag("myufrag", false, listenerAddr.IP) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() }() recv := make([]byte, n) n2, rAddr, err := pktConn.ReadFrom(recv) require.NoError(t, err, "error receiving data") require.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch") require.Equal(t, n, n2, "received byte size mismatch") require.Equal(t, msg.Raw, recv, "received bytes mismatch") // Check echo response n, err = pktConn.WriteTo(recv, conn.LocalAddr()) require.NoError(t, err, "error writing echo STUN packet") recvEcho := make([]byte, n) n3, err := readStreamingPacket(conn, recvEcho) require.NoError(t, err, "error receiving echo data") require.Equal(t, n2, n3, "received byte size mismatch") require.Equal(t, msg.Raw, recvEcho, "received bytes mismatch") }) } } func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 0, }) require.NoError(t, err, "error starting listener") defer func() { _ = listener.Close() }() tcpMux := NewTCPMuxDefault(TCPMuxParams{ Listener: listener, Logger: loggerFactory.NewLogger("ice"), ReadBufferSize: 20, }) defer func() { _ = tcpMux.Close() }() listenerAddr, ok := listener.Addr().(*net.TCPAddr) require.True(t, ok) _, err = tcpMux.GetConnByUfrag("test", false, listenerAddr.IP) require.NoError(t, err, "error getting conn by ufrag") require.NoError(t, tcpMux.Close(), "error closing tcpMux") conn, err := tcpMux.GetConnByUfrag("test", false, listenerAddr.IP) require.Nil(t, conn, "should receive nil because mux is closed") require.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") } func TestTCPMux_FirstPacketTimeout(t *testing.T) { defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 0, }) require.NoError(t, err, "error starting listener") defer func() { _ = listener.Close() }() tcpMux := NewTCPMuxDefault(TCPMuxParams{ Listener: listener, Logger: loggerFactory.NewLogger("ice"), ReadBufferSize: 20, FirstStunBindTimeout: time.Second, }) defer func() { _ = tcpMux.Close() }() require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint require.NoError(t, err, "error dialing test TCP connection") defer func() { _ = conn.Close() }() // Don't send any data, the mux should close the connection after the timeout time.Sleep(1500 * time.Millisecond) require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second))) buf := make([]byte, 1) _, err = conn.Read(buf) require.ErrorIs(t, err, io.EOF) } func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) { defer test.CheckRoutines(t)() loggerFactory := logging.NewDefaultLoggerFactory() listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 0, }) require.NoError(t, err, "error starting listener") defer func() { _ = listener.Close() }() tcpMux := NewTCPMuxDefault(TCPMuxParams{ Listener: listener, Logger: loggerFactory.NewLogger("ice"), ReadBufferSize: 20, AliveDurationForConnFromStun: time.Second, }) defer func() { _ = tcpMux.Close() }() require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") t.Run("close connection from stun msg after timeout", func(t *testing.T) { conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint require.NoError(t, err, "error dialing test TCP connection") defer func() { _ = conn.Close() }() msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername("myufrag:otherufrag"), stun.NewShortTermIntegrity("myufrag"), stun.Fingerprint, ) require.NoError(t, err, "error building STUN packet") msg.Encode() _, err = writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing TCP STUN packet") time.Sleep(1500 * time.Millisecond) require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second))) buf := make([]byte, 1) _, err = conn.Read(buf) require.ErrorIs(t, err, io.EOF) }) t.Run("connection keep alive if access by user", func(t *testing.T) { conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint require.NoError(t, err, "error dialing test TCP connection") defer func() { _ = conn.Close() }() msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername("myufrag2:otherufrag2"), stun.NewShortTermIntegrity("myufrag2"), stun.Fingerprint, ) require.NoError(t, err, "error building STUN packet") msg.Encode() n, err := writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing TCP STUN packet") // wait for the connection to be created time.Sleep(100 * time.Millisecond) listenerAddr, ok := listener.Addr().(*net.TCPAddr) require.True(t, ok) pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listenerAddr.IP) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() }() time.Sleep(1500 * time.Millisecond) // timeout, not closed buf := make([]byte, 1024) require.NoError(t, conn.SetReadDeadline(time.Now().Add(100*time.Millisecond))) _, err = conn.Read(buf) require.ErrorIs(t, err, os.ErrDeadlineExceeded) recv := make([]byte, n) n2, rAddr, err := pktConn.ReadFrom(recv) require.NoError(t, err, "error receiving data") require.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch") require.Equal(t, n, n2, "received byte size mismatch") require.Equal(t, msg.Raw, recv, "received bytes mismatch") }) } ice-4.1.0/tcp_packet_conn.go000066400000000000000000000154751511704470200157520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "errors" "fmt" "io" "net" "sync" "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/transport/v3/packetio" ) type bufferedConn struct { net.Conn buf *packetio.Buffer logger logging.LeveledLogger closed int32 } func newBufferedConn(conn net.Conn, bufSize int, logger logging.LeveledLogger) net.Conn { buf := packetio.NewBuffer() if bufSize > 0 { buf.SetLimitSize(bufSize) } bc := &bufferedConn{ Conn: conn, buf: buf, logger: logger, } go bc.writeProcess() return bc } func (bc *bufferedConn) Write(b []byte) (int, error) { n, err := bc.buf.Write(b) if err != nil { return n, err } return n, nil } func (bc *bufferedConn) writeProcess() { pktBuf := make([]byte, receiveMTU) for atomic.LoadInt32(&bc.closed) == 0 { n, err := bc.buf.Read(pktBuf) if errors.Is(err, io.EOF) { return } if err != nil { bc.logger.Warnf("Failed to read from buffer: %s", err) continue } if _, err := bc.Conn.Write(pktBuf[:n]); err != nil { bc.logger.Warnf("Failed to write: %s", err) continue } } } func (bc *bufferedConn) Close() error { atomic.StoreInt32(&bc.closed, 1) _ = bc.buf.Close() return bc.Conn.Close() } type tcpPacketConn struct { params *tcpPacketParams // conns is a map of net.Conns indexed by remote net.Addr.String() conns map[string]net.Conn recvChan chan streamingPacket mu sync.Mutex wg sync.WaitGroup closedChan chan struct{} closeOnce sync.Once aliveTimer *time.Timer } type streamingPacket struct { Data []byte RAddr net.Addr Err error } type tcpPacketParams struct { ReadBuffer int LocalAddr net.Addr Logger logging.LeveledLogger WriteBuffer int AliveDuration time.Duration } func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn { packet := &tcpPacketConn{ params: ¶ms, conns: map[string]net.Conn{}, recvChan: make(chan streamingPacket, params.ReadBuffer), closedChan: make(chan struct{}), } if params.AliveDuration > 0 { packet.aliveTimer = time.AfterFunc(params.AliveDuration, func() { packet.params.Logger.Warn("close tcp packet conn by alive timeout") _ = packet.Close() }) } return packet } func (t *tcpPacketConn) ClearAliveTimer() { t.mu.Lock() if t.aliveTimer != nil { t.aliveTimer.Stop() } t.mu.Unlock() } func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error { t.params.Logger.Infof( "Added connection: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr(), ) t.mu.Lock() defer t.mu.Unlock() select { case <-t.closedChan: return io.ErrClosedPipe default: } if _, ok := t.conns[conn.RemoteAddr().String()]; ok { return fmt.Errorf("%w: %s", errConnectionAddrAlreadyExist, conn.RemoteAddr().String()) } if t.params.WriteBuffer > 0 { conn = newBufferedConn(conn, t.params.WriteBuffer, t.params.Logger) } t.conns[conn.RemoteAddr().String()] = conn t.wg.Add(1) go func() { defer t.wg.Done() if firstPacketData != nil { select { case <-t.closedChan: // NOTE: recvChan can fill up and never drain in edge // cases while closing a connection, which can cause the // packetConn to never finish closing. Bail out early // here to prevent that. return case t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil}: } } t.startReading(conn) }() return nil } func (t *tcpPacketConn) startReading(conn net.Conn) { buf := make([]byte, receiveMTU) for { n, err := readStreamingPacket(conn, buf) if err != nil { t.params.Logger.Warnf("Failed to read streaming packet: %s", err) last := t.removeConn(conn) // Only propagate connection closure errors if no other open connection exists. if last || (!errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed)) { t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err}) } return } data := make([]byte, n) copy(data, buf[:n]) t.handleRecv(streamingPacket{data, conn.RemoteAddr(), nil}) } } func (t *tcpPacketConn) handleRecv(pkt streamingPacket) { t.mu.Lock() recvChan := t.recvChan if t.isClosed() { recvChan = nil } t.mu.Unlock() select { case recvChan <- pkt: case <-t.closedChan: } } func (t *tcpPacketConn) isClosed() bool { select { case <-t.closedChan: return true default: return false } } // WriteTo is for passive and s-o candidates. func (t *tcpPacketConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) { pkt, ok := <-t.recvChan if !ok { return 0, nil, io.ErrClosedPipe } if pkt.Err != nil { return 0, pkt.RAddr, pkt.Err } if cap(b) < len(pkt.Data) { return 0, pkt.RAddr, io.ErrShortBuffer } n = len(pkt.Data) copy(b, pkt.Data[:n]) return n, pkt.RAddr, err } // WriteTo is for active and s-o candidates. func (t *tcpPacketConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { t.mu.Lock() conn, ok := t.conns[rAddr.String()] t.mu.Unlock() if !ok { return 0, io.ErrClosedPipe } n, err = writeStreamingPacket(conn, buf) if err != nil { t.params.Logger.Tracef("%w %s", errWrite, rAddr) return n, err } return n, err } func (t *tcpPacketConn) closeAndLogError(closer io.Closer) { err := closer.Close() if err != nil { t.params.Logger.Warnf("%v: %s", errClosingConnection, err) } } func (t *tcpPacketConn) removeConn(conn net.Conn) bool { t.mu.Lock() defer t.mu.Unlock() t.closeAndLogError(conn) // wait for some time to flush pending writes _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) // read deadline as well just in case _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) delete(t.conns, conn.RemoteAddr().String()) return len(t.conns) == 0 } func (t *tcpPacketConn) Close() error { t.mu.Lock() var shouldCloseRecvChan bool t.closeOnce.Do(func() { close(t.closedChan) shouldCloseRecvChan = true if t.aliveTimer != nil { t.aliveTimer.Stop() } }) for _, conn := range t.conns { t.closeAndLogError(conn) // wait for some time to flush pending writes _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) // read deadline as well just in case _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) delete(t.conns, conn.RemoteAddr().String()) } t.mu.Unlock() t.wg.Wait() if shouldCloseRecvChan { close(t.recvChan) } return nil } func (t *tcpPacketConn) LocalAddr() net.Addr { return t.params.LocalAddr } func (t *tcpPacketConn) SetDeadline(time.Time) error { return nil } func (t *tcpPacketConn) SetReadDeadline(time.Time) error { return nil } func (t *tcpPacketConn) SetWriteDeadline(time.Time) error { return nil } func (t *tcpPacketConn) CloseChannel() <-chan struct{} { return t.closedChan } func (t *tcpPacketConn) String() string { return fmt.Sprintf("tcpPacketConn{LocalAddr: %s}", t.params.LocalAddr) } ice-4.1.0/tcp_packet_conn_test.go000066400000000000000000000143531511704470200170030ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "bytes" "fmt" "io" "net" "sync" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) func TestBufferedConn_Write_ErrorAfterClose(t *testing.T) { defer test.CheckRoutines(t)() logger := logging.NewDefaultLoggerFactory().NewLogger("ice") c1, c2 := net.Pipe() defer func() { _ = c2.Close() }() bc := newBufferedConn(c1, 0, logger) require.NoError(t, bc.Close()) n, err := bc.Write([]byte("hello")) require.Error(t, err) require.Equal(t, 0, n) } type hookLogger struct { ch chan struct{} once sync.Once } func newHookLogger() *hookLogger { return &hookLogger{ch: make(chan struct{})} } func (l *hookLogger) Trace(msg string) {} func (l *hookLogger) Tracef(string, ...any) {} func (l *hookLogger) Debug(msg string) {} func (l *hookLogger) Debugf(string, ...any) {} func (l *hookLogger) Info(msg string) {} func (l *hookLogger) Infof(string, ...any) {} func (l *hookLogger) Warn(msg string) { l.once.Do(func() { close(l.ch) }) } func (l *hookLogger) Warnf(string, ...any) { l.once.Do(func() { close(l.ch) }) } func (l *hookLogger) Error(msg string) {} func (l *hookLogger) Errorf(string, ...any) {} func TestBufferedConn_writeProcess_ReadError(t *testing.T) { c1, c2 := net.Pipe() t.Cleanup(func() { _ = c2.Close() }) logger := newHookLogger() under := newBufferedConn(c1, 0, logger) bc, ok := under.(*bufferedConn) require.True(t, ok, "expected *bufferedConn") _ = bc.buf.SetReadDeadline(time.Unix(0, 0)) select { case <-logger.ch: case <-time.After(2 * time.Second): require.FailNow(t, "timed out waiting for read-error branch") } require.NoError(t, bc.Close()) } func TestBufferedConn_writeProcess_WriteError(t *testing.T) { c1, c2 := net.Pipe() logger := newHookLogger() under := newBufferedConn(c1, 0, logger) bc, ok := under.(*bufferedConn) require.True(t, ok, "expected *bufferedConn") require.NoError(t, c2.Close()) n, err := bc.Write([]byte("hello")) require.NoError(t, err) require.Equal(t, 5, n) select { case <-logger.ch: case <-time.After(2 * time.Second): require.FailNow(t, "timed out waiting for write-error branch") } require.NoError(t, bc.Close()) } func newTestTCPPC(t *testing.T, readBuf int) *tcpPacketConn { t.Helper() return newTCPPacketConn(tcpPacketParams{ ReadBuffer: readBuf, LocalAddr: &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}, Logger: logging.NewDefaultLoggerFactory().NewLogger("ice"), WriteBuffer: 0, AliveDuration: 0, }) } func TestTCPPacketConn_AddConn_ReturnsClosed(t *testing.T) { defer test.CheckRoutines(t)() tpc := newTestTCPPC(t, 8) require.NoError(t, tpc.Close()) c1, c2 := net.Pipe() defer func() { _ = c2.Close() }() err := tpc.AddConn(c1, nil) require.ErrorIs(t, err, io.ErrClosedPipe) _ = c1.Close() } func TestTCPPacketConn_AddConn_DuplicateRemoteAddr(t *testing.T) { defer test.CheckRoutines(t)() tpc := newTestTCPPC(t, 8) c1, c2 := net.Pipe() defer func() { _ = c2.Close() }() require.NoError(t, tpc.AddConn(c1, nil)) err := tpc.AddConn(c1, nil) require.ErrorIs(t, err, errConnectionAddrAlreadyExist) require.NoError(t, tpc.Close()) _ = c1.Close() } func TestTCPPacketConn_AddConn_FirstPacket_BailsOnClosed(t *testing.T) { defer test.CheckRoutines(t)() // unbuffered recvChan so send can't proceed until there is a receiver. tpc := newTestTCPPC(t, 0) c1, c2 := net.Pipe() defer func() { _ = c2.Close() }() // firstPacketData is non-nil, so the goroutine will try the select. require.NoError(t, tpc.AddConn(c1, []byte("hello"))) require.NoError(t, tpc.Close()) _ = c1.Close() } func TestTCPPacketConn_ReadFrom_ShortBuffer(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("ice") tpc := newTCPPacketConn(tcpPacketParams{ ReadBuffer: 1, // buffered channel so we can enqueue a packet LocalAddr: &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}, Logger: logger, WriteBuffer: 0, AliveDuration: 0, }) defer func() { _ = tpc.Close() }() raddr := &net.TCPAddr{IP: net.IP{10, 0, 0, 1}, Port: 4242} big := bytes.Repeat([]byte{0xAB}, 10) // packet larger than read buffer tpc.recvChan <- streamingPacket{Data: big, RAddr: raddr, Err: nil} smallBuf := make([]byte, 5) // cap=5 < len(big)=10 n, addr, err := tpc.ReadFrom(smallBuf) require.ErrorIs(t, err, io.ErrShortBuffer) require.Equal(t, 0, n) require.Equal(t, raddr.String(), addr.String()) } func TestTCPPacketConn_WriteTo_ErrorBranch_WithProvidedMock(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("ice") tpc := newTCPPacketConn(tcpPacketParams{ ReadBuffer: 1, LocalAddr: &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}, Logger: logger, WriteBuffer: 0, AliveDuration: 0, }) t.Cleanup(func() { _ = tpc.Close() }) mc := &mockConn{} tpc.mu.Lock() tpc.conns[mc.RemoteAddr().String()] = mc tpc.mu.Unlock() n, err := tpc.WriteTo([]byte("hello"), mc.RemoteAddr()) require.Equal(t, 0, n) require.ErrorIs(t, err, io.EOF) } func TestTCPPacketConn_SetDeadlines(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("ice") addr := &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 12345} tpc := newTCPPacketConn(tcpPacketParams{ ReadBuffer: 8, LocalAddr: addr, Logger: logger, WriteBuffer: 0, AliveDuration: 0, }) require.NoError(t, tpc.SetReadDeadline(time.Now().Add(200*time.Millisecond))) require.NoError(t, tpc.SetWriteDeadline(time.Now().Add(200*time.Millisecond))) require.NoError(t, tpc.Close()) require.NoError(t, tpc.SetReadDeadline(time.Now().Add(200*time.Millisecond))) require.NoError(t, tpc.SetWriteDeadline(time.Now().Add(200*time.Millisecond))) } func TestTCPPacketConn_String(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("ice") addr := &net.TCPAddr{IP: net.IP{10, 0, 0, 1}, Port: 54321} tpc := newTCPPacketConn(tcpPacketParams{ ReadBuffer: 1, LocalAddr: addr, Logger: logger, WriteBuffer: 0, AliveDuration: 0, }) got := tpc.String() want := fmt.Sprintf("tcpPacketConn{LocalAddr: %s}", addr) require.Equal(t, want, got) _ = tpc.Close() } ice-4.1.0/tcptype.go000066400000000000000000000023241511704470200142750ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import "strings" // TCPType is the type of ICE TCP candidate as described in // https://tools.ietf.org/html/rfc6544#section-4.5 type TCPType int const ( // TCPTypeUnspecified is the default value. For example UDP candidates do not // need this field. TCPTypeUnspecified TCPType = iota // TCPTypeActive is active TCP candidate, which initiates TCP connections. TCPTypeActive // TCPTypePassive is passive TCP candidate, only accepts TCP connections. TCPTypePassive // TCPTypeSimultaneousOpen is like active and passive at the same time. TCPTypeSimultaneousOpen ) // NewTCPType creates a new TCPType from string. func NewTCPType(value string) TCPType { switch strings.ToLower(value) { case "active": return TCPTypeActive case "passive": return TCPTypePassive case "so": return TCPTypeSimultaneousOpen default: return TCPTypeUnspecified } } func (t TCPType) String() string { switch t { case TCPTypeUnspecified: return "" case TCPTypeActive: return "active" case TCPTypePassive: return "passive" case TCPTypeSimultaneousOpen: return "so" default: return ErrUnknownType.Error() } } ice-4.1.0/tcptype_test.go000066400000000000000000000014261511704470200153360ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "github.com/stretchr/testify/require" ) func TestTCPType(t *testing.T) { var tcpType TCPType require.Equal(t, TCPTypeUnspecified, tcpType) require.Equal(t, TCPTypeActive, NewTCPType("active")) require.Equal(t, TCPTypePassive, NewTCPType("passive")) require.Equal(t, TCPTypeSimultaneousOpen, NewTCPType("so")) require.Equal(t, TCPTypeUnspecified, NewTCPType("something else")) require.Equal(t, "", TCPTypeUnspecified.String()) require.Equal(t, "active", TCPTypeActive.String()) require.Equal(t, "passive", TCPTypePassive.String()) require.Equal(t, "so", TCPTypeSimultaneousOpen.String()) require.Equal(t, "Unknown", TCPType(-1).String()) } ice-4.1.0/transport.go000066400000000000000000000071251511704470200146450ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "context" "net" "sync/atomic" "time" "github.com/pion/stun/v3" ) // Dial connects to the remote agent, acting as the controlling ice agent. // Dial blocks until at least one ice candidate pair has successfully connected. func (a *Agent) Dial(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) { return a.connect(ctx, true, remoteUfrag, remotePwd) } // Accept connects to the remote agent, acting as the controlled ice agent. // Accept blocks until at least one ice candidate pair has successfully connected. func (a *Agent) Accept(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) { return a.connect(ctx, false, remoteUfrag, remotePwd) } // Conn represents the ICE connection. // At the moment the lifetime of the Conn is equal to the Agent. type Conn struct { bytesReceived atomic.Uint64 bytesSent atomic.Uint64 agent *Agent } // BytesSent returns the number of bytes sent. func (c *Conn) BytesSent() uint64 { return c.bytesSent.Load() } // BytesReceived returns the number of bytes received. func (c *Conn) BytesReceived() uint64 { return c.bytesReceived.Load() } func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) { err := a.loop.Err() if err != nil { return nil, err } err = a.startConnectivityChecks(isControlling, remoteUfrag, remotePwd) //nolint:contextcheck if err != nil { return nil, err } // Block until pair selected select { case <-a.loop.Done(): return nil, a.loop.Err() case <-ctx.Done(): return nil, ErrCanceledByCaller case <-a.onConnected: } return &Conn{ agent: a, }, nil } // Read implements the Conn Read method. func (c *Conn) Read(p []byte) (int, error) { err := c.agent.loop.Err() if err != nil { return 0, err } n, err := c.agent.buf.Read(p) c.bytesReceived.Add(uint64(n)) //nolint:gosec // G115 return n, err } // Write implements the Conn Write method. func (c *Conn) Write(packet []byte) (int, error) { err := c.agent.loop.Err() if err != nil { return 0, err } if stun.IsMessage(packet) { return 0, errWriteSTUNMessageToIceConn } pair := c.agent.getSelectedPair() if pair == nil { if err = c.agent.loop.Run(c.agent.loop, func(_ context.Context) { pair = c.agent.getBestValidCandidatePair() }); err != nil { return 0, err } if pair == nil { return 0, err } } // Write application data via the selected pair and update stats with actual bytes written. n, err := pair.Write(packet) if n > 0 { c.bytesSent.Add(uint64(n)) pair.UpdatePacketSent(n) } return n, err } // Close implements the Conn Close method. It is used to close // the connection. Any calls to Read and Write will be unblocked and return an error. func (c *Conn) Close() error { return c.agent.Close() } // LocalAddr returns the local address of the current selected pair or nil if there is none. func (c *Conn) LocalAddr() net.Addr { pair := c.agent.getSelectedPair() if pair == nil { return nil } return pair.Local.addr() } // RemoteAddr returns the remote address of the current selected pair or nil if there is none. func (c *Conn) RemoteAddr() net.Addr { pair := c.agent.getSelectedPair() if pair == nil { return nil } return pair.Remote.addr() } // SetDeadline is a stub. func (c *Conn) SetDeadline(time.Time) error { return nil } // SetReadDeadline is a stub. func (c *Conn) SetReadDeadline(time.Time) error { return nil } // SetWriteDeadline is a stub. func (c *Conn) SetWriteDeadline(time.Time) error { return nil } ice-4.1.0/transport_test.go000066400000000000000000000205371511704470200157060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "context" "net" "net/netip" "sync" "testing" "time" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) func TestStressDuplex(t *testing.T) { // Check for leaking routines defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 20).Stop() // Run the test stressDuplex(t) } func testTimeout(t *testing.T, conn *Conn, timeout time.Duration) { t.Helper() const pollRate = 100 * time.Millisecond const margin = 20 * time.Millisecond // Allow 20msec error in time ticker := time.NewTicker(pollRate) defer func() { ticker.Stop() require.NoError(t, conn.Close()) }() startedAt := time.Now() for cnt := time.Duration(0); cnt <= timeout+defaultKeepaliveInterval+pollRate; cnt += pollRate { <-ticker.C var cs ConnectionState require.NoError(t, conn.agent.loop.Run(context.Background(), func(_ context.Context) { cs = conn.agent.connectionState })) if cs != ConnectionStateConnected { elapsed := time.Since(startedAt) require.Less(t, timeout, elapsed+margin) return } } t.Fatalf("Connection failed to time out in time. (expected timeout: %v)", timeout) //nolint } func TestTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } // Check for leaking routines defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 20).Stop() t.Run("WithoutDisconnectTimeout", func(t *testing.T) { ca, cb := pipe(t, nil) require.NoError(t, cb.Close()) testTimeout(t, ca, defaultDisconnectedTimeout) }) t.Run("WithDisconnectTimeout", func(t *testing.T) { ca, cb := pipeWithTimeout(t, 5*time.Second, 3*time.Second) require.NoError(t, cb.Close()) testTimeout(t, ca, 5*time.Second) }) } func TestReadClosed(t *testing.T) { // Check for leaking routines defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 20).Stop() ca, cb := pipe(t, nil) require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) empty := make([]byte, 10) _, err := ca.Read(empty) require.Error(t, err) } func stressDuplex(t *testing.T) { t.Helper() ca, cb := pipe(t, nil) defer func() { require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) }() opt := test.Options{ MsgSize: 10, MsgCount: 1, // Order not reliable due to UDP & potentially multiple candidate pairs. } require.NoError(t, test.StressDuplex(ca, cb, opt)) } func gatherAndExchangeCandidates(t *testing.T, aAgent, bAgent *Agent) { t.Helper() var wg sync.WaitGroup wg.Add(2) require.NoError(t, aAgent.OnCandidate(func(candidate Candidate) { if candidate == nil { wg.Done() } })) require.NoError(t, aAgent.GatherCandidates()) require.NoError(t, bAgent.OnCandidate(func(candidate Candidate) { if candidate == nil { wg.Done() } })) require.NoError(t, bAgent.GatherCandidates()) wg.Wait() candidates, err := aAgent.GetLocalCandidates() require.NoError(t, err) for _, c := range candidates { if addr, parseErr := netip.ParseAddr(c.Address()); parseErr == nil { require.False(t, shouldFilterLocationTrackedIP(addr)) } candidateCopy, copyErr := c.copy() require.NoError(t, copyErr) require.NoError(t, bAgent.AddRemoteCandidate(candidateCopy)) } candidates, err = bAgent.GetLocalCandidates() require.NoError(t, err) for _, c := range candidates { candidateCopy, copyErr := c.copy() require.NoError(t, copyErr) require.NoError(t, aAgent.AddRemoteCandidate(candidateCopy)) } } func connect(t *testing.T, aAgent, bAgent *Agent) (*Conn, *Conn) { t.Helper() gatherAndExchangeCandidates(t, aAgent, bAgent) accepted := make(chan struct{}) var aConn *Conn go func() { var acceptErr error bUfrag, bPwd, acceptErr := bAgent.GetLocalUserCredentials() require.NoError(t, acceptErr) aConn, acceptErr = aAgent.Accept(context.TODO(), bUfrag, bPwd) require.NoError(t, acceptErr) close(accepted) }() aUfrag, aPwd, err := aAgent.GetLocalUserCredentials() require.NoError(t, err) bConn, err := bAgent.Dial(context.TODO(), aUfrag, aPwd) require.NoError(t, err) // Ensure accepted <-accepted return aConn, bConn } func pipe(t *testing.T, defaultConfig *AgentConfig) (*Conn, *Conn) { t.Helper() var urls []*stun.URI aNotifier, aConnected := onConnected() bNotifier, bConnected := onConnected() cfg := &AgentConfig{} if defaultConfig != nil { *cfg = *defaultConfig } cfg.Urls = urls cfg.NetworkTypes = supportedNetworkTypes() aAgent, err := NewAgent(cfg) require.NoError(t, err) require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) t.Cleanup(func() { require.NoError(t, aAgent.Close()) }) bAgent, err := NewAgent(cfg) require.NoError(t, err) require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) t.Cleanup(func() { require.NoError(t, bAgent.Close()) }) aConn, bConn := connect(t, aAgent, bAgent) // Ensure pair selected // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair <-aConnected <-bConnected return aConn, bConn } func pipeWithTimeout(t *testing.T, disconnectTimeout time.Duration, iceKeepalive time.Duration) (*Conn, *Conn) { t.Helper() var urls []*stun.URI aNotifier, aConnected := onConnected() bNotifier, bConnected := onConnected() cfg := &AgentConfig{ Urls: urls, DisconnectedTimeout: &disconnectTimeout, KeepaliveInterval: &iceKeepalive, NetworkTypes: supportedNetworkTypes(), } aAgent, err := NewAgent(cfg) require.NoError(t, err) require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) t.Cleanup(func() { require.NoError(t, aAgent.Close()) }) bAgent, err := NewAgent(cfg) require.NoError(t, err) require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) t.Cleanup(func() { require.NoError(t, bAgent.Close()) }) aConn, bConn := connect(t, aAgent, bAgent) // Ensure pair selected // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair <-aConnected <-bConnected return aConn, bConn } func onConnected() (func(ConnectionState), chan struct{}) { done := make(chan struct{}) return func(state ConnectionState) { if state == ConnectionStateConnected { close(done) } }, done } func randomPort(tb testing.TB) int { tb.Helper() conn, err := net.ListenPacket("udp4", "127.0.0.1:0") // nolint: noctx if err != nil { tb.Fatalf("failed to pickPort: %v", err) } defer func() { _ = conn.Close() }() switch addr := conn.LocalAddr().(type) { case *net.UDPAddr: return addr.Port default: tb.Fatalf("unknown addr type %T", addr) return 0 } } func TestConnStats(t *testing.T) { // Check for leaking routines defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 20).Stop() ca, cb := pipe(t, nil) _, err := ca.Write(make([]byte, 10)) require.NoError(t, err) defer closePipe(t, ca, cb) var wg sync.WaitGroup wg.Add(1) go func() { buf := make([]byte, 10) _, err := cb.Read(buf) require.NoError(t, err) wg.Done() }() wg.Wait() require.Equal(t, uint64(10), ca.BytesSent()) require.Equal(t, uint64(10), cb.BytesReceived()) } func TestAgent_connect_ErrEarly(t *testing.T) { defer test.CheckRoutines(t)() cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), } a, err := NewAgent(cfg) require.NoError(t, err) require.NoError(t, a.Close()) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // isControlling = true conn, cerr := a.connect(ctx, true, "ufragX", "pwdX") require.Nil(t, conn) require.Error(t, cerr, "expected error from a.loop.Err() short-circuit") } func TestConn_Write_RejectsSTUN(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(10 * time.Second).Stop() cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), MulticastDNSMode: MulticastDNSModeDisabled, } a, err := NewAgent(cfg) require.NoError(t, err) defer func() { _ = a.Close() }() c := &Conn{agent: a} require.Nil(t, c.agent.getSelectedPair(), "precondition: no selected pair") msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} msg.Encode() n, werr := c.Write(msg.Raw) require.Zero(t, n) require.ErrorIs(t, werr, errWriteSTUNMessageToIceConn) } ice-4.1.0/transport_vnet_test.go000066400000000000000000000045241511704470200167400ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "fmt" "net" "testing" "time" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/pion/transport/v3/vnet" "github.com/stretchr/testify/require" ) func TestRemoteLocalAddr(t *testing.T) { // Check for leaking routines defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 20).Stop() // Agent0 is behind 1:1 NAT natType0 := &vnet.NATType{Mode: vnet.NATModeNAT1To1} // Agent1 is behind 1:1 NAT natType1 := &vnet.NATType{Mode: vnet.NATModeNAT1To1} builtVnet, errVnet := buildVNet(natType0, natType1) require.NoError(t, errVnet, "should succeed") defer builtVnet.close() stunServerURL := &stun.URI{ Scheme: stun.SchemeTypeSTUN, Host: vnetSTUNServerIP, Port: vnetSTUNServerPort, Proto: stun.ProtoTypeUDP, } t.Run("Disconnected Returns nil", func(t *testing.T) { disconnectedAgent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) disconnectedConn := Conn{agent: disconnectedAgent} require.Nil(t, disconnectedConn.RemoteAddr()) require.Nil(t, disconnectedConn.LocalAddr()) require.NoError(t, disconnectedConn.Close()) }) t.Run("Remote/Local Pair Match between Agents", func(t *testing.T) { ca, cb := pipeWithVNet(t, builtVnet, &agentTestConfig{ urls: []*stun.URI{stunServerURL}, }, &agentTestConfig{ urls: []*stun.URI{stunServerURL}, }, ) defer closePipe(t, ca, cb) aRAddr := ca.RemoteAddr() aLAddr := ca.LocalAddr() bRAddr := cb.RemoteAddr() bLAddr := cb.LocalAddr() // Assert that nothing is nil require.NotNil(t, aRAddr) require.NotNil(t, aLAddr) require.NotNil(t, bRAddr) require.NotNil(t, bLAddr) // Assert addresses require.Equal(t, aLAddr.String(), fmt.Sprintf("%s:%d", vnetLocalIPA, bRAddr.(*net.UDPAddr).Port), //nolint:forcetypeassert ) require.Equal(t, bLAddr.String(), fmt.Sprintf("%s:%d", vnetLocalIPB, aRAddr.(*net.UDPAddr).Port), //nolint:forcetypeassert ) require.Equal(t, aRAddr.String(), fmt.Sprintf("%s:%d", vnetGlobalIPB, bLAddr.(*net.UDPAddr).Port), //nolint:forcetypeassert ) require.Equal(t, bRAddr.String(), fmt.Sprintf("%s:%d", vnetGlobalIPA, aLAddr.(*net.UDPAddr).Port), //nolint:forcetypeassert ) }) } ice-4.1.0/udp_mux.go000066400000000000000000000234731511704470200142760ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "errors" "io" "net" "net/netip" "os" "strings" "sync" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" ) // UDPMux allows multiple connections to go over a single UDP port. type UDPMux interface { io.Closer GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) RemoveConnByUfrag(ufrag string) GetListenAddresses() []net.Addr } // UDPMuxDefault is an implementation of the interface. type UDPMuxDefault struct { params UDPMuxParams closedChan chan struct{} closeOnce sync.Once // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType connsIPv4, connsIPv6 map[string]*udpMuxedConn addressMapMu sync.RWMutex addressMap map[ipPort]*udpMuxedConn // Buffer pool to recycle buffers for net.UDPAddr encodes/decodes pool *sync.Pool mu sync.Mutex // For UDP connection listen at unspecified address localAddrsForUnspecified []net.Addr } // UDPMuxParams are parameters for UDPMux. type UDPMuxParams struct { Logger logging.LeveledLogger UDPConn net.PacketConn UDPConnString string // Required for gathering local addresses // in case a un UDPConn is passed which does not // bind to a specific local address. Net transport.Net } // NewUDPMuxDefault creates an implementation of UDPMux. func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { //nolint:cyclop if params.Logger == nil { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } var localAddrsForUnspecified []net.Addr if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { //nolint:nestif params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr()) } else if ok && udpAddr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection // with UDPMuxDefault, so print a warn log and create a local address list for mux. params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []NetworkType switch { case udpAddr.IP.To4() != nil: networks = []NetworkType{NetworkTypeUDP4} case udpAddr.IP.To16() != nil: networks = []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6} default: params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) } if len(networks) > 0 { if params.Net == nil { var err error if params.Net, err = stdnet.NewNet(); err != nil { params.Logger.Errorf("Failed to get create network: %v", err) } } _, addrs, err := localInterfaces(params.Net, nil, nil, networks, true) if err == nil { localAddrsForUnspecified = make([]net.Addr, len(addrs)) for i, addr := range addrs { localAddrsForUnspecified[i] = &net.UDPAddr{ IP: addr.addr.AsSlice(), Port: udpAddr.Port, Zone: addr.addr.Zone(), } } } else { params.Logger.Errorf("Failed to get local interfaces for unspecified addr: %v", err) } } } params.UDPConnString = params.UDPConn.LocalAddr().String() mux := &UDPMuxDefault{ addressMap: map[ipPort]*udpMuxedConn{}, params: params, connsIPv4: make(map[string]*udpMuxedConn), connsIPv6: make(map[string]*udpMuxedConn), closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() any { // Big enough buffer to fit both packet and address return newBufferHolder(receiveMTU) }, }, localAddrsForUnspecified: localAddrsForUnspecified, } go mux.connWorker() return mux } // LocalAddr returns the listening address of this UDPMuxDefault. func (m *UDPMuxDefault) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetListenAddresses returns the list of addresses that this mux is listening on. func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { if len(m.localAddrsForUnspecified) > 0 { return m.localAddrsForUnspecified } return []net.Addr{m.LocalAddr()} } // GetConn returns a PacketConn given the connection's ufrag and network address. // creates the connection if an existing one can't be found. func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { // don't check addr for mux using unspecified address if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConnString != addr.String() { return nil, errInvalidAddress } var isIPv6 bool if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { isIPv6 = true } m.mu.Lock() defer m.mu.Unlock() if m.IsClosed() { return nil, io.ErrClosedPipe } if conn, ok := m.getConn(ufrag, isIPv6); ok { return conn, nil } c := m.createMuxedConn(ufrag) go func() { <-c.CloseChannel() m.RemoveConnByUfrag(ufrag) }() if isIPv6 { m.connsIPv6[ufrag] = c } else { m.connsIPv4[ufrag] = c } return c, nil } // RemoveConnByUfrag stops and removes the muxed packet connection. func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) // Keep lock section small to avoid deadlock with conn lock. m.mu.Lock() if c, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) removedConns = append(removedConns, c) } if c, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) removedConns = append(removedConns, c) } m.mu.Unlock() if len(removedConns) == 0 { // No need to lock if no connection was found. return } m.addressMapMu.Lock() defer m.addressMapMu.Unlock() for _, c := range removedConns { addresses := c.getAddresses() for _, addr := range addresses { delete(m.addressMap, addr) } } } // IsClosed returns true if the mux had been closed. func (m *UDPMuxDefault) IsClosed() bool { select { case <-m.closedChan: return true default: return false } } // Close the mux, no further connections could be created. func (m *UDPMuxDefault) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() defer m.mu.Unlock() for _, c := range m.connsIPv4 { _ = c.Close() } for _, c := range m.connsIPv6 { _ = c.Close() } m.connsIPv4 = make(map[string]*udpMuxedConn) m.connsIPv6 = make(map[string]*udpMuxedConn) close(m.closedChan) _ = m.params.UDPConn.Close() }) return err } func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, rAddr) } func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr ipPort) { if m.IsClosed() { return } m.addressMapMu.Lock() defer m.addressMapMu.Unlock() existing, ok := m.addressMap[addr] if ok { existing.removeAddress(addr) } m.addressMap[addr] = conn m.params.Logger.Debugf("Registered %s for %s", addr.addr.String(), conn.params.Key) } func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ Mux: m, Key: key, AddrPool: m.pool, LocalAddr: m.LocalAddr(), Logger: m.params.Logger, }) return c } func (m *UDPMuxDefault) connWorker() { //nolint:cyclop logger := m.params.Logger defer func() { _ = m.Close() }() buf := make([]byte, receiveMTU) for { n, addr, err := m.params.UDPConn.ReadFrom(buf) if m.IsClosed() { return } else if err != nil { if os.IsTimeout(err) { continue } else if !errors.Is(err, io.EOF) { logger.Errorf("Failed to read UDP packet: %v", err) } return } netUDPAddr, ok := addr.(*net.UDPAddr) if !ok { logger.Errorf("Underlying PacketConn did not return a UDPAddr") return } udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) //nolint:gosec if err != nil { logger.Errorf("Failed to create a new IP/Port host pair") return } // If we have already seen this address dispatch to the appropriate destination m.addressMapMu.Lock() destinationConn := m.addressMap[udpAddr] m.addressMapMu.Unlock() // If we haven't seen this address before but is a STUN packet lookup by ufrag if destinationConn == nil && stun.IsMessage(buf[:n]) { msg := &stun.Message{ Raw: append([]byte{}, buf[:n]...), } if err = msg.Decode(); err != nil { m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err) continue } attr, stunAttrErr := msg.Get(stun.AttrUsername) if stunAttrErr != nil { m.params.Logger.Warnf("No Username attribute in STUN message from %s", addr.String()) continue } ufrag := strings.Split(string(attr), ":")[0] isIPv6 := netUDPAddr.IP.To4() == nil m.mu.Lock() destinationConn, _ = m.getConn(ufrag, isIPv6) m.mu.Unlock() } if destinationConn == nil { m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr.addr, addr) continue } if err = destinationConn.writePacket(buf[:n], netUDPAddr); err != nil { m.params.Logger.Errorf("Failed to write packet: %v", err) } } } func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { if isIPv6 { val, ok = m.connsIPv6[ufrag] } else { val, ok = m.connsIPv4[ufrag] } return } type bufferHolder struct { next *bufferHolder buf []byte addr *net.UDPAddr } func newBufferHolder(size int) *bufferHolder { return &bufferHolder{ buf: make([]byte, size), } } func (b *bufferHolder) reset() { b.next = nil b.addr = nil } type ipPort struct { addr netip.Addr port uint16 } // newIPPort create a custom type of address based on netip.Addr and // port. The underlying ip address passed is converted to IPv6 format // to simplify ip address handling. func newIPPort(ip net.IP, zone string, port uint16) (ipPort, error) { n, ok := netip.AddrFromSlice(ip.To16()) if !ok { return ipPort{}, errInvalidIPAddress } return ipPort{ addr: n.WithZone(zone), port: port, }, nil } ice-4.1.0/udp_mux_multi.go000066400000000000000000000142151511704470200155020ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" "net" "github.com/pion/logging" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" ) // MultiUDPMuxDefault implements both UDPMux and AllConnsGetter, // allowing users to pass multiple UDPMux instances to the ICE agent // configuration. type MultiUDPMuxDefault struct { muxes []UDPMux localAddrToMux map[string]UDPMux } // NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that // uses the provided UDPMux instances. func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault { addrToMux := make(map[string]UDPMux) for _, mux := range muxes { for _, addr := range mux.GetListenAddresses() { addrToMux[addr.String()] = mux } } return &MultiUDPMuxDefault{ muxes: muxes, localAddrToMux: addrToMux, } } // GetConn returns a PacketConn given the connection's ufrag and network // creates the connection if an existing one can't be found. func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { mux, ok := m.localAddrToMux[addr.String()] if !ok { return nil, errNoUDPMuxAvailable } return mux.GetConn(ufrag, addr) } // RemoveConnByUfrag stops and removes the muxed packet connection // from all underlying UDPMux instances. func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) { for _, mux := range m.muxes { mux.RemoveConnByUfrag(ufrag) } } // Close the multi mux, no further connections could be created. func (m *MultiUDPMuxDefault) Close() error { var err error for _, mux := range m.muxes { if e := mux.Close(); e != nil { err = e } } return err } // GetListenAddresses returns the list of addresses that this mux is listening on. func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr { addrs := make([]net.Addr, 0, len(m.localAddrToMux)) for _, mux := range m.muxes { addrs = append(addrs, mux.GetListenAddresses()...) } return addrs } // NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that // listen all interfaces on the provided port. func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { //nolint:cyclop params := multiUDPMuxFromPortParam{ networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, } for _, opt := range opts { opt.apply(¶ms) } if params.net == nil { var err error if params.net, err = stdnet.NewNet(); err != nil { return nil, fmt.Errorf("failed to get create network: %w", err) } } _, addrs, err := localInterfaces(params.net, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback) if err != nil { return nil, err } conns := make([]net.PacketConn, 0, len(addrs)) for _, addr := range addrs { conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{ IP: addr.addr.AsSlice(), Port: port, Zone: addr.addr.Zone(), }) if listenErr != nil { err = listenErr break } if params.readBufferSize > 0 { _ = conn.SetReadBuffer(params.readBufferSize) } if params.writeBufferSize > 0 { _ = conn.SetWriteBuffer(params.writeBufferSize) } conns = append(conns, conn) } if err != nil { for _, conn := range conns { _ = conn.Close() } return nil, err } muxes := make([]UDPMux, 0, len(conns)) for _, conn := range conns { mux := NewUDPMuxDefault(UDPMuxParams{ Logger: params.logger, UDPConn: conn, Net: params.net, }) muxes = append(muxes, mux) } return NewMultiUDPMuxDefault(muxes...), nil } // UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort. type UDPMuxFromPortOption interface { apply(*multiUDPMuxFromPortParam) } type multiUDPMuxFromPortParam struct { ifFilter func(string) (keep bool) ipFilter func(ip net.IP) (keep bool) networks []NetworkType readBufferSize int writeBufferSize int logger logging.LeveledLogger includeLoopback bool net transport.Net } type udpMuxFromPortOption struct { f func(*multiUDPMuxFromPortParam) } func (o *udpMuxFromPortOption) apply(p *multiUDPMuxFromPortParam) { o.f(p) } // UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used. func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPortOption { return &udpMuxFromPortOption{ f: func(p *multiUDPMuxFromPortParam) { p.ifFilter = f }, } } // UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used. func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOption { return &udpMuxFromPortOption{ f: func(p *multiUDPMuxFromPortParam) { p.ipFilter = f }, } } // UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6. func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption { return &udpMuxFromPortOption{ f: func(p *multiUDPMuxFromPortParam) { p.networks = networks }, } } // UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size. func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption { return &udpMuxFromPortOption{ f: func(p *multiUDPMuxFromPortParam) { p.readBufferSize = size }, } } // UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size. func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption { return &udpMuxFromPortOption{ f: func(p *multiUDPMuxFromPortParam) { p.writeBufferSize = size }, } } // UDPMuxFromPortWithLogger set the logger for the created UDPMux. func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption { return &udpMuxFromPortOption{ f: func(p *multiUDPMuxFromPortParam) { p.logger = logger }, } } // UDPMuxFromPortWithLoopback set loopback interface should be included. func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption { return &udpMuxFromPortOption{ f: func(p *multiUDPMuxFromPortParam) { p.includeLoopback = true }, } } // UDPMuxFromPortWithNet sets the network transport to use. func UDPMuxFromPortWithNet(n transport.Net) UDPMuxFromPortOption { return &udpMuxFromPortOption{ f: func(p *multiUDPMuxFromPortParam) { p.net = n }, } } ice-4.1.0/udp_mux_multi_test.go000066400000000000000000000250061511704470200165410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "errors" "net" "sync" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) func TestMultiUDPMux(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() conn1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) conn2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) conn3, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6loopback}) if err != nil { // IPv6 is not supported on this machine t.Log("ipv6 is not supported on this machine") } muxes := []UDPMux{} muxV41 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1}) muxes = append(muxes, muxV41) muxV42 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2}) muxes = append(muxes, muxV42) if conn3 != nil { muxV6 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn3}) muxes = append(muxes, muxV6) } udpMuxMulti := NewMultiUDPMuxDefault(muxes...) defer func() { _ = udpMuxMulti.Close() _ = conn1.Close() _ = conn2.Close() }() wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp) }() wg.Add(1) go func() { defer wg.Done() testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4) }() testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6) wg.Wait() require.NoError(t, udpMuxMulti.Close()) // Can't create more connections _, err = udpMuxMulti.GetConn("failufrag", conn1.LocalAddr()) require.Error(t, err) } func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) { t.Helper() addrs := udpMuxMulti.GetListenAddresses() pktConns := make([]net.PacketConn, 0, len(addrs)) for _, addr := range addrs { udpAddr, ok := addr.(*net.UDPAddr) require.True(t, ok) if network == udp4 && udpAddr.IP.To4() == nil { continue } else if network == udp6 && udpAddr.IP.To4() != nil { continue } c, err := udpMuxMulti.GetConn(ufrag, addr) require.NoError(t, err, "error retrieving muxed connection for ufrag") pktConns = append(pktConns, c) } defer func() { for _, c := range pktConns { _ = c.Close() } }() // Try talking with each PacketConn for _, pktConn := range pktConns { remoteConn, err := net.DialUDP(network, nil, pktConn.LocalAddr().(*net.UDPAddr)) // nolint require.NoError(t, err, "error dialing test UDP connection") testMuxConnectionPair(t, pktConn, remoteConn, ufrag) } } func TestUnspecifiedUDPMux(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() muxPort := 7778 udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(problematicNetworkInterfaces)) require.NoError(t, err) require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes") defer func() { _ = udpMuxMulti.Close() }() wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp) }() wg.Add(1) go func() { defer wg.Done() testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4) }() testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6) wg.Wait() require.NoError(t, udpMuxMulti.Close()) } func TestMultiUDPMux_GetConn_NoUDPMuxAvailable(t *testing.T) { conn1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) defer func() { _ = conn1.Close() }() conn2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) defer func() { _ = conn2.Close() }() mux1 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1}) mux2 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2}) multi := NewMultiUDPMuxDefault(mux1, mux2) defer func() { _ = multi.Close() }() // Pick a port that is guaranteed not to match any listening address addrs := multi.GetListenAddresses() require.NotEmpty(t, addrs) udpAddr, ok := addrs[0].(*net.UDPAddr) require.True(t, ok, "expected *net.UDPAddr") // Build a set of in-use ports inUse := make(map[int]struct{}, len(addrs)) for _, a := range addrs { if ua, ok := a.(*net.UDPAddr); ok { inUse[ua.Port] = struct{}{} } } // Find a nearby port not in use newPort := udpAddr.Port + 1 for i := 0; i < 100; i++ { if _, exists := inUse[newPort]; !exists { break } newPort++ } missing := &net.UDPAddr{IP: udpAddr.IP, Port: newPort, Zone: udpAddr.Zone} pc, getErr := multi.GetConn("missing-ufrag", missing) require.Nil(t, pc) require.ErrorIs(t, getErr, errNoUDPMuxAvailable) } type closeErrUDPMux struct { UDPMux ret error } func (w *closeErrUDPMux) Close() error { _ = w.UDPMux.Close() // ensure underlying resources are released return w.ret } var ( errCloseBoom = errors.New("close boom") errCloseFirst = errors.New("first close failed") errCloseSecond = errors.New("second close failed") ) func TestMultiUDPMux_Close_PropagatesError(t *testing.T) { udp1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) udp2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) mux1 := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp1}) mux2real := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp2}) mux2 := &closeErrUDPMux{UDPMux: mux2real, ret: errCloseBoom} multi := NewMultiUDPMuxDefault(mux1, mux2) got := multi.Close() require.ErrorIs(t, got, errCloseBoom) } func TestMultiUDPMux_Close_LastErrorWins(t *testing.T) { udpA, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) udpB, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) muxAReal := NewUDPMuxDefault(UDPMuxParams{UDPConn: udpA}) muxBReal := NewUDPMuxDefault(UDPMuxParams{UDPConn: udpB}) muxA := &closeErrUDPMux{UDPMux: muxAReal, ret: errCloseFirst} muxB := &closeErrUDPMux{UDPMux: muxBReal, ret: errCloseSecond} multi := NewMultiUDPMuxDefault(muxA, muxB) got := multi.Close() require.ErrorIs(t, got, errCloseSecond) } func TestUDPMuxFromPortOptions_Apply(t *testing.T) { t.Run("IPFilter", func(t *testing.T) { var p multiUDPMuxFromPortParam keepLoopbackV4 := func(ip net.IP) bool { return ip.IsLoopback() && ip.To4() != nil } opt := UDPMuxFromPortWithIPFilter(keepLoopbackV4) opt.apply(&p) require.NotNil(t, p.ipFilter) require.True(t, p.ipFilter(net.ParseIP("127.0.0.1"))) require.False(t, p.ipFilter(net.ParseIP("8.8.8.8"))) }) t.Run("Networks single", func(t *testing.T) { var p multiUDPMuxFromPortParam opt := UDPMuxFromPortWithNetworks(NetworkTypeUDP4) opt.apply(&p) require.Len(t, p.networks, 1) require.Equal(t, NetworkTypeUDP4, p.networks[0]) }) t.Run("Networks multiple", func(t *testing.T) { var p multiUDPMuxFromPortParam opt := UDPMuxFromPortWithNetworks(NetworkTypeUDP4, NetworkTypeUDP6) opt.apply(&p) require.Len(t, p.networks, 2) require.ElementsMatch(t, []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, p.networks) }) t.Run("ReadBufferSize", func(t *testing.T) { var p multiUDPMuxFromPortParam opt := UDPMuxFromPortWithReadBufferSize(4096) opt.apply(&p) require.Equal(t, 4096, p.readBufferSize) }) t.Run("WriteBufferSize", func(t *testing.T) { var p multiUDPMuxFromPortParam opt := UDPMuxFromPortWithWriteBufferSize(8192) opt.apply(&p) require.Equal(t, 8192, p.writeBufferSize) }) t.Run("Logger", func(t *testing.T) { var p multiUDPMuxFromPortParam logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test") opt := UDPMuxFromPortWithLogger(logger) opt.apply(&p) require.NotNil(t, p.logger) require.Equal(t, logger, p.logger) }) t.Run("Net", func(t *testing.T) { var p multiUDPMuxFromPortParam n, err := stdnet.NewNet() require.NoError(t, err) opt := UDPMuxFromPortWithNet(n) opt.apply(&p) require.NotNil(t, p.net) require.Equal(t, n, p.net) }) } func TestNewMultiUDPMuxFromPort_PortInUse_ListenErrorAndCleanup(t *testing.T) { pre, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) defer func() { _ = pre.Close() }() srvAddr, ok := pre.LocalAddr().(*net.UDPAddr) require.True(t, ok, "pre.LocalAddr is not *net.UDPAddr") port := srvAddr.Port multi, buildErr := NewMultiUDPMuxFromPort( port, UDPMuxFromPortWithLoopback(), UDPMuxFromPortWithNetworks(NetworkTypeUDP4), ) require.Nil(t, multi) require.Error(t, buildErr) } func TestNewMultiUDPMuxFromPort_Success_SetsBuffers(t *testing.T) { multi, err := NewMultiUDPMuxFromPort( 0, UDPMuxFromPortWithLoopback(), UDPMuxFromPortWithNetworks(NetworkTypeUDP4), UDPMuxFromPortWithReadBufferSize(4096), UDPMuxFromPortWithWriteBufferSize(8192), ) require.NoError(t, err) require.NotNil(t, multi) addrs := multi.GetListenAddresses() require.NotEmpty(t, addrs) require.NoError(t, multi.Close()) } func TestNewMultiUDPMuxFromPort_CleanupClosesAll(t *testing.T) { stdNet, err := stdnet.NewNet() require.NoError(t, err) _, addrs, err := localInterfaces(stdNet, nil, nil, []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, true) require.NoError(t, err) if len(addrs) < 2 { t.Skip("need at least two local addresses to hit partial-success then failure") } second := addrs[1].addr l2, err := stdNet.ListenUDP("udp", &net.UDPAddr{ IP: second.AsSlice(), Port: 0, Zone: second.Zone(), }) require.NoError(t, err) defer func() { _ = l2.Close() }() udpAddr2, ok := l2.LocalAddr().(*net.UDPAddr) require.True(t, ok, "LocalAddr is not *net.UDPAddr") picked := udpAddr2.Port preBinds := []net.PacketConn{l2} for i := 2; i < len(addrs); i++ { a := addrs[i].addr l, e := stdNet.ListenUDP("udp", &net.UDPAddr{ IP: a.AsSlice(), Port: picked, Zone: a.Zone(), }) if e == nil { preBinds = append(preBinds, l) } } t.Cleanup(func() { for _, c := range preBinds { _ = c.Close() } }) require.GreaterOrEqual(t, len(preBinds), 1, "need at least one prebound address after the first") multi, buildErr := NewMultiUDPMuxFromPort( picked, UDPMuxFromPortWithNet(stdNet), UDPMuxFromPortWithNetworks(NetworkTypeUDP4, NetworkTypeUDP6), UDPMuxFromPortWithLoopback(), ) require.Nil(t, multi) require.Error(t, buildErr) first := addrs[0].addr rebind, err := stdNet.ListenUDP("udp", &net.UDPAddr{ IP: first.AsSlice(), Port: picked, Zone: first.Zone(), }) require.NoError(t, err, "expected first address/port to be free after cleanup") _ = rebind.Close() } ice-4.1.0/udp_mux_test.go000066400000000000000000000542251511704470200153340ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "crypto/rand" "crypto/sha256" "encoding/binary" "errors" "io" "net" "sync" "testing" "time" "github.com/pion/ice/v4/internal/fakenet" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) func TestUDPMux(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 30).Stop() conn4, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) conn6, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6loopback}) if err != nil { t.Log("IPv6 is not supported on this machine") } connUnspecified, err := net.ListenUDP(udp, nil) require.NoError(t, err) conn4Unspecified, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4zero}) require.NoError(t, err) conn6Unspecified, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6unspecified}) if err != nil { t.Log("IPv6 is not supported on this machine") } type testCase struct { name string conn net.PacketConn network string } testCases := []testCase{ {name: "IPv4loopback", conn: conn4, network: udp4}, {name: "IPv6loopback", conn: conn6, network: udp6}, {name: "Unspecified", conn: connUnspecified, network: udp}, {name: "IPv4Unspecified", conn: conn4Unspecified, network: udp4}, {name: "IPv6Unspecified", conn: conn6Unspecified, network: udp6}, } if ipv6Available(t) { addr6 := getLocalIPAddress(t, NetworkTypeUDP6) conn6Unspecified, listenEerr := net.ListenUDP(udp, &net.UDPAddr{ IP: addr6.AsSlice(), Zone: addr6.Zone(), }) if listenEerr != nil { t.Log("IPv6 is not supported on this machine") } testCases = append(testCases, testCase{name: "IPv6Specified", conn: conn6Unspecified, network: udp6}, ) } for _, subTest := range testCases { network, conn := subTest.network, subTest.conn if udpConn, ok := conn.(*net.UDPConn); !ok || udpConn == nil { continue } t.Run(subTest.name, func(t *testing.T) { udpMux := NewUDPMuxDefault(UDPMuxParams{ Logger: nil, UDPConn: conn, }) defer func() { _ = udpMux.Close() _ = conn.Close() }() require.NotNil(t, udpMux.LocalAddr(), "udpMux.LocalAddr() is nil") wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() testMuxConnection(t, udpMux, "ufrag1", udp) }() const ptrSize = 32 << (^uintptr(0) >> 63) if network == udp { wg.Add(1) go func() { defer wg.Done() testMuxConnection(t, udpMux, "ufrag2", udp4) }() testMuxConnection(t, udpMux, "ufrag3", udp6) } else if ptrSize != 32 || network != udp6 { testMuxConnection(t, udpMux, "ufrag2", network) } wg.Wait() require.NoError(t, udpMux.Close()) // Can't create more connections _, err = udpMux.GetConn("failufrag", udpMux.LocalAddr()) require.Error(t, err) }) } } func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) { t.Helper() pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr()) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() }() addr, ok := pktConn.LocalAddr().(*net.UDPAddr) require.True(t, ok, "pktConn.LocalAddr() is not a net.UDPAddr") if addr.IP.IsUnspecified() { addr = &net.UDPAddr{Port: addr.Port} } remoteConn, err := net.DialUDP(network, nil, addr) require.NoError(t, err, "error dialing test UDP connection") testMuxConnectionPair(t, pktConn, remoteConn, ufrag) } func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) { t.Helper() // Initial messages are dropped _, err := remoteConn.Write([]byte("dropped bytes")) require.NoError(t, err) // Wait for packet to be consumed time.Sleep(time.Millisecond) // Write out to establish connection msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag")) msg.Encode() _, err = pktConn.WriteTo(msg.Raw, remoteConn.LocalAddr()) require.NoError(t, err) // Ensure received buf := make([]byte, receiveMTU) n, err := remoteConn.Read(buf) require.NoError(t, err) require.Equal(t, msg.Raw, buf[:n]) // Start writing packets through mux targetSize := 1 * 1024 * 1024 readDone := make(chan struct{}, 1) remoteReadDone := make(chan struct{}, 1) // Read packets from the muxed side go func() { defer func() { t.Logf("closing read chan for: %s", ufrag) close(readDone) }() readBuf := make([]byte, receiveMTU) nextSeq := uint32(0) for read := 0; read < targetSize; { n, _, err := pktConn.ReadFrom(readBuf) require.NoError(t, err) require.Equal(t, receiveMTU, n) verifyPacket(t, readBuf[:n], nextSeq) // Write it back to sender _, err = pktConn.WriteTo(readBuf[:n], remoteConn.LocalAddr()) require.NoError(t, err) read += n nextSeq++ } }() go func() { defer func() { close(remoteReadDone) }() readBuf := make([]byte, receiveMTU) nextSeq := uint32(0) for read := 0; read < targetSize; { n, _, err := remoteConn.ReadFrom(readBuf) require.NoError(t, err) require.Equal(t, receiveMTU, n) verifyPacket(t, readBuf[:n], nextSeq) read += n nextSeq++ } }() sequence := 0 for written := 0; written < targetSize; { buf := make([]byte, receiveMTU) // Byte 0-4: sequence // Bytes 4-36: sha256 checksum // Bytes2 36-mtu: random data _, err := rand.Read(buf[36:]) require.NoError(t, err) h := sha256.Sum256(buf[36:]) copy(buf[4:36], h[:]) binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence)) //nolint:gosec // G115 _, err = remoteConn.Write(buf) require.NoError(t, err) written += len(buf) sequence++ time.Sleep(time.Millisecond) } <-readDone <-remoteReadDone } func verifyPacket(t *testing.T, b []byte, nextSeq uint32) { t.Helper() readSeq := binary.LittleEndian.Uint32(b[0:4]) require.Equal(t, nextSeq, readSeq) h := sha256.Sum256(b[36:]) require.Equal(t, h[:], b[4:36]) } func TestUDPMux_Agent_Restart(t *testing.T) { oneSecond := time.Second connA, connB := pipe(t, &AgentConfig{ DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) defer closePipe(t, connA, connB) aNotifier, aConnected := onConnected() require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier)) bNotifier, bConnected := onConnected() require.NoError(t, connB.agent.OnConnectionStateChange(bNotifier)) // Maintain Credentials across restarts ufragA, pwdA, err := connA.agent.GetLocalUserCredentials() require.NoError(t, err) ufragB, pwdB, err := connB.agent.GetLocalUserCredentials() require.NoError(t, err) require.NoError(t, err) // Restart and Re-Signal require.NoError(t, connA.agent.Restart(ufragA, pwdA)) require.NoError(t, connB.agent.Restart(ufragB, pwdB)) require.NoError(t, connA.agent.SetRemoteCredentials(ufragB, pwdB)) require.NoError(t, connB.agent.SetRemoteCredentials(ufragA, pwdA)) gatherAndExchangeCandidates(t, connA.agent, connB.agent) // Wait until both have gone back to connected <-aConnected <-bConnected } func secondTestMuxedConn(t *testing.T, capBytes int) *udpMuxedConn { t.Helper() pool := &sync.Pool{ New: func() any { return &bufferHolder{buf: make([]byte, capBytes)} }, } params := &udpMuxedConnParams{ AddrPool: pool, LocalAddr: &net.UDPAddr{IP: net.IPv4zero, Port: 0}, } return newUDPMuxedConn(params) } func TestUDPMuxedConn_ReadFrom(t *testing.T) { conn := secondTestMuxedConn(t, 1500) remote := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678} payload := []byte("this is a payload of length 29!") require.NoError(t, conn.writePacket(payload, remote)) // read with too small of a buffer -> expect io.ErrShortBuffer, n=0, rAddr=nil small := make([]byte, 8) n, raddr, err := conn.ReadFrom(small) require.ErrorIs(t, err, io.ErrShortBuffer) require.Equal(t, 0, n) require.Nil(t, raddr) // try again with sufficient buffer require.NoError(t, conn.writePacket(payload, remote)) dst := make([]byte, len(payload)) n, raddr, err = conn.ReadFrom(dst) require.NoError(t, err) require.Equal(t, len(payload), n) require.Equal(t, payload, dst[:n]) // rAddr should be what was set on the packet. require.NotNil(t, raddr) require.Equal(t, remote.String(), raddr.String()) } func TestUDPMuxedConn_ReadFrom_EOFAfterClose(t *testing.T) { conn := secondTestMuxedConn(t, 64) // close with empty queue -> immediate EOF branch inside ReadFrom. require.NoError(t, conn.Close()) buf := make([]byte, 16) n, raddr, err := conn.ReadFrom(buf) require.Equal(t, 0, n) require.Nil(t, raddr) require.ErrorIs(t, err, io.EOF) } func TestUDPMuxedConn_ReadFrom_WaitingThenClosedEOF(t *testing.T) { conn := secondTestMuxedConn(t, 64) errCh := make(chan error, 1) go func() { // empty queue sets state to waiting and block on notify/closedChan. _, _, err := conn.ReadFrom(make([]byte, 16)) errCh <- err }() // let goroutine enter Waiting state. time.Sleep(10 * time.Millisecond) require.NoError(t, conn.Close()) select { case err := <-errCh: require.ErrorIs(t, err, io.EOF) case <-time.After(1 * time.Second): require.Fail(t, "timeout waiting for ReadFrom to return after Close") } } func TestUDPMuxedConn_WriteTo_ClosedPipe(t *testing.T) { conn := secondTestMuxedConn(t, 64) require.NoError(t, conn.Close()) n, err := conn.WriteTo([]byte("x"), &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}) require.Equal(t, 0, n) require.ErrorIs(t, err, io.ErrClosedPipe) } // non-*net.UDPAddr that still satisfies net.Addr. type notUDPAddr struct{} func (notUDPAddr) Network() string { return "udp" } func (notUDPAddr) String() string { return "not-a-udp-addr" } func TestUDPMuxedConn_WriteTo_BadAddrType(t *testing.T) { conn := secondTestMuxedConn(t, 64) n, err := conn.WriteTo([]byte("x"), notUDPAddr{}) require.Equal(t, 0, n) require.ErrorIs(t, err, errFailedToCastUDPAddr) } // uses invalid IP length so newIPPort returns error. func TestUDPMuxedConn_WriteTo_newIPPortError(t *testing.T) { conn := secondTestMuxedConn(t, 64) invalidIP := net.IP{1} // len=1 -> invalid raddr := &net.UDPAddr{IP: invalidIP, Port: 1234} n, err := conn.WriteTo([]byte("x"), raddr) require.Equal(t, 0, n) require.Error(t, err) } func TestUDPMuxedConn_WriteTo_InvalidPort(t *testing.T) { conn := secondTestMuxedConn(t, 64) raddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: -1} n, err := conn.WriteTo([]byte("x"), raddr) require.Equal(t, 0, n) require.ErrorIs(t, err, ErrPort) raddr = &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 0x10000} n, err = conn.WriteTo([]byte("x"), raddr) require.Equal(t, 0, n) require.ErrorIs(t, err, ErrPort) } func TestUDPMuxedConn_SetDeadlines(t *testing.T) { conn := secondTestMuxedConn(t, 64) // While open require.NoError(t, conn.SetReadDeadline(time.Now().Add(250*time.Millisecond))) require.NoError(t, conn.SetWriteDeadline(time.Now().Add(250*time.Millisecond))) // After close require.NoError(t, conn.Close()) require.NoError(t, conn.SetReadDeadline(time.Now().Add(250*time.Millisecond))) require.NoError(t, conn.SetWriteDeadline(time.Now().Add(250*time.Millisecond))) } func TestUDPMuxedConn_removeAddress(t *testing.T) { conn := secondTestMuxedConn(t, 64) mk := func(ip string, port uint16) ipPort { p, err := newIPPort(net.ParseIP(ip), "", port) require.NoError(t, err) return p } a1 := mk("1.1.1.1", 1000) a2 := mk("2.2.2.2", 2000) a3 := mk("3.3.3.3", 3000) a4 := mk("9.9.9.9", 9000) // non-existent in lists below t.Run("remove-existing-middle", func(t *testing.T) { // true for a1/a3, false for a2 conn.addresses = []ipPort{a1, a2, a3} conn.removeAddress(a2) got := conn.getAddresses() require.Equal(t, []ipPort{a1, a3}, got) }) t.Run("remove-non-existing", func(t *testing.T) { // only true (no matches) conn.addresses = []ipPort{a1, a3} conn.removeAddress(a4) got := conn.getAddresses() require.Equal(t, []ipPort{a1, a3}, got) }) t.Run("remove-duplicates-all", func(t *testing.T) { // all occurrences are removed (false twice) conn.addresses = []ipPort{a1, a1, a2} conn.removeAddress(a1) got := conn.getAddresses() require.Equal(t, []ipPort{a2}, got) }) t.Run("remove-from-empty", func(t *testing.T) { // no iters loop, no panic + remain empty conn.addresses = nil conn.removeAddress(a1) got := conn.getAddresses() require.Empty(t, got) }) } func TestUDPMuxedConn_writePacket_ShortBuffer(t *testing.T) { conn := secondTestMuxedConn(t, 8) // pool buf cap=8 addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 9999} err := conn.writePacket(make([]byte, 16), addr) // len=16 > cap=8 require.ErrorIs(t, err, io.ErrShortBuffer) require.Nil(t, conn.bufHead) require.Nil(t, conn.bufTail) } func TestUDPMuxedConn_writePacket_ClosedState(t *testing.T) { conn := secondTestMuxedConn(t, 64) addr := &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1234} // closed state before write conn.mu.Lock() conn.state = udpMuxedConnClosed conn.mu.Unlock() err := conn.writePacket([]byte{1, 2, 3}, addr) // fits in cap require.ErrorIs(t, err, io.ErrClosedPipe) // queue unchanged require.Nil(t, conn.bufHead) require.Nil(t, conn.bufTail) } func TestUDPMuxedConn_writePacket_NotifyDefaultBranch(t *testing.T) { conn := secondTestMuxedConn(t, 64) addr := &net.UDPAddr{IP: net.IPv4(9, 9, 9, 9), Port: 4242} // fill notify channel so send would block conn.notify <- struct{}{} // set pre-state to waiting so the post-unlock select triggers conn.mu.Lock() conn.state = udpMuxedConnWaiting conn.mu.Unlock() // write should take default branch err := conn.writePacket([]byte("hello"), addr) require.NoError(t, err) // packet enqueued require.NotNil(t, conn.bufHead) require.NotNil(t, conn.bufTail) // channel still full => no send happened (default path executed) require.Equal(t, 1, len(conn.notify)) } func TestNewUDPMuxDefault_LocalAddrNotUDPAddr(t *testing.T) { defer test.CheckRoutines(t)() c1, c2 := net.Pipe() defer func() { _ = c2.Close() }() pc := &fakenet.PacketConn{Conn: c1} mux := NewUDPMuxDefault(UDPMuxParams{ Logger: logging.NewDefaultLoggerFactory().NewLogger("ice"), UDPConn: pc, }) require.NotNil(t, mux) defer func() { _ = mux.Close() }() addrs := mux.GetListenAddresses() require.Len(t, addrs, 1) require.Equal(t, pc.LocalAddr().String(), addrs[0].String()) } func TestUDPMuxDefault_GetConn_InvalidAddress(t *testing.T) { defer test.CheckRoutines(t)() connA, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) udpMux := NewUDPMuxDefault(UDPMuxParams{ Logger: nil, UDPConn: connA, }) defer func() { _ = udpMux.Close() _ = connA.Close() }() connB, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) defer func() { _ = connB.Close() }() pc, gerr := udpMux.GetConn("some-ufrag", connB.LocalAddr()) require.Nil(t, pc) require.ErrorIs(t, gerr, errInvalidAddress) } func TestUDPMuxDefault_registerConnForAddress_ClosedMuxEarlyReturn(t *testing.T) { defer test.CheckRoutines(t)() udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) udpMux := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp}) require.NoError(t, udpMux.Close()) _ = udp.Close() conn := secondTestMuxedConn(t, 64) addr, err := newIPPort(net.ParseIP("1.2.3.4"), "", 9999) require.NoError(t, err) before := len(udpMux.addressMap) udpMux.registerConnForAddress(conn, addr) after := len(udpMux.addressMap) require.Equal(t, before, after) _, exists := udpMux.addressMap[addr] require.False(t, exists) } func TestUDPMuxDefault_registerConnForAddress_ReplacesExisting(t *testing.T) { defer test.CheckRoutines(t)() udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) udpMux := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp}) defer func() { _ = udpMux.Close() _ = udp.Close() }() ipAddr, err := newIPPort(net.ParseIP("5.6.7.8"), "", 12345) require.NoError(t, err) existing := secondTestMuxedConn(t, 64) existing.addresses = []ipPort{ipAddr} udpMux.addressMapMu.Lock() udpMux.addressMap[ipAddr] = existing udpMux.addressMapMu.Unlock() // new conn should replace existing mapping and cause removeAddress on the old one. newConn := secondTestMuxedConn(t, 64) udpMux.registerConnForAddress(newConn, ipAddr) // map should now point to newConn. udpMux.addressMapMu.RLock() mapped := udpMux.addressMap[ipAddr] udpMux.addressMapMu.RUnlock() require.Equal(t, newConn, mapped) // old conn should have ipAddr removed from its addresses. require.False(t, existing.containsAddress(ipAddr), "old conn should have removed the address backref") } func stunWithLen(l uint16) []byte { m := stun.New() m.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} m.Encode() out := append([]byte{}, m.Raw...) out[2] = byte(l >> 8) out[3] = byte(l & 0xff) return out } type scriptedUDPPC struct { local *net.UDPAddr seq []struct { data []byte addr net.Addr err error } i int } func (s *scriptedUDPPC) ReadFrom(p []byte) (int, net.Addr, error) { if s.i >= len(s.seq) { return 0, s.local, errIoEOF } step := s.seq[s.i] s.i++ if step.err != nil { return 0, step.addr, step.err } n := copy(p, step.data) return n, step.addr, nil } func (s *scriptedUDPPC) WriteTo([]byte, net.Addr) (int, error) { return 0, nil } func (s *scriptedUDPPC) Close() error { return nil } func (s *scriptedUDPPC) LocalAddr() net.Addr { return s.local } func (s *scriptedUDPPC) SetDeadline(time.Time) error { return nil } func (s *scriptedUDPPC) SetReadDeadline(time.Time) error { return nil } func (s *scriptedUDPPC) SetWriteDeadline(time.Time) error { return nil } var errIoEOF = errors.New("EOF") func TestUDPMux_connWorker_AddrNotUDP(t *testing.T) { defer test.CheckRoutines(t)() c1, c2 := net.Pipe() defer func() { _ = c2.Close() }() pc := &fakenet.PacketConn{Conn: c1} mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) require.NotNil(t, mux) defer func() { _ = mux.Close() }() _, _ = c2.Write([]byte("frame")) _ = c2.Close() } func TestUDPMux_connWorker_ReadError_Timeout(t *testing.T) { defer test.CheckRoutines(t)() c1, c2 := net.Pipe() pc := &fakenet.PacketConn{Conn: c1} mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) require.NotNil(t, mux) _ = pc.SetReadDeadline(time.Unix(0, 0)) _ = c2.Close() _ = mux.Close() } func TestUDPMux_connWorker_NewIPPortError(t *testing.T) { defer test.CheckRoutines(t)() badIP := net.IP{1} remote := &net.UDPAddr{IP: badIP, Port: 9999} pc := &scriptedUDPPC{ local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7000}, seq: []struct { data []byte addr net.Addr err error }{ {data: []byte{1}, addr: remote, err: nil}, // triggers newIPPort error }, } mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) require.NotNil(t, mux) _ = mux.Close() } func TestUDPMux_connWorker_STUNDecodeError(t *testing.T) { defer test.CheckRoutines(t)() remote := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 2), Port: 5678} pc := &scriptedUDPPC{ local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7001}, seq: []struct { data []byte addr net.Addr err error }{ // bad STUN length -> Decode() error -> Warnf + continue {data: stunWithLen(4), addr: remote, err: nil}, {data: nil, addr: remote, err: errIoEOF}, // exit loop }, } mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) require.NotNil(t, mux) _ = mux.Close() } func TestUDPMux_connWorker_STUNNoUsername(t *testing.T) { defer test.CheckRoutines(t)() msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} msg.Encode() // valid STUN + no USERNAME remote := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 3), Port: 5679} pc := &scriptedUDPPC{ local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7002}, seq: []struct { data []byte addr net.Addr err error }{ {data: append([]byte{}, msg.Raw...), addr: remote, err: nil}, // Get(USERNAME) fails {data: nil, addr: remote, err: errIoEOF}, // exit loop }, } mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) require.NotNil(t, mux) _ = mux.Close() } func TestUDPMux_connWorker_WritePacketError(t *testing.T) { defer test.CheckRoutines(t)() local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7003} remote := &net.UDPAddr{IP: net.IPv4(203, 0, 113, 7), Port: 5555} payload := []byte("0123456789ABCDEF") pc := &scriptedUDPPC{ local: local, seq: []struct { data []byte addr net.Addr err error }{ {data: payload, addr: remote, err: nil}, {data: nil, addr: remote, err: errIoEOF}, // exit loop }, } mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) require.NotNil(t, mux) defer func() { _ = mux.Close() }() // shrink pool to force io.ErrShortBuffer in writePacket mux.pool = &sync.Pool{New: func() any { return newBufferHolder(8) }} // make connWorker route to new conn. c, err := mux.GetConn("ufragX", mux.LocalAddr()) require.NoError(t, err) defer func() { _ = c.Close() }() // remote port is controlled. we use 5555 here to skip int overflow check as we would // otherwise have to cast remote.Port (int) to uint16. ipport, err := newIPPort(remote.IP, remote.Zone, 5555) require.NoError(t, err) cInner, ok := c.(*udpMuxedConn) require.True(t, ok, "expected *udpMuxedConn from UDPMuxDefault.GetConn") mux.registerConnForAddress(cInner, ipport) } func TestNewUDPMuxDefault_UnspecifiedAddr_AutoInitNet(t *testing.T) { defer test.CheckRoutines(t)() conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4zero}) require.NoError(t, err) defer func() { _ = conn.Close() }() mux := NewUDPMuxDefault(UDPMuxParams{ Logger: nil, UDPConn: conn, Net: nil, }) require.NotNil(t, mux) defer func() { _ = mux.Close() }() addrs := mux.GetListenAddresses() require.GreaterOrEqual(t, len(addrs), 1, "should list at least one local listen address") udpAddr, ok := conn.LocalAddr().(*net.UDPAddr) require.True(t, ok, "LocalAddr is not *net.UDPAddr") wantPort := udpAddr.Port for _, a := range addrs { ua, ok := a.(*net.UDPAddr) require.True(t, ok, "returned listen address must be *net.UDPAddr") require.Equal(t, wantPort, ua.Port, "listen addresses should reuse the same UDP port") } } ice-4.1.0/udp_mux_universal.go000066400000000000000000000203561511704470200163630ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "fmt" "net" "time" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" ) // UniversalUDPMux allows multiple connections to go over a single UDP port for // host, server reflexive and relayed candidates. // Actual connection muxing is happening in the UDPMux. type UniversalUDPMux interface { UDPMux GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) } // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. // It the passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { *UDPMuxDefault params UniversalUDPMuxParams // Since we have a shared socket, for srflx candidates it makes sense // to have a shared mapped address across all the agents // stun.XORMappedAddress indexed by the STUN server addr xorMappedMap map[string]*xorMapped } // UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive. type UniversalUDPMuxParams struct { Logger logging.LeveledLogger UDPConn net.PacketConn XORMappedAddrCacheTTL time.Duration Net transport.Net } // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux. func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { if params.Logger == nil { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } if params.XORMappedAddrCacheTTL == 0 { params.XORMappedAddrCacheTTL = time.Second * 25 } mux := &UniversalUDPMuxDefault{ params: params, xorMappedMap: make(map[string]*xorMapped), } // Wrap UDP connection, process server reflexive messages // before they are passed to the UDPMux connection handler (connWorker) mux.params.UDPConn = &udpConn{ PacketConn: params.UDPConn, mux: mux, logger: params.Logger, } // Embed UDPMux udpMuxParams := UDPMuxParams{ Logger: params.Logger, UDPConn: mux.params.UDPConn, Net: mux.params.Net, } mux.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) return mux } // udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets. type udpConn struct { net.PacketConn mux *UniversalUDPMuxDefault logger logging.LeveledLogger } // GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr. // Not implemented yet. func (m *UniversalUDPMuxDefault) GetRelayedAddr(net.Addr, time.Duration) (*net.Addr, error) { return nil, errNotImplemented } // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL // (e.g. STUN URL) to be able to support multiple STUN/TURN servers // and return a unique connection per server. func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) } // ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address. // It passes processed packets further to the UDPMux (maybe this is not really necessary). func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = c.PacketConn.ReadFrom(p) if err != nil { return n, addr, err } if stun.IsMessage(p[:n]) { //nolint:nestif msg := &stun.Message{ Raw: append([]byte{}, p[:n]...), } if err = msg.Decode(); err != nil { c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err) return n, addr, nil } udpAddr, ok := addr.(*net.UDPAddr) if !ok { // Message about this err will be logged in the UDPMux return n, addr, err } if c.mux.isXORMappedResponse(msg, udpAddr.String()) { err = c.mux.handleXORMappedResponse(udpAddr, msg) if err != nil { c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err) err = nil } return n, addr, err } } return n, addr, err } // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool { m.mu.Lock() defer m.mu.Unlock() // Check first if it is a STUN server address, // because remote peer can also send similar messages but as a BindingSuccess. _, ok := m.xorMappedMap[stunAddr] _, err := msg.Get(stun.AttrXORMappedAddress) return err == nil && ok } // handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute. // and set the mapped address for the server. func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error { m.mu.Lock() defer m.mu.Unlock() mappedAddr, ok := m.xorMappedMap[stunAddr.String()] if !ok { return errNoXorAddrMapping } var addr stun.XORMappedAddress if err := addr.GetFrom(msg); err != nil { return err } m.xorMappedMap[stunAddr.String()] = mappedAddr mappedAddr.SetAddr(&addr) return nil } // GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server. // Makes a STUN binding request to discover mapped address otherwise. // Blocks until the stun.XORMappedAddress has been discovered or deadline. // Method is safe for concurrent use. func (m *UniversalUDPMuxDefault) GetXORMappedAddr( serverAddr net.Addr, deadline time.Duration, ) (*stun.XORMappedAddress, error) { m.mu.Lock() mappedAddr, ok := m.xorMappedMap[serverAddr.String()] // If we already have a mapping for this STUN server (address already received) // and if it is not too old we return it without making a new request to STUN server if ok { if mappedAddr.expired() { mappedAddr.closeWaiters() delete(m.xorMappedMap, serverAddr.String()) ok = false } else if mappedAddr.pending() { ok = false } } m.mu.Unlock() if ok { return mappedAddr.addr, nil } // Otherwise, make a STUN request to discover the address // or wait for already sent request to complete waitAddrReceived, err := m.writeSTUN(serverAddr) if err != nil { return nil, fmt.Errorf("%w: %s", errWriteSTUNMessage, err) //nolint:errorlint } // Block until response was handled by the connWorker routine and XORMappedAddress was updated select { case <-waitAddrReceived: // When channel closed, addr was obtained m.mu.Lock() mappedAddr := *m.xorMappedMap[serverAddr.String()] m.mu.Unlock() if mappedAddr.addr == nil { return nil, errNoXorAddrMapping } return mappedAddr.addr, nil case <-time.After(deadline): return nil, errXORMappedAddrTimeout } } // writeSTUN sends a STUN request via UDP conn. // // The returned channel is closed when the STUN response has been received. // Method is safe for concurrent use. func (m *UniversalUDPMuxDefault) writeSTUN(serverAddr net.Addr) (chan struct{}, error) { m.mu.Lock() defer m.mu.Unlock() // If record present in the map, we already sent a STUN request, // just wait when waitAddrReceived will be closed addrMap, ok := m.xorMappedMap[serverAddr.String()] if !ok { addrMap = &xorMapped{ expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL), waitAddrReceived: make(chan struct{}), } m.xorMappedMap[serverAddr.String()] = addrMap } req, err := stun.Build(stun.BindingRequest, stun.TransactionID) if err != nil { return nil, err } if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil { return nil, err } return addrMap.waitAddrReceived, nil } type xorMapped struct { addr *stun.XORMappedAddress waitAddrReceived chan struct{} expiresAt time.Time } func (a *xorMapped) closeWaiters() { select { case <-a.waitAddrReceived: // Notify was close, ok, that means we received duplicate response just exit break default: // Notify tha twe have a new addr close(a.waitAddrReceived) } } func (a *xorMapped) pending() bool { return a.addr == nil } func (a *xorMapped) expired() bool { return a.expiresAt.Before(time.Now()) } func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) { a.addr = addr a.closeWaiters() } ice-4.1.0/udp_mux_universal_test.go000066400000000000000000000247471511704470200174320ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "encoding/binary" "io" "net" "sync" "testing" "time" "github.com/pion/ice/v4/internal/fakenet" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/stretchr/testify/require" ) func TestUniversalUDPMux(t *testing.T) { conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{ Logger: nil, UDPConn: conn, }) defer func() { _ = udpMux.Close() _ = conn.Close() }() require.NotNil(t, udpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() testMuxSrflxConnection(t, udpMux, "ufrag4", udp) }() wg.Wait() } func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) { t.Helper() pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr()) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() }() remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{ // nolint Port: udpMux.LocalAddr().(*net.UDPAddr).Port, }) require.NoError(t, err, "error dialing test UDP connection") defer func() { _ = remoteConn.Close() }() // Use small value for TTL to check expiration of the address udpMux.params.XORMappedAddrCacheTTL = time.Millisecond * 20 testXORIP := net.ParseIP("213.141.156.236") testXORPort := 21254 wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() address, e := udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Second) require.NoError(t, e) require.NotNil(t, address) require.True(t, address.IP.Equal(testXORIP)) require.Equal(t, address.Port, testXORPort) }() // Wait until GetXORMappedAddr calls sendSTUN method time.Sleep(time.Millisecond) // Check that mapped address filled correctly after sent STUN udpMux.mu.Lock() mappedAddr, ok := udpMux.xorMappedMap[remoteConn.LocalAddr().String()] require.True(t, ok) require.NotNil(t, mappedAddr) require.True(t, mappedAddr.pending()) require.False(t, mappedAddr.expired()) udpMux.mu.Unlock() // Clean receiver read buffer buf := make([]byte, receiveMTU) _, err = remoteConn.Read(buf) require.NoError(t, err) // Write back to udpMux XOR message with address msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag")) addr := &stun.XORMappedAddress{ IP: testXORIP, Port: testXORPort, } err = addr.AddTo(msg) require.NoError(t, err) msg.Encode() _, err = remoteConn.Write(msg.Raw) require.NoError(t, err) // Wait for the packet to be consumed and parsed by udpMux wg.Wait() // We should get address immediately from the cached map address, err := udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Second) require.NoError(t, err) require.NotNil(t, address) udpMux.mu.Lock() // Check mappedAddr is not pending, we didn't send STUN twice require.False(t, mappedAddr.pending()) // Check expiration by TTL time.Sleep(time.Millisecond * 21) require.True(t, mappedAddr.expired()) udpMux.mu.Unlock() // After expire, we send STUN request again // but we not receive response in 5 milliseconds and should get error here address, err = udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Millisecond*5) require.NotNil(t, err) require.Nil(t, address) } func TestUniversalUDPMux_GetConnForURL_UniquePerURL(t *testing.T) { conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{ Logger: nil, UDPConn: conn, }) defer func() { _ = udpMux.Close() _ = conn.Close() }() lf := udpMux.LocalAddr() require.NotNil(t, lf) // different URLs -> must be distinct muxed conns pc1, err := udpMux.GetConnForURL("ufragX", "stun:serverA", lf) require.NoError(t, err) defer func() { _ = pc1.Close() }() pc2, err := udpMux.GetConnForURL("ufragX", "stun:serverB", lf) require.NoError(t, err) defer func() { _ = pc2.Close() }() c1, ok := pc1.(*udpMuxedConn) require.True(t, ok, "pc1 is not *udpMuxedConn") c2, ok := pc2.(*udpMuxedConn) require.True(t, ok, "pc2 is not *udpMuxedConn") require.NotEqual(t, c1, c2, "expected distinct muxed conns for different URLs with same ufrag") pc1b, err := udpMux.GetConnForURL("ufragX", "stun:serverA", lf) require.NoError(t, err) defer func() { _ = pc1b.Close() }() c1b, ok := pc1b.(*udpMuxedConn) require.True(t, ok, "pc1b is not *udpMuxedConn") require.Equal(t, c1, c1b, "expected same muxed conn when requesting the same (ufrag,url)") } func newLogger() logging.LeveledLogger { return logging.NewDefaultLoggerFactory().NewLogger("ice") } func newFakenetReader(t *testing.T, payload []byte) *fakenet.PacketConn { t.Helper() r, w := net.Pipe() go func() { _, _ = w.Write(payload) _ = w.Close() }() pc := &fakenet.PacketConn{} pc.Conn = r return pc } func Test_udpConn_ReadFrom_STUNDecodeError(t *testing.T) { server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) t.Cleanup(func() { _ = server.Close() }) srvAddr, ok := server.LocalAddr().(*net.UDPAddr) require.True(t, ok, "server.LocalAddr is not *net.UDPAddr") client, err := net.DialUDP("udp4", nil, srvAddr) require.NoError(t, err) t.Cleanup(func() { _ = client.Close() }) // build a valid STUN Binding Request then corrupt the header length field. msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} msg.Encode() raw := append([]byte{}, msg.Raw...) decl := binary.BigEndian.Uint16(raw[2:4]) binary.BigEndian.PutUint16(raw[2:4], decl+4) // makes Decode() fail _, err = client.Write(raw) require.NoError(t, err) u := &udpConn{PacketConn: server, mux: nil, logger: newLogger()} _ = server.SetReadDeadline(time.Now().Add(time.Second)) buf := make([]byte, 1500) n, addr, gotErr := u.ReadFrom(buf) require.Equal(t, len(raw), n) require.IsType(t, &net.UDPAddr{}, addr) require.NoError(t, gotErr) } func Test_udpConn_ReadFrom_AddrNotUDP(t *testing.T) { msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} msg.Encode() pc := newFakenetReader(t, msg.Raw) u := &udpConn{PacketConn: pc, mux: nil, logger: newLogger()} buf := make([]byte, 1500) n, addr, gotErr := u.ReadFrom(buf) require.Equal(t, len(msg.Raw), n) require.NoError(t, gotErr) require.NotNil(t, addr) _, isUDP := addr.(*net.UDPAddr) require.False(t, isUDP, "expected a non-UDP addr from fakenet.PacketConn") } func Test_udpConn_ReadFrom_XOR(t *testing.T) { server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) t.Cleanup(func() { _ = server.Close() }) srvAddr, ok := server.LocalAddr().(*net.UDPAddr) require.True(t, ok, "server.LocalAddr is not *net.UDPAddr") client, err := net.DialUDP("udp4", nil, srvAddr) require.NoError(t, err) t.Cleanup(func() { _ = client.Close() }) // success response + short XORMappedAddress value will make GetFrom() fail. msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassSuccessResponse} msg.Add(stun.AttrXORMappedAddress, []byte{0x00}) // intentionally invalid msg.Encode() mux := &UniversalUDPMuxDefault{ UDPMuxDefault: &UDPMuxDefault{}, xorMappedMap: map[string]*xorMapped{ client.LocalAddr().String(): { waitAddrReceived: make(chan struct{}), expiresAt: time.Now().Add(time.Minute), }, }, } _, err = client.Write(msg.Raw) require.NoError(t, err) u := &udpConn{PacketConn: server, mux: mux, logger: newLogger()} _ = server.SetReadDeadline(time.Now().Add(time.Second)) buf := make([]byte, 1500) n, addr, gotErr := u.ReadFrom(buf) require.Equal(t, len(msg.Raw), n) require.IsType(t, &net.UDPAddr{}, addr) require.NoError(t, gotErr) } func Test_udpConn_ReadFrom_NonSTUN(t *testing.T) { payload := []byte("not a stun packet") pc := newFakenetReader(t, payload) u := &udpConn{PacketConn: pc, mux: nil, logger: newLogger()} buf := make([]byte, 1500) n, addr, gotErr := u.ReadFrom(buf) require.NoError(t, gotErr) require.Equal(t, len(payload), n) require.Equal(t, payload, buf[:n]) require.NotNil(t, addr) _, isUDP := addr.(*net.UDPAddr) require.False(t, isUDP, "expected a non-UDP addr from fakenet.PacketConn") } func TestUniversalUDPMux_handleXORMappedResponse_NoMapping(t *testing.T) { mux := &UniversalUDPMuxDefault{ UDPMuxDefault: &UDPMuxDefault{}, xorMappedMap: make(map[string]*xorMapped), } stunSrv := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 3478} msg := stun.New() err := mux.handleXORMappedResponse(stunSrv, msg) require.ErrorIs(t, err, errNoXorAddrMapping) } func newFakePC(t *testing.T) (*fakenet.PacketConn, net.Conn, net.Conn) { t.Helper() c1, c2 := net.Pipe() pc := &fakenet.PacketConn{} pc.Conn = c1 return pc, c1, c2 } func TestUniversalUDPMux_GetXORMappedAddr_Pending_WriteError(t *testing.T) { serverAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 3478} pc, c1, c2 := newFakePC(t) _ = c2.Close() // other end unused _ = c1.Close() // force future WriteTo to error mux := &UniversalUDPMuxDefault{ UDPMuxDefault: &UDPMuxDefault{}, params: UniversalUDPMuxParams{ UDPConn: pc, // writeSTUN will call WriteTo on this fakenet PacketConn }, xorMappedMap: map[string]*xorMapped{ serverAddr.String(): { waitAddrReceived: make(chan struct{}), expiresAt: time.Now().Add(time.Minute), }, }, } addr, err := mux.GetXORMappedAddr(serverAddr, time.Second) require.Nil(t, addr) require.ErrorIs(t, err, errWriteSTUNMessage) } func TestUniversalUDPMux_GetXORMappedAddr_WaitClosed_NoAddr(t *testing.T) { serverAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 3478} pc, c1, c2 := newFakePC(t) drainDone := make(chan struct{}) go func() { _, _ = io.Copy(io.Discard, c2) close(drainDone) }() t.Cleanup(func() { _ = c1.Close() _ = c2.Close() <-drainDone }) waitCh := make(chan struct{}) close(waitCh) mux := &UniversalUDPMuxDefault{ UDPMuxDefault: &UDPMuxDefault{}, params: UniversalUDPMuxParams{ UDPConn: pc, }, xorMappedMap: map[string]*xorMapped{ serverAddr.String(): { addr: nil, waitAddrReceived: waitCh, expiresAt: time.Now().Add(time.Minute), }, }, } addr, err := mux.GetXORMappedAddr(serverAddr, time.Second) require.Nil(t, addr) require.ErrorIs(t, err, errNoXorAddrMapping) } ice-4.1.0/udp_muxed_conn.go000066400000000000000000000110751511704470200156170ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "io" "net" "slices" "sync" "time" "github.com/pion/logging" ) type udpMuxedConnState int const ( udpMuxedConnOpen udpMuxedConnState = iota udpMuxedConnWaiting udpMuxedConnClosed ) type udpMuxedConnParams struct { Mux *UDPMuxDefault AddrPool *sync.Pool Key string LocalAddr net.Addr Logger logging.LeveledLogger } // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag. type udpMuxedConn struct { params *udpMuxedConnParams // Remote addresses that we have sent to on this conn addresses []ipPort // FIFO queue holding incoming packets bufHead, bufTail *bufferHolder notify chan struct{} closedChan chan struct{} state udpMuxedConnState mu sync.Mutex } func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { return &udpMuxedConn{ params: params, notify: make(chan struct{}, 1), closedChan: make(chan struct{}), } } func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) { for { c.mu.Lock() if c.bufTail != nil { pkt := c.bufTail c.bufTail = pkt.next if pkt == c.bufHead { c.bufHead = nil } c.mu.Unlock() if len(b) < len(pkt.buf) { err = io.ErrShortBuffer } else { n = copy(b, pkt.buf) rAddr = pkt.addr } pkt.reset() c.params.AddrPool.Put(pkt) return n, rAddr, err } if c.state == udpMuxedConnClosed { c.mu.Unlock() return 0, nil, io.EOF } c.state = udpMuxedConnWaiting c.mu.Unlock() select { case <-c.notify: case <-c.closedChan: return 0, nil, io.EOF } } } func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { if c.isClosed() { return 0, io.ErrClosedPipe } // Each time we write to a new address, we'll register it with the mux netUDPAddr, ok := rAddr.(*net.UDPAddr) if !ok { return 0, errFailedToCastUDPAddr } port := netUDPAddr.Port if port < 0 || port > 0xFFFF { return 0, ErrPort } ipAndPort, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(port)) if err != nil { return 0, err } if !c.containsAddress(ipAndPort) { c.addAddress(ipAndPort) } return c.params.Mux.writeTo(buf, rAddr) } func (c *udpMuxedConn) LocalAddr() net.Addr { return c.params.LocalAddr } func (c *udpMuxedConn) SetDeadline(time.Time) error { return nil } func (c *udpMuxedConn) SetReadDeadline(time.Time) error { return nil } func (c *udpMuxedConn) SetWriteDeadline(time.Time) error { return nil } func (c *udpMuxedConn) CloseChannel() <-chan struct{} { return c.closedChan } func (c *udpMuxedConn) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.state != udpMuxedConnClosed { for pkt := c.bufTail; pkt != nil; { next := pkt.next pkt.reset() c.params.AddrPool.Put(pkt) pkt = next } c.bufHead = nil c.bufTail = nil c.state = udpMuxedConnClosed close(c.closedChan) } return nil } func (c *udpMuxedConn) isClosed() bool { c.mu.Lock() defer c.mu.Unlock() return c.state == udpMuxedConnClosed } func (c *udpMuxedConn) getAddresses() []ipPort { c.mu.Lock() defer c.mu.Unlock() addresses := make([]ipPort, len(c.addresses)) copy(addresses, c.addresses) return addresses } func (c *udpMuxedConn) addAddress(addr ipPort) { c.mu.Lock() c.addresses = append(c.addresses, addr) c.mu.Unlock() // Map it on mux c.params.Mux.registerConnForAddress(c, addr) } func (c *udpMuxedConn) removeAddress(addr ipPort) { c.mu.Lock() defer c.mu.Unlock() newAddresses := make([]ipPort, 0, len(c.addresses)) for _, a := range c.addresses { if a != addr { newAddresses = append(newAddresses, a) } } c.addresses = newAddresses } func (c *udpMuxedConn) containsAddress(addr ipPort) bool { c.mu.Lock() defer c.mu.Unlock() return slices.Contains(c.addresses, addr) } func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error { pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert if cap(pkt.buf) < len(data) { c.params.AddrPool.Put(pkt) return io.ErrShortBuffer } pkt.buf = append(pkt.buf[:0], data...) pkt.addr = addr c.mu.Lock() if c.state == udpMuxedConnClosed { c.mu.Unlock() pkt.reset() c.params.AddrPool.Put(pkt) return io.ErrClosedPipe } if c.bufHead != nil { c.bufHead.next = pkt } c.bufHead = pkt if c.bufTail == nil { c.bufTail = pkt } state := c.state c.state = udpMuxedConnOpen c.mu.Unlock() if state == udpMuxedConnWaiting { select { case c.notify <- struct{}{}: default: } } return nil } ice-4.1.0/url.go000066400000000000000000000047201511704470200134110ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import "github.com/pion/stun/v3" type ( // URL represents a STUN (rfc7064) or TURN (rfc7065) URI. // // Deprecated: Please use pion/stun.URI. URL = stun.URI // ProtoType indicates the transport protocol type that is used in the ice.URL // structure. // // Deprecated: TPlease use pion/stun.ProtoType. ProtoType = stun.ProtoType // SchemeType indicates the type of server used in the ice.URL structure. // // Deprecated: Please use pion/stun.SchemeType. SchemeType = stun.SchemeType ) const ( // SchemeTypeSTUN indicates the URL represents a STUN server. // // Deprecated: Please use pion/stun.SchemeTypeSTUN. SchemeTypeSTUN = stun.SchemeTypeSTUN // SchemeTypeSTUNS indicates the URL represents a STUNS (secure) server. // // Deprecated: Please use pion/stun.SchemeTypeSTUNS. SchemeTypeSTUNS = stun.SchemeTypeSTUNS // SchemeTypeTURN indicates the URL represents a TURN server. // // Deprecated: Please use pion/stun.SchemeTypeTURN. SchemeTypeTURN = stun.SchemeTypeTURN // SchemeTypeTURNS indicates the URL represents a TURNS (secure) server. // // Deprecated: Please use pion/stun.SchemeTypeTURNS. SchemeTypeTURNS = stun.SchemeTypeTURNS ) const ( // ProtoTypeUDP indicates the URL uses a UDP transport. // // Deprecated: Please use pion/stun.ProtoTypeUDP. ProtoTypeUDP = stun.ProtoTypeUDP // ProtoTypeTCP indicates the URL uses a TCP transport. // // Deprecated: Please use pion/stun.ProtoTypeTCP. ProtoTypeTCP = stun.ProtoTypeTCP ) // Unknown represents and unknown ProtoType or SchemeType. // // Deprecated: Please use pion/stun.SchemeTypeUnknown or pion/stun.ProtoTypeUnknown. const Unknown = 0 // ParseURL parses a STUN or TURN urls following the ABNF syntax described in. // https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065 // respectively. // // Deprecated: Please use pion/stun.ParseURI. var ParseURL = stun.ParseURI //nolint:gochecknoglobals // NewSchemeType defines a procedure for creating a new SchemeType from a raw. // string naming the scheme type. // // Deprecated: Please use pion/stun.NewSchemeType. var NewSchemeType = stun.NewSchemeType //nolint:gochecknoglobals // NewProtoType defines a procedure for creating a new ProtoType from a raw. // string naming the transport protocol type. // // Deprecated: Please use pion/stun.NewProtoType. var NewProtoType = stun.NewProtoType //nolint:gochecknoglobals ice-4.1.0/usecandidate.go000066400000000000000000000012561511704470200152410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import "github.com/pion/stun/v3" // UseCandidateAttr represents USE-CANDIDATE attribute. type UseCandidateAttr struct{} // AddTo adds USE-CANDIDATE attribute to message. func (UseCandidateAttr) AddTo(m *stun.Message) error { m.Add(stun.AttrUseCandidate, nil) return nil } // IsSet returns true if USE-CANDIDATE attribute is set. func (UseCandidateAttr) IsSet(m *stun.Message) bool { _, err := m.Get(stun.AttrUseCandidate) return err == nil } // UseCandidate is shorthand for UseCandidateAttr. func UseCandidate() UseCandidateAttr { return UseCandidateAttr{} } ice-4.1.0/usecandidate_test.go000066400000000000000000000011331511704470200162720ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "testing" "github.com/pion/stun/v3" "github.com/stretchr/testify/require" ) func TestUseCandidateAttr_AddTo(t *testing.T) { msg := stun.New() msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} require.False(t, UseCandidate().IsSet(msg)) require.NoError(t, UseCandidate().AddTo(msg)) msg.Encode() msg2 := &stun.Message{Raw: append([]byte{}, msg.Raw...)} require.NoError(t, msg2.Decode()) require.True(t, UseCandidate().IsSet(msg2)) } ice-4.1.0/utils_test.go000066400000000000000000000034661511704470200150140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package ice import ( "testing" "github.com/stretchr/testify/require" ) func newHostRemote(t *testing.T) *CandidateHost { t.Helper() remoteHostConfig := &CandidateHostConfig{ Network: "udp", Address: "1.2.3.5", Port: 12350, Component: 1, } hostRemote, err := NewCandidateHost(remoteHostConfig) require.NoError(t, err) return hostRemote } func newPrflxRemote(t *testing.T) *CandidatePeerReflexive { t.Helper() prflxConfig := &CandidatePeerReflexiveConfig{ Network: "udp", Address: "10.10.10.2", Port: 19217, Component: 1, RelAddr: "4.3.2.1", RelPort: 43211, } prflxRemote, err := NewCandidatePeerReflexive(prflxConfig) require.NoError(t, err) return prflxRemote } func newSrflxRemote(t *testing.T) *CandidateServerReflexive { t.Helper() srflxConfig := &CandidateServerReflexiveConfig{ Network: "udp", Address: "10.10.10.2", Port: 19218, Component: 1, RelAddr: "4.3.2.1", RelPort: 43212, } srflxRemote, err := NewCandidateServerReflexive(srflxConfig) require.NoError(t, err) return srflxRemote } func newRelayRemote(t *testing.T) *CandidateRelay { t.Helper() relayConfig := &CandidateRelayConfig{ Network: "udp", Address: "1.2.3.4", Port: 12340, Component: 1, RelAddr: "4.3.2.1", RelPort: 43210, } relayRemote, err := NewCandidateRelay(relayConfig) require.NoError(t, err) return relayRemote } func newHostLocal(t *testing.T) *CandidateHost { t.Helper() localHostConfig := &CandidateHostConfig{ Network: "udp", Address: "192.168.1.1", Port: 19216, Component: 1, } hostLocal, err := NewCandidateHost(localHostConfig) require.NoError(t, err) return hostLocal }